KV Caching in LLMs: A Visual Demonstration

A visual demonstration of KV caching in language models

  ยท   12 min read

KV Caching in Language Models

This blog post provides a stupidly simple demonstration of KV (Key-Value) caching in transformer language models, using the TinyLlama model as an example. We’ll explore:

  1. How attention mechanisms work in transformer models
  2. What Q, K, V projections are and how they’re computed
  3. How KV caching optimizes inference
  4. The impact of KV caching on attention patterns and computation efficiency

What is KV Caching?

KV caching is an optimization technique used during autoregressive text generation to avoid redundant computations. In a typical transformer model:

  • For each token, we compute Query (Q), Key (K), and Value (V) matrices
  • During generation, we only add one new token at a time
  • Previous tokens’ K and V values don’t change when generating new tokens
  • KV caching stores these K and V values to avoid recomputing them

Setup and Model Initialization

First, let’s import the necessary libraries and load the TinyLlama model.

 1import torch
 2import pickle
 3import numpy as np
 4from pathlib import Path
 5import matplotlib.pyplot as plt
 6import pandas as pd
 7from transformers import AutoTokenizer, AutoModelForCausalLM
 8import sys
 9
10sys.path.append("../src")
11# Import custom utilities
12from attention_helpers.qkvo_hooks import capture_model_attention_internals
13from attention_helpers.gqa import reshape_llama_attention, compute_multihead_attention
14from plot_helpers.plotter import (
15    plot_single_matrix,
16    visualize_gqa_attention,
17    plot_attention_matrices,
18    get_axis_limits,
19    plot_kv_cache_verification,
20    plot_hybrid_verification,
21)
22
23# Create data directory if it doesn't exist
24Path("../data").mkdir(exist_ok=True)

Load the TinyLlama Model

We’ll use the TinyLlama-1.1B-Chat model for our experiments. This is a smaller model that demonstrates the same principles as larger LLMs.

 1# Load model and tokenizer
 2tokenizer = AutoTokenizer.from_pretrained(
 3    "TinyLlama/TinyLlama-1.1B-Chat-v1.0", padding=False
 4)
 5model = AutoModelForCausalLM.from_pretrained(
 6    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
 7    device_map="cpu",
 8    torch_dtype=torch.float16,
 9)
10
11# Display model information
12print(f"Model Parameters: {model.num_parameters():,}")
13print(f"Model Architecture: {model.config.architectures[0]}")
14print(f"Model Context Length: {model.config.max_position_embeddings}")

Output:

1Model Parameters: 1,100,048,384
2Model Architecture: LlamaForCausalLM
3Model Context Length: 2048

Define Helper Functions for Text Generation

Let’s define a function to generate text using the TinyLlama model with a chat template.

 1def generate_text(
 2    messages,
 3    model,
 4    tokenizer,
 5    max_tokens=100,
 6    temperature=0,
 7    verbose=True,
 8    padding=True,
 9    truncation=True,
10):
11    """
12    Generate text using TinyLlama model with chat template
13
14    Args:
15        messages (list): List of message dictionaries with 'role' and 'content'
16        model: The language model to use
17        tokenizer: The tokenizer to use
18        max_tokens (int): Maximum number of tokens to generate
19        temperature (float): Sampling temperature (0.0 = deterministic)
20
21    Returns:
22        str: Generated text
23    """
24    # Apply chat template to format messages
25    prompt = tokenizer.apply_chat_template(
26        messages,
27        tokenize=False,
28        padding=padding,
29        truncation=truncation,
30    )
31
32    # Encode the prompt with attention mask
33    inputs = tokenizer(
34        prompt,
35        return_tensors="pt",
36        padding=padding,
37        truncation=truncation,
38        return_attention_mask=True,
39    )
40
41    if verbose:
42        # Print tokenization information
43        print("\n\nTokenization Information:")
44        text = tokenizer.apply_chat_template(messages, tokenize=False, padding=False)
45        tokens = tokenizer.encode(
46            text, padding=False, truncation=True, return_tensors="pt"
47        )
48        print(f"Input sequence length: {tokens.shape[1]} tokens")
49
50        # Create DataFrame for token visualization
51        token_data = {
52            "token_index": range(len(tokens[0])),
53            "token": tokens[0].tolist(),
54            "decoded_token": [tokenizer.decode([t]) for t in tokens[0]],
55        }
56        df = pd.DataFrame(token_data)
57        print("\nToken Details:")
58        print(df.to_string(index=False))
59
60    # Generate with attention mask
61    outputs = model.generate(
62        inputs.input_ids,
63        attention_mask=inputs.attention_mask,
64        max_new_tokens=max_tokens,
65        do_sample=(temperature > 0),
66        pad_token_id=tokenizer.eos_token_id,
67    )
68
69    # Decode and return the generated text, keeping special tokens
70    decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=False)
71    if verbose:
72        print(
73            f"\nGenerated tokens: {len(outputs[0]) - len(inputs.input_ids[0])} tokens"
74        )
75        print(f"Total tokens in final sequence: {len(outputs[0])}")
76        print(f"\nGenerated Text:\n{decoded_output}")
77        print("--------------------------------")
78    return decoded_output

Define Example Prompts

Let’s create two different prompts that we’ll use to demonstrate KV caching. Note that the prompts have some common prefix, which will be important for our caching demonstration.

 1SYSTEM_PROMPT = "You are a helpful assistant."
 2
 3# First prompt: About the capital of India
 4msg1 = [
 5    {"role": "system", "content": SYSTEM_PROMPT},
 6    {
 7        "role": "user",
 8        "content": "What is the capital of India?",
 9    },
10]
11
12# Second prompt: About the capital of France
13msg2 = [
14    {"role": "system", "content": SYSTEM_PROMPT},
15    {
16        "role": "user",
17        "content": "What is the capital of France? Answer the question in just one word not more than that.",
18    },
19]

Run Example Generation

Let’s generate a response for the first prompt to see how the model works.

1# Run example text generation
2generated_text = generate_text(msg1, model, tokenizer, verbose=True)

Output:

 1Tokenization Information:
 2Input sequence length: 32 tokens
 3
 4Token Details:
 5 token_index  token decoded_token
 6           0      1           <s>
 7           1    529             <
 8           2  29989             |
 9           3   5205        system
10           4  29989             |
11           5  29958             >
12           6     13            \n
13           7   3492           You
14           8    526           are
15           9    263             a
16          10   8444       helpful
17          11  20255     assistant
18          12  29889             .
19          13      2          </s>
20          14  29871
21          15     13            \n
22          16  29966             <
23          17  29989             |
24          18   1792          user
25          19  29989             |
26          20  29958             >
27          21     13            \n
28          22   5618          What
29          23    338            is
30          24    278           the
31          25   7483       capital
32          26    310            of
33          27   7513         India
34          28  29973             ?
35          29      2          </s>
36          30  29871
37          31     13            \n
38
39Generated tokens: 17 tokens
40Total tokens in final sequence: 49
41
42Generated Text:
43<s> <|system|>
44You are a helpful assistant.</s>
45<|user|>
46What is the capital of India?</s>
47<|assistant|>
48The capital of India is New Delhi.</s>
49--------------------------------

Capture Attention Internals (Q, K, V Projections)

Now let’s capture the Query, Key, and Value projections for both of our prompts. These are the internal matrices used by the attention mechanism in transformers.

 1# Capture attention internals for both inputs
 2data_obj1 = capture_model_attention_internals(
 3    messages=msg1,
 4    model=model,
 5    tokenizer=tokenizer,
 6    padding=True,
 7    truncation=True,
 8    return_attention_mask=True,
 9    verbose=True,
10)
11data_obj2 = capture_model_attention_internals(
12    messages=msg2,
13    model=model,
14    tokenizer=tokenizer,
15    padding=True,
16    truncation=True,
17    return_attention_mask=True,
18    verbose=True,
19)
20
21# Save the data for later use or analysis
22with open("../data/msg1-qkvo.pkl", "wb") as f:
23    pickle.dump(data_obj1, f)
24
25with open("../data/msg2-qkvo.pkl", "wb") as f:
26    pickle.dump(data_obj2, f)

Output:

 1Attention layer shapes:
 2Q projection: torch.Size([1, 32, 2048])
 3K projection: torch.Size([1, 32, 256])
 4V projection: torch.Size([1, 32, 256])
 5O projection: torch.Size([1, 32, 2048])
 6Attention layer shapes:
 7Q projection: torch.Size([1, 44, 2048])
 8K projection: torch.Size([1, 44, 256])
 9V projection: torch.Size([1, 44, 256])
10O projection: torch.Size([1, 44, 2048])

Visualize Raw Projections (KV Cache Components)

Let’s visualize what’s actually stored in the KV cache. These are the raw K and V matrices for a specific layer. This is what would be stored in the KV cache during generation.

 1# Select a layer for visualization
 2layer_to_visualize = 20
 3
 4# Extract raw Q, K, V matrices for the selected layer
 5q1 = data_obj1["attention_matrices"]["q"][layer_to_visualize]
 6k1 = data_obj1["attention_matrices"]["k"][layer_to_visualize]
 7v1 = data_obj1["attention_matrices"]["v"][layer_to_visualize]
 8
 9# Convert to 2D format for visualization
10q1_2d = torch.squeeze(q1)  # [seq_len, q_dim]
11k1_2d = torch.squeeze(k1)  # [seq_len, k_dim]
12v1_2d = torch.squeeze(v1)  # [seq_len, v_dim]
13
14print(f"Q1 shape after squeeze: {q1_2d.shape}")
15print(f"K1 shape after squeeze: {k1_2d.shape}")
16print(f"V1 shape after squeeze: {v1_2d.shape}")
17
18# Visualize raw K matrix (what would be cached)
19plot_single_matrix(
20    k1_2d,
21    matrix_type="K",
22    plot_title=f"Raw Key Matrix (Cached in KV Cache) for Layer {layer_to_visualize}",
23    cmap="coolwarm",
24    tokens=data_obj1["decoded_input_tokens"],
25)
26
27# Visualize raw V matrix (what would be cached)
28plot_single_matrix(
29    v1_2d,
30    matrix_type="V",
31    plot_title=f"Raw Value Matrix (Cached in KV Cache) for Layer {layer_to_visualize}",
32    cmap="coolwarm",
33    tokens=data_obj1["decoded_input_tokens"],
34)

Output:

1Q1 shape after squeeze: torch.Size([32, 2048])
2K1 shape after squeeze: torch.Size([32, 256])
3V1 shape after squeeze: torch.Size([32, 256])

Computing Attention Patterns

Now let’s compute the attention patterns for both prompts. First, we reshape the Q, K, V matrices to the multi-head format, then compute the attention scores, probabilities, and outputs.

[batch_size, seq_len, hidden_dim] -> [batch_size, num_heads, seq_len, head_dim]

 1# Extract Q, K, V matrices for a specific layer
 2layer_to_analyze = 20
 3q1, k1, v1 = [
 4    data_obj1["attention_matrices"][key][layer_to_analyze] for key in ["q", "k", "v"]
 5]
 6q2, k2, v2 = [
 7    data_obj2["attention_matrices"][key][layer_to_analyze] for key in ["q", "k", "v"]
 8]
 9
10# Reshape into multi-head format
11q1_mh, k1_mh, v1_mh = reshape_llama_attention(q1, k1, v1, verbose=True)
12q2_mh, k2_mh, v2_mh = reshape_llama_attention(q2, k2, v2, verbose=False)
13
14# Compute attention separately for each message
15attention_msg1 = compute_multihead_attention(q1_mh, k1_mh, v1_mh)
16attention_msg2 = compute_multihead_attention(q2_mh, k2_mh, v2_mh)
17
18# Print information about the attention heads
19print(f"Grouping of queries:")
20for i in list(attention_msg1["heads"].keys())[:5]:  # Just show the first 5
21    print(i)
22print("... (more heads)")

Output:

 1After squeeze:
 2Q shape: torch.Size([32, 2048])
 3K shape: torch.Size([32, 256])
 4V shape: torch.Size([32, 256])
 5
 6Head dimensions:
 7Q head dim: 64
 8KV head dim: 32
 9Num Q heads: 32
10Num KV heads: 8
11
12After reshape:
13Q shape: torch.Size([32, 32, 64])
14K shape: torch.Size([8, 32, 32])
15V shape: torch.Size([8, 32, 32])
16Grouping of queries:
17q_head_0_kv_head_0
18q_head_1_kv_head_0
19q_head_2_kv_head_0
20q_head_3_kv_head_0
21q_head_4_kv_head_1
22... (more heads)

Visualizing Attention Patterns

Let’s visualize the attention patterns for each prompt. This shows which tokens are attending to which other tokens.

1# Visualize attention patterns
2visualize_gqa_attention(attention_msg1, title_prefix="Message 1 -")
3visualize_gqa_attention(attention_msg2, title_prefix="Message 2 -")

Output:

Implementing KV Caching

Now we’ll demonstrate how KV caching works. We’ll:

  1. Find the common prefix between our two prompts
  2. Create a hybrid KV cache that reuses cached values for the common prefix
  3. Compare attention patterns with and without caching
 1def find_common_prefix_length(tokens1, tokens2):
 2    """Find the length of common prefix between two token sequences."""
 3    common_prefix_length = 0
 4    for i in range(min(len(tokens1), len(tokens2))):
 5        if tokens1[i] == tokens2[i]:
 6            common_prefix_length += 1
 7        else:
 8            break
 9    return common_prefix_length
10
11
12def create_hybrid_kv_cache(q2_mh, k1_mh, k2_mh, v1_mh, v2_mh, common_prefix_length):
13    """Create hybrid K,V matrices using cached values for common prefix."""
14    if common_prefix_length == 0:
15        return k2_mh, v2_mh
16
17    hybrid_k = k2_mh.clone()
18    hybrid_v = v2_mh.clone()
19
20    # Only copy the common prefix!
21    hybrid_k[:, :common_prefix_length, :] = k1_mh[:, :common_prefix_length, :]
22    hybrid_v[:, :common_prefix_length, :] = v1_mh[:, :common_prefix_length, :]
23
24    # Let's add some verification
25    print("Verifying cache creation:")
26    print(f"Common prefix length: {common_prefix_length}")
27    print(f"Total sequence length: {k2_mh.shape[1]}")
28    print("For prefix tokens: hybrid_k should match k1_mh")
29    print("For non-prefix tokens: hybrid_k should match k2_mh")
30
31    return hybrid_k, hybrid_v

Finding the Common Prefix

Let’s identify the common prefix between our two prompts, which will be the part we can cache.

 1# Extract token sequences from both prompts
 2tokens1 = data_obj1["input_tokens"][0].tolist()
 3tokens2 = data_obj2["input_tokens"][0].tolist()
 4
 5# Find common prefix length
 6common_prefix_length = find_common_prefix_length(tokens1, tokens2)
 7print(f"Common prefix length: {common_prefix_length} tokens")
 8
 9# Display the common prefix tokens
10print("\nCommon Prefix Tokens:")
11for i in range(common_prefix_length):
12    token = tokens1[i]
13    decoded = tokenizer.decode([token])
14    print(f"Token {i}: {token} - '{decoded}'")

Output:

 1Common prefix length: 27 tokens
 2
 3Common Prefix Tokens:
 4Token 0: 1 - '<s>'
 5Token 1: 529 - '<'
 6Token 2: 29989 - '|'
 7Token 3: 5205 - 'system'
 8Token 4: 29989 - '|'
 9Token 5: 29958 - '>'
10Token 6: 13 - '
11'
12Token 7: 3492 - 'You'
13Token 8: 526 - 'are'
14Token 9: 263 - 'a'
15Token 10: 8444 - 'helpful'
16Token 11: 20255 - 'assistant'
17Token 12: 29889 - '.'
18Token 13: 2 - '</s>'
19Token 14: 29871 - ''
20Token 15: 13 - '
21'
22Token 16: 29966 - '<'
23Token 17: 29989 - '|'
24Token 18: 1792 - 'user'
25Token 19: 29989 - '|'
26Token 20: 29958 - '>'
27Token 21: 13 - '
28'
29Token 22: 5618 - 'What'
30Token 23: 338 - 'is'
31Token 24: 278 - 'the'
32Token 25: 7483 - 'capital'
33Token 26: 310 - 'of'

Creating and Evaluating the Hybrid KV Cache

Now let’s create a hybrid KV cache and compare the attention patterns with and without caching to verify that the results are identical. this is the secret sauce «< watch closely »>

 1# Create hybrid cached version using common prefix
 2hybrid_k, hybrid_v = create_hybrid_kv_cache(
 3    q2_mh, k1_mh, k2_mh, v1_mh, v2_mh, common_prefix_length
 4)
 5
 6# Compute attention with and without caching
 7original_attention = compute_multihead_attention(q2_mh, k2_mh, v2_mh)
 8cached_attention = compute_multihead_attention(q2_mh, hybrid_k, hybrid_v)
 9
10# Visualize the difference in attention patterns
11max_tokens = min(15, len(tokens2))
12diff_attn = plot_attention_matrices(
13    original_attention, cached_attention, common_prefix_length, max_tokens
14)
15
16# Print statistics and computational savings
17print(f"Maximum difference in attention values: {np.max(diff_attn):.8f}")
18print(f"Mean difference: {np.mean(diff_attn):.8f}")
19
20if common_prefix_length > 0:
21    token_savings = common_prefix_length / len(tokens2)
22    compute_savings = (common_prefix_length * len(tokens2)) / (
23        len(tokens2) * len(tokens2)
24    )
25    print(f"By caching KV for the common prefix, you save:")
26    print(f" - {token_savings:.1%} of KV computations")
27    print(f" - {compute_savings:.1%} of attention computations")

Output:

 1Verifying cache creation:
 2Common prefix length: 27
 3Total sequence length: 44
 4For prefix tokens: hybrid_k should match k1_mh
 5For non-prefix tokens: hybrid_k should match k2_mh
 6Maximum difference in attention values: 0.00000000
 7Mean difference: 0.00000000
 8By caching KV for the common prefix, you save:
 9 - 61.4% of KV computations
10 - 61.4% of attention computations

Visualizing Individual KV Vectors

Finally, let’s visualize the actual K and V vectors for specific tokens to confirm that our caching mechanism is working correctly.

 1if common_prefix_length > 0:
 2    # Get common axis limits for consistent visualization
 3    head_idx = 5
 4    k_limits, v_limits = get_axis_limits(
 5        k1_mh, k2_mh, v1_mh, v2_mh, hybrid_k, hybrid_v, head_idx, k1_mh.shape[-1] // 2
 6    )
 7
 8    # Calculate maximum sequence length
 9    max_seq_len = max(k1_mh.shape[1], k2_mh.shape[1])
10
11    # Plot input1's KV values (what gets cached)
12    fig1 = plot_kv_cache_verification(
13        k1_mh,
14        v1_mh,
15        common_prefix_length,
16        k_limits,
17        v_limits,
18        max_seq_len,
19        input_name="input1",
20    )
21    plt.show()
22
23    # Plot input2's verification with hybrid cache
24    fig2 = plot_hybrid_verification(
25        k2_mh,
26        v2_mh,
27        hybrid_k,
28        hybrid_v,
29        common_prefix_length,
30        k_limits,
31        v_limits,
32        max_seq_len,
33    )
34    plt.show()

Output:

Conclusion

In this notebook, we’ve demonstrated stripped down version of how KV caching works in transformer models:

  1. We loaded a model and generated text for two different prompts
  2. We extracted Q, K, V projections from the model’s attention layers
  3. We found the common prefix between the prompts
  4. We created a hybrid KV cache that reuses cached values for the common prefix
  5. We verified that attention patterns are identical with and without caching
  6. We quantified the computational savings from KV caching

The key insight is that K and V values for tokens that remain the same across prompts don’t need to be recomputed. This optimization is crucial for efficient text generation in large language models, especially for long contexts.

Codebase

Entire codebase is available on github.


Written By

Sagar Sarkale