TinyBERT
can we also (as seen in DistilBERT) transfer knowledge from the other layers of the teacher BERT? Yes! We also transfer knowledge from embdding and encoder layers.
we do the following:
- As in DistilBERT, the student BERT was trained using the logits generated by the output layer of the teacher BERT, aiming to replicate the same logits.
- The student BERT was trained to mimic the hidden states and attention matrices produced by the teacher BERT.
- The student BERT was trained to produce the output of the embedding layer from the teacher BERT, ensuring consistency in the produced embeddings.
we use a two-stage learning framework where we apply distillation in both the pre-training and fine-tuning stage.
The Teacher
The teacher BERT comprises encoder layers, starting with an input sentence fed into an embedding layer to obtain input embeddings. These embeddings are then passed through the encoder layers, which utilize self-attention to learn contextual relations within the input sentence, resulting in a representation. This representation is forwarded to the prediction layer, typically a feedforward network. In tasks like masked language modeling, the prediction layer returns logits for all words in the vocabulary, considering the masked word. Utilizing the pre-trained BERT-Base model as the teacher, which consists of 12 encoder layers and attention heads, producing 768-dimensional representations, with a total of 110 million parameters.
The Student
The architecture of the student BERT mirrors that of the teacher BERT, with encoder layers. However, the student BERT has fewer encoder $M$ layers compared to the teacher BERT, denoted as $N$. Specifically, the student BERT employs 4 encoder layers. Additionally, the representation size (hidden state dimension) is set to 312. Notably, the number of parameters in the student BERT is substantially lower, totaling 14.5 million parameters.
Layers of Distillation:
- Transfomer Layer
- Embedding Layer
- Prediction Layer
Teacher Index: Embedding: 0, 1-N: Encoding layers, N+1 Prediction (FFN) layer
Student Index: Embedding: 0, 1-M: Encoding layers, M+1 Prediction (FFN) layer
Knowledge transfer:
$$ n = g(m) $$
In distillation, knowledge is transferred from corresponding layers in the teacher to the student using a mapping function $ g $. Each layer in the student learns from its counterpart in the teacher. For example, transferring knowledge from the 0th layer (embedding layer) in the teacher to the 0th layer in the student. Similarly, knowledge from the $ N +1 $ layer (prediction layer) in the teacher is transferred to the $g (M+1)$ layer in the student. This process occurs at each layer, facilitating knowledge transfer throughout the model.
Transformer Layer
- Attention-based distillation
- Hidden state-based distillation
Attention Based Distillation
To perform attention-based distillation, we train the student network by minimizing the mean squared error between the attention matrix of the student and the teacher BERT.
The attention-based distillation loss in TinyBERT is often formulated as the mean squared error (MSE) between the attention matrices of the teacher and student models. The MSE loss measures the average squared difference between corresponding elements of the attention matrices.
Mathematically, the attention-based distillation loss $ L_{\text{attention}} ) can be expressed as:
$$ L_{\text{attention}} = \frac{1}{N} \sum_{i=1}^{h} MSE(A^T_{ij} - A^S_{ij}) $$
Where:
- $ N $ is the number of tokens or elements in the attention matrices.
- $ A^T_{i} $ and $ A^S_{i} $ are the attention weights corresponding to the $ i^{th}$ token in the input sequence, for the teacher and student models respectively.
This formulation computes the squared difference between each element of the attention matrices and averages them over all elements. Minimizing this loss encourages the student model to mimic the attention patterns of the teacher model, facilitating knowledge distillation.
Hidden state-based distillation
In hidden state-based distillation, the hidden states produced by the teacher and student models are compared to encourage the student model to mimic the representations learned by the teacher model. Let's denote the hidden states produced by the teacher model as $ H^T $ and the hidden states produced by the student model as $ H^S $.
One common approach to compute the hidden state-based distillation loss is by using the mean squared error (MSE) between the hidden states of the teacher and student models. The MSE loss measures the average squared difference between corresponding elements of the hidden state matrices.
Mathematically, the hidden state-based distillation loss $ L_{\text{hidden}} $ can be expressed as:
$$ L_{\text{hidden}} = MSE( H^T, H^S) $$
Where:
- $ N $ is the number of tokens or elements in the hidden state matrices.
- $ H^T $ and $ H^S $ are the hidden states for the teacher and student models respectively.
This formulation computes the squared difference between each element of the hidden state matrices and averages them over all elements. Minimizing this loss encourages the student model to learn representations that are similar to those learned by the teacher model, facilitating knowledge transfer.
But, the dimension of teacher and student are different ?
- YES
One approach: handle the dimensionality difference between the teacher and student hidden states is to directly transform the student's hidden states to match the dimensionality of the teacher's hidden states before computing the mean squared error (MSE) loss.
$$ L_{\text{hidden}} = MSE( H^TW_h, H^S) $$
- $W$ is a learnable linear transformation matrix that maps $H^S$ from dimension $d_s$ to dimension $d_t$.
This approach effectively aligns the representations of the teacher and student models, enabling the computation of the MSE loss without requiring additional steps to project both hidden states into a shared latent space.
Embedding Layer
When distilling knowledge from the embedding layer of the teacher model to the student model, the Mean Squared Error (MSE) loss can be used to measure the discrepancy between their embeddings. Let's denote the embedding vectors produced by the teacher and student models as $ E^T $ and $ E^S $ respectively.
The MSE loss for the embedding layer distillation can be computed as follows:
$$ L_{\text{embedding}} = \frac{1}{N} \sum_{i=1}^{N} || E^T_i - E^S_i ||^2 $$
Where:
- $ N $ is the number of tokens or elements in the embedding vectors.
- $ E^T_i $ and $ E^S_i $ are the embedding vectors for the $ i $th token in the input sequence, produced by the teacher and student models respectively.
This formula calculates the squared difference between each element of the embedding vectors and averages them over all elements.
In case of size mismtatch, we can employe the linear transformation as used in hidden state distillation.
For the predcition layer distillation we can use the DistilBERT approach.
The final loss function is as follow:
$$ L = \sum_{m=0}^{N+1} \lambda_m L_{layer} (S_m, T_{g(m)})$$
- $L_{layer}$ loss fucntion of the layer $m$.
- $\lambda_m$ hyperparameter of $m^{th}$ layer
In TinyBERT, distillation is applied across all layers and in both pre-training and fine-tuning phases. Compared to BERT-Base, TinyBERT is 96% more effective, 7.5 times smaller, and 9.4 times faster during inference.