Knowledge distillation
Here, we discusses the challenges associated with using pre-trained BERT models for downstream tasks due to their computational expense and high resource requirements. we highlight the difficulty of deploying these models on edge devices such as mobile phones. To address this issue, we explore knowledge distillation as a method to transfer knowledge from large pre-trained BERT models to smaller variants.
Knowledge distillation:
Knowledge distillation in BERT involves transferring knowledge from a large, pre-trained BERT model to a smaller variant. The process aims to distill the knowledge learned by the large model into a more compact form that retains most of its performance. This allows for reducing the computational resources required for inference, making the model more suitable for deployment on resource-constrained devices like mobile phones. During knowledge distillation, the smaller model learns from the soft targets or logits produced by the larger model, effectively learning to mimic its behavior. This process typically involves training the smaller model on a combination of the original training data and the soft targets produced by the larger model. Overall, knowledge distillation enables the creation of efficient BERT models that maintain high performance while being more lightweight and computationally efficient.
example: = i completed writing my _________
model: Homework
: 0.995 Cake
: 0.000 Assignment
: 0.001 Car
: 0.000 Book
: 0.002
The predicted word here is Homework
, apart from that Book
and Assignment
are more relevant than other two. This is known as dark knowledge
, and smaller BERT (student) learns dark knowldge from large BERT (teacher).
How to extract dark knowledge:
From model's probability, Homework has the hihgest socre, and close to zero for the rest of the words. In such a scenario, how a students learns the dark knowledge?
Softmax Tempreture
The softmax function with temperature scaling, also known as softmax temperature, introduces a temperature parameter ($T$) to the softmax function. This parameter controls the "softness" or "sharpness" of the resulting probability distribution. A higher temperature leads to a softer distribution where the probabilities are more spread out, while a lower temperature results in a sharper distribution with higher confidence in the predicted class.
The softmax function with temperature scaling is defined as:
$$ \text{softmax}_T(z)_i = \frac{e^{z_i / T}}{\sum_{j=1}^N e^{z_j / T}} $$
Where:
- $z$ is the input vector (logits).
- $N$ is the number of elements in the input vector.
- $T$ is the temperature parameter.
This formula calculates the softmax probabilities ($\text{softmax}_T(z)_i$) for each element $i$ in the input vector $z$ using the temperature $T$.
When $T = 1$, the softmax function with temperature scaling reduces to the standard softmax function. As $T$ increases, the resulting probability distribution becomes softer, and as $T$ decreases, the distribution becomes sharper.
Adjusting the temperature parameter allows for controlling the exploration-exploitation trade-off in reinforcement learning and fine-tuning the behavior of neural network models in various tasks.
For example
$T=1$ Homework
: 0.995 Cake
: 0.000 Assignment
: 0.001 Car
: 0.000 Book
: 0.002
$T=2$ Homework
: 0.935 Cake
: 0.001 Assignment
: 0.017 Car
: 0.001 Book
: 0.046
$T=5$ Homework
: 0.637 Cake
: 0.019 Assignment
: 0.018 Car
: 0.020 Book
: 0.191
How to train student:
In the process, both the teacher (Pretrained) and student (naive) networks receive the input sentence, producing probability distributions as outputs. The teacher network, being pre-trained, provides the target probability distribution, termed as soft target. Meanwhile, the student network generates its prediction, termed as soft prediction, aiming to mimic the teacher's behavior.
example: I completedwriting my _____
Teacher (soft target): Homework
: 0.637 Cake
: 0.019 Assignment
: 0.018 Car
: 0.020 Book
: 0.191
Student (soft prediction): Homework
: 0.363 Cake
: 0.134 Assignment
: 0.242 Car
: 0.059 Book
: 0.200
The next step involves computing the cross-entropy loss between the soft target and soft prediction, known as distillation loss. Through backpropagation, the student network is trained to minimize this loss, aligning its predictions with the teacher's targets. Both networks maintain the same softmax temperature, set greater than 1. Hence, we train student network by minimizing the distillation loss, and also student loss.
Objective
The distinction between soft and hard predictions lies in the temperature setting. Soft prediction, derived from the student network with temperature greater than 1, produces a probability distribution. Conversely, hard prediction, obtained with temperature set to 1, results in a standard softmax function prediction. The student loss is determined by the cross-entropy loss between the hard target and hard prediction.
example: I completedwriting my _____
Teacher (soft target) $T=5$: Homework
: 0.637 Cake
: 0.019 Assignment
: 0.018 Car
: 0.020 Book
: 0.191
Student (soft prediction) $T=5$: Homework
: 0.363 Cake
: 0.134 Assignment
: 0.242 Car
: 0.059 Book
: 0.200
Student (soft prediction) $T=1$: Homework
: 0.838 Cake
: 0.005 Assignment
: 0.115 Car
: 0.000 Book
: 0.038
Hard target: Homework
: 1 Cake
: 0 Assignment
: 0 Car
: 0 Book
: 0
distillation loss
: Teacher vs Student on ($T=5$)
student loss
: Student ($T=1$) vs Hard targets
Finals loss is weighted sum of student loss and distillation loss.
$$ L = \alpha \times S_{loss} + \beta \times D_{loss} $$
$\alpha$ and $\beta$ are hyperparameters.
DistillERT:
Introduced by huggingface which is smaller, faster, cheaper, and lighter version of BERT.
Teacher-student architecture
The teacher BERT serves as a large pre-trained model, specifically utilizing the BERT-Base variant. BERT-Base undergoes pre-training through tasks like masked language modeling and next sentence prediction. Leveraging its proficiency in masked language modeling, the pre-trained BERT model can effectively predict masked words.
The student BERT differs from the teacher BERT in that it lacks pre-training and learns from the teacher. It's a smaller model with fewer layers, containing 66 million parameters compared to the teacher's 110 million. Due to its reduced complexity, the student BERT trains faster than the teacher variant.
Training the DistilBERT
Lets follow RoBERTa model training style. We will train DistilBERT on MLM (masked language model) that is dynamic masking, and large batch size.
In addition to distillation and student loss, cosine embedding loss is computed, serving as a distance measure between the representations learned by the teacher and student BERT models. Minimizing this loss enhances the accuracy of the student's representation, aligning it more closely with the teacher's embedding.
$$ L = S_{loss} + D_{loss} + E_{loss} $$
$S_{loss}$ student loss
$D_{loss}$ distil loss
$E_{loss}$ cosine embedding loss
DistilBERT achieves nearly 97% accuracy compared to the original BERT-Base model while offering lighter weight, enabling easy deployment on edge devices and providing a 60% speed improvement in inference. It as trained on eight 16 GB V100 GPUs for approximately 90 hours, the pre-trained DistilBERT is publicly available through Hugging Face. It can be downloaded and fine-tuned for downstream tasks similar to the original BERT model.