This is basic example to explain concepts.
from transformers import BertModel, BertTokenizer
import torch
/home/whiskey/miniconda3/envs/nlp/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
lets download and load the bert-base-uncased
model.
model = BertModel.from_pretrained("bert-base-uncased", output_hidden_states=True)
Download and load the tokenizer which was used in pre-train bert-base-uncased
model.
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
Preprocessing the input
example = "i love kathmandu"
tokens = tokenizer.tokenize(example)
tokens
['i', 'love', 'kathmandu']
Lets add additional tokens, [CLS] [SEP]
tokens = ["[CLS]"] + tokens + ["[SEP]"]
tokens
['[CLS]', 'i', 'love', 'kathmandu', '[SEP]']
We need to add padding tokens because we have token length is 7 but actual token is 5.
tokens = tokens + ["[PAD]", "[PAD]"]
tokens
['[CLS]', 'i', 'love', 'kathmandu', '[SEP]', '[PAD]', '[PAD]']
Now create a attention mask for actual token (1) and padding (0).
attn_mask = [0 if token == "[PAD]" else 1 for token in tokens]
attn_mask
[1, 1, 1, 1, 1, 0, 0]
lets convert tokens to their IDs.
token_ids = tokenizer.convert_tokens_to_ids(tokens)
token_ids
[101, 1045, 2293, 28045, 102, 0, 0]
token_ids = torch.tensor(token_ids).unsqueeze(0)
attn_mask = torch.tensor(attn_mask).unsqueeze(0)
print(f"token IDs: {token_ids.shape}, attention mask: {attn_mask.shape}")
token IDs: torch.Size([1, 7]), attention mask: torch.Size([1, 7])
The Embedding
hidden_rep
: he representation of all the tokens obtained from the final encoder (12)
pooler_output
: representation of the [CLS] token.
output = model(token_ids, attention_mask= attn_mask)
last_hidden_state= output.last_hidden_state
pooler_output = output.pooler_output
hidden_states = output.hidden_states
print(f"last hidden:{last_hidden_state.shape} hidden layers:{len(hidden_states)}, CLS: {pooler_output.shape}")
last hidden:torch.Size([1, 7, 768]) hidden layers:13, CLS: torch.Size([1, 768])
The size [1,7,768]
indicates [batch_size, sequence_length, hidden_size]
. The size [1,768]
indicates [batch_size, hidden_size]
.
print(f"last layer, representation of 'love'\n{last_hidden_state[0][1]}")
last layer, representation of 'love' tensor([ 3.1862e-01, 5.0867e-01, 1.4940e-01, -2.8184e-01, 1.0760e-02, 4.1292e-01, -1.1322e-01, 1.3862e+00, -2.2813e-01, -4.7011e-01, -2.5818e-01, -6.8372e-01, 1.7890e-01, 5.9333e-01, 2.3110e-01, 6.0013e-01, 6.1079e-01, 1.1835e-01, 3.0312e-01, 4.7956e-01, -9.0474e-02, -4.7122e-01, -1.1157e+00, 2.8013e-01, 5.1831e-01, 3.5709e-01, -4.4087e-01, -1.6675e-01, 2.4361e-01, -1.4690e-01, 2.2097e-01, -1.4364e-01, -1.5772e-01, 2.3723e-01, -7.6925e-01, -1.7788e-01, -8.3212e-02, -5.7498e-02, -6.6200e-01, -8.1028e-02, 6.1719e-02, -7.3170e-01, -1.1594e-01, -6.9784e-01, -2.7962e-02, -9.3233e-01, -7.4641e-02, 3.9156e-02, 6.4064e-01, -1.0957e+00, 1.2172e-02, 1.2453e-01, 2.2704e-01, 3.5148e-01, -2.6128e-01, 1.3611e+00, -3.4164e-02, -6.2567e-01, -1.7320e-01, 4.1086e-01, -2.2863e-01, 2.2597e-01, 8.0072e-01, -8.1334e-01, -4.6740e-01, 8.9285e-01, -4.3471e-01, -1.6304e-01, -1.6450e-01, -9.1366e-01, 9.6539e-01, 4.1627e-02, 1.2581e-01, -1.0572e-01, -3.9535e-01, -1.5666e-01, -1.0576e+00, 8.1594e-01, 1.3039e-01, -1.2229e-01, -1.0086e-01, 1.4630e+00, -1.2300e+00, 5.4424e-01, -5.4054e-01, -2.3735e-02, -2.3277e-01, 4.4434e-01, -4.1168e-01, -5.0603e-01, -1.7757e-01, -4.4975e-01, -8.5226e-02, 6.0133e-01, -2.1632e-01, -1.1666e+00, -6.5839e-01, -2.4942e-01, -9.7081e-02, -1.9084e-01, -2.3861e-01, -2.9358e-01, 7.2426e-02, -1.2594e-01, -9.6252e-01, 7.0934e-02, 2.4492e-02, -7.5538e-02, 1.1916e+00, -8.4068e-01, 1.6002e-01, -6.8600e-01, -2.4405e-01, -1.2449e-01, -1.6900e-02, 5.4088e-01, 3.4330e-01, -2.4779e-01, -2.6710e-01, -5.1212e-02, -1.4777e-01, 1.5025e-01, 6.1416e-01, 4.2979e-01, 3.6097e-01, -7.3596e-01, -1.7493e-01, 4.3301e-01, -4.4997e-01, -7.4958e-01, -3.9143e-01, 4.8851e-01, -2.0442e-01, -9.0376e-01, 5.4466e-01, 1.3276e-01, 3.0045e-01, 3.0290e-01, 1.9902e-01, -6.8191e-01, 2.9952e-01, 9.4701e-02, 2.6000e-01, 3.8404e-01, 4.4243e-01, 6.2018e-02, -4.0736e-01, 5.5149e-01, 8.0305e-01, -6.0269e-02, 6.1119e-01, -7.7568e-02, -2.6159e-01, -2.1779e-01, -6.7564e-01, -4.5183e-01, -5.3268e-01, 6.3423e-01, 7.3586e-01, 8.1137e-02, -7.7695e-02, -2.4448e-01, -1.6764e-01, -1.6911e-01, -2.7556e-01, -3.1078e-01, 3.0063e-01, 4.0683e-01, 2.1039e-01, -6.1488e-01, 3.6530e-02, -7.5048e-01, 1.9977e-01, 7.7794e-01, 7.7518e-01, -6.7273e-01, 4.3956e-01, 7.7076e-01, 1.5638e-01, 3.8009e-01, -3.4522e-01, 4.0633e-01, 5.0821e-02, -7.9294e-02, -5.8968e-01, -1.6253e-01, 4.3255e-01, -4.1972e-01, 3.7828e-01, 9.3417e-01, 1.9317e-01, 4.2900e-01, 4.6161e-02, 5.3920e-01, 1.2338e-01, -4.3318e-01, 2.2601e-01, -3.0456e-01, 3.3208e-01, 8.1517e-01, 4.9641e-01, -2.8837e-01, 1.9259e-01, -1.9196e-01, -2.8835e-01, 8.2389e-01, 2.1382e-01, -5.0230e-01, -2.8033e-01, -1.4386e-01, -2.4066e-01, 4.5183e-01, 2.9018e-01, -3.3238e-01, 7.6613e-01, -2.2858e-01, 2.1562e-01, -1.5033e-02, -4.3519e-02, 4.6809e-01, 3.5235e-02, -5.7315e-01, -1.2247e+00, 8.7446e-01, -5.3191e-01, 3.9934e-01, -1.5189e-01, 1.5814e-01, -5.2017e-01, 2.7776e-02, -1.0298e+00, -1.1225e+00, 2.3432e-01, 5.0779e-01, 7.4528e-01, 9.1145e-01, -7.1726e-01, -4.3346e-01, -4.0445e-01, 1.6764e-01, 7.5148e-01, 4.5240e-01, 3.9794e-01, 9.7296e-01, -9.3493e-02, 5.5920e-02, 2.2008e-01, -4.7061e-01, -5.5885e-01, 2.7293e-01, -2.2021e-01, -1.1468e+00, -6.9024e-01, -2.8018e-01, -5.0438e-01, -1.0352e-01, -1.6996e-01, 7.0182e-01, -1.7532e-01, 4.2186e-01, -4.0852e-01, -1.6920e-01, 1.0813e+00, -2.5174e-01, -1.1453e-01, -4.3818e-01, -9.9075e-01, -1.3550e-01, 6.5167e-02, 7.6289e-01, -6.1525e-01, -2.2155e-01, 9.1627e-02, 5.9597e-01, -3.9134e-01, -4.5271e-01, 9.7105e-01, 1.5274e-02, -1.6970e-01, -7.3324e-01, 1.1416e-01, 8.7804e-01, -1.1139e-01, 4.6074e-01, 4.5102e-01, -4.3742e-01, -2.0925e-01, -2.7499e-01, -3.3725e-01, -5.6249e-01, -2.4529e-01, 3.3172e-01, 5.0531e-02, -4.5348e-01, 4.3045e-01, -9.9580e-02, 4.9946e-01, 2.0121e-01, 1.4416e-01, -2.9800e-01, -1.1358e+00, -4.7961e-02, -3.1552e-01, 8.3810e-01, 4.7140e-01, -1.1135e-01, -6.4007e-03, -9.7005e-02, -4.6666e+00, 8.6160e-02, -1.0469e-01, -4.7429e-01, 7.5138e-01, -5.5269e-01, -2.6248e-01, 3.0869e-01, -4.2183e-01, -5.9321e-01, 4.2071e-01, -6.1134e-01, 1.3827e-01, -2.2363e-01, 4.1900e-01, -1.0203e+00, 3.2149e-01, -5.9846e-01, -7.0980e-01, 1.3477e+00, -3.7652e-01, 2.1615e-01, 2.8045e-01, -9.4238e-01, -2.8517e-01, 5.9093e-01, -5.8326e-02, 4.0520e-01, -8.4271e-01, 6.3854e-01, -6.4629e-01, 3.2763e-02, 3.2824e-01, -5.2670e-02, 3.8074e-01, 2.7017e-01, -3.2644e-01, 1.6558e-01, 9.8959e-02, -9.4526e-01, -1.1275e-02, -5.0865e-01, -1.6080e-01, -2.5150e-01, 3.0774e-01, 2.4781e-01, -6.9735e-01, -6.5072e-01, -3.5622e-01, 2.3935e-01, -5.4576e-01, -4.7295e-01, 8.8449e-01, -2.4570e-01, 8.7697e-02, 6.0601e-01, 3.5812e-01, 1.1248e-01, 2.6865e-02, -2.0416e-01, 5.7046e-01, 1.1026e-01, -9.1923e-01, 3.8365e-02, 7.4815e-01, -6.5813e-02, -6.3228e-01, -1.2311e-01, 8.0985e-02, 9.5490e-01, 2.1679e-01, -3.1574e-01, 1.9691e-01, -1.0650e+00, -5.7394e-01, -3.4685e-02, -7.5367e-02, -5.7165e-01, 2.4795e-01, -4.0776e-01, -6.5695e-01, -2.1715e-01, 1.6921e-01, -3.1418e-01, -2.6184e-01, -6.7987e-01, 1.0747e-01, 4.5096e-02, 5.1449e-01, 2.7659e-01, 1.5968e-01, -1.7073e-03, 5.6223e-01, 1.2407e-01, 2.6819e-01, 1.7530e-01, 8.8831e-01, -6.7690e-01, 3.0574e-01, -8.6536e-02, 5.6855e-01, -2.1340e-01, -7.0454e-02, 4.3164e-01, 1.9373e-01, -2.2988e-01, -1.1271e+00, 3.8047e-01, 1.9446e-01, -1.4082e-01, -6.3198e-02, -2.7196e-01, 8.4955e-02, -5.7896e-01, -4.0740e-01, -6.7326e-01, -4.5778e-02, 5.8528e-01, -2.4662e-01, -1.8599e-01, -1.1088e-01, 7.8608e-01, -3.2339e-01, -4.9305e-01, -4.5717e-01, 3.3706e-01, 4.5208e-01, -6.0603e-01, -7.2381e-01, -3.3287e-01, 1.8222e-01, 1.8102e-01, 1.6536e-01, 1.3156e-01, -1.2076e-01, 1.9233e-01, -7.9349e-01, -4.6311e-01, 1.6354e-01, -1.0870e-01, 5.8225e-02, 3.6691e-01, -4.2800e-01, 5.8694e-01, -1.8585e-01, 8.8113e-01, -7.4489e-01, -6.3407e-01, -3.9415e-01, 3.3960e-01, 3.1307e-01, -1.8274e-01, 7.1520e-02, -2.7720e-01, 5.3749e-01, -1.1966e-01, 3.8869e-01, -1.1869e+00, -3.3095e-02, -1.9177e-01, -4.8460e-01, 1.8028e-01, -1.2386e-01, 5.0279e-01, 8.8191e-01, 6.9002e-01, 5.0999e-01, 3.2240e-01, -2.6183e-01, -5.8290e-01, -3.8619e-01, -9.2705e-02, 2.3294e-01, -5.7509e-01, -2.6123e-01, 3.5139e-01, -4.1708e-01, -4.2239e-01, -2.3083e-01, 8.1954e-03, -4.9662e-01, -4.7840e-01, 3.8759e-02, 5.5232e-01, 5.2939e-01, 4.6774e-01, -3.8046e-01, -6.9474e-02, 4.2584e-01, 8.8188e-01, -5.4262e-01, -9.4521e-03, -2.4057e-01, -5.1243e-01, -9.7442e-01, -9.7313e-01, 9.6955e-01, -8.8934e-01, 2.8327e-01, 7.2861e-01, 5.0486e-01, 2.0997e-01, -4.0938e-01, 1.1308e-01, -4.7488e-01, -1.7587e-01, -7.8031e-02, -3.8377e-02, -3.4457e-01, 5.3594e-01, -3.1236e-01, -1.3837e-01, -2.6020e-01, 3.5784e-01, 1.8997e-01, -8.5632e-01, 6.9272e-01, -6.6168e-02, -2.8632e-01, 1.8989e-02, -1.9501e-01, 2.7093e-01, 2.5045e-01, -1.7063e-02, -1.4374e-01, 9.4995e-01, -3.9204e-01, -2.9093e-01, -5.4969e-01, -3.4450e-01, 5.8438e-01, -1.5057e-01, 7.9860e-01, 5.3157e-01, 6.3988e-01, -4.0382e-01, 3.7626e-03, -3.5091e-01, 1.4385e-01, -6.3540e-02, 2.1004e-01, -6.5894e-01, 3.1622e-01, 6.5557e-02, -2.5855e-01, -7.5161e-01, -5.2992e-01, 1.9083e-01, -2.0577e-01, 6.7220e-01, -4.1980e-01, -9.5139e-02, -5.5717e-03, -5.1454e-01, 1.3873e-01, -1.9702e-01, -7.2888e-01, 2.6505e-01, 9.2675e-03, 3.8930e-01, -1.3924e-01, 1.4359e-01, 8.9520e-02, 5.1103e-01, 2.1713e-01, 7.2420e-01, 9.3646e-01, 6.6479e-01, -3.2557e-01, 1.0497e+00, -8.8188e-02, 1.6001e-01, 4.2804e-01, -5.9942e-01, -1.1932e-01, 2.2498e-01, 2.8211e-04, -1.8749e-04, 2.2671e-01, -2.8624e-01, -6.0135e-02, 6.0868e-01, 3.7792e-02, -2.4935e-01, 9.5642e-01, -3.7062e-01, -4.2931e-01, 4.2356e-01, 2.6393e-01, 7.7339e-01, -4.0946e-01, 4.6467e-01, -1.3405e-01, 3.1842e-01, 1.4601e-01, 3.4192e-01, -3.8468e-01, 3.4263e-01, -3.5902e-01, 5.7246e-01, -1.0331e-01, 2.7986e-01, 7.5572e-01, 3.1356e-01, -6.7991e-01, 3.3968e-01, -2.7787e-03, -3.2683e-01, -1.4089e-01, -3.4801e-01, -1.7667e-01, -5.0184e-02, 8.7150e-01, 6.1842e-01, 9.5388e-02, -4.5136e-01, 8.1731e-01, -5.1404e-01, -3.3745e-01, 3.7875e-02, -3.3433e-01, -4.3818e-01, 7.9624e-01, 3.9827e-01, -9.9076e-02, 2.7977e-01, -8.6621e-01, 4.1664e-01, 4.5073e-01, 7.8351e-01, -8.4463e-02, -3.5700e-01, -1.0421e-01, 8.3446e-01, -1.4272e-01, 3.2673e-01, -1.4472e-01, -9.9517e-01, -1.1686e-01, 7.0524e-01, 6.5207e-02, 6.3838e-02, 8.0976e-01, -2.6855e-02, -4.8215e-01, 1.4215e-02, -1.7119e-01, -3.8744e-01, 9.5679e-01, -3.2997e-01, -2.6818e-01, 4.7255e-01, -2.3883e-01, -8.2873e-01, -2.6071e-02, 6.4115e-01, 5.3056e-01, 5.1954e-01, 6.1046e-02, -7.2385e-01, -3.0921e-02, 5.1194e-02, 8.0109e-01, -1.2863e-01, -3.4957e-02, 3.9325e-01, 3.5836e-01, -5.4677e-02, -2.6197e-01, -5.4597e-01, -2.4615e-01, -1.9462e-01, 2.7137e-01, 8.7353e-01, -2.9035e-01, 7.8191e-01, 4.2285e-02, -5.7310e-01, -3.7404e-02, -1.7615e-01, 2.8678e-01, 5.3713e-01, -7.9005e-01, -4.1035e-01, -3.9232e-01, -1.0510e+00, -3.2326e-01, -1.0810e+00, 9.1602e-01, 3.6374e-01, 1.3233e+00, -1.1201e-01, 3.5219e-01, -5.5904e-01, 2.5476e-01, -3.9654e-01, 1.2394e-01, 3.1380e-01, 1.5336e-01, 8.5311e-02, 3.8395e-01, 5.4818e-01, -3.2665e-01, 2.4916e-01, -3.1327e-01, 6.4903e-01, -4.6011e-01, 5.3281e-01, 2.8596e-02, -9.2876e-01, 8.4156e-01, 1.1013e+00, -4.7625e-01, 3.1146e-01, -6.6659e-02, 1.1244e-01, 4.9979e-01, -1.5115e-01, -1.0703e-01, -7.9332e-01, 1.7420e-01, -4.4306e-01, 1.9569e-01, 6.6357e-01, -1.4642e-01, -2.5401e-01, -2.2001e-01, 9.5454e-01, 1.6118e-01, -4.9574e-01, -5.2578e-01, 1.4362e-01, 7.2309e-02, 1.8952e-01, 4.8230e-01, -8.4709e-01, 2.3703e-01, -3.2139e-01, -2.9955e-02, -4.5533e-01, 3.6851e-01, 3.6113e-01, -8.6004e-02, -6.7998e-01, -4.8985e-01, 3.6725e-01, 1.3295e-02, -6.1268e-01, 6.8871e-01, -5.3720e-01, 7.2303e-01, -5.6271e-01, 6.4017e-01, -1.4560e-01, 2.0167e-01, -2.3267e-02, -5.5859e-04, -4.8885e-01, 5.9847e-01, 3.0279e-01], grad_fn=<SelectBackward0>)
pooler_output
is the entire input sequence generated by the BERT model. Specifically, it's the output of the "pooler" layer, which is typically a fully connected layer with a tanh activation function applied to it. In general, it holds the aggregate representation of the sentence, so we can use cls_head as the representation of the sentence i love kathmandu
.
print(f"pooler output.i.e representation of 'I love kathmandu'. \n{pooler_output}")
pooler output.i.e representation of 'I love kathmandu'. tensor([[-0.8357, -0.1822, 0.4747, 0.6436, -0.2796, -0.1092, 0.8457, 0.1583, 0.1907, -0.9996, 0.1853, 0.1747, 0.9669, -0.1852, 0.9110, -0.4437, -0.0987, -0.4950, 0.3837, -0.7624, 0.5009, 0.7871, 0.5610, 0.1600, 0.3615, 0.3753, -0.4926, 0.8958, 0.9266, 0.6399, -0.6181, 0.0950, -0.9717, -0.1272, 0.3411, -0.9715, 0.1475, -0.7362, 0.0306, 0.0822, -0.8589, 0.2027, 0.9935, -0.0741, -0.0652, -0.2863, -0.9996, 0.1983, -0.8285, -0.3280, -0.2946, -0.5097, 0.1277, 0.3611, 0.3224, 0.2838, -0.1187, 0.1056, -0.0781, -0.4620, -0.4998, 0.2000, 0.0673, -0.8527, -0.2595, -0.5242, -0.0076, -0.1295, -0.0038, -0.0472, 0.7797, 0.1425, 0.3412, -0.7356, -0.3774, 0.1124, -0.3790, 1.0000, -0.3603, -0.9575, -0.4016, -0.3286, 0.2985, 0.5956, -0.4848, -1.0000, 0.2178, -0.0367, -0.9782, 0.1734, 0.2324, -0.1238, -0.6653, 0.3045, -0.0755, -0.0893, -0.1990, 0.3857, -0.1061, 0.0060, 0.0086, -0.1714, 0.1197, -0.2372, 0.1262, -0.2163, -0.4496, 0.0401, -0.2700, 0.5711, 0.2220, -0.1947, 0.2412, -0.9343, 0.5700, -0.1493, -0.9637, -0.3257, -0.9756, 0.6032, 0.1781, -0.0780, 0.9453, 0.5571, 0.1407, 0.0362, 0.3646, -1.0000, -0.3394, 0.0233, 0.2274, -0.0698, -0.9545, -0.9000, 0.4713, 0.9317, 0.0066, 0.9901, -0.1645, 0.8779, 0.1500, -0.0391, -0.2480, -0.2947, 0.1984, 0.2943, -0.6335, 0.1586, 0.1640, 0.0066, 0.1490, -0.2045, 0.3069, -0.8929, -0.3562, 0.9205, 0.3154, 0.3515, 0.6472, -0.1485, -0.2634, 0.7564, 0.1758, 0.2387, 0.0654, 0.2956, -0.2752, 0.4048, -0.7578, 0.1234, 0.3163, -0.1411, 0.5326, -0.9559, -0.2317, 0.3995, 0.9758, 0.6521, 0.1369, 0.0850, -0.1517, 0.2689, -0.9010, 0.9546, -0.1763, 0.2290, 0.5736, -0.2094, -0.8271, -0.4318, 0.7565, -0.1056, -0.8077, 0.0939, -0.4023, -0.3229, 0.3441, 0.4487, -0.2178, -0.3261, 0.0241, 0.8773, 0.9459, 0.8034, -0.5759, 0.4446, -0.8731, -0.3450, 0.1232, 0.1715, 0.1602, 0.9859, 0.2193, -0.1262, -0.9025, -0.9735, -0.0249, -0.8609, 0.0312, -0.5592, 0.1535, 0.6513, -0.1058, 0.3437, -0.9785, -0.7176, 0.2701, -0.1294, 0.3077, -0.1899, 0.2584, -0.2188, -0.4242, 0.7890, 0.7613, 0.6384, -0.6211, 0.7924, -0.1569, 0.8204, -0.4833, 0.9555, -0.1928, 0.4865, -0.8961, 0.5040, -0.8705, 0.2597, -0.0485, -0.6772, -0.2114, 0.3194, 0.2365, 0.8293, -0.4225, 0.9941, -0.2830, -0.9169, 0.4675, -0.0909, -0.9690, -0.2601, 0.1455, -0.6810, -0.2796, -0.2128, -0.9236, 0.8716, 0.0920, 0.9778, 0.1281, -0.8849, -0.2400, -0.8457, -0.2042, -0.0414, 0.6000, -0.2036, -0.9297, 0.3387, 0.4034, 0.3005, 0.5991, 0.9923, 0.9932, 0.9581, 0.8348, 0.8566, -0.5704, 0.0602, 0.9998, -0.3528, -0.9998, -0.9052, -0.4421, 0.2460, -1.0000, -0.0200, 0.0821, -0.8786, -0.4253, 0.9661, 0.9825, -1.0000, 0.8029, 0.9152, -0.4052, 0.1657, -0.1109, 0.9520, 0.2111, 0.2441, -0.1008, 0.2060, 0.2395, -0.7713, 0.4306, 0.4774, 0.1620, 0.1018, -0.5417, -0.9116, -0.1347, -0.0955, -0.2447, -0.9207, -0.0161, -0.2207, 0.5460, 0.0637, 0.1019, -0.7694, 0.1332, -0.7393, 0.3748, 0.4445, -0.9078, -0.5733, -0.1475, -0.4459, 0.3976, -0.9046, 0.9533, -0.1851, -0.0272, 1.0000, -0.2259, -0.8401, 0.1765, 0.1236, 0.0120, 1.0000, 0.3590, -0.9559, -0.2993, 0.1456, -0.2983, -0.2820, 0.9947, -0.1810, 0.3798, 0.5133, 0.9366, -0.9749, -0.1349, -0.8788, -0.9393, 0.9284, 0.8910, -0.0110, -0.5099, 0.0026, 0.1924, 0.1650, -0.9494, 0.5135, 0.3702, -0.1057, 0.8645, -0.8382, -0.3240, 0.3474, 0.1986, 0.3930, -0.3834, 0.3965, -0.1664, 0.1177, -0.1960, 0.0754, -0.9521, -0.0400, 1.0000, 0.1457, -0.4794, -0.1106, -0.0228, -0.3202, 0.2443, 0.2820, -0.2128, -0.7773, -0.1090, -0.9080, -0.9657, 0.6889, 0.0778, -0.1594, 0.9968, 0.2527, 0.1082, -0.1482, 0.0633, -0.0750, 0.4864, -0.5328, 0.9487, -0.1227, 0.3059, 0.7986, 0.3730, -0.2166, -0.5505, 0.0077, -0.8476, -0.0030, -0.9217, 0.9377, -0.3854, 0.2321, 0.0428, -0.2700, 1.0000, 0.4203, 0.4848, -0.5653, 0.8460, -0.5559, -0.6931, -0.2780, 0.0789, 0.4914, -0.1549, 0.1516, -0.9496, -0.3756, -0.1492, -0.9686, -0.9818, 0.4957, 0.6610, 0.0145, -0.0433, -0.5876, -0.4846, 0.1688, -0.1218, -0.9168, 0.5548, -0.1388, 0.3953, -0.1332, 0.3244, -0.5027, 0.7807, 0.6314, 0.2511, 0.0779, -0.7110, 0.7183, -0.7626, 0.3074, -0.0645, 1.0000, -0.3013, -0.2689, 0.6639, 0.6914, 0.0063, 0.1015, -0.3817, 0.0796, 0.4610, 0.4566, -0.8046, -0.2481, 0.4441, -0.5994, -0.5116, 0.6993, -0.1261, 0.0455, 0.1239, -0.0202, 0.9981, -0.1739, -0.0759, -0.3944, 0.0818, -0.2071, -0.5883, 0.9999, 0.2985, -0.1575, -0.9798, 0.3835, -0.8701, 0.9841, 0.7024, -0.7812, 0.3578, 0.2949, -0.0303, 0.7199, -0.1156, -0.1210, 0.0991, 0.0938, 0.9400, -0.3120, -0.9297, -0.4143, 0.2790, -0.9366, 0.6300, -0.4139, -0.0782, -0.1882, 0.3217, 0.8082, -0.0626, -0.9611, -0.1016, -0.0329, 0.9454, 0.0739, -0.3345, -0.8701, -0.4491, -0.2535, 0.4966, -0.9060, 0.9429, -0.9734, 0.3720, 0.9999, 0.1690, -0.7277, 0.1247, -0.2867, 0.1213, 0.3556, 0.4100, -0.9293, -0.1903, -0.1115, 0.1854, -0.1383, 0.3850, 0.5818, 0.0977, -0.2679, -0.4341, -0.0303, 0.3145, 0.6554, -0.2225, -0.0471, 0.0594, -0.0933, -0.8966, -0.1226, -0.1466, -0.9069, 0.5453, -1.0000, -0.3925, -0.5733, -0.1923, 0.7422, -0.0759, -0.1191, -0.6316, 0.4116, 0.8261, 0.6735, -0.1739, 0.0073, -0.6623, 0.1041, -0.0090, 0.2303, 0.1850, 0.6571, -0.1187, 1.0000, 0.0425, -0.3779, -0.9520, 0.1834, -0.1619, 0.9974, -0.8624, -0.8984, 0.1848, -0.2696, -0.7284, 0.1267, -0.0695, -0.5209, 0.1273, 0.9360, 0.8511, -0.3208, 0.2976, -0.2469, -0.3425, 0.0311, -0.5138, 0.9711, 0.1189, 0.8520, 0.4906, 0.1186, 0.9322, 0.1485, 0.6336, 0.0614, 0.9999, 0.1679, -0.8800, 0.4211, -0.9740, -0.1092, -0.9395, 0.1832, 0.0115, 0.8338, -0.1513, 0.9407, 0.5031, 0.1108, 0.0315, 0.6071, 0.2934, -0.8703, -0.9707, -0.9766, 0.2818, -0.3850, -0.0065, 0.2401, 0.1317, 0.2421, 0.2649, -0.9998, 0.8918, 0.3107, -0.4011, 0.9412, 0.1921, 0.1823, 0.1558, -0.9718, -0.9556, -0.2409, -0.2092, 0.7035, 0.5164, 0.7444, 0.3120, -0.4264, -0.0232, 0.5036, -0.2250, -0.9836, 0.2923, 0.2831, -0.9512, 0.9228, -0.5638, -0.1392, 0.5769, 0.3294, 0.9146, 0.6901, 0.4188, 0.1131, 0.3818, 0.8430, 0.9227, 0.9757, 0.3712, 0.7213, 0.2551, 0.2250, 0.2202, -0.8909, 0.0254, -0.1250, 0.0541, 0.1357, -0.1238, -0.9539, 0.2963, -0.0436, 0.2701, -0.2931, 0.1724, -0.3304, -0.1554, -0.6303, -0.3313, 0.4272, 0.2015, 0.8716, 0.0547, -0.0267, -0.5496, -0.0929, 0.4459, -0.8612, 0.8963, 0.0137, 0.4834, -0.4487, -0.0741, 0.3976, -0.3935, -0.2915, -0.1870, -0.6522, 0.8120, 0.0012, -0.3690, -0.3210, 0.5532, 0.2246, 0.8582, 0.3838, 0.1747, 0.0179, -0.1084, 0.1621, -0.0988, -0.9998, 0.3584, 0.2679, -0.3447, 0.1248, -0.5345, -0.0385, -0.9605, -0.0186, -0.2772, -0.4289, -0.4703, -0.3578, 0.3284, 0.2124, -0.0556, 0.8367, 0.1933, 0.6517, 0.3827, 0.4744, -0.5819, 0.8273]], grad_fn=<TanhBackward0>)
These embedings are taken from topmost encoder layer of BERT, i.e. 12 layer. Can we extract the embeddings from all the encoding layers? YES
output.hidden_states
(tensor([[[ 1.6855e-01, -2.8577e-01, -3.2613e-01, ..., -2.7571e-02, 3.8253e-02, 1.6400e-01], [-3.4027e-04, 5.3974e-01, -2.8805e-01, ..., 7.5731e-01, 8.9008e-01, 1.6575e-01], [ 1.1558e+00, 8.5331e-02, -1.1208e-01, ..., 4.3965e-01, 8.5903e-01, -3.2685e-01], ..., [-3.6430e-01, -1.6172e-01, 9.0174e-02, ..., -1.7849e-01, 1.2818e-01, -4.5116e-02], [ 1.6776e-01, -8.9038e-01, -3.1798e-01, ..., 3.1737e-02, 6.4863e-02, 1.8418e-01], [ 3.5675e-01, -8.7076e-01, -3.7621e-01, ..., -7.9391e-02, -1.0115e-01, 1.9312e-01]]], grad_fn=<NativeLayerNormBackward0>), tensor([[[ 0.1616, 0.0459, -0.0789, ..., -0.0204, 0.1398, 0.0676], [ 0.7336, 0.9462, -0.2094, ..., 0.4195, 0.7503, -0.0092], [ 1.4193, 0.6091, 0.3613, ..., 0.4949, 0.7861, -0.0039], ..., [-0.0169, 0.2284, 0.0774, ..., -0.2815, 0.5997, 0.1019], [-0.1388, -0.8165, 0.3362, ..., 0.3381, 0.3591, -0.3433], [ 0.0269, -0.7895, 0.2525, ..., 0.2607, 0.2141, -0.3796]]], grad_fn=<NativeLayerNormBackward0>), tensor([[[ 0.0480, -0.1574, -0.1225, ..., 0.0428, 0.0960, 0.0974], [ 0.6310, 1.0092, 0.3655, ..., 0.4078, 0.2656, -0.1203], [ 2.0291, 0.8538, 0.9834, ..., 0.9003, 0.4365, -0.1713], ..., [-0.1134, 0.0539, 0.2443, ..., -0.0801, 0.4949, 0.0048], [-0.1568, -0.3104, 0.2979, ..., 0.8824, -0.0709, -0.2684], [-0.1532, -0.3178, 0.2342, ..., 0.7318, -0.1784, -0.3049]]], grad_fn=<NativeLayerNormBackward0>), tensor([[[ 0.0042, -0.2541, -0.0267, ..., 0.2045, 0.0852, 0.2579], [ 0.7227, 0.6391, 0.6960, ..., 0.6194, -0.0594, -0.0120], [ 2.5173, 0.1945, 1.1035, ..., 0.5088, 0.2563, -0.2985], ..., [-0.0731, -0.0839, 0.1313, ..., 0.0441, 0.0764, 0.0078], [-0.3059, -0.3007, 0.6740, ..., 0.8598, 0.0677, -0.4079], [-0.2730, -0.2796, 0.6555, ..., 0.6897, -0.0085, -0.3978]]], grad_fn=<NativeLayerNormBackward0>), tensor([[[ 0.1170, -0.4117, -0.5989, ..., 0.4627, 0.0961, 0.5678], [ 1.1077, 0.6526, 0.6694, ..., 0.5790, -0.1841, -0.4510], [ 2.3473, -0.0616, 0.5843, ..., 0.5617, -0.3385, 0.0799], ..., [-0.0380, -0.0509, 0.0145, ..., 0.0044, 0.0548, -0.0264], [-0.2247, -0.6068, 0.4737, ..., 0.7821, -0.1056, -0.4011], [-0.3516, -0.7095, 0.6233, ..., 0.5555, -0.3046, -0.2072]]], grad_fn=<NativeLayerNormBackward0>), tensor([[[ 0.0552, -0.3304, -0.7737, ..., 0.0372, 0.2087, 0.7120], [ 0.8668, -0.1016, 0.3839, ..., 0.3220, -0.0550, -0.0919], [ 1.9763, 0.1825, 0.6982, ..., 0.4447, -0.5356, 0.0066], ..., [-0.0213, -0.0395, 0.0210, ..., 0.0227, -0.0033, -0.0352], [ 0.0780, -0.4096, 0.2830, ..., 0.3291, -0.1267, -0.4650], [-0.1007, -0.5328, 0.5132, ..., 0.1434, -0.1889, -0.2615]]], grad_fn=<NativeLayerNormBackward0>), tensor([[[ 0.0397, -0.5367, -0.6432, ..., -0.1023, 0.2667, 0.6516], [ 0.7390, 0.5079, 0.3047, ..., -0.2435, 0.6199, -0.4579], [ 1.7817, 0.3081, 0.9233, ..., 0.0095, -0.3492, -0.0181], ..., [ 0.0113, -0.0383, -0.0030, ..., 0.0082, -0.0225, -0.0510], [ 0.1866, -0.1390, 0.2746, ..., 0.4320, -0.1927, -0.6178], [-0.1228, -0.3369, 0.4414, ..., 0.0941, -0.3949, -0.2522]]], grad_fn=<NativeLayerNormBackward0>), tensor([[[-0.1469, -0.3336, -0.5482, ..., 0.1259, 0.1292, 0.8209], [ 0.4745, 0.6440, 0.1888, ..., -0.1612, 0.5875, -0.6186], [ 1.6383, 0.6971, 0.9196, ..., 0.1550, -0.1455, -0.2603], ..., [-0.0220, -0.0541, -0.0105, ..., -0.0130, 0.0223, -0.0645], [ 0.0159, -0.0823, 0.2857, ..., 0.5241, -0.0141, -0.5355], [-0.3188, -0.4194, 0.1176, ..., 0.3143, -0.3524, -0.0331]]], grad_fn=<NativeLayerNormBackward0>), tensor([[[ 3.2213e-02, 4.8598e-02, -8.8377e-01, ..., -6.5647e-01, 2.1170e-01, 7.8293e-01], [ 2.2881e-01, 5.4502e-01, -2.6024e-01, ..., -2.5965e-01, 2.7225e-01, 1.2774e-01], [ 1.4240e+00, 5.4923e-01, 1.1938e+00, ..., -5.5137e-01, -1.1682e-01, -1.0709e-01], ..., [-1.2949e-03, -4.6647e-02, 3.8625e-02, ..., -4.0555e-02, -3.2852e-02, -9.3598e-02], [ 5.1433e-02, 1.1127e-01, 1.6794e-01, ..., 3.7711e-01, -2.3262e-01, -4.7220e-01], [-2.8241e-01, -1.7677e-01, -1.0904e-01, ..., 7.8062e-02, -6.4763e-01, -6.7554e-02]]], grad_fn=<NativeLayerNormBackward0>), tensor([[[ 0.0034, 0.6391, -0.8313, ..., -0.4380, -0.1068, 0.2339], [ 0.3149, 0.9895, 0.0229, ..., -0.3240, 0.2127, 0.2870], [ 1.1894, 0.7824, 1.1637, ..., -0.4862, 0.0501, -0.2613], ..., [-0.0035, -0.0096, 0.0600, ..., -0.0635, -0.0905, -0.0571], [ 0.1742, 0.5720, -0.0192, ..., 0.5380, -0.3071, -0.5322], [-0.1722, 0.2455, -0.3569, ..., 0.3873, -0.7113, -0.1223]]], grad_fn=<NativeLayerNormBackward0>), tensor([[[-0.1033, 0.6074, -0.6408, ..., -0.4095, -0.7298, 0.0868], [ 0.4959, 0.3527, 0.4731, ..., -0.2954, 0.0267, 0.0942], [ 0.8277, 0.8742, 1.0342, ..., -0.6788, 0.2928, -0.4686], ..., [ 0.0081, 0.0195, 0.0409, ..., 0.0983, -0.1053, -0.0237], [-0.0699, 0.6632, 0.1258, ..., 0.4723, -0.3174, -0.2322], [-0.3441, 0.3234, -0.0623, ..., 0.5664, -0.6259, 0.0931]]], grad_fn=<NativeLayerNormBackward0>), tensor([[[ 0.0429, 0.5304, 0.0735, ..., -0.2732, -0.2552, 0.1550], [ 0.4396, 0.3368, 0.3334, ..., -0.4953, 0.4322, 0.4333], [ 0.5574, 0.8172, 0.7482, ..., -0.7575, 0.4598, -0.2706], ..., [ 0.0034, -0.0063, -0.0107, ..., 0.0116, -0.0336, 0.0155], [ 0.2764, 0.5989, 0.5468, ..., 0.1767, -0.3204, -0.3011], [ 0.1069, 0.1951, 0.5959, ..., 0.1325, -0.5346, -0.3428]]], grad_fn=<NativeLayerNormBackward0>), tensor([[[ 0.0528, 0.3700, 0.0808, ..., -0.2243, 0.0355, 0.0747], [ 0.3186, 0.5087, 0.1494, ..., -0.4888, 0.5985, 0.3028], [ 0.5457, 0.8315, 0.8822, ..., -0.6653, 0.3154, -0.3502], ..., [ 0.6960, 0.2177, -0.0880, ..., -0.4141, -0.7644, -0.2708], [ 0.0960, 0.3780, 0.3556, ..., 0.1857, 0.0318, -0.0084], [-0.0137, 0.2465, 0.3321, ..., 0.2353, -0.0009, -0.0647]]], grad_fn=<NativeLayerNormBackward0>))
for i, layer_hidden_states in enumerate(hidden_states):
print(f"Layer {i + 1} hidden states shape: {layer_hidden_states.shape}")
Layer 1 hidden states shape: torch.Size([1, 7, 768]) Layer 2 hidden states shape: torch.Size([1, 7, 768]) Layer 3 hidden states shape: torch.Size([1, 7, 768]) Layer 4 hidden states shape: torch.Size([1, 7, 768]) Layer 5 hidden states shape: torch.Size([1, 7, 768]) Layer 6 hidden states shape: torch.Size([1, 7, 768]) Layer 7 hidden states shape: torch.Size([1, 7, 768]) Layer 8 hidden states shape: torch.Size([1, 7, 768]) Layer 9 hidden states shape: torch.Size([1, 7, 768]) Layer 10 hidden states shape: torch.Size([1, 7, 768]) Layer 11 hidden states shape: torch.Size([1, 7, 768]) Layer 12 hidden states shape: torch.Size([1, 7, 768]) Layer 13 hidden states shape: torch.Size([1, 7, 768])
Fine Tuning the BERT for downstream task: Text classification
During fine tuning, we can adjust the weights of the model in the following two ways.
- Update the weights of the pre-trained BERT model along with the classification layer.
- Update only the weights of the calssification layer not the pertrained BERT model. The model as feature extractor.
from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import torch
import numpy as np
ds = load_dataset("imdb")
ds
DatasetDict({ train: Dataset({ features: ['text', 'label'], num_rows: 25000 }) test: Dataset({ features: ['text', 'label'], num_rows: 25000 }) unsupervised: Dataset({ features: ['text', 'label'], num_rows: 50000 }) })
train = ds["train"]
test = ds["test"]
val = ds["unsupervised"]
print(f"Train: {len(train)}, Test: {len(test)}, Validation: {len(val)}")
Train: 25000, Test: 25000, Validation: 50000
Lets initialize the pre-trained BERT model.
classifier = BertForSequenceClassification.from_pretrained("bert-base-uncased")
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
input_ids = tokenizer.convert_tokens_to_ids(tokens[:-2])
input_ids
[101, 1045, 2293, 28045, 102]
Segment IDs (token type IDs):
Suppose we have two sentences in the input. In that case, segment IDs are used to distinguish one sentence from the other. All the tokens from the first sentence will be apped to 0 and all the tokens from the second sentence will be mapped to 1. Since here we have only one sentence, all the tokens will be mapped to 0 as shown below.
token_type_ids = np.zeros(5)
token_type_ids
array([0., 0., 0., 0., 0.])
attention mask: 1s for all tokens 0s for [PAD] tokens.
attention_mask = np.ones(5)
attention_mask
array([1., 1., 1., 1., 1.])
tokenizer(example)
{'input_ids': [101, 1045, 2293, 28045, 102], 'token_type_ids': [0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1]}
Lets pass four sentences, and set the maximum sequence laength to 5, padding True.
tokenizer(["I love Kathmandu", "Beautiful Nepal", "Mount Everest"], padding=True, max_length=5)
/home/whiskey/miniconda3/envs/nlp/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:2706: UserWarning: `max_length` is ignored when `padding`=`True` and there is no truncation strategy. To pad to max length, use `padding='max_length'`. warnings.warn(
{'input_ids': [[101, 1045, 2293, 28045, 102], [101, 3376, 8222, 102, 0], [101, 4057, 23914, 102, 0]], 'token_type_ids': [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1], [1, 1, 1, 1, 0], [1, 1, 1, 1, 0]]}
With all this in mind, lets create a preprocessing function.
def preprocess(data):
return tokenizer(data["text"], padding=True, truncation=True)
Lets preprocess the data using map for each dataset category.
train = train.map(preprocess, batched=True, batch_size=len(train))
test = test.map(preprocess, batched=True, batch_size=len(test))
val = val.map(preprocess, batched=True, batch_size=len(val))
set the data type to torch.
train.set_format("torch", columns=["input_ids", "attention_mask", "label"])
test.set_format("torch", columns=["input_ids", "attention_mask", "label"])
val.set_format("torch", columns=["input_ids", "attention_mask", "label"])
Training the model
batch_size = 4
epochs = 3
warmup_steps = 500
weight_decay = 0.01
training_args = TrainingArguments(
num_train_epochs=epochs,
output_dir=f"../../results",
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
warmup_steps=warmup_steps,
weight_decay=weight_decay,
logging_dir=f"../../logs"
)
trainer = Trainer(
model=classifier,
args=training_args,
train_dataset=train,
eval_dataset=test
)
trainer.train()
0%| | 1/9375 [01:35<248:42:39, 95.52s/it] 0%| | 0/18750 [00:04<?, ?it/s] 3%|▎ | 500/18750 [01:50<1:07:59, 4.47it/s]
{'loss': 0.5396, 'grad_norm': 57.49171447753906, 'learning_rate': 5e-05, 'epoch': 0.08}
5%|▌ | 1000/18750 [03:40<1:03:58, 4.62it/s]
{'loss': 0.5416, 'grad_norm': 9.111489295959473, 'learning_rate': 4.863013698630137e-05, 'epoch': 0.16}
8%|▊ | 1500/18750 [05:30<1:03:36, 4.52it/s]
{'loss': 0.4923, 'grad_norm': 0.3694137632846832, 'learning_rate': 4.726027397260274e-05, 'epoch': 0.24}
11%|█ | 2000/18750 [07:16<56:13, 4.97it/s]
{'loss': 0.4673, 'grad_norm': 25.703596115112305, 'learning_rate': 4.589041095890411e-05, 'epoch': 0.32}
13%|█▎ | 2500/18750 [08:58<54:22, 4.98it/s]
{'loss': 0.4554, 'grad_norm': 8.654304504394531, 'learning_rate': 4.452054794520548e-05, 'epoch': 0.4}
16%|█▌ | 3000/18750 [10:39<52:57, 4.96it/s]
{'loss': 0.4689, 'grad_norm': 17.73918342590332, 'learning_rate': 4.3150684931506855e-05, 'epoch': 0.48}
19%|█▊ | 3500/18750 [12:21<51:02, 4.98it/s]
{'loss': 0.4482, 'grad_norm': 0.19607552886009216, 'learning_rate': 4.1780821917808224e-05, 'epoch': 0.56}
21%|██▏ | 4000/18750 [14:02<49:28, 4.97it/s]
{'loss': 0.4312, 'grad_norm': 9.35656452178955, 'learning_rate': 4.041095890410959e-05, 'epoch': 0.64}
24%|██▍ | 4500/18750 [15:44<47:57, 4.95it/s]
{'loss': 0.4376, 'grad_norm': 1.4899685382843018, 'learning_rate': 3.904109589041096e-05, 'epoch': 0.72}
27%|██▋ | 5000/18750 [17:25<46:03, 4.98it/s]
{'loss': 0.4632, 'grad_norm': 0.2304084300994873, 'learning_rate': 3.767123287671233e-05, 'epoch': 0.8}
29%|██▉ | 5500/18750 [19:07<44:23, 4.97it/s]
{'loss': 0.3996, 'grad_norm': 25.21976089477539, 'learning_rate': 3.63013698630137e-05, 'epoch': 0.88}
32%|███▏ | 6000/18750 [20:51<46:02, 4.62it/s]
{'loss': 0.4196, 'grad_norm': 10.193493843078613, 'learning_rate': 3.493150684931507e-05, 'epoch': 0.96}
35%|███▍ | 6500/18750 [22:42<44:30, 4.59it/s]
{'loss': 0.3272, 'grad_norm': 0.05335897579789162, 'learning_rate': 3.356164383561644e-05, 'epoch': 1.04}
37%|███▋ | 7000/18750 [24:32<42:49, 4.57it/s]
{'loss': 0.2588, 'grad_norm': 0.7967578768730164, 'learning_rate': 3.219178082191781e-05, 'epoch': 1.12}
40%|████ | 7500/18750 [26:22<41:17, 4.54it/s]
{'loss': 0.2978, 'grad_norm': 48.885040283203125, 'learning_rate': 3.082191780821918e-05, 'epoch': 1.2}
43%|████▎ | 8000/18750 [28:13<38:53, 4.61it/s]
{'loss': 0.2567, 'grad_norm': 0.3560728430747986, 'learning_rate': 2.945205479452055e-05, 'epoch': 1.28}
45%|████▌ | 8500/18750 [29:57<34:15, 4.99it/s]
{'loss': 0.2406, 'grad_norm': 0.06602785736322403, 'learning_rate': 2.808219178082192e-05, 'epoch': 1.36}
48%|████▊ | 9000/18750 [31:38<32:40, 4.97it/s]
{'loss': 0.2412, 'grad_norm': 29.276275634765625, 'learning_rate': 2.671232876712329e-05, 'epoch': 1.44}
51%|█████ | 9500/18750 [33:20<30:54, 4.99it/s]
{'loss': 0.2416, 'grad_norm': 0.06330031901597977, 'learning_rate': 2.534246575342466e-05, 'epoch': 1.52}
53%|█████▎ | 10000/18750 [35:01<29:14, 4.99it/s]
{'loss': 0.2592, 'grad_norm': 0.05070902034640312, 'learning_rate': 2.3972602739726026e-05, 'epoch': 1.6}
56%|█████▌ | 10500/18750 [36:42<27:36, 4.98it/s]
{'loss': 0.2729, 'grad_norm': 81.3860855102539, 'learning_rate': 2.2602739726027396e-05, 'epoch': 1.68}
59%|█████▊ | 11000/18750 [38:24<25:58, 4.97it/s]
{'loss': 0.2288, 'grad_norm': 0.10501725971698761, 'learning_rate': 2.1232876712328768e-05, 'epoch': 1.76}
61%|██████▏ | 11500/18750 [40:05<24:20, 4.97it/s]
{'loss': 0.2385, 'grad_norm': 0.019870679825544357, 'learning_rate': 1.9863013698630137e-05, 'epoch': 1.84}
64%|██████▍ | 12000/18750 [41:47<22:35, 4.98it/s]
{'loss': 0.272, 'grad_norm': 0.07312991470098495, 'learning_rate': 1.8493150684931506e-05, 'epoch': 1.92}
67%|██████▋ | 12500/18750 [43:28<21:03, 4.95it/s]
{'loss': 0.2471, 'grad_norm': 0.13753291964530945, 'learning_rate': 1.7123287671232875e-05, 'epoch': 2.0}
69%|██████▉ | 13000/18750 [45:10<19:16, 4.97it/s]
{'loss': 0.084, 'grad_norm': 1083.43701171875, 'learning_rate': 1.5753424657534248e-05, 'epoch': 2.08}
72%|███████▏ | 13500/18750 [46:51<17:37, 4.97it/s]
{'loss': 0.123, 'grad_norm': 0.01554066687822342, 'learning_rate': 1.4383561643835617e-05, 'epoch': 2.16}
75%|███████▍ | 14000/18750 [48:32<15:54, 4.97it/s]
{'loss': 0.1462, 'grad_norm': 0.028229428455233574, 'learning_rate': 1.3013698630136986e-05, 'epoch': 2.24}
77%|███████▋ | 14500/18750 [50:14<14:17, 4.95it/s]
{'loss': 0.1217, 'grad_norm': 1.8434722423553467, 'learning_rate': 1.1643835616438355e-05, 'epoch': 2.32}
80%|████████ | 15000/18750 [51:55<12:34, 4.97it/s]
{'loss': 0.0791, 'grad_norm': 0.011949531733989716, 'learning_rate': 1.0273972602739726e-05, 'epoch': 2.4}
83%|████████▎ | 15500/18750 [53:37<10:54, 4.97it/s]
{'loss': 0.1024, 'grad_norm': 0.1625107228755951, 'learning_rate': 8.904109589041095e-06, 'epoch': 2.48}
85%|████████▌ | 16000/18750 [55:18<09:11, 4.98it/s]
{'loss': 0.1172, 'grad_norm': 0.01604519970715046, 'learning_rate': 7.5342465753424655e-06, 'epoch': 2.56}
88%|████████▊ | 16500/18750 [57:00<07:32, 4.97it/s]
{'loss': 0.0917, 'grad_norm': 0.09570044279098511, 'learning_rate': 6.1643835616438354e-06, 'epoch': 2.64}
91%|█████████ | 17000/18750 [58:41<05:52, 4.96it/s]
{'loss': 0.0847, 'grad_norm': 0.99578857421875, 'learning_rate': 4.7945205479452054e-06, 'epoch': 2.72}
93%|█████████▎| 17500/18750 [1:00:22<04:11, 4.98it/s]
{'loss': 0.0987, 'grad_norm': 0.01294002402573824, 'learning_rate': 3.4246575342465754e-06, 'epoch': 2.8}
96%|█████████▌| 18000/18750 [1:02:04<02:30, 4.98it/s]
{'loss': 0.1108, 'grad_norm': 0.026291929185390472, 'learning_rate': 2.054794520547945e-06, 'epoch': 2.88}
99%|█████████▊| 18500/18750 [1:03:45<00:50, 4.98it/s]
{'loss': 0.102, 'grad_norm': 0.02042091079056263, 'learning_rate': 6.849315068493151e-07, 'epoch': 2.96}
100%|██████████| 18750/18750 [1:04:36<00:00, 4.84it/s]
{'train_runtime': 3876.8226, 'train_samples_per_second': 19.346, 'train_steps_per_second': 4.836, 'train_loss': 0.2728967039489746, 'epoch': 3.0}
TrainOutput(global_step=18750, training_loss=0.2728967039489746, metrics={'train_runtime': 3876.8226, 'train_samples_per_second': 19.346, 'train_steps_per_second': 4.836, 'train_loss': 0.2728967039489746, 'epoch': 3.0})
trainer.evaluate()
100%|██████████| 6250/6250 [06:44<00:00, 15.44it/s]
{'eval_loss': 0.4087878167629242, 'eval_runtime': 404.6947, 'eval_samples_per_second': 61.775, 'eval_steps_per_second': 15.444, 'epoch': 3.0}