
KV Caching in Language Models
This blog post provides a stupidly simple demonstration of KV (Key-Value) caching in transformer language models, using the TinyLlama model as an example. We’ll explore:
- How attention mechanisms work in transformer models
- What Q, K, V projections are and how they’re computed
- How KV caching optimizes inference
- The impact of KV caching on attention patterns and computation efficiency
What is KV Caching?
KV caching is an optimization technique used during autoregressive text generation to avoid redundant computations. In a typical transformer model:
- For each token, we compute Query (Q), Key (K), and Value (V) matrices
- During generation, we only add one new token at a time
- Previous tokens’ K and V values don’t change when generating new tokens
- KV caching stores these K and V values to avoid recomputing them
Setup and Model Initialization
First, let’s import the necessary libraries and load the TinyLlama model.
1import torch
2import pickle
3import numpy as np
4from pathlib import Path
5import matplotlib.pyplot as plt
6import pandas as pd
7from transformers import AutoTokenizer, AutoModelForCausalLM
8import sys
9
10sys.path.append("../src")
11# Import custom utilities
12from attention_helpers.qkvo_hooks import capture_model_attention_internals
13from attention_helpers.gqa import reshape_llama_attention, compute_multihead_attention
14from plot_helpers.plotter import (
15 plot_single_matrix,
16 visualize_gqa_attention,
17 plot_attention_matrices,
18 get_axis_limits,
19 plot_kv_cache_verification,
20 plot_hybrid_verification,
21)
22
23# Create data directory if it doesn't exist
24Path("../data").mkdir(exist_ok=True)
Load the TinyLlama Model
We’ll use the TinyLlama-1.1B-Chat model for our experiments. This is a smaller model that demonstrates the same principles as larger LLMs.
1# Load model and tokenizer
2tokenizer = AutoTokenizer.from_pretrained(
3 "TinyLlama/TinyLlama-1.1B-Chat-v1.0", padding=False
4)
5model = AutoModelForCausalLM.from_pretrained(
6 "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
7 device_map="cpu",
8 torch_dtype=torch.float16,
9)
10
11# Display model information
12print(f"Model Parameters: {model.num_parameters():,}")
13print(f"Model Architecture: {model.config.architectures[0]}")
14print(f"Model Context Length: {model.config.max_position_embeddings}")
Output:
1Model Parameters: 1,100,048,384
2Model Architecture: LlamaForCausalLM
3Model Context Length: 2048
Define Helper Functions for Text Generation
Let’s define a function to generate text using the TinyLlama model with a chat template.
1def generate_text(
2 messages,
3 model,
4 tokenizer,
5 max_tokens=100,
6 temperature=0,
7 verbose=True,
8 padding=True,
9 truncation=True,
10):
11 """
12 Generate text using TinyLlama model with chat template
13
14 Args:
15 messages (list): List of message dictionaries with 'role' and 'content'
16 model: The language model to use
17 tokenizer: The tokenizer to use
18 max_tokens (int): Maximum number of tokens to generate
19 temperature (float): Sampling temperature (0.0 = deterministic)
20
21 Returns:
22 str: Generated text
23 """
24 # Apply chat template to format messages
25 prompt = tokenizer.apply_chat_template(
26 messages,
27 tokenize=False,
28 padding=padding,
29 truncation=truncation,
30 )
31
32 # Encode the prompt with attention mask
33 inputs = tokenizer(
34 prompt,
35 return_tensors="pt",
36 padding=padding,
37 truncation=truncation,
38 return_attention_mask=True,
39 )
40
41 if verbose:
42 # Print tokenization information
43 print("\n\nTokenization Information:")
44 text = tokenizer.apply_chat_template(messages, tokenize=False, padding=False)
45 tokens = tokenizer.encode(
46 text, padding=False, truncation=True, return_tensors="pt"
47 )
48 print(f"Input sequence length: {tokens.shape[1]} tokens")
49
50 # Create DataFrame for token visualization
51 token_data = {
52 "token_index": range(len(tokens[0])),
53 "token": tokens[0].tolist(),
54 "decoded_token": [tokenizer.decode([t]) for t in tokens[0]],
55 }
56 df = pd.DataFrame(token_data)
57 print("\nToken Details:")
58 print(df.to_string(index=False))
59
60 # Generate with attention mask
61 outputs = model.generate(
62 inputs.input_ids,
63 attention_mask=inputs.attention_mask,
64 max_new_tokens=max_tokens,
65 do_sample=(temperature > 0),
66 pad_token_id=tokenizer.eos_token_id,
67 )
68
69 # Decode and return the generated text, keeping special tokens
70 decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=False)
71 if verbose:
72 print(
73 f"\nGenerated tokens: {len(outputs[0]) - len(inputs.input_ids[0])} tokens"
74 )
75 print(f"Total tokens in final sequence: {len(outputs[0])}")
76 print(f"\nGenerated Text:\n{decoded_output}")
77 print("--------------------------------")
78 return decoded_output
Define Example Prompts
Let’s create two different prompts that we’ll use to demonstrate KV caching. Note that the prompts have some common prefix, which will be important for our caching demonstration.
1SYSTEM_PROMPT = "You are a helpful assistant."
2
3# First prompt: About the capital of India
4msg1 = [
5 {"role": "system", "content": SYSTEM_PROMPT},
6 {
7 "role": "user",
8 "content": "What is the capital of India?",
9 },
10]
11
12# Second prompt: About the capital of France
13msg2 = [
14 {"role": "system", "content": SYSTEM_PROMPT},
15 {
16 "role": "user",
17 "content": "What is the capital of France? Answer the question in just one word not more than that.",
18 },
19]
Run Example Generation
Let’s generate a response for the first prompt to see how the model works.
1# Run example text generation
2generated_text = generate_text(msg1, model, tokenizer, verbose=True)
Output:
1Tokenization Information:
2Input sequence length: 32 tokens
3
4Token Details:
5 token_index token decoded_token
6 0 1 <s>
7 1 529 <
8 2 29989 |
9 3 5205 system
10 4 29989 |
11 5 29958 >
12 6 13 \n
13 7 3492 You
14 8 526 are
15 9 263 a
16 10 8444 helpful
17 11 20255 assistant
18 12 29889 .
19 13 2 </s>
20 14 29871
21 15 13 \n
22 16 29966 <
23 17 29989 |
24 18 1792 user
25 19 29989 |
26 20 29958 >
27 21 13 \n
28 22 5618 What
29 23 338 is
30 24 278 the
31 25 7483 capital
32 26 310 of
33 27 7513 India
34 28 29973 ?
35 29 2 </s>
36 30 29871
37 31 13 \n
38
39Generated tokens: 17 tokens
40Total tokens in final sequence: 49
41
42Generated Text:
43<s> <|system|>
44You are a helpful assistant.</s>
45<|user|>
46What is the capital of India?</s>
47<|assistant|>
48The capital of India is New Delhi.</s>
49--------------------------------
Capture Attention Internals (Q, K, V Projections)
Now let’s capture the Query, Key, and Value projections for both of our prompts. These are the internal matrices used by the attention mechanism in transformers.
1# Capture attention internals for both inputs
2data_obj1 = capture_model_attention_internals(
3 messages=msg1,
4 model=model,
5 tokenizer=tokenizer,
6 padding=True,
7 truncation=True,
8 return_attention_mask=True,
9 verbose=True,
10)
11data_obj2 = capture_model_attention_internals(
12 messages=msg2,
13 model=model,
14 tokenizer=tokenizer,
15 padding=True,
16 truncation=True,
17 return_attention_mask=True,
18 verbose=True,
19)
20
21# Save the data for later use or analysis
22with open("../data/msg1-qkvo.pkl", "wb") as f:
23 pickle.dump(data_obj1, f)
24
25with open("../data/msg2-qkvo.pkl", "wb") as f:
26 pickle.dump(data_obj2, f)
Output:
1Attention layer shapes:
2Q projection: torch.Size([1, 32, 2048])
3K projection: torch.Size([1, 32, 256])
4V projection: torch.Size([1, 32, 256])
5O projection: torch.Size([1, 32, 2048])
6Attention layer shapes:
7Q projection: torch.Size([1, 44, 2048])
8K projection: torch.Size([1, 44, 256])
9V projection: torch.Size([1, 44, 256])
10O projection: torch.Size([1, 44, 2048])
Visualize Raw Projections (KV Cache Components)
Let’s visualize what’s actually stored in the KV cache. These are the raw K and V matrices for a specific layer. This is what would be stored in the KV cache during generation.
1# Select a layer for visualization
2layer_to_visualize = 20
3
4# Extract raw Q, K, V matrices for the selected layer
5q1 = data_obj1["attention_matrices"]["q"][layer_to_visualize]
6k1 = data_obj1["attention_matrices"]["k"][layer_to_visualize]
7v1 = data_obj1["attention_matrices"]["v"][layer_to_visualize]
8
9# Convert to 2D format for visualization
10q1_2d = torch.squeeze(q1) # [seq_len, q_dim]
11k1_2d = torch.squeeze(k1) # [seq_len, k_dim]
12v1_2d = torch.squeeze(v1) # [seq_len, v_dim]
13
14print(f"Q1 shape after squeeze: {q1_2d.shape}")
15print(f"K1 shape after squeeze: {k1_2d.shape}")
16print(f"V1 shape after squeeze: {v1_2d.shape}")
17
18# Visualize raw K matrix (what would be cached)
19plot_single_matrix(
20 k1_2d,
21 matrix_type="K",
22 plot_title=f"Raw Key Matrix (Cached in KV Cache) for Layer {layer_to_visualize}",
23 cmap="coolwarm",
24 tokens=data_obj1["decoded_input_tokens"],
25)
26
27# Visualize raw V matrix (what would be cached)
28plot_single_matrix(
29 v1_2d,
30 matrix_type="V",
31 plot_title=f"Raw Value Matrix (Cached in KV Cache) for Layer {layer_to_visualize}",
32 cmap="coolwarm",
33 tokens=data_obj1["decoded_input_tokens"],
34)
Output:
1Q1 shape after squeeze: torch.Size([32, 2048])
2K1 shape after squeeze: torch.Size([32, 256])
3V1 shape after squeeze: torch.Size([32, 256])


Computing Attention Patterns
Now let’s compute the attention patterns for both prompts. First, we reshape the Q, K, V matrices to the multi-head format, then compute the attention scores, probabilities, and outputs.
[batch_size, seq_len, hidden_dim]
-> [batch_size, num_heads, seq_len, head_dim]
1# Extract Q, K, V matrices for a specific layer
2layer_to_analyze = 20
3q1, k1, v1 = [
4 data_obj1["attention_matrices"][key][layer_to_analyze] for key in ["q", "k", "v"]
5]
6q2, k2, v2 = [
7 data_obj2["attention_matrices"][key][layer_to_analyze] for key in ["q", "k", "v"]
8]
9
10# Reshape into multi-head format
11q1_mh, k1_mh, v1_mh = reshape_llama_attention(q1, k1, v1, verbose=True)
12q2_mh, k2_mh, v2_mh = reshape_llama_attention(q2, k2, v2, verbose=False)
13
14# Compute attention separately for each message
15attention_msg1 = compute_multihead_attention(q1_mh, k1_mh, v1_mh)
16attention_msg2 = compute_multihead_attention(q2_mh, k2_mh, v2_mh)
17
18# Print information about the attention heads
19print(f"Grouping of queries:")
20for i in list(attention_msg1["heads"].keys())[:5]: # Just show the first 5
21 print(i)
22print("... (more heads)")
Output:
1After squeeze:
2Q shape: torch.Size([32, 2048])
3K shape: torch.Size([32, 256])
4V shape: torch.Size([32, 256])
5
6Head dimensions:
7Q head dim: 64
8KV head dim: 32
9Num Q heads: 32
10Num KV heads: 8
11
12After reshape:
13Q shape: torch.Size([32, 32, 64])
14K shape: torch.Size([8, 32, 32])
15V shape: torch.Size([8, 32, 32])
16Grouping of queries:
17q_head_0_kv_head_0
18q_head_1_kv_head_0
19q_head_2_kv_head_0
20q_head_3_kv_head_0
21q_head_4_kv_head_1
22... (more heads)
Visualizing Attention Patterns
Let’s visualize the attention patterns for each prompt. This shows which tokens are attending to which other tokens.
1# Visualize attention patterns
2visualize_gqa_attention(attention_msg1, title_prefix="Message 1 -")
3visualize_gqa_attention(attention_msg2, title_prefix="Message 2 -")
Output:

Implementing KV Caching
Now we’ll demonstrate how KV caching works. We’ll:
- Find the common prefix between our two prompts
- Create a hybrid KV cache that reuses cached values for the common prefix
- Compare attention patterns with and without caching
1def find_common_prefix_length(tokens1, tokens2):
2 """Find the length of common prefix between two token sequences."""
3 common_prefix_length = 0
4 for i in range(min(len(tokens1), len(tokens2))):
5 if tokens1[i] == tokens2[i]:
6 common_prefix_length += 1
7 else:
8 break
9 return common_prefix_length
10
11
12def create_hybrid_kv_cache(q2_mh, k1_mh, k2_mh, v1_mh, v2_mh, common_prefix_length):
13 """Create hybrid K,V matrices using cached values for common prefix."""
14 if common_prefix_length == 0:
15 return k2_mh, v2_mh
16
17 hybrid_k = k2_mh.clone()
18 hybrid_v = v2_mh.clone()
19
20 # Only copy the common prefix!
21 hybrid_k[:, :common_prefix_length, :] = k1_mh[:, :common_prefix_length, :]
22 hybrid_v[:, :common_prefix_length, :] = v1_mh[:, :common_prefix_length, :]
23
24 # Let's add some verification
25 print("Verifying cache creation:")
26 print(f"Common prefix length: {common_prefix_length}")
27 print(f"Total sequence length: {k2_mh.shape[1]}")
28 print("For prefix tokens: hybrid_k should match k1_mh")
29 print("For non-prefix tokens: hybrid_k should match k2_mh")
30
31 return hybrid_k, hybrid_v
Finding the Common Prefix
Let’s identify the common prefix between our two prompts, which will be the part we can cache.
1# Extract token sequences from both prompts
2tokens1 = data_obj1["input_tokens"][0].tolist()
3tokens2 = data_obj2["input_tokens"][0].tolist()
4
5# Find common prefix length
6common_prefix_length = find_common_prefix_length(tokens1, tokens2)
7print(f"Common prefix length: {common_prefix_length} tokens")
8
9# Display the common prefix tokens
10print("\nCommon Prefix Tokens:")
11for i in range(common_prefix_length):
12 token = tokens1[i]
13 decoded = tokenizer.decode([token])
14 print(f"Token {i}: {token} - '{decoded}'")
Output:
1Common prefix length: 27 tokens
2
3Common Prefix Tokens:
4Token 0: 1 - '<s>'
5Token 1: 529 - '<'
6Token 2: 29989 - '|'
7Token 3: 5205 - 'system'
8Token 4: 29989 - '|'
9Token 5: 29958 - '>'
10Token 6: 13 - '
11'
12Token 7: 3492 - 'You'
13Token 8: 526 - 'are'
14Token 9: 263 - 'a'
15Token 10: 8444 - 'helpful'
16Token 11: 20255 - 'assistant'
17Token 12: 29889 - '.'
18Token 13: 2 - '</s>'
19Token 14: 29871 - ''
20Token 15: 13 - '
21'
22Token 16: 29966 - '<'
23Token 17: 29989 - '|'
24Token 18: 1792 - 'user'
25Token 19: 29989 - '|'
26Token 20: 29958 - '>'
27Token 21: 13 - '
28'
29Token 22: 5618 - 'What'
30Token 23: 338 - 'is'
31Token 24: 278 - 'the'
32Token 25: 7483 - 'capital'
33Token 26: 310 - 'of'
Creating and Evaluating the Hybrid KV Cache
Now let’s create a hybrid KV cache and compare the attention patterns
with and without caching to verify that the results are identical.
this is the secret sauce
«< watch closely »>
1# Create hybrid cached version using common prefix
2hybrid_k, hybrid_v = create_hybrid_kv_cache(
3 q2_mh, k1_mh, k2_mh, v1_mh, v2_mh, common_prefix_length
4)
5
6# Compute attention with and without caching
7original_attention = compute_multihead_attention(q2_mh, k2_mh, v2_mh)
8cached_attention = compute_multihead_attention(q2_mh, hybrid_k, hybrid_v)
9
10# Visualize the difference in attention patterns
11max_tokens = min(15, len(tokens2))
12diff_attn = plot_attention_matrices(
13 original_attention, cached_attention, common_prefix_length, max_tokens
14)
15
16# Print statistics and computational savings
17print(f"Maximum difference in attention values: {np.max(diff_attn):.8f}")
18print(f"Mean difference: {np.mean(diff_attn):.8f}")
19
20if common_prefix_length > 0:
21 token_savings = common_prefix_length / len(tokens2)
22 compute_savings = (common_prefix_length * len(tokens2)) / (
23 len(tokens2) * len(tokens2)
24 )
25 print(f"By caching KV for the common prefix, you save:")
26 print(f" - {token_savings:.1%} of KV computations")
27 print(f" - {compute_savings:.1%} of attention computations")
Output:
1Verifying cache creation:
2Common prefix length: 27
3Total sequence length: 44
4For prefix tokens: hybrid_k should match k1_mh
5For non-prefix tokens: hybrid_k should match k2_mh
6Maximum difference in attention values: 0.00000000
7Mean difference: 0.00000000
8By caching KV for the common prefix, you save:
9 - 61.4% of KV computations
10 - 61.4% of attention computations

Visualizing Individual KV Vectors
Finally, let’s visualize the actual K and V vectors for specific tokens to confirm that our caching mechanism is working correctly.
1if common_prefix_length > 0:
2 # Get common axis limits for consistent visualization
3 head_idx = 5
4 k_limits, v_limits = get_axis_limits(
5 k1_mh, k2_mh, v1_mh, v2_mh, hybrid_k, hybrid_v, head_idx, k1_mh.shape[-1] // 2
6 )
7
8 # Calculate maximum sequence length
9 max_seq_len = max(k1_mh.shape[1], k2_mh.shape[1])
10
11 # Plot input1's KV values (what gets cached)
12 fig1 = plot_kv_cache_verification(
13 k1_mh,
14 v1_mh,
15 common_prefix_length,
16 k_limits,
17 v_limits,
18 max_seq_len,
19 input_name="input1",
20 )
21 plt.show()
22
23 # Plot input2's verification with hybrid cache
24 fig2 = plot_hybrid_verification(
25 k2_mh,
26 v2_mh,
27 hybrid_k,
28 hybrid_v,
29 common_prefix_length,
30 k_limits,
31 v_limits,
32 max_seq_len,
33 )
34 plt.show()
Output:


Conclusion
In this notebook, we’ve demonstrated stripped down version of how KV caching works in transformer models:
- We loaded a model and generated text for two different prompts
- We extracted Q, K, V projections from the model’s attention layers
- We found the common prefix between the prompts
- We created a hybrid KV cache that reuses cached values for the common prefix
- We verified that attention patterns are identical with and without caching
- We quantified the computational savings from KV caching
The key insight is that K and V values for tokens that remain the same across prompts don’t need to be recomputed. This optimization is crucial for efficient text generation in large language models, especially for long contexts.
Codebase
Entire codebase is available on github.
Written By