ALBERT:
ALBERT is based on a type of model called a "transformer," which is designed to handle sequences of words and capture their meaning. The special thing about ALBERT is that it's a lighter and more efficient version of another popular model called BERT.
ALBERT achieves its efficiency by using techniques like sharing parameters across layers and reducing the number of parameters it needs to learn. This means that ALBERT can be smaller and faster while still being very good at understanding language.
Despite being smaller, ALBERT can still do many of the same things as bigger models like BERT. It can understand the meaning of words in context, answer questions, summarize text, and much more.
Techniques used in ALBERT:
- Cross-layer parameter sharing:
- All Shared: Share paremeters of all the the subluyers.
- Shared FFN: Only shate the parameter of FFN from first encoder to rest of FFN.
- Shared Attention: Only share the mulit-headed attention of first encoder to other encoder.
- Factorized Embedding Parameterization: Instead of storing a unique embedding vector for each word in the vocabulary, Factorized Embedding Parameterization factorizes the embedding matrix into two smaller matrices. These matrices are typically called the embedding matrix and the projection matrix.
The embedding matrix contains the embeddings for each word in the vocabulary but with reduced dimensions compared to traditional methods. The projection matrix maps these reduced-dimensional embeddings to the original embedding size.
$v$ vocab size i.e. 30000, $H$ embedding size i.e. 768. The default hiddne layer embedding size would be $V x H$ = 30000 x 768. In order to reduce the hiddne embedding dimension we will use wordpiece embedding of $E$ i.e. 128. We first project vectors to low-dimensional embedding space $VxE$ and then low-dimensional to hiddne space $ExH$.
The process would be first $VxE$ and $ExH$. That is 30000 X 128 and 128 X 768 => 30000 X 768.
Sentence Order Prediction: Its a binary classification task. For the given pair od sentence, the modle predicts if the order is swapped or not.
Regular: S1: She cooked momo. S2: it was delicious. Class: Positive (since pair in order) Negative: S1: it was delicious. S2:She cooked momo. Class: Negative.
from transformers import AlbertTokenizer, AlbertModel
import torch
model = AlbertModel.from_pretrained("albert/albert-base-v2")
config.json: 100%|██████████| 684/684 [00:00<00:00, 59.3kB/s] model.safetensors: 100%|██████████| 47.4M/47.4M [00:06<00:00, 7.36MB/s]
tokenizer = AlbertTokenizer.from_pretrained("albert/albert-base-v2")
tokenizer_config.json: 100%|██████████| 25.0/25.0 [00:00<00:00, 2.94kB/s] spiece.model: 100%|██████████| 760k/760k [00:00<00:00, 3.13MB/s] tokenizer.json: 100%|██████████| 1.31M/1.31M [00:00<00:00, 3.31MB/s]
example = "Kathmandu is city of Temples and Gods."
inputs = tokenizer(example, return_tensors="pt")
print(inputs)
{'input_ids': tensor([[ 2, 28823, 25, 136, 16, 9111, 17, 4769, 9, 3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
outtput = model(**inputs)
outtput.last_hidden_state
tensor([[[ 0.1889, -0.0525, 0.1712, ..., -0.4810, -0.3654, 0.5008], [-1.4933, -0.6441, 1.0500, ..., -0.1704, 1.0213, -0.8828], [ 0.4297, -1.0495, 1.1510, ..., 0.5195, 0.2841, -0.8472], ..., [-1.4946, 0.5316, 0.4043, ..., 0.0406, 0.3832, -0.4737], [-0.2297, -0.1473, 0.0360, ..., -0.5322, 1.1365, 0.1591], [ 0.0754, 0.1233, -0.0653, ..., -0.1315, 0.1112, 0.1950]]], grad_fn=<NativeLayerNormBackward0>)
print(f'[CLS] agregate repr:\n{outtput.pooler_output}')
[CLS] agregate repr: tensor([[ 0.7480, -0.7183, 0.7389, -0.3165, -0.4304, -0.9739, 0.6986, -0.6569, 0.7478, -0.9931, 0.9904, 0.7885, -0.1239, -0.9645, -0.9904, -0.7107, 0.8002, 0.6817, 0.9925, -0.6690, -0.4718, -0.9980, 0.9947, 0.9788, 0.8318, -0.6517, 0.7317, -0.9964, -0.9975, -0.6567, -1.0000, 0.8155, 0.6574, 0.7069, 0.6780, -0.6211, 0.8497, 0.9613, -0.6596, 0.8089, 0.6778, -0.9946, -0.7629, 0.7876, 0.6490, 0.6558, 0.9810, -0.9790, 0.5818, -0.6748, -0.6661, -0.6844, -0.6762, -0.9856, 0.2254, 0.8281, -0.7373, -0.7002, 1.0000, -0.9769, 0.7069, -0.8502, 0.8438, 0.2927, -0.6895, 0.6156, 0.6949, 0.9988, -0.6254, 0.9119, 0.6679, 0.7879, -0.3429, -0.8122, 0.9561, 0.7167, 0.3144, 0.6676, 0.7428, -0.9642, 0.8378, 0.8353, -0.7357, 0.7907, -0.9327, -0.8619, 0.6695, -0.9998, 0.7482, 0.6929, 0.7075, 0.8176, -0.6775, -1.0000, 0.6586, -0.7398, -0.9989, 0.8166, 0.6263, -0.6930, 0.7668, -0.6051, 0.8258, -0.8254, -0.7298, -0.6973, 0.6706, 0.9205, 0.7018, 0.9992, -0.9964, -0.6539, 0.6092, 0.9938, -0.6906, 0.1527, -0.7929, 0.5107, -0.9927, 0.6985, 0.6851, 0.1965, 0.6773, -0.6105, 0.7223, 0.4514, 0.6252, -0.8310, 0.9563, -0.9287, 0.4921, 0.6201, -0.9997, -0.3554, 0.7674, 0.9622, -0.6220, -0.7167, -0.1947, 0.6699, 0.8108, -0.6990, 0.6789, -0.7253, -0.6993, 0.6473, -0.2293, -0.6751, 0.6862, 0.9982, -0.8668, 0.9946, 0.8527, -0.9046, -0.9854, 0.5973, 0.9624, -0.9810, 0.8857, -0.7223, -0.6435, -0.8367, -0.9882, -0.8037, -0.9903, -0.8123, 0.9976, 0.3530, 0.9906, -0.9981, -0.7257, 0.7565, -0.6913, 0.9924, 0.6746, 0.6554, 0.6921, 0.9300, -0.7456, -0.5779, 0.9566, -0.9862, 0.6797, 0.7240, 0.9466, 0.6930, 0.7871, -0.8550, 0.6646, -0.6449, -0.6339, -0.4777, 0.8507, 0.8854, 1.0000, -0.7594, 0.9150, -0.9785, 0.9625, -0.9983, -0.7132, 0.5841, 0.8704, 0.7321, 0.7042, 0.8202, -0.9193, -0.9951, -0.9997, -0.6803, -0.9988, 0.9687, -0.9918, 0.7089, -0.9983, 0.9912, 0.9932, -0.7059, 0.9992, -0.6896, 0.8572, 0.6947, -0.9999, 0.9302, 0.7076, 0.6693, -0.1754, -0.6160, 0.9800, -0.9964, -0.9387, -0.4314, 0.6927, -0.9979, -0.9012, -0.7023, 0.8367, 0.7273, 0.7299, -0.9986, 0.9999, 0.7102, -0.6288, -0.7691, -0.3286, 0.9999, -0.4059, -0.9529, -0.6982, 0.9964, 0.9623, 0.7689, 0.8457, -0.7125, 0.8143, -0.5848, -0.9824, -0.9937, -0.9969, 0.6576, -0.9892, 0.6690, -0.5904, -0.9992, -0.7216, 0.9033, 0.9990, 0.9981, 0.6763, -0.8688, 0.7273, -0.7157, 0.9602, -0.7053, 0.9478, -0.9298, -0.9479, 0.7188, 0.7559, 0.6406, -0.6989, -0.9537, 0.7399, -0.7318, 0.9962, -0.8774, 0.9962, -0.9901, -0.9990, 0.6568, -0.8304, -0.7532, 0.9953, -0.7099, -0.9965, -0.9993, 0.5885, 0.9155, 0.8450, -0.9160, 0.7839, -0.6962, -0.6958, 0.9997, 0.7058, -0.6864, -0.6091, 0.6869, 0.6750, 0.9162, -0.8623, -0.6936, -0.7201, -0.9891, -0.7081, -0.6127, -0.5834, 0.2845, 0.9761, 0.9857, -0.7105, -0.8425, 0.9982, -0.9982, 0.8019, -0.9998, 0.3047, -0.9989, -0.9957, -0.7509, -0.8550, -0.6299, -0.6884, -0.9894, 0.7115, -0.7720, 0.6894, -0.8142, 0.9657, 0.9045, -0.9996, -0.6542, -0.9959, 0.6661, 0.7404, 0.7284, 0.0045, -0.6956, 0.7847, -0.9844, 0.6812, 0.9291, -0.7849, 0.9483, -0.6721, 0.6909, -0.9625, -0.6178, -0.6630, 0.9920, 0.9919, -0.8805, -0.6848, 0.6912, -0.8362, 0.9849, -0.9998, 0.9944, -0.9943, -0.8257, -0.9995, 0.9965, 0.9917, 0.1694, -0.7474, -0.9653, -0.9447, 0.7617, -0.6878, -0.6068, -0.6634, 0.9985, 0.7098, 0.9670, -0.8919, 0.1461, 0.5971, 0.9657, -0.9846, 0.9941, -0.9877, -0.7081, 0.4523, 1.0000, -0.7356, 0.6778, -0.9979, -0.9884, -0.8037, 0.6884, 0.9960, -0.6710, -0.9010, 0.9452, 0.9954, -0.9991, 0.8121, 0.9967, 0.7580, 0.7871, 0.7552, 0.8968, 0.9161, 0.7260, 0.9953, -0.6553, 0.9996, -0.9990, -0.9998, 0.9660, -0.6774, 0.9543, -0.5616, 0.6106, -0.7865, 0.7358, -0.7248, -0.7214, 0.6790, 0.9552, 0.7028, 0.9474, 0.5962, -0.9969, -0.9999, 0.6302, 0.8292, 0.8183, -0.6714, 0.9425, -0.6468, 0.1064, -0.7718, -0.7356, 0.7066, -0.9985, 0.9997, -0.8583, 0.9893, -0.6784, 0.7742, 0.9097, 0.9928, -0.8207, -1.0000, -0.7526, -0.9827, 0.7498, -0.6350, -0.9963, -0.6854, 0.2386, -0.6581, 0.9990, 0.4524, -0.7992, -0.9898, 0.9998, -0.7001, -0.9999, 0.6568, 0.6863, 0.9152, 0.6629, 0.5887, 0.7008, -0.7331, -0.7110, 0.9909, -0.9844, 0.6694, -0.9858, 0.9118, -0.7209, -0.8315, 0.8015, 0.9968, 0.9993, -0.2117, -0.9999, -0.9852, -0.9994, -0.9959, 0.6786, -0.9657, 0.9744, -0.7240, -0.6284, 0.9957, 0.5893, -0.6634, -0.6367, -0.9908, -0.8948, -0.7226, -0.9504, 0.7220, -0.9999, -0.3995, -0.9226, -0.9915, -0.1740, -0.9772, 0.6779, -0.9995, 0.3855, -0.7823, -0.2626, -0.7200, 0.7545, 0.6265, -0.6417, 0.5966, 0.9970, -0.7991, 0.9999, 0.9989, 0.9457, 0.7758, -0.8869, -0.6932, 0.9975, -0.6772, 0.3000, 0.9804, -0.9280, 0.4962, -0.9932, 0.9893, -0.5604, -0.9149, -0.9179, -0.8246, -0.9999, -0.9989, -0.9993, 0.7545, 0.9877, -0.9862, 0.9983, 0.7875, -0.9092, 0.9973, 0.7790, -0.9828, 0.6373, -0.7250, 1.0000, -0.6371, 0.7647, -0.7756, -0.9951, 0.9028, 0.9385, -0.6572, -0.7226, -0.5563, -0.9279, 0.3550, -0.6227, 0.7367, 0.9998, -0.9315, 0.7261, -0.8284, 0.7064, -0.9776, 0.6040, -0.6732, -0.2978, -0.9189, -0.9909, 0.9973, 0.7544, 0.8204, 0.7184, 0.8252, -0.6507, 0.9993, -0.9999, -0.6313, 0.6799, -0.8706, 0.0617, -0.5925, -0.8423, 0.6263, 0.9927, -0.6034, 0.9943, -0.6645, 0.7578, 0.8290, 0.6544, 0.6708, 0.9985, 0.7224, -0.8245, -0.6776, -0.6636, -0.6560, 0.9951, 0.9968, -0.7462, -0.6292, -0.6743, -0.9281, 0.9474, 0.7350, 0.9978, 0.9730, 0.7758, -0.6156, -0.9955, 0.9987, 0.1031, -0.7237, 0.6375, 0.7740, -0.7267, -0.0242, 0.7976, -0.8231, 0.5603, 0.7439, 0.9954, 0.9921, -0.7035, 0.8418, -0.9997, -0.6602, 0.9915, -0.9855, 0.7504, 0.9894, -0.9999, -0.6457, -0.6793, 0.9731, 0.9468, 0.7235, 0.6962, -0.5654, -0.7888, 0.9744, 0.6141, -0.7421, 0.9928, -0.8270, -0.6347, 0.7390, 0.7552, 0.9993, 0.8957, -0.9670, -0.7261, 0.6729, -0.7446, 0.9918, -0.9977, -0.6980, 0.5759, -0.8292, 0.6735, 0.8474, 0.7510, -0.7098, -0.9981, 0.6830, 0.9924, 0.7176, 0.9973, 0.6794, -0.1467, 0.6055, 0.8862, 0.4908, 0.9665, -0.9883, 0.7348, -0.9423, 0.7466, -0.8337, -0.9511, -0.8552, 0.6803, 0.9732, 0.8166, 0.6983, -0.5381, -0.8140, 0.7632, 0.6991, 0.6142, -0.6754, -0.8385, -0.7915, -0.9943, 0.3846, 0.6482, 0.3245, 0.7022, -0.7124, -0.9434, 0.7229, -0.6655, 0.7887, -0.9062, -0.9995, -0.8238, -0.9747, -0.7772, 0.6993, -0.6658, 0.7329, -0.7324, -1.0000, -0.7762, 0.6579, -0.9091, 0.8046, -0.9280, 0.7083, 0.9996, 0.9999, -0.9608, 0.6545, -0.9993, -0.6869, -0.8232, -1.0000, 0.6965, 0.9996, 0.8435, 0.6487, -0.9970, -0.8848, 0.9983, -0.9705, -0.8392, 0.7256, -0.6923, -0.9992, 0.6794, 0.7713, 0.7113, 0.7003, -0.9941, -0.7695, 0.4395, -0.9992, 0.6828, 0.9816, -0.7054, 0.6363, -0.3594, -0.9924, 0.7081]], grad_fn=<TanhBackward0>)
RoBERT:
"Robustly optimized BERT approach, enhanced version of BERT.
- Dynamic Masking
- No NSP task
- Large batch Size
- Byte-Level BPE (BBPE) as a tokenizer.
from transformers import RobertaConfig, RobertaTokenizer, RobertaModel, pipeline
import torch
model = RobertaModel.from_pretrained("FacebookAI/roberta-base")
config.json: 100%|██████████| 481/481 [00:00<00:00, 117kB/s] model.safetensors: 100%|██████████| 499M/499M [01:09<00:00, 7.19MB/s] Some weights of RobertaModel were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
model.config
RobertaConfig { "_name_or_path": "FacebookAI/roberta-base", "architectures": [ "RobertaForMaskedLM" ], "attention_probs_dropout_prob": 0.1, "bos_token_id": 0, "classifier_dropout": null, "eos_token_id": 2, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-05, "max_position_embeddings": 514, "model_type": "roberta", "num_attention_heads": 12, "num_hidden_layers": 12, "pad_token_id": 1, "position_embedding_type": "absolute", "transformers_version": "4.39.1", "type_vocab_size": 1, "use_cache": true, "vocab_size": 50265 }
tokenizer = RobertaTokenizer.from_pretrained("FacebookAI/roberta-base")
tokenizer_config.json: 100%|██████████| 25.0/25.0 [00:00<00:00, 4.02kB/s] vocab.json: 100%|██████████| 899k/899k [00:00<00:00, 4.48MB/s] merges.txt: 100%|██████████| 456k/456k [00:00<00:00, 4.36MB/s] tokenizer.json: 100%|██████████| 1.36M/1.36M [00:00<00:00, 6.68MB/s]
r_example = "You are a beautiful person."
tokenizer.tokenize(r_example)
['You', 'Ġare', 'Ġa', 'Ġbeautiful', 'Ġperson', '.']
The Ġ
is a white space character in RoBERTa.
inputs = tokenizer(r_example, return_tensors="pt")
output = model(**inputs)
print(f'token repr\nshape:{outtput.last_hidden_state.shape}\n {outtput.last_hidden_state}')
token repr shape:torch.Size([1, 10, 768]) tensor([[[ 0.1889, -0.0525, 0.1712, ..., -0.4810, -0.3654, 0.5008], [-1.4933, -0.6441, 1.0500, ..., -0.1704, 1.0213, -0.8828], [ 0.4297, -1.0495, 1.1510, ..., 0.5195, 0.2841, -0.8472], ..., [-1.4946, 0.5316, 0.4043, ..., 0.0406, 0.3832, -0.4737], [-0.2297, -0.1473, 0.0360, ..., -0.5322, 1.1365, 0.1591], [ 0.0754, 0.1233, -0.0653, ..., -0.1315, 0.1112, 0.1950]]], grad_fn=<NativeLayerNormBackward0>)
print(f'aggregate repr\nshape:{outtput.pooler_output.shape}\n {outtput.pooler_output}')
aggregate repr shape:torch.Size([1, 768]) tensor([[ 0.7480, -0.7183, 0.7389, -0.3165, -0.4304, -0.9739, 0.6986, -0.6569, 0.7478, -0.9931, 0.9904, 0.7885, -0.1239, -0.9645, -0.9904, -0.7107, 0.8002, 0.6817, 0.9925, -0.6690, -0.4718, -0.9980, 0.9947, 0.9788, 0.8318, -0.6517, 0.7317, -0.9964, -0.9975, -0.6567, -1.0000, 0.8155, 0.6574, 0.7069, 0.6780, -0.6211, 0.8497, 0.9613, -0.6596, 0.8089, 0.6778, -0.9946, -0.7629, 0.7876, 0.6490, 0.6558, 0.9810, -0.9790, 0.5818, -0.6748, -0.6661, -0.6844, -0.6762, -0.9856, 0.2254, 0.8281, -0.7373, -0.7002, 1.0000, -0.9769, 0.7069, -0.8502, 0.8438, 0.2927, -0.6895, 0.6156, 0.6949, 0.9988, -0.6254, 0.9119, 0.6679, 0.7879, -0.3429, -0.8122, 0.9561, 0.7167, 0.3144, 0.6676, 0.7428, -0.9642, 0.8378, 0.8353, -0.7357, 0.7907, -0.9327, -0.8619, 0.6695, -0.9998, 0.7482, 0.6929, 0.7075, 0.8176, -0.6775, -1.0000, 0.6586, -0.7398, -0.9989, 0.8166, 0.6263, -0.6930, 0.7668, -0.6051, 0.8258, -0.8254, -0.7298, -0.6973, 0.6706, 0.9205, 0.7018, 0.9992, -0.9964, -0.6539, 0.6092, 0.9938, -0.6906, 0.1527, -0.7929, 0.5107, -0.9927, 0.6985, 0.6851, 0.1965, 0.6773, -0.6105, 0.7223, 0.4514, 0.6252, -0.8310, 0.9563, -0.9287, 0.4921, 0.6201, -0.9997, -0.3554, 0.7674, 0.9622, -0.6220, -0.7167, -0.1947, 0.6699, 0.8108, -0.6990, 0.6789, -0.7253, -0.6993, 0.6473, -0.2293, -0.6751, 0.6862, 0.9982, -0.8668, 0.9946, 0.8527, -0.9046, -0.9854, 0.5973, 0.9624, -0.9810, 0.8857, -0.7223, -0.6435, -0.8367, -0.9882, -0.8037, -0.9903, -0.8123, 0.9976, 0.3530, 0.9906, -0.9981, -0.7257, 0.7565, -0.6913, 0.9924, 0.6746, 0.6554, 0.6921, 0.9300, -0.7456, -0.5779, 0.9566, -0.9862, 0.6797, 0.7240, 0.9466, 0.6930, 0.7871, -0.8550, 0.6646, -0.6449, -0.6339, -0.4777, 0.8507, 0.8854, 1.0000, -0.7594, 0.9150, -0.9785, 0.9625, -0.9983, -0.7132, 0.5841, 0.8704, 0.7321, 0.7042, 0.8202, -0.9193, -0.9951, -0.9997, -0.6803, -0.9988, 0.9687, -0.9918, 0.7089, -0.9983, 0.9912, 0.9932, -0.7059, 0.9992, -0.6896, 0.8572, 0.6947, -0.9999, 0.9302, 0.7076, 0.6693, -0.1754, -0.6160, 0.9800, -0.9964, -0.9387, -0.4314, 0.6927, -0.9979, -0.9012, -0.7023, 0.8367, 0.7273, 0.7299, -0.9986, 0.9999, 0.7102, -0.6288, -0.7691, -0.3286, 0.9999, -0.4059, -0.9529, -0.6982, 0.9964, 0.9623, 0.7689, 0.8457, -0.7125, 0.8143, -0.5848, -0.9824, -0.9937, -0.9969, 0.6576, -0.9892, 0.6690, -0.5904, -0.9992, -0.7216, 0.9033, 0.9990, 0.9981, 0.6763, -0.8688, 0.7273, -0.7157, 0.9602, -0.7053, 0.9478, -0.9298, -0.9479, 0.7188, 0.7559, 0.6406, -0.6989, -0.9537, 0.7399, -0.7318, 0.9962, -0.8774, 0.9962, -0.9901, -0.9990, 0.6568, -0.8304, -0.7532, 0.9953, -0.7099, -0.9965, -0.9993, 0.5885, 0.9155, 0.8450, -0.9160, 0.7839, -0.6962, -0.6958, 0.9997, 0.7058, -0.6864, -0.6091, 0.6869, 0.6750, 0.9162, -0.8623, -0.6936, -0.7201, -0.9891, -0.7081, -0.6127, -0.5834, 0.2845, 0.9761, 0.9857, -0.7105, -0.8425, 0.9982, -0.9982, 0.8019, -0.9998, 0.3047, -0.9989, -0.9957, -0.7509, -0.8550, -0.6299, -0.6884, -0.9894, 0.7115, -0.7720, 0.6894, -0.8142, 0.9657, 0.9045, -0.9996, -0.6542, -0.9959, 0.6661, 0.7404, 0.7284, 0.0045, -0.6956, 0.7847, -0.9844, 0.6812, 0.9291, -0.7849, 0.9483, -0.6721, 0.6909, -0.9625, -0.6178, -0.6630, 0.9920, 0.9919, -0.8805, -0.6848, 0.6912, -0.8362, 0.9849, -0.9998, 0.9944, -0.9943, -0.8257, -0.9995, 0.9965, 0.9917, 0.1694, -0.7474, -0.9653, -0.9447, 0.7617, -0.6878, -0.6068, -0.6634, 0.9985, 0.7098, 0.9670, -0.8919, 0.1461, 0.5971, 0.9657, -0.9846, 0.9941, -0.9877, -0.7081, 0.4523, 1.0000, -0.7356, 0.6778, -0.9979, -0.9884, -0.8037, 0.6884, 0.9960, -0.6710, -0.9010, 0.9452, 0.9954, -0.9991, 0.8121, 0.9967, 0.7580, 0.7871, 0.7552, 0.8968, 0.9161, 0.7260, 0.9953, -0.6553, 0.9996, -0.9990, -0.9998, 0.9660, -0.6774, 0.9543, -0.5616, 0.6106, -0.7865, 0.7358, -0.7248, -0.7214, 0.6790, 0.9552, 0.7028, 0.9474, 0.5962, -0.9969, -0.9999, 0.6302, 0.8292, 0.8183, -0.6714, 0.9425, -0.6468, 0.1064, -0.7718, -0.7356, 0.7066, -0.9985, 0.9997, -0.8583, 0.9893, -0.6784, 0.7742, 0.9097, 0.9928, -0.8207, -1.0000, -0.7526, -0.9827, 0.7498, -0.6350, -0.9963, -0.6854, 0.2386, -0.6581, 0.9990, 0.4524, -0.7992, -0.9898, 0.9998, -0.7001, -0.9999, 0.6568, 0.6863, 0.9152, 0.6629, 0.5887, 0.7008, -0.7331, -0.7110, 0.9909, -0.9844, 0.6694, -0.9858, 0.9118, -0.7209, -0.8315, 0.8015, 0.9968, 0.9993, -0.2117, -0.9999, -0.9852, -0.9994, -0.9959, 0.6786, -0.9657, 0.9744, -0.7240, -0.6284, 0.9957, 0.5893, -0.6634, -0.6367, -0.9908, -0.8948, -0.7226, -0.9504, 0.7220, -0.9999, -0.3995, -0.9226, -0.9915, -0.1740, -0.9772, 0.6779, -0.9995, 0.3855, -0.7823, -0.2626, -0.7200, 0.7545, 0.6265, -0.6417, 0.5966, 0.9970, -0.7991, 0.9999, 0.9989, 0.9457, 0.7758, -0.8869, -0.6932, 0.9975, -0.6772, 0.3000, 0.9804, -0.9280, 0.4962, -0.9932, 0.9893, -0.5604, -0.9149, -0.9179, -0.8246, -0.9999, -0.9989, -0.9993, 0.7545, 0.9877, -0.9862, 0.9983, 0.7875, -0.9092, 0.9973, 0.7790, -0.9828, 0.6373, -0.7250, 1.0000, -0.6371, 0.7647, -0.7756, -0.9951, 0.9028, 0.9385, -0.6572, -0.7226, -0.5563, -0.9279, 0.3550, -0.6227, 0.7367, 0.9998, -0.9315, 0.7261, -0.8284, 0.7064, -0.9776, 0.6040, -0.6732, -0.2978, -0.9189, -0.9909, 0.9973, 0.7544, 0.8204, 0.7184, 0.8252, -0.6507, 0.9993, -0.9999, -0.6313, 0.6799, -0.8706, 0.0617, -0.5925, -0.8423, 0.6263, 0.9927, -0.6034, 0.9943, -0.6645, 0.7578, 0.8290, 0.6544, 0.6708, 0.9985, 0.7224, -0.8245, -0.6776, -0.6636, -0.6560, 0.9951, 0.9968, -0.7462, -0.6292, -0.6743, -0.9281, 0.9474, 0.7350, 0.9978, 0.9730, 0.7758, -0.6156, -0.9955, 0.9987, 0.1031, -0.7237, 0.6375, 0.7740, -0.7267, -0.0242, 0.7976, -0.8231, 0.5603, 0.7439, 0.9954, 0.9921, -0.7035, 0.8418, -0.9997, -0.6602, 0.9915, -0.9855, 0.7504, 0.9894, -0.9999, -0.6457, -0.6793, 0.9731, 0.9468, 0.7235, 0.6962, -0.5654, -0.7888, 0.9744, 0.6141, -0.7421, 0.9928, -0.8270, -0.6347, 0.7390, 0.7552, 0.9993, 0.8957, -0.9670, -0.7261, 0.6729, -0.7446, 0.9918, -0.9977, -0.6980, 0.5759, -0.8292, 0.6735, 0.8474, 0.7510, -0.7098, -0.9981, 0.6830, 0.9924, 0.7176, 0.9973, 0.6794, -0.1467, 0.6055, 0.8862, 0.4908, 0.9665, -0.9883, 0.7348, -0.9423, 0.7466, -0.8337, -0.9511, -0.8552, 0.6803, 0.9732, 0.8166, 0.6983, -0.5381, -0.8140, 0.7632, 0.6991, 0.6142, -0.6754, -0.8385, -0.7915, -0.9943, 0.3846, 0.6482, 0.3245, 0.7022, -0.7124, -0.9434, 0.7229, -0.6655, 0.7887, -0.9062, -0.9995, -0.8238, -0.9747, -0.7772, 0.6993, -0.6658, 0.7329, -0.7324, -1.0000, -0.7762, 0.6579, -0.9091, 0.8046, -0.9280, 0.7083, 0.9996, 0.9999, -0.9608, 0.6545, -0.9993, -0.6869, -0.8232, -1.0000, 0.6965, 0.9996, 0.8435, 0.6487, -0.9970, -0.8848, 0.9983, -0.9705, -0.8392, 0.7256, -0.6923, -0.9992, 0.6794, 0.7713, 0.7113, 0.7003, -0.9941, -0.7695, 0.4395, -0.9992, 0.6828, 0.9816, -0.7054, 0.6363, -0.3594, -0.9924, 0.7081]], grad_fn=<TanhBackward0>)
Lets see masking example
masked_example = "you are a beautiful <mask>."
filler = pipeline("fill-mask", model="FacebookAI/roberta-base")
res = filler(masked_example)
for item in res:
print(item)
{'score': 0.23962363600730896, 'token': 693, 'token_str': ' woman', 'sequence': 'you are a beautiful woman.'} {'score': 0.18328218162059784, 'token': 621, 'token_str': ' person', 'sequence': 'you are a beautiful person.'} {'score': 0.18187165260314941, 'token': 1816, 'token_str': ' girl', 'sequence': 'you are a beautiful girl.'} {'score': 0.04936479032039642, 'token': 7047, 'token_str': ' soul', 'sequence': 'you are a beautiful soul.'} {'score': 0.037081778049468994, 'token': 6429, 'token_str': ' lady', 'sequence': 'you are a beautiful lady.'}
ELECTRA (Efficiently Learning an Encoder that Classifies Token Replacements Accurately)
ELECTRA is a language model that learns by distinguishing between real and fake words in a sentence, rather than predicting missing words directly. This makes training more efficient while still producing high-quality language representations.
The replaced token is similar to MLM but instead of masking token we replace the token with different token and train the model. The model basically classifies whether the given tokens are actual or replaced.
How does it work?
cls_example = ["The", "teacher", "taught", "the", "math", "subject"]
replaced_example = ["a", "teacher", "learnt", "the", "math", "subject"]
We can user BERT classifier for original or replaced category. This is also known as descriminator. First, we Mask the example tokens $T1$= ["The", "teacher", "taught", "the", "math", "subject"]
and mask them to feed into generator $T2$= ["[MASK]", "teacher", "[MASK]", "the", "math", "subject"]
. Now, we feed the $T2$ to descriminator to find of the each token is original or replaced.
Note: The descriminator is ELECTRA.
One major advantage of ELECTRA over BERT is, BERT uses MLM as a training objective where we mask only 15% of the tokens. Hence, the token signals to the models are aonly 15% of the tokens, since it only predicts those masked tokens. Whereas in ELECTRA, all tokens are training signals since the models classifies whether given tokens are original or replaced.
from transformers import ElectraTokenizer, ElectraModel
import torch
import torch.nn.functional as F
descriminator = ElectraModel.from_pretrained("google/electra-small-discriminator")
tokenizer = ElectraTokenizer.from_pretrained("google/electra-small-discriminator")
original_text = "The quick brown fox jumps over the lazy dog"
repalced_text = "The quick brown duck jumps over the lazy dog"
inputs = tokenizer(repalced_text, return_tensors="pt")
output = descriminator(**inputs)
logits = output.last_hidden_state
probs = F.softmax(logits, dim=-1)
We can use FFN for classifier to classify whehter the token is original or replaced.
SpanBERT:
Its is mostly used for question answering, where we predict the span of text. Lets see it in action.
span_example = "You are expected to know the laws of your country"
tokens = span_example.split(" ")
tokens
['You', 'are', 'expected', 'to', 'know', 'the', 'laws', 'of', 'your', 'country']
lets mask the tokens. Instead of randomly masking, we will do random contiguous span of tokens.
for i in range(5,9,1):
tokens[i] = "[MASK]"
tokens
['You', 'are', 'expected', 'to', 'know', '[MASK]', '[MASK]', '[MASK]', '[MASK]', 'country']
span boundary objective (SBO):
We will train SpanBERT with MLM objective along with SBO. The SBO only uses the toekn present in the span boundary. The boundary includes immediate tokens that are before the start of the span and after the end of span. In our example, know
and country
are the span boundary token.
How does the model distinguishes token $X_6$ to $X_7$ since we have same span boundary?
It uses position encoding of the masked token. To predict $X_7$, we use Know
$R_5$ and country
$R_10$ as span boundary and position embedding of masked token i.e. $P_2$.
QnA with pretrained SpanBERT
qa_pipeline = pipeline(
"question-answering",
model="mrm8488/spanbert-finetuned-squadv2",
tokenizer="mrm8488/spanbert-finetuned-squadv2"
)
config.json: 100%|██████████| 493/493 [00:00<00:00, 329kB/s] model.safetensors: 100%|██████████| 433M/433M [01:00<00:00, 7.15MB/s] Some weights of the model checkpoint at mrm8488/spanbert-finetuned-squadv2 were not used when initializing BertForQuestionAnswering: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight'] - This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). tokenizer_config.json: 100%|██████████| 24.0/24.0 [00:00<00:00, 11.6kB/s] vocab.txt: 100%|██████████| 213k/213k [00:00<00:00, 2.94MB/s] special_tokens_map.json: 100%|██████████| 112/112 [00:00<00:00, 58.5kB/s]
results = qa_pipeline({
'question': "What is machine learning?",
'context': "Machine learning is a subset of artificial intelligence. It is widely for creating a variety of applications such as email filtering and computer vision"
})
results["answer"]
'a subset of artificial intelligence'
Summary:
ALBERT: a lite version of BERT. It uses cross-layer parameter sharing and factorized embedding parameterization techniques to reduce models parameter. It is also used un Sentence Order Prediction (SOP) as a classification task.
RoBERTa: a variant of BERT which uses dynamic masking techniques only for MLM task. It uses a large batch size for speed and performance. In addition, it uses the BBPE as a tokenizer with vocabulary size of 50K.
ELECTRA: Instead of MLM task, it detects fake/replaced tokens in sentence.
SpanBeERT: it uses the MLM objective and SBO objective to predict the masked tokens. It uses the positional embedding to predict tokens just by using immediate boundary tokens.