Question ansering with fine-tuned BERT
In this task, inputs are questions along with the paragraph containing answer. Our objective is to extract anser from given paragraph for the question.
For instance:
Question: "What is immune system?"
Paragraph: "The immune system is a system of many biological structures and processes within an organism that protects against disease. To function properly, an immune system must detect a wide variety of agents, known as pathogens, from viruses to parasitic worms, and distinguish them from the organism's own healthy tissue."
The model job is to extract the anser form given pragraph.
Answer: "a system of many biological structures and processes within an organism that protects against disease"
How does the BERT know where to pick from and end to ?
Paragraph: "The immune system is a system of many biological structures and processes within an organism that protects against disease
. To function properly, an immune system must detect a wide variety of agents, known as pathogens, from viruses to parasitic worms, and distinguish them from the organism's own healthy tissue."
To achive the starting and ending index, we will use two vectors start S
and end E
, which are learnt during the training phase.
Steps:
Compute the probability of each token
- Compute the dot product for each token $i$ between the $R_i$ ans $S$/$E$.
- apply the softmax function
- Select the highest probability as starting/ending token.
Training steps:
- Toeknize the question-paragraph pair.
- Use pretrained-BERT to extract embeddings.
- Compute the start/end tokens (using above mentioned steps)
- Select the text span containg answer with the help of S/E indexes.
Lets see it in action
from transformers import BertForQuestionAnswering, BertTokenizer
import numpy as np
import torch
Lets download and load the model. In this case we will use SQUAD (Stanford Question-Answering Dataset), with model bert-large-uncased-whole-word-masking-fine-tuned-squad
.
model = BertForQuestionAnswering.from_pretrained("google-bert/bert-large-uncased-whole-word-masking-finetuned-squad")
config.json: 100%|██████████| 443/443 [00:00<00:00, 42.4kB/s] model.safetensors: 100%|██████████| 1.34G/1.34G [03:05<00:00, 7.24MB/s] Some weights of the model checkpoint at google-bert/bert-large-uncased-whole-word-masking-finetuned-squad were not used when initializing BertForQuestionAnswering: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight'] - This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-large-uncased-whole-word-masking-finetuned-squad")
tokenizer_config.json: 100%|██████████| 48.0/48.0 [00:00<00:00, 14.2kB/s] vocab.txt: 100%|██████████| 232k/232k [00:00<00:00, 3.59MB/s] tokenizer.json: 100%|██████████| 466k/466k [00:00<00:00, 4.54MB/s]
Lets Preprocess the input
Q = "[CLS] What is the immune system? [SEP]"
P = "The immune system is a system of many biological structures and processes within an organism that protects against disease. To function properly, an immune system must detect a wide variety of agents, known as pathogens, from viruses to parasitic worms, and distinguish them from the organism's own healthy tissue. [SEP]"
tokenize the question and paragraph.
q_tokens = tokenizer.tokenize(Q)
p_tokens = tokenizer.tokenize(P)
combine the Q and P tokens and get thier input ids.
tokens = q_tokens + p_tokens
input_ids = tokenizer.convert_tokens_to_ids(tokens)
Create a segment IDs. We will assign 0 to question and 1 to paragrapph.
segment_ids = [0] * len(q_tokens) + [1] * len(p_tokens)
67
input_ids = torch.tensor([input_ids])
segment_ids = torch.tensor([segment_ids])
print(f"input: {input_ids.shape}, segment: {segment_ids.shape}")
input: torch.Size([1, 67]), segment: torch.Size([1, 67])
Now, lets feed the input_ids
and segment_ids
to the model, which returns the start score and end score for all tokens.
output = model(input_ids, token_type_ids=segment_ids)
Lets get the start and end index from logits.
s_index = torch.argmax(output.start_logits)
e_index = torch.argmax(output.end_logits)
Now, we will crop the tokens(words) from start to end indexes, and print them.
print(' '.join(tokens[s_index: e_index+1]))
a system of many biological structures and processes within an organism that protects against disease