Unpacking the Power of Context Distillation

Ali Issa
5 min readOct 18, 2023

--

A few weeks ago, I checked the paper of Llama2, and they were using a technique called “GhostAttention,” which is based on “Context distillation”.

The idea behind “ghost attention” is that models tend to forget the first instruction given to them when the conversation becomes lengthy. Therefore, this method was introduced to retain that initial instruction throughout the conversation. Please refer to the image below for a better understanding.

Source : Llama2 paper

These notes represent my findings following my research on the “Context distillation” technique.

Context Distillation

Knowledge Distillation vs. Context Distillation

Knowledge Distillation: Students attempt to learn from the teacher, which may have different parameter sizes. This technique is primarily used for optimizing models to use a smaller model instead of a larger one. The same input is provided to both models.

Context Distillation: The teacher and the student have the same number of parameters, but the teacher has a prompt and scratchpad, whereas the student does not.

Generated by Bing Image Creator , just for fun 🙂

Prompts and extended reasoning are crucial characteristics for improving model performance. When the user provides a query without a prompt, the model’s performance tends to decrease.

However, with a well-constructed prompt, accuracy increases. This may seem obvious, but researchers explored this concept to find a solution where a prompt isn’t always necessary for every query or question posed to the LLM.

Costs and Limitations

1) A longer prompt requires more computational resources, and each additional reasoning step demands more computation.
2) There is a finite context window size.

How LLM Mimics Human Behaviour

The more we work on something or practice it, the stronger our knowledge becomes, making it easier to solve complex problems. With repeated practice, a process becomes more “distilled” and memorable in our brains. We no longer need to remember every single step each time. In other words, after completing a complex task several times, you won’t need to remember each step; you will automatically know what to do when you are familiar with the “general title” (for a human it’s the action he/she will take without any detailed steps on how to complete the action, and for the LLM case it would be a query without a prompt).

This is where the concept of “context distillation” comes in.

How it works

In the context distillation process, we have two models, the Teacher, and the Student.

Context token could be a simple prompt or a reasoning chain.

1)Input Generation

Depending on our task, whether it’s an addition problem, a movie review, or any other task, the data can be synthesised, either randomly generated by another LLM or sourced from elsewhere.

They generate samples from the teacher, creating a training dataset consisting of a query and an answer, and then train the student model. The reason for this approach is that they cannot fit two models in memory simultaneously, so they generate data from the teacher before training the student model.

2) Inserting Input into the Teacher and Student Prompts
- The Teacher has access to a prompt, which is added to the input. The Teacher then performs a kind of reasoning called Scratchpad before the final completion step.

The Scratchpad process involves reasoning before generating the final answer. The sequence of tokens generated before the final answer is referred to as the scratchpad.

- The student has no additional information added to the input. It takes the input and attempts to predict the same output as the teacher without performing any reasoning or adding extra information about the input.

3) Answer Extraction
The final answer is extracted from the teacher’s completion after adding a prompt to the input generator and performing reasoning.

4) Training

The student until it is more likely to predict an output similar to the teacher’s output. The student learns to predict this output directly without going through the reasoning process.

More details about training and hypothesis

They used a T5–11B instruction fine-tuned on the V2 dataset, which includes instructions, positive and negative examples, and explanations. They conducted 10 explanation tasks, with 5 tasks where the teacher benefited from explanations and 5 where the teacher did not. Several hypotheses were explored, such as:

A) Learning from Abstract Instructions and Explanations

- Hypothesis 1: Context distillation can internalize much of the knowledge within the task instructions.

What I’ve gathered from this concept is that, since we typically need a well-crafted prompt and explanations for good results, this technique enables achieving strong results with just an input and no additional explanation. This is achieved by training the student model to generate an output as if it contained an explanation with the input.

-Hypothesis 2: Context distillation can learn from natural language explanations when they benefit the teacher.

They provided examples to both the teacher and student, but added explanations only to the teacher model. The goal was to see if the student could internalize the explanations for the given examples. The conclusion was that when the teacher learns more from explanations, the student can effectively learn from the teacher.

-Hypothesis 3: Recursive distillation to override past updates.

Replacing old information with current information, like a new president of a country.

Recursive distillation essentially involves overriding or adding additional information to prior knowledge of a given prompt. This process employs the student as a new teacher and utilises a recursive technique on a new student model to overwrite the prompt.

B) Learning from Concrete Examples
- Hypothesis 4 : Prompting the Teacher model with 4–8 training examples to aid in task completion. Then, distillation with 4 examples to facilitate student learning.

They concluded that using in-context examples allowed the student, through context distillation, to outperform fine-tuning with gradient descent using 4–8 in-context examples.

C) Learning from Step-by-Step Reasoning
By distilling the COT(chain of thought)reasoning, the student, after applying context distillation, could achieve impressive results, surpassing even the teacher.

For more detailed information, please refer to the following resources: official_paper link ,video

Don’t forget to follow for similar content 🙂😄

--

--