How to Optimize LLMs for Efficient Serving

Ali Issa
3 min readMay 8, 2024

--

Image generated by DALL E-3

We all know that serving LLMs in production is complicated. It requires extensive research and the design of the best architecture and techniques possible to serve the LLMs, minimizing cost and latency when calling these models.

In this short course by DeepLearning.AI and Predibase (a company that allows you to fine-tune and serve open-source large language models), taught by Travis Addair, I took some key notes:

The course started by briefly explaining the difference between encoders, decoders, and encoder-decoder models, and how decoder-based models work. They generate a token based on the previous token (auto-regressive models).

KV 𝐜𝐚𝐜𝐡𝐞

When the model receives a text as input for the first tie, it will perform some matrix multiplication (Query Key and values that are being used in the head of each MultiHead that exists in each layer in the transformer architecture). When you are trying to generate several tokens, for each token generated, the answer will be given again to the LLM to produce the new token until it reach the final token. After few generations, you will notice that values of K and V were already computed in the previous steps and the model is recomputing them. For this reason, We only need to perform new computations for the new token. That’s why storing the ‘old’ computations will reduce the time required to generate each token. So, we can cache them to speed up inference.

Increasing throughput means handling more requests simultaneously, but it can increase latency.

How to increase throughput

Using 𝐜𝐨𝐧𝐭𝐢𝐧𝐮𝐨𝐮𝐬 𝐛𝐚𝐭𝐜𝐡𝐢𝐧𝐠. The idea behind it is when a model is handling multiple requests, a new request will replace a new one which will be much faster than waiting for a group of requests to end, and sending the new ones.

How do we know if a request in a batch has finished?

By applying 𝐟𝐢𝐥𝐭𝐞𝐫𝐢𝐧𝐠 𝐁𝐚𝐭𝐜𝐡𝐞𝐬. The Filter will check which sequence we reached the end, or we generated the required number of tokens and will remove it from the cached batch.

𝐐𝐮𝐚𝐧𝐭𝐢𝐳𝐚𝐭𝐢𝐨𝐧

The number is usually presented as a vector which consists of three parts: the sign (indicating positive or negative), the exponent (representing the range of values), and the mantissa (the decimal or fractional component of the values, indicating the precision of the number we want to represent). When a model is quantized for inference in production, it’s more efficient to dequantize each layer just as it’s needed, rather than dequantizing the entire model at once. This approach avoids unnecessary computations and enhances performance.

𝐋𝐨𝐑𝐀𝐗

It is a framework that allows users to serve fine-tuned models on a single GPU, dramatically reducing the cost of serving without compromising on throughput or latency. LORAX was used in this course multiple Lora adapters, jumping from one adapter to another, or calling multiple adapters at the same time.

If you like what you see, hit the follow button! You can also find me on LinkedIn, and we can follow each other there too. 😊

--

--

No responses yet