LLM inferencing

https://www.youtube.com/watch?v=fcgPYo3OtV0&t=2336s

Here are the study notes when I took the inference class from stanford cs336, discussing techniques to improve the inference efficiency.


Multi Head Attention MHA
Grouped Head Attention GHA (use few keys and values)
Multi Query Attention MQA
Multi Head Latent Attention MLA (compressed latent KV): deepseek reduce N*H from 16348 to 512
MLA is not compatible with ROPE

Cross Layer Attention (CLA): reuse the same K, V projection between different layers

Local Attention
full n^2 attention, sliding window attention, dilated sliding window, global+sliding window
sliding window attention: KV cache size remain constant as you increase longer sequence size
solution: interleave local attention with global attention (hybrid layers): for every 6 layers you have global attention

inference is memory limited
– lower dimensional KV cache (GQA, MLA, shared KV cache)
– local attention on some of the layers

alternative to transformers
– state space models -> continuous state space, fast discrete representations
– diffusion models
– mamba
– jamba: interleave Transformers and Mamba layers: 1:7 ratio
– BASED: use linear attention + local attention
– MiniMax: use linear attention + full attention (once in a while) (456 parameter MoE)

quantization: reduce precision of numbers, need to worry about accuracy
fp32 4bytes: needed for parameters and optimizers states during training
bf16: default for inference
int8: for inference only
LLM.int8()

model pruning

speculative decoding and speculative sampling: small model generation + big model evaluation is faster than big model generation, you get a 2x speedup, try to make draft model as close to the target model (model distillation)

paged attention


Leave a comment