This is basic example to explain concepts.

In [1]:
from transformers import BertModel, BertTokenizer
import torch
/home/whiskey/miniconda3/envs/nlp/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

lets download and load the bert-base-uncased model.

In [2]:
model = BertModel.from_pretrained("bert-base-uncased",  output_hidden_states=True)

Download and load the tokenizer which was used in pre-train bert-base-uncased model.

In [3]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

Preprocessing the input

In [4]:
example = "i love kathmandu"
In [5]:
tokens = tokenizer.tokenize(example)
tokens
Out[5]:
['i', 'love', 'kathmandu']

Lets add additional tokens, [CLS] [SEP]

In [6]:
tokens = ["[CLS]"] + tokens + ["[SEP]"]
tokens
Out[6]:
['[CLS]', 'i', 'love', 'kathmandu', '[SEP]']

We need to add padding tokens because we have token length is 7 but actual token is 5.

In [7]:
tokens = tokens + ["[PAD]", "[PAD]"]
tokens
Out[7]:
['[CLS]', 'i', 'love', 'kathmandu', '[SEP]', '[PAD]', '[PAD]']

Now create a attention mask for actual token (1) and padding (0).

In [8]:
attn_mask = [0 if token == "[PAD]" else 1 for token in tokens]
attn_mask
Out[8]:
[1, 1, 1, 1, 1, 0, 0]

lets convert tokens to their IDs.

In [9]:
token_ids = tokenizer.convert_tokens_to_ids(tokens)
token_ids
Out[9]:
[101, 1045, 2293, 28045, 102, 0, 0]
In [10]:
token_ids = torch.tensor(token_ids).unsqueeze(0)
attn_mask = torch.tensor(attn_mask).unsqueeze(0)
print(f"token IDs: {token_ids.shape}, attention mask: {attn_mask.shape}")
token IDs: torch.Size([1, 7]), attention mask: torch.Size([1, 7])

The Embedding

hidden_rep: he representation of all the tokens obtained from the final encoder (12)

pooler_output: representation of the [CLS] token.

In [11]:
output = model(token_ids, attention_mask= attn_mask)
In [12]:
last_hidden_state= output.last_hidden_state
pooler_output = output.pooler_output
hidden_states = output.hidden_states
In [13]:
print(f"last hidden:{last_hidden_state.shape} hidden layers:{len(hidden_states)}, CLS: {pooler_output.shape}")
last hidden:torch.Size([1, 7, 768]) hidden layers:13, CLS: torch.Size([1, 768])

The size [1,7,768] indicates [batch_size, sequence_length, hidden_size]. The size [1,768] indicates [batch_size, hidden_size].

In [14]:
print(f"last layer, representation of 'love'\n{last_hidden_state[0][1]}")
last layer, representation of 'love'
tensor([ 3.1862e-01,  5.0867e-01,  1.4940e-01, -2.8184e-01,  1.0760e-02,
         4.1292e-01, -1.1322e-01,  1.3862e+00, -2.2813e-01, -4.7011e-01,
        -2.5818e-01, -6.8372e-01,  1.7890e-01,  5.9333e-01,  2.3110e-01,
         6.0013e-01,  6.1079e-01,  1.1835e-01,  3.0312e-01,  4.7956e-01,
        -9.0474e-02, -4.7122e-01, -1.1157e+00,  2.8013e-01,  5.1831e-01,
         3.5709e-01, -4.4087e-01, -1.6675e-01,  2.4361e-01, -1.4690e-01,
         2.2097e-01, -1.4364e-01, -1.5772e-01,  2.3723e-01, -7.6925e-01,
        -1.7788e-01, -8.3212e-02, -5.7498e-02, -6.6200e-01, -8.1028e-02,
         6.1719e-02, -7.3170e-01, -1.1594e-01, -6.9784e-01, -2.7962e-02,
        -9.3233e-01, -7.4641e-02,  3.9156e-02,  6.4064e-01, -1.0957e+00,
         1.2172e-02,  1.2453e-01,  2.2704e-01,  3.5148e-01, -2.6128e-01,
         1.3611e+00, -3.4164e-02, -6.2567e-01, -1.7320e-01,  4.1086e-01,
        -2.2863e-01,  2.2597e-01,  8.0072e-01, -8.1334e-01, -4.6740e-01,
         8.9285e-01, -4.3471e-01, -1.6304e-01, -1.6450e-01, -9.1366e-01,
         9.6539e-01,  4.1627e-02,  1.2581e-01, -1.0572e-01, -3.9535e-01,
        -1.5666e-01, -1.0576e+00,  8.1594e-01,  1.3039e-01, -1.2229e-01,
        -1.0086e-01,  1.4630e+00, -1.2300e+00,  5.4424e-01, -5.4054e-01,
        -2.3735e-02, -2.3277e-01,  4.4434e-01, -4.1168e-01, -5.0603e-01,
        -1.7757e-01, -4.4975e-01, -8.5226e-02,  6.0133e-01, -2.1632e-01,
        -1.1666e+00, -6.5839e-01, -2.4942e-01, -9.7081e-02, -1.9084e-01,
        -2.3861e-01, -2.9358e-01,  7.2426e-02, -1.2594e-01, -9.6252e-01,
         7.0934e-02,  2.4492e-02, -7.5538e-02,  1.1916e+00, -8.4068e-01,
         1.6002e-01, -6.8600e-01, -2.4405e-01, -1.2449e-01, -1.6900e-02,
         5.4088e-01,  3.4330e-01, -2.4779e-01, -2.6710e-01, -5.1212e-02,
        -1.4777e-01,  1.5025e-01,  6.1416e-01,  4.2979e-01,  3.6097e-01,
        -7.3596e-01, -1.7493e-01,  4.3301e-01, -4.4997e-01, -7.4958e-01,
        -3.9143e-01,  4.8851e-01, -2.0442e-01, -9.0376e-01,  5.4466e-01,
         1.3276e-01,  3.0045e-01,  3.0290e-01,  1.9902e-01, -6.8191e-01,
         2.9952e-01,  9.4701e-02,  2.6000e-01,  3.8404e-01,  4.4243e-01,
         6.2018e-02, -4.0736e-01,  5.5149e-01,  8.0305e-01, -6.0269e-02,
         6.1119e-01, -7.7568e-02, -2.6159e-01, -2.1779e-01, -6.7564e-01,
        -4.5183e-01, -5.3268e-01,  6.3423e-01,  7.3586e-01,  8.1137e-02,
        -7.7695e-02, -2.4448e-01, -1.6764e-01, -1.6911e-01, -2.7556e-01,
        -3.1078e-01,  3.0063e-01,  4.0683e-01,  2.1039e-01, -6.1488e-01,
         3.6530e-02, -7.5048e-01,  1.9977e-01,  7.7794e-01,  7.7518e-01,
        -6.7273e-01,  4.3956e-01,  7.7076e-01,  1.5638e-01,  3.8009e-01,
        -3.4522e-01,  4.0633e-01,  5.0821e-02, -7.9294e-02, -5.8968e-01,
        -1.6253e-01,  4.3255e-01, -4.1972e-01,  3.7828e-01,  9.3417e-01,
         1.9317e-01,  4.2900e-01,  4.6161e-02,  5.3920e-01,  1.2338e-01,
        -4.3318e-01,  2.2601e-01, -3.0456e-01,  3.3208e-01,  8.1517e-01,
         4.9641e-01, -2.8837e-01,  1.9259e-01, -1.9196e-01, -2.8835e-01,
         8.2389e-01,  2.1382e-01, -5.0230e-01, -2.8033e-01, -1.4386e-01,
        -2.4066e-01,  4.5183e-01,  2.9018e-01, -3.3238e-01,  7.6613e-01,
        -2.2858e-01,  2.1562e-01, -1.5033e-02, -4.3519e-02,  4.6809e-01,
         3.5235e-02, -5.7315e-01, -1.2247e+00,  8.7446e-01, -5.3191e-01,
         3.9934e-01, -1.5189e-01,  1.5814e-01, -5.2017e-01,  2.7776e-02,
        -1.0298e+00, -1.1225e+00,  2.3432e-01,  5.0779e-01,  7.4528e-01,
         9.1145e-01, -7.1726e-01, -4.3346e-01, -4.0445e-01,  1.6764e-01,
         7.5148e-01,  4.5240e-01,  3.9794e-01,  9.7296e-01, -9.3493e-02,
         5.5920e-02,  2.2008e-01, -4.7061e-01, -5.5885e-01,  2.7293e-01,
        -2.2021e-01, -1.1468e+00, -6.9024e-01, -2.8018e-01, -5.0438e-01,
        -1.0352e-01, -1.6996e-01,  7.0182e-01, -1.7532e-01,  4.2186e-01,
        -4.0852e-01, -1.6920e-01,  1.0813e+00, -2.5174e-01, -1.1453e-01,
        -4.3818e-01, -9.9075e-01, -1.3550e-01,  6.5167e-02,  7.6289e-01,
        -6.1525e-01, -2.2155e-01,  9.1627e-02,  5.9597e-01, -3.9134e-01,
        -4.5271e-01,  9.7105e-01,  1.5274e-02, -1.6970e-01, -7.3324e-01,
         1.1416e-01,  8.7804e-01, -1.1139e-01,  4.6074e-01,  4.5102e-01,
        -4.3742e-01, -2.0925e-01, -2.7499e-01, -3.3725e-01, -5.6249e-01,
        -2.4529e-01,  3.3172e-01,  5.0531e-02, -4.5348e-01,  4.3045e-01,
        -9.9580e-02,  4.9946e-01,  2.0121e-01,  1.4416e-01, -2.9800e-01,
        -1.1358e+00, -4.7961e-02, -3.1552e-01,  8.3810e-01,  4.7140e-01,
        -1.1135e-01, -6.4007e-03, -9.7005e-02, -4.6666e+00,  8.6160e-02,
        -1.0469e-01, -4.7429e-01,  7.5138e-01, -5.5269e-01, -2.6248e-01,
         3.0869e-01, -4.2183e-01, -5.9321e-01,  4.2071e-01, -6.1134e-01,
         1.3827e-01, -2.2363e-01,  4.1900e-01, -1.0203e+00,  3.2149e-01,
        -5.9846e-01, -7.0980e-01,  1.3477e+00, -3.7652e-01,  2.1615e-01,
         2.8045e-01, -9.4238e-01, -2.8517e-01,  5.9093e-01, -5.8326e-02,
         4.0520e-01, -8.4271e-01,  6.3854e-01, -6.4629e-01,  3.2763e-02,
         3.2824e-01, -5.2670e-02,  3.8074e-01,  2.7017e-01, -3.2644e-01,
         1.6558e-01,  9.8959e-02, -9.4526e-01, -1.1275e-02, -5.0865e-01,
        -1.6080e-01, -2.5150e-01,  3.0774e-01,  2.4781e-01, -6.9735e-01,
        -6.5072e-01, -3.5622e-01,  2.3935e-01, -5.4576e-01, -4.7295e-01,
         8.8449e-01, -2.4570e-01,  8.7697e-02,  6.0601e-01,  3.5812e-01,
         1.1248e-01,  2.6865e-02, -2.0416e-01,  5.7046e-01,  1.1026e-01,
        -9.1923e-01,  3.8365e-02,  7.4815e-01, -6.5813e-02, -6.3228e-01,
        -1.2311e-01,  8.0985e-02,  9.5490e-01,  2.1679e-01, -3.1574e-01,
         1.9691e-01, -1.0650e+00, -5.7394e-01, -3.4685e-02, -7.5367e-02,
        -5.7165e-01,  2.4795e-01, -4.0776e-01, -6.5695e-01, -2.1715e-01,
         1.6921e-01, -3.1418e-01, -2.6184e-01, -6.7987e-01,  1.0747e-01,
         4.5096e-02,  5.1449e-01,  2.7659e-01,  1.5968e-01, -1.7073e-03,
         5.6223e-01,  1.2407e-01,  2.6819e-01,  1.7530e-01,  8.8831e-01,
        -6.7690e-01,  3.0574e-01, -8.6536e-02,  5.6855e-01, -2.1340e-01,
        -7.0454e-02,  4.3164e-01,  1.9373e-01, -2.2988e-01, -1.1271e+00,
         3.8047e-01,  1.9446e-01, -1.4082e-01, -6.3198e-02, -2.7196e-01,
         8.4955e-02, -5.7896e-01, -4.0740e-01, -6.7326e-01, -4.5778e-02,
         5.8528e-01, -2.4662e-01, -1.8599e-01, -1.1088e-01,  7.8608e-01,
        -3.2339e-01, -4.9305e-01, -4.5717e-01,  3.3706e-01,  4.5208e-01,
        -6.0603e-01, -7.2381e-01, -3.3287e-01,  1.8222e-01,  1.8102e-01,
         1.6536e-01,  1.3156e-01, -1.2076e-01,  1.9233e-01, -7.9349e-01,
        -4.6311e-01,  1.6354e-01, -1.0870e-01,  5.8225e-02,  3.6691e-01,
        -4.2800e-01,  5.8694e-01, -1.8585e-01,  8.8113e-01, -7.4489e-01,
        -6.3407e-01, -3.9415e-01,  3.3960e-01,  3.1307e-01, -1.8274e-01,
         7.1520e-02, -2.7720e-01,  5.3749e-01, -1.1966e-01,  3.8869e-01,
        -1.1869e+00, -3.3095e-02, -1.9177e-01, -4.8460e-01,  1.8028e-01,
        -1.2386e-01,  5.0279e-01,  8.8191e-01,  6.9002e-01,  5.0999e-01,
         3.2240e-01, -2.6183e-01, -5.8290e-01, -3.8619e-01, -9.2705e-02,
         2.3294e-01, -5.7509e-01, -2.6123e-01,  3.5139e-01, -4.1708e-01,
        -4.2239e-01, -2.3083e-01,  8.1954e-03, -4.9662e-01, -4.7840e-01,
         3.8759e-02,  5.5232e-01,  5.2939e-01,  4.6774e-01, -3.8046e-01,
        -6.9474e-02,  4.2584e-01,  8.8188e-01, -5.4262e-01, -9.4521e-03,
        -2.4057e-01, -5.1243e-01, -9.7442e-01, -9.7313e-01,  9.6955e-01,
        -8.8934e-01,  2.8327e-01,  7.2861e-01,  5.0486e-01,  2.0997e-01,
        -4.0938e-01,  1.1308e-01, -4.7488e-01, -1.7587e-01, -7.8031e-02,
        -3.8377e-02, -3.4457e-01,  5.3594e-01, -3.1236e-01, -1.3837e-01,
        -2.6020e-01,  3.5784e-01,  1.8997e-01, -8.5632e-01,  6.9272e-01,
        -6.6168e-02, -2.8632e-01,  1.8989e-02, -1.9501e-01,  2.7093e-01,
         2.5045e-01, -1.7063e-02, -1.4374e-01,  9.4995e-01, -3.9204e-01,
        -2.9093e-01, -5.4969e-01, -3.4450e-01,  5.8438e-01, -1.5057e-01,
         7.9860e-01,  5.3157e-01,  6.3988e-01, -4.0382e-01,  3.7626e-03,
        -3.5091e-01,  1.4385e-01, -6.3540e-02,  2.1004e-01, -6.5894e-01,
         3.1622e-01,  6.5557e-02, -2.5855e-01, -7.5161e-01, -5.2992e-01,
         1.9083e-01, -2.0577e-01,  6.7220e-01, -4.1980e-01, -9.5139e-02,
        -5.5717e-03, -5.1454e-01,  1.3873e-01, -1.9702e-01, -7.2888e-01,
         2.6505e-01,  9.2675e-03,  3.8930e-01, -1.3924e-01,  1.4359e-01,
         8.9520e-02,  5.1103e-01,  2.1713e-01,  7.2420e-01,  9.3646e-01,
         6.6479e-01, -3.2557e-01,  1.0497e+00, -8.8188e-02,  1.6001e-01,
         4.2804e-01, -5.9942e-01, -1.1932e-01,  2.2498e-01,  2.8211e-04,
        -1.8749e-04,  2.2671e-01, -2.8624e-01, -6.0135e-02,  6.0868e-01,
         3.7792e-02, -2.4935e-01,  9.5642e-01, -3.7062e-01, -4.2931e-01,
         4.2356e-01,  2.6393e-01,  7.7339e-01, -4.0946e-01,  4.6467e-01,
        -1.3405e-01,  3.1842e-01,  1.4601e-01,  3.4192e-01, -3.8468e-01,
         3.4263e-01, -3.5902e-01,  5.7246e-01, -1.0331e-01,  2.7986e-01,
         7.5572e-01,  3.1356e-01, -6.7991e-01,  3.3968e-01, -2.7787e-03,
        -3.2683e-01, -1.4089e-01, -3.4801e-01, -1.7667e-01, -5.0184e-02,
         8.7150e-01,  6.1842e-01,  9.5388e-02, -4.5136e-01,  8.1731e-01,
        -5.1404e-01, -3.3745e-01,  3.7875e-02, -3.3433e-01, -4.3818e-01,
         7.9624e-01,  3.9827e-01, -9.9076e-02,  2.7977e-01, -8.6621e-01,
         4.1664e-01,  4.5073e-01,  7.8351e-01, -8.4463e-02, -3.5700e-01,
        -1.0421e-01,  8.3446e-01, -1.4272e-01,  3.2673e-01, -1.4472e-01,
        -9.9517e-01, -1.1686e-01,  7.0524e-01,  6.5207e-02,  6.3838e-02,
         8.0976e-01, -2.6855e-02, -4.8215e-01,  1.4215e-02, -1.7119e-01,
        -3.8744e-01,  9.5679e-01, -3.2997e-01, -2.6818e-01,  4.7255e-01,
        -2.3883e-01, -8.2873e-01, -2.6071e-02,  6.4115e-01,  5.3056e-01,
         5.1954e-01,  6.1046e-02, -7.2385e-01, -3.0921e-02,  5.1194e-02,
         8.0109e-01, -1.2863e-01, -3.4957e-02,  3.9325e-01,  3.5836e-01,
        -5.4677e-02, -2.6197e-01, -5.4597e-01, -2.4615e-01, -1.9462e-01,
         2.7137e-01,  8.7353e-01, -2.9035e-01,  7.8191e-01,  4.2285e-02,
        -5.7310e-01, -3.7404e-02, -1.7615e-01,  2.8678e-01,  5.3713e-01,
        -7.9005e-01, -4.1035e-01, -3.9232e-01, -1.0510e+00, -3.2326e-01,
        -1.0810e+00,  9.1602e-01,  3.6374e-01,  1.3233e+00, -1.1201e-01,
         3.5219e-01, -5.5904e-01,  2.5476e-01, -3.9654e-01,  1.2394e-01,
         3.1380e-01,  1.5336e-01,  8.5311e-02,  3.8395e-01,  5.4818e-01,
        -3.2665e-01,  2.4916e-01, -3.1327e-01,  6.4903e-01, -4.6011e-01,
         5.3281e-01,  2.8596e-02, -9.2876e-01,  8.4156e-01,  1.1013e+00,
        -4.7625e-01,  3.1146e-01, -6.6659e-02,  1.1244e-01,  4.9979e-01,
        -1.5115e-01, -1.0703e-01, -7.9332e-01,  1.7420e-01, -4.4306e-01,
         1.9569e-01,  6.6357e-01, -1.4642e-01, -2.5401e-01, -2.2001e-01,
         9.5454e-01,  1.6118e-01, -4.9574e-01, -5.2578e-01,  1.4362e-01,
         7.2309e-02,  1.8952e-01,  4.8230e-01, -8.4709e-01,  2.3703e-01,
        -3.2139e-01, -2.9955e-02, -4.5533e-01,  3.6851e-01,  3.6113e-01,
        -8.6004e-02, -6.7998e-01, -4.8985e-01,  3.6725e-01,  1.3295e-02,
        -6.1268e-01,  6.8871e-01, -5.3720e-01,  7.2303e-01, -5.6271e-01,
         6.4017e-01, -1.4560e-01,  2.0167e-01, -2.3267e-02, -5.5859e-04,
        -4.8885e-01,  5.9847e-01,  3.0279e-01], grad_fn=<SelectBackward0>)

pooler_output is the entire input sequence generated by the BERT model. Specifically, it's the output of the "pooler" layer, which is typically a fully connected layer with a tanh activation function applied to it. In general, it holds the aggregate representation of the sentence, so we can use cls_head as the representation of the sentence i love kathmandu.

In [15]:
print(f"pooler output.i.e representation of 'I love kathmandu'. \n{pooler_output}")
pooler output.i.e representation of 'I love kathmandu'. 
tensor([[-0.8357, -0.1822,  0.4747,  0.6436, -0.2796, -0.1092,  0.8457,  0.1583,
          0.1907, -0.9996,  0.1853,  0.1747,  0.9669, -0.1852,  0.9110, -0.4437,
         -0.0987, -0.4950,  0.3837, -0.7624,  0.5009,  0.7871,  0.5610,  0.1600,
          0.3615,  0.3753, -0.4926,  0.8958,  0.9266,  0.6399, -0.6181,  0.0950,
         -0.9717, -0.1272,  0.3411, -0.9715,  0.1475, -0.7362,  0.0306,  0.0822,
         -0.8589,  0.2027,  0.9935, -0.0741, -0.0652, -0.2863, -0.9996,  0.1983,
         -0.8285, -0.3280, -0.2946, -0.5097,  0.1277,  0.3611,  0.3224,  0.2838,
         -0.1187,  0.1056, -0.0781, -0.4620, -0.4998,  0.2000,  0.0673, -0.8527,
         -0.2595, -0.5242, -0.0076, -0.1295, -0.0038, -0.0472,  0.7797,  0.1425,
          0.3412, -0.7356, -0.3774,  0.1124, -0.3790,  1.0000, -0.3603, -0.9575,
         -0.4016, -0.3286,  0.2985,  0.5956, -0.4848, -1.0000,  0.2178, -0.0367,
         -0.9782,  0.1734,  0.2324, -0.1238, -0.6653,  0.3045, -0.0755, -0.0893,
         -0.1990,  0.3857, -0.1061,  0.0060,  0.0086, -0.1714,  0.1197, -0.2372,
          0.1262, -0.2163, -0.4496,  0.0401, -0.2700,  0.5711,  0.2220, -0.1947,
          0.2412, -0.9343,  0.5700, -0.1493, -0.9637, -0.3257, -0.9756,  0.6032,
          0.1781, -0.0780,  0.9453,  0.5571,  0.1407,  0.0362,  0.3646, -1.0000,
         -0.3394,  0.0233,  0.2274, -0.0698, -0.9545, -0.9000,  0.4713,  0.9317,
          0.0066,  0.9901, -0.1645,  0.8779,  0.1500, -0.0391, -0.2480, -0.2947,
          0.1984,  0.2943, -0.6335,  0.1586,  0.1640,  0.0066,  0.1490, -0.2045,
          0.3069, -0.8929, -0.3562,  0.9205,  0.3154,  0.3515,  0.6472, -0.1485,
         -0.2634,  0.7564,  0.1758,  0.2387,  0.0654,  0.2956, -0.2752,  0.4048,
         -0.7578,  0.1234,  0.3163, -0.1411,  0.5326, -0.9559, -0.2317,  0.3995,
          0.9758,  0.6521,  0.1369,  0.0850, -0.1517,  0.2689, -0.9010,  0.9546,
         -0.1763,  0.2290,  0.5736, -0.2094, -0.8271, -0.4318,  0.7565, -0.1056,
         -0.8077,  0.0939, -0.4023, -0.3229,  0.3441,  0.4487, -0.2178, -0.3261,
          0.0241,  0.8773,  0.9459,  0.8034, -0.5759,  0.4446, -0.8731, -0.3450,
          0.1232,  0.1715,  0.1602,  0.9859,  0.2193, -0.1262, -0.9025, -0.9735,
         -0.0249, -0.8609,  0.0312, -0.5592,  0.1535,  0.6513, -0.1058,  0.3437,
         -0.9785, -0.7176,  0.2701, -0.1294,  0.3077, -0.1899,  0.2584, -0.2188,
         -0.4242,  0.7890,  0.7613,  0.6384, -0.6211,  0.7924, -0.1569,  0.8204,
         -0.4833,  0.9555, -0.1928,  0.4865, -0.8961,  0.5040, -0.8705,  0.2597,
         -0.0485, -0.6772, -0.2114,  0.3194,  0.2365,  0.8293, -0.4225,  0.9941,
         -0.2830, -0.9169,  0.4675, -0.0909, -0.9690, -0.2601,  0.1455, -0.6810,
         -0.2796, -0.2128, -0.9236,  0.8716,  0.0920,  0.9778,  0.1281, -0.8849,
         -0.2400, -0.8457, -0.2042, -0.0414,  0.6000, -0.2036, -0.9297,  0.3387,
          0.4034,  0.3005,  0.5991,  0.9923,  0.9932,  0.9581,  0.8348,  0.8566,
         -0.5704,  0.0602,  0.9998, -0.3528, -0.9998, -0.9052, -0.4421,  0.2460,
         -1.0000, -0.0200,  0.0821, -0.8786, -0.4253,  0.9661,  0.9825, -1.0000,
          0.8029,  0.9152, -0.4052,  0.1657, -0.1109,  0.9520,  0.2111,  0.2441,
         -0.1008,  0.2060,  0.2395, -0.7713,  0.4306,  0.4774,  0.1620,  0.1018,
         -0.5417, -0.9116, -0.1347, -0.0955, -0.2447, -0.9207, -0.0161, -0.2207,
          0.5460,  0.0637,  0.1019, -0.7694,  0.1332, -0.7393,  0.3748,  0.4445,
         -0.9078, -0.5733, -0.1475, -0.4459,  0.3976, -0.9046,  0.9533, -0.1851,
         -0.0272,  1.0000, -0.2259, -0.8401,  0.1765,  0.1236,  0.0120,  1.0000,
          0.3590, -0.9559, -0.2993,  0.1456, -0.2983, -0.2820,  0.9947, -0.1810,
          0.3798,  0.5133,  0.9366, -0.9749, -0.1349, -0.8788, -0.9393,  0.9284,
          0.8910, -0.0110, -0.5099,  0.0026,  0.1924,  0.1650, -0.9494,  0.5135,
          0.3702, -0.1057,  0.8645, -0.8382, -0.3240,  0.3474,  0.1986,  0.3930,
         -0.3834,  0.3965, -0.1664,  0.1177, -0.1960,  0.0754, -0.9521, -0.0400,
          1.0000,  0.1457, -0.4794, -0.1106, -0.0228, -0.3202,  0.2443,  0.2820,
         -0.2128, -0.7773, -0.1090, -0.9080, -0.9657,  0.6889,  0.0778, -0.1594,
          0.9968,  0.2527,  0.1082, -0.1482,  0.0633, -0.0750,  0.4864, -0.5328,
          0.9487, -0.1227,  0.3059,  0.7986,  0.3730, -0.2166, -0.5505,  0.0077,
         -0.8476, -0.0030, -0.9217,  0.9377, -0.3854,  0.2321,  0.0428, -0.2700,
          1.0000,  0.4203,  0.4848, -0.5653,  0.8460, -0.5559, -0.6931, -0.2780,
          0.0789,  0.4914, -0.1549,  0.1516, -0.9496, -0.3756, -0.1492, -0.9686,
         -0.9818,  0.4957,  0.6610,  0.0145, -0.0433, -0.5876, -0.4846,  0.1688,
         -0.1218, -0.9168,  0.5548, -0.1388,  0.3953, -0.1332,  0.3244, -0.5027,
          0.7807,  0.6314,  0.2511,  0.0779, -0.7110,  0.7183, -0.7626,  0.3074,
         -0.0645,  1.0000, -0.3013, -0.2689,  0.6639,  0.6914,  0.0063,  0.1015,
         -0.3817,  0.0796,  0.4610,  0.4566, -0.8046, -0.2481,  0.4441, -0.5994,
         -0.5116,  0.6993, -0.1261,  0.0455,  0.1239, -0.0202,  0.9981, -0.1739,
         -0.0759, -0.3944,  0.0818, -0.2071, -0.5883,  0.9999,  0.2985, -0.1575,
         -0.9798,  0.3835, -0.8701,  0.9841,  0.7024, -0.7812,  0.3578,  0.2949,
         -0.0303,  0.7199, -0.1156, -0.1210,  0.0991,  0.0938,  0.9400, -0.3120,
         -0.9297, -0.4143,  0.2790, -0.9366,  0.6300, -0.4139, -0.0782, -0.1882,
          0.3217,  0.8082, -0.0626, -0.9611, -0.1016, -0.0329,  0.9454,  0.0739,
         -0.3345, -0.8701, -0.4491, -0.2535,  0.4966, -0.9060,  0.9429, -0.9734,
          0.3720,  0.9999,  0.1690, -0.7277,  0.1247, -0.2867,  0.1213,  0.3556,
          0.4100, -0.9293, -0.1903, -0.1115,  0.1854, -0.1383,  0.3850,  0.5818,
          0.0977, -0.2679, -0.4341, -0.0303,  0.3145,  0.6554, -0.2225, -0.0471,
          0.0594, -0.0933, -0.8966, -0.1226, -0.1466, -0.9069,  0.5453, -1.0000,
         -0.3925, -0.5733, -0.1923,  0.7422, -0.0759, -0.1191, -0.6316,  0.4116,
          0.8261,  0.6735, -0.1739,  0.0073, -0.6623,  0.1041, -0.0090,  0.2303,
          0.1850,  0.6571, -0.1187,  1.0000,  0.0425, -0.3779, -0.9520,  0.1834,
         -0.1619,  0.9974, -0.8624, -0.8984,  0.1848, -0.2696, -0.7284,  0.1267,
         -0.0695, -0.5209,  0.1273,  0.9360,  0.8511, -0.3208,  0.2976, -0.2469,
         -0.3425,  0.0311, -0.5138,  0.9711,  0.1189,  0.8520,  0.4906,  0.1186,
          0.9322,  0.1485,  0.6336,  0.0614,  0.9999,  0.1679, -0.8800,  0.4211,
         -0.9740, -0.1092, -0.9395,  0.1832,  0.0115,  0.8338, -0.1513,  0.9407,
          0.5031,  0.1108,  0.0315,  0.6071,  0.2934, -0.8703, -0.9707, -0.9766,
          0.2818, -0.3850, -0.0065,  0.2401,  0.1317,  0.2421,  0.2649, -0.9998,
          0.8918,  0.3107, -0.4011,  0.9412,  0.1921,  0.1823,  0.1558, -0.9718,
         -0.9556, -0.2409, -0.2092,  0.7035,  0.5164,  0.7444,  0.3120, -0.4264,
         -0.0232,  0.5036, -0.2250, -0.9836,  0.2923,  0.2831, -0.9512,  0.9228,
         -0.5638, -0.1392,  0.5769,  0.3294,  0.9146,  0.6901,  0.4188,  0.1131,
          0.3818,  0.8430,  0.9227,  0.9757,  0.3712,  0.7213,  0.2551,  0.2250,
          0.2202, -0.8909,  0.0254, -0.1250,  0.0541,  0.1357, -0.1238, -0.9539,
          0.2963, -0.0436,  0.2701, -0.2931,  0.1724, -0.3304, -0.1554, -0.6303,
         -0.3313,  0.4272,  0.2015,  0.8716,  0.0547, -0.0267, -0.5496, -0.0929,
          0.4459, -0.8612,  0.8963,  0.0137,  0.4834, -0.4487, -0.0741,  0.3976,
         -0.3935, -0.2915, -0.1870, -0.6522,  0.8120,  0.0012, -0.3690, -0.3210,
          0.5532,  0.2246,  0.8582,  0.3838,  0.1747,  0.0179, -0.1084,  0.1621,
         -0.0988, -0.9998,  0.3584,  0.2679, -0.3447,  0.1248, -0.5345, -0.0385,
         -0.9605, -0.0186, -0.2772, -0.4289, -0.4703, -0.3578,  0.3284,  0.2124,
         -0.0556,  0.8367,  0.1933,  0.6517,  0.3827,  0.4744, -0.5819,  0.8273]],
       grad_fn=<TanhBackward0>)

These embedings are taken from topmost encoder layer of BERT, i.e. 12 layer. Can we extract the embeddings from all the encoding layers? YES

In [16]:
output.hidden_states
Out[16]:
(tensor([[[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
            3.8253e-02,  1.6400e-01],
          [-3.4027e-04,  5.3974e-01, -2.8805e-01,  ...,  7.5731e-01,
            8.9008e-01,  1.6575e-01],
          [ 1.1558e+00,  8.5331e-02, -1.1208e-01,  ...,  4.3965e-01,
            8.5903e-01, -3.2685e-01],
          ...,
          [-3.6430e-01, -1.6172e-01,  9.0174e-02,  ..., -1.7849e-01,
            1.2818e-01, -4.5116e-02],
          [ 1.6776e-01, -8.9038e-01, -3.1798e-01,  ...,  3.1737e-02,
            6.4863e-02,  1.8418e-01],
          [ 3.5675e-01, -8.7076e-01, -3.7621e-01,  ..., -7.9391e-02,
           -1.0115e-01,  1.9312e-01]]], grad_fn=<NativeLayerNormBackward0>),
 tensor([[[ 0.1616,  0.0459, -0.0789,  ..., -0.0204,  0.1398,  0.0676],
          [ 0.7336,  0.9462, -0.2094,  ...,  0.4195,  0.7503, -0.0092],
          [ 1.4193,  0.6091,  0.3613,  ...,  0.4949,  0.7861, -0.0039],
          ...,
          [-0.0169,  0.2284,  0.0774,  ..., -0.2815,  0.5997,  0.1019],
          [-0.1388, -0.8165,  0.3362,  ...,  0.3381,  0.3591, -0.3433],
          [ 0.0269, -0.7895,  0.2525,  ...,  0.2607,  0.2141, -0.3796]]],
        grad_fn=<NativeLayerNormBackward0>),
 tensor([[[ 0.0480, -0.1574, -0.1225,  ...,  0.0428,  0.0960,  0.0974],
          [ 0.6310,  1.0092,  0.3655,  ...,  0.4078,  0.2656, -0.1203],
          [ 2.0291,  0.8538,  0.9834,  ...,  0.9003,  0.4365, -0.1713],
          ...,
          [-0.1134,  0.0539,  0.2443,  ..., -0.0801,  0.4949,  0.0048],
          [-0.1568, -0.3104,  0.2979,  ...,  0.8824, -0.0709, -0.2684],
          [-0.1532, -0.3178,  0.2342,  ...,  0.7318, -0.1784, -0.3049]]],
        grad_fn=<NativeLayerNormBackward0>),
 tensor([[[ 0.0042, -0.2541, -0.0267,  ...,  0.2045,  0.0852,  0.2579],
          [ 0.7227,  0.6391,  0.6960,  ...,  0.6194, -0.0594, -0.0120],
          [ 2.5173,  0.1945,  1.1035,  ...,  0.5088,  0.2563, -0.2985],
          ...,
          [-0.0731, -0.0839,  0.1313,  ...,  0.0441,  0.0764,  0.0078],
          [-0.3059, -0.3007,  0.6740,  ...,  0.8598,  0.0677, -0.4079],
          [-0.2730, -0.2796,  0.6555,  ...,  0.6897, -0.0085, -0.3978]]],
        grad_fn=<NativeLayerNormBackward0>),
 tensor([[[ 0.1170, -0.4117, -0.5989,  ...,  0.4627,  0.0961,  0.5678],
          [ 1.1077,  0.6526,  0.6694,  ...,  0.5790, -0.1841, -0.4510],
          [ 2.3473, -0.0616,  0.5843,  ...,  0.5617, -0.3385,  0.0799],
          ...,
          [-0.0380, -0.0509,  0.0145,  ...,  0.0044,  0.0548, -0.0264],
          [-0.2247, -0.6068,  0.4737,  ...,  0.7821, -0.1056, -0.4011],
          [-0.3516, -0.7095,  0.6233,  ...,  0.5555, -0.3046, -0.2072]]],
        grad_fn=<NativeLayerNormBackward0>),
 tensor([[[ 0.0552, -0.3304, -0.7737,  ...,  0.0372,  0.2087,  0.7120],
          [ 0.8668, -0.1016,  0.3839,  ...,  0.3220, -0.0550, -0.0919],
          [ 1.9763,  0.1825,  0.6982,  ...,  0.4447, -0.5356,  0.0066],
          ...,
          [-0.0213, -0.0395,  0.0210,  ...,  0.0227, -0.0033, -0.0352],
          [ 0.0780, -0.4096,  0.2830,  ...,  0.3291, -0.1267, -0.4650],
          [-0.1007, -0.5328,  0.5132,  ...,  0.1434, -0.1889, -0.2615]]],
        grad_fn=<NativeLayerNormBackward0>),
 tensor([[[ 0.0397, -0.5367, -0.6432,  ..., -0.1023,  0.2667,  0.6516],
          [ 0.7390,  0.5079,  0.3047,  ..., -0.2435,  0.6199, -0.4579],
          [ 1.7817,  0.3081,  0.9233,  ...,  0.0095, -0.3492, -0.0181],
          ...,
          [ 0.0113, -0.0383, -0.0030,  ...,  0.0082, -0.0225, -0.0510],
          [ 0.1866, -0.1390,  0.2746,  ...,  0.4320, -0.1927, -0.6178],
          [-0.1228, -0.3369,  0.4414,  ...,  0.0941, -0.3949, -0.2522]]],
        grad_fn=<NativeLayerNormBackward0>),
 tensor([[[-0.1469, -0.3336, -0.5482,  ...,  0.1259,  0.1292,  0.8209],
          [ 0.4745,  0.6440,  0.1888,  ..., -0.1612,  0.5875, -0.6186],
          [ 1.6383,  0.6971,  0.9196,  ...,  0.1550, -0.1455, -0.2603],
          ...,
          [-0.0220, -0.0541, -0.0105,  ..., -0.0130,  0.0223, -0.0645],
          [ 0.0159, -0.0823,  0.2857,  ...,  0.5241, -0.0141, -0.5355],
          [-0.3188, -0.4194,  0.1176,  ...,  0.3143, -0.3524, -0.0331]]],
        grad_fn=<NativeLayerNormBackward0>),
 tensor([[[ 3.2213e-02,  4.8598e-02, -8.8377e-01,  ..., -6.5647e-01,
            2.1170e-01,  7.8293e-01],
          [ 2.2881e-01,  5.4502e-01, -2.6024e-01,  ..., -2.5965e-01,
            2.7225e-01,  1.2774e-01],
          [ 1.4240e+00,  5.4923e-01,  1.1938e+00,  ..., -5.5137e-01,
           -1.1682e-01, -1.0709e-01],
          ...,
          [-1.2949e-03, -4.6647e-02,  3.8625e-02,  ..., -4.0555e-02,
           -3.2852e-02, -9.3598e-02],
          [ 5.1433e-02,  1.1127e-01,  1.6794e-01,  ...,  3.7711e-01,
           -2.3262e-01, -4.7220e-01],
          [-2.8241e-01, -1.7677e-01, -1.0904e-01,  ...,  7.8062e-02,
           -6.4763e-01, -6.7554e-02]]], grad_fn=<NativeLayerNormBackward0>),
 tensor([[[ 0.0034,  0.6391, -0.8313,  ..., -0.4380, -0.1068,  0.2339],
          [ 0.3149,  0.9895,  0.0229,  ..., -0.3240,  0.2127,  0.2870],
          [ 1.1894,  0.7824,  1.1637,  ..., -0.4862,  0.0501, -0.2613],
          ...,
          [-0.0035, -0.0096,  0.0600,  ..., -0.0635, -0.0905, -0.0571],
          [ 0.1742,  0.5720, -0.0192,  ...,  0.5380, -0.3071, -0.5322],
          [-0.1722,  0.2455, -0.3569,  ...,  0.3873, -0.7113, -0.1223]]],
        grad_fn=<NativeLayerNormBackward0>),
 tensor([[[-0.1033,  0.6074, -0.6408,  ..., -0.4095, -0.7298,  0.0868],
          [ 0.4959,  0.3527,  0.4731,  ..., -0.2954,  0.0267,  0.0942],
          [ 0.8277,  0.8742,  1.0342,  ..., -0.6788,  0.2928, -0.4686],
          ...,
          [ 0.0081,  0.0195,  0.0409,  ...,  0.0983, -0.1053, -0.0237],
          [-0.0699,  0.6632,  0.1258,  ...,  0.4723, -0.3174, -0.2322],
          [-0.3441,  0.3234, -0.0623,  ...,  0.5664, -0.6259,  0.0931]]],
        grad_fn=<NativeLayerNormBackward0>),
 tensor([[[ 0.0429,  0.5304,  0.0735,  ..., -0.2732, -0.2552,  0.1550],
          [ 0.4396,  0.3368,  0.3334,  ..., -0.4953,  0.4322,  0.4333],
          [ 0.5574,  0.8172,  0.7482,  ..., -0.7575,  0.4598, -0.2706],
          ...,
          [ 0.0034, -0.0063, -0.0107,  ...,  0.0116, -0.0336,  0.0155],
          [ 0.2764,  0.5989,  0.5468,  ...,  0.1767, -0.3204, -0.3011],
          [ 0.1069,  0.1951,  0.5959,  ...,  0.1325, -0.5346, -0.3428]]],
        grad_fn=<NativeLayerNormBackward0>),
 tensor([[[ 0.0528,  0.3700,  0.0808,  ..., -0.2243,  0.0355,  0.0747],
          [ 0.3186,  0.5087,  0.1494,  ..., -0.4888,  0.5985,  0.3028],
          [ 0.5457,  0.8315,  0.8822,  ..., -0.6653,  0.3154, -0.3502],
          ...,
          [ 0.6960,  0.2177, -0.0880,  ..., -0.4141, -0.7644, -0.2708],
          [ 0.0960,  0.3780,  0.3556,  ...,  0.1857,  0.0318, -0.0084],
          [-0.0137,  0.2465,  0.3321,  ...,  0.2353, -0.0009, -0.0647]]],
        grad_fn=<NativeLayerNormBackward0>))
In [17]:
for i, layer_hidden_states in enumerate(hidden_states):
    print(f"Layer {i + 1} hidden states shape: {layer_hidden_states.shape}")
Layer 1 hidden states shape: torch.Size([1, 7, 768])
Layer 2 hidden states shape: torch.Size([1, 7, 768])
Layer 3 hidden states shape: torch.Size([1, 7, 768])
Layer 4 hidden states shape: torch.Size([1, 7, 768])
Layer 5 hidden states shape: torch.Size([1, 7, 768])
Layer 6 hidden states shape: torch.Size([1, 7, 768])
Layer 7 hidden states shape: torch.Size([1, 7, 768])
Layer 8 hidden states shape: torch.Size([1, 7, 768])
Layer 9 hidden states shape: torch.Size([1, 7, 768])
Layer 10 hidden states shape: torch.Size([1, 7, 768])
Layer 11 hidden states shape: torch.Size([1, 7, 768])
Layer 12 hidden states shape: torch.Size([1, 7, 768])
Layer 13 hidden states shape: torch.Size([1, 7, 768])

Fine Tuning the BERT for downstream task: Text classification

During fine tuning, we can adjust the weights of the model in the following two ways.

  • Update the weights of the pre-trained BERT model along with the classification layer.
  • Update only the weights of the calssification layer not the pertrained BERT model. The model as feature extractor.
In [18]:
from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import torch
import numpy as np
In [19]:
ds = load_dataset("imdb")
ds
Out[19]:
DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})
In [20]:
train = ds["train"]
test = ds["test"]
val = ds["unsupervised"]

print(f"Train: {len(train)}, Test: {len(test)}, Validation: {len(val)}")
Train: 25000, Test: 25000, Validation: 50000

Lets initialize the pre-trained BERT model.

In [21]:
classifier = BertForSequenceClassification.from_pretrained("bert-base-uncased")
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
In [22]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
In [23]:
input_ids = tokenizer.convert_tokens_to_ids(tokens[:-2])
input_ids
Out[23]:
[101, 1045, 2293, 28045, 102]

Segment IDs (token type IDs):

Suppose we have two sentences in the input. In that case, segment IDs are used to distinguish one sentence from the other. All the tokens from the first sentence will be apped to 0 and all the tokens from the second sentence will be mapped to 1. Since here we have only one sentence, all the tokens will be mapped to 0 as shown below.

In [24]:
token_type_ids = np.zeros(5)
token_type_ids
Out[24]:
array([0., 0., 0., 0., 0.])

attention mask: 1s for all tokens 0s for [PAD] tokens.

In [25]:
attention_mask = np.ones(5)
attention_mask
Out[25]:
array([1., 1., 1., 1., 1.])
In [26]:
tokenizer(example)
Out[26]:
{'input_ids': [101, 1045, 2293, 28045, 102], 'token_type_ids': [0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1]}

Lets pass four sentences, and set the maximum sequence laength to 5, padding True.

In [27]:
tokenizer(["I love Kathmandu", "Beautiful Nepal", "Mount Everest"], padding=True, max_length=5)
/home/whiskey/miniconda3/envs/nlp/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:2706: UserWarning: `max_length` is ignored when `padding`=`True` and there is no truncation strategy. To pad to max length, use `padding='max_length'`.
  warnings.warn(
Out[27]:
{'input_ids': [[101, 1045, 2293, 28045, 102], [101, 3376, 8222, 102, 0], [101, 4057, 23914, 102, 0]], 'token_type_ids': [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1], [1, 1, 1, 1, 0], [1, 1, 1, 1, 0]]}

With all this in mind, lets create a preprocessing function.

In [28]:
def preprocess(data):
    return tokenizer(data["text"], padding=True, truncation=True)

Lets preprocess the data using map for each dataset category.

In [29]:
train = train.map(preprocess, batched=True, batch_size=len(train))
test = test.map(preprocess, batched=True, batch_size=len(test))
val = val.map(preprocess, batched=True, batch_size=len(val))

set the data type to torch.

In [30]:
train.set_format("torch", columns=["input_ids", "attention_mask", "label"])
test.set_format("torch", columns=["input_ids", "attention_mask", "label"])
val.set_format("torch", columns=["input_ids", "attention_mask", "label"])

Training the model

In [36]:
batch_size = 4
epochs = 3

warmup_steps = 500
weight_decay = 0.01
In [37]:
training_args = TrainingArguments(
    num_train_epochs=epochs,
    output_dir=f"../../results",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    warmup_steps=warmup_steps,
    weight_decay=weight_decay,
    logging_dir=f"../../logs"
)
In [38]:
trainer = Trainer(
    model=classifier,
    args=training_args,
    train_dataset=train,
    eval_dataset=test
)
In [40]:
trainer.train()
  0%|          | 1/9375 [01:35<248:42:39, 95.52s/it]
  0%|          | 0/18750 [00:04<?, ?it/s]
  3%|▎         | 500/18750 [01:50<1:07:59,  4.47it/s]
{'loss': 0.5396, 'grad_norm': 57.49171447753906, 'learning_rate': 5e-05, 'epoch': 0.08}
  5%|▌         | 1000/18750 [03:40<1:03:58,  4.62it/s]
{'loss': 0.5416, 'grad_norm': 9.111489295959473, 'learning_rate': 4.863013698630137e-05, 'epoch': 0.16}
  8%|▊         | 1500/18750 [05:30<1:03:36,  4.52it/s]
{'loss': 0.4923, 'grad_norm': 0.3694137632846832, 'learning_rate': 4.726027397260274e-05, 'epoch': 0.24}
 11%|█         | 2000/18750 [07:16<56:13,  4.97it/s]  
{'loss': 0.4673, 'grad_norm': 25.703596115112305, 'learning_rate': 4.589041095890411e-05, 'epoch': 0.32}
 13%|█▎        | 2500/18750 [08:58<54:22,  4.98it/s]  
{'loss': 0.4554, 'grad_norm': 8.654304504394531, 'learning_rate': 4.452054794520548e-05, 'epoch': 0.4}
 16%|█▌        | 3000/18750 [10:39<52:57,  4.96it/s]  
{'loss': 0.4689, 'grad_norm': 17.73918342590332, 'learning_rate': 4.3150684931506855e-05, 'epoch': 0.48}
 19%|█▊        | 3500/18750 [12:21<51:02,  4.98it/s]  
{'loss': 0.4482, 'grad_norm': 0.19607552886009216, 'learning_rate': 4.1780821917808224e-05, 'epoch': 0.56}
 21%|██▏       | 4000/18750 [14:02<49:28,  4.97it/s]  
{'loss': 0.4312, 'grad_norm': 9.35656452178955, 'learning_rate': 4.041095890410959e-05, 'epoch': 0.64}
 24%|██▍       | 4500/18750 [15:44<47:57,  4.95it/s]  
{'loss': 0.4376, 'grad_norm': 1.4899685382843018, 'learning_rate': 3.904109589041096e-05, 'epoch': 0.72}
 27%|██▋       | 5000/18750 [17:25<46:03,  4.98it/s]  
{'loss': 0.4632, 'grad_norm': 0.2304084300994873, 'learning_rate': 3.767123287671233e-05, 'epoch': 0.8}
 29%|██▉       | 5500/18750 [19:07<44:23,  4.97it/s]  
{'loss': 0.3996, 'grad_norm': 25.21976089477539, 'learning_rate': 3.63013698630137e-05, 'epoch': 0.88}
 32%|███▏      | 6000/18750 [20:51<46:02,  4.62it/s]  
{'loss': 0.4196, 'grad_norm': 10.193493843078613, 'learning_rate': 3.493150684931507e-05, 'epoch': 0.96}
 35%|███▍      | 6500/18750 [22:42<44:30,  4.59it/s]  
{'loss': 0.3272, 'grad_norm': 0.05335897579789162, 'learning_rate': 3.356164383561644e-05, 'epoch': 1.04}
 37%|███▋      | 7000/18750 [24:32<42:49,  4.57it/s]  
{'loss': 0.2588, 'grad_norm': 0.7967578768730164, 'learning_rate': 3.219178082191781e-05, 'epoch': 1.12}
 40%|████      | 7500/18750 [26:22<41:17,  4.54it/s]  
{'loss': 0.2978, 'grad_norm': 48.885040283203125, 'learning_rate': 3.082191780821918e-05, 'epoch': 1.2}
 43%|████▎     | 8000/18750 [28:13<38:53,  4.61it/s]  
{'loss': 0.2567, 'grad_norm': 0.3560728430747986, 'learning_rate': 2.945205479452055e-05, 'epoch': 1.28}
 45%|████▌     | 8500/18750 [29:57<34:15,  4.99it/s]  
{'loss': 0.2406, 'grad_norm': 0.06602785736322403, 'learning_rate': 2.808219178082192e-05, 'epoch': 1.36}
 48%|████▊     | 9000/18750 [31:38<32:40,  4.97it/s]  
{'loss': 0.2412, 'grad_norm': 29.276275634765625, 'learning_rate': 2.671232876712329e-05, 'epoch': 1.44}
 51%|█████     | 9500/18750 [33:20<30:54,  4.99it/s]  
{'loss': 0.2416, 'grad_norm': 0.06330031901597977, 'learning_rate': 2.534246575342466e-05, 'epoch': 1.52}
 53%|█████▎    | 10000/18750 [35:01<29:14,  4.99it/s] 
{'loss': 0.2592, 'grad_norm': 0.05070902034640312, 'learning_rate': 2.3972602739726026e-05, 'epoch': 1.6}
 56%|█████▌    | 10500/18750 [36:42<27:36,  4.98it/s]  
{'loss': 0.2729, 'grad_norm': 81.3860855102539, 'learning_rate': 2.2602739726027396e-05, 'epoch': 1.68}
 59%|█████▊    | 11000/18750 [38:24<25:58,  4.97it/s]  
{'loss': 0.2288, 'grad_norm': 0.10501725971698761, 'learning_rate': 2.1232876712328768e-05, 'epoch': 1.76}
 61%|██████▏   | 11500/18750 [40:05<24:20,  4.97it/s]  
{'loss': 0.2385, 'grad_norm': 0.019870679825544357, 'learning_rate': 1.9863013698630137e-05, 'epoch': 1.84}
 64%|██████▍   | 12000/18750 [41:47<22:35,  4.98it/s]
{'loss': 0.272, 'grad_norm': 0.07312991470098495, 'learning_rate': 1.8493150684931506e-05, 'epoch': 1.92}
 67%|██████▋   | 12500/18750 [43:28<21:03,  4.95it/s]
{'loss': 0.2471, 'grad_norm': 0.13753291964530945, 'learning_rate': 1.7123287671232875e-05, 'epoch': 2.0}
 69%|██████▉   | 13000/18750 [45:10<19:16,  4.97it/s]
{'loss': 0.084, 'grad_norm': 1083.43701171875, 'learning_rate': 1.5753424657534248e-05, 'epoch': 2.08}
 72%|███████▏  | 13500/18750 [46:51<17:37,  4.97it/s]
{'loss': 0.123, 'grad_norm': 0.01554066687822342, 'learning_rate': 1.4383561643835617e-05, 'epoch': 2.16}
 75%|███████▍  | 14000/18750 [48:32<15:54,  4.97it/s]
{'loss': 0.1462, 'grad_norm': 0.028229428455233574, 'learning_rate': 1.3013698630136986e-05, 'epoch': 2.24}
 77%|███████▋  | 14500/18750 [50:14<14:17,  4.95it/s]
{'loss': 0.1217, 'grad_norm': 1.8434722423553467, 'learning_rate': 1.1643835616438355e-05, 'epoch': 2.32}
 80%|████████  | 15000/18750 [51:55<12:34,  4.97it/s]
{'loss': 0.0791, 'grad_norm': 0.011949531733989716, 'learning_rate': 1.0273972602739726e-05, 'epoch': 2.4}
 83%|████████▎ | 15500/18750 [53:37<10:54,  4.97it/s]
{'loss': 0.1024, 'grad_norm': 0.1625107228755951, 'learning_rate': 8.904109589041095e-06, 'epoch': 2.48}
 85%|████████▌ | 16000/18750 [55:18<09:11,  4.98it/s]
{'loss': 0.1172, 'grad_norm': 0.01604519970715046, 'learning_rate': 7.5342465753424655e-06, 'epoch': 2.56}
 88%|████████▊ | 16500/18750 [57:00<07:32,  4.97it/s]
{'loss': 0.0917, 'grad_norm': 0.09570044279098511, 'learning_rate': 6.1643835616438354e-06, 'epoch': 2.64}
 91%|█████████ | 17000/18750 [58:41<05:52,  4.96it/s]
{'loss': 0.0847, 'grad_norm': 0.99578857421875, 'learning_rate': 4.7945205479452054e-06, 'epoch': 2.72}
 93%|█████████▎| 17500/18750 [1:00:22<04:11,  4.98it/s]
{'loss': 0.0987, 'grad_norm': 0.01294002402573824, 'learning_rate': 3.4246575342465754e-06, 'epoch': 2.8}
 96%|█████████▌| 18000/18750 [1:02:04<02:30,  4.98it/s]
{'loss': 0.1108, 'grad_norm': 0.026291929185390472, 'learning_rate': 2.054794520547945e-06, 'epoch': 2.88}
 99%|█████████▊| 18500/18750 [1:03:45<00:50,  4.98it/s]
{'loss': 0.102, 'grad_norm': 0.02042091079056263, 'learning_rate': 6.849315068493151e-07, 'epoch': 2.96}
100%|██████████| 18750/18750 [1:04:36<00:00,  4.84it/s]
{'train_runtime': 3876.8226, 'train_samples_per_second': 19.346, 'train_steps_per_second': 4.836, 'train_loss': 0.2728967039489746, 'epoch': 3.0}
Out[40]:
TrainOutput(global_step=18750, training_loss=0.2728967039489746, metrics={'train_runtime': 3876.8226, 'train_samples_per_second': 19.346, 'train_steps_per_second': 4.836, 'train_loss': 0.2728967039489746, 'epoch': 3.0})
In [41]:
trainer.evaluate()
100%|██████████| 6250/6250 [06:44<00:00, 15.44it/s]
Out[41]:
{'eval_loss': 0.4087878167629242,
 'eval_runtime': 404.6947,
 'eval_samples_per_second': 61.775,
 'eval_steps_per_second': 15.444,
 'epoch': 3.0}
In [ ]: