Back to foundations Foundation
Last updated: Apr 1, 2026 ~40 min intermediate

Attention Is All You Need to Implement

Part 1 of 4: Scaled Dot-Product & Multi-Head Attention

TL;DR: Attention is differentiable retrieval — every token computes a weighted combination of all other tokens, with weights learned from the data. This article derives scaled dot-product attention from first principles (including the variance proof for why we scale by dk\sqrt{d_k}), builds multi-head attention with explicit shape annotations at every step, implements a causal mask for autoregressive decoding, and adds KV-cache for efficient inference. Full tested implementation at rlvr-from-scratch.

Prerequisites: Basic PyTorch (tensors, nn.Module, nn.Linear). Linear algebra (matrix multiplication, transpose).


The Problem With Sequences

If you’re building a sequence model and your sequences are long, you have a latency problem.

Recurrent networks process tokens one at a time. Information from the first token has to travel through every intermediate hidden state to reach the last token — O(n)O(n) sequential operations. Double the sequence length, double the latency. For a 4,096-token context, that’s 4,096 serial steps before the last token knows anything about the first.

This is not just slow. It’s a fundamental architectural bottleneck. The path length between any two tokens grows linearly with distance. Long-range dependencies have to survive compression through hundreds of hidden states, each one lossy.

Attention solves both problems at once:

PropertyRNNAttention
Path length between any two tokensO(n)O(n)O(1)O(1)
Sequential operationsO(n)O(n)O(1)O(1)
Computation per layerO(nd2)O(n \cdot d^2)O(n2d)O(n^2 \cdot d)
Parallelizable

The tradeoff is quadratic computation in sequence length (n2n^2) versus linear in the RNN case. But n2n^2 parallel operations on a GPU are faster than nn sequential operations. For modern hardware, attention wins.

That’s the engineering reason attention became the dominant sequence modeling primitive. Not elegance, not novelty — parallelism and direct information flow.

This is Part 1 of the Transformer Internals series, where I build a complete transformer from scratch in PyTorch, equation by equation, with tests at every layer. The complete implementation lives at rlvr-from-scratch.


Queries, Keys, and Values

The analogy everyone uses is a database. It’s imperfect, but useful as a starting point.

Imagine a key-value store. You have a query — the thing you’re searching for. Each entry has a key — a descriptor of what it contains. And each entry has a value — the actual content returned when matched.

In attention:

  • Query (Q): “What am I looking for?” — derived from the current token
  • Key (K): “What do I contain?” — derived from every token in the sequence
  • Value (V): “What information do I return if selected?” — also derived from every token

Now break the analogy. A database lookup is hard — you match one key exactly and get one value back. Attention is soft — every key contributes to the output, weighted by how well it matches the query. There’s no binary match/no-match. You get a weighted combination of all values, where the weights reflect relevance.

This is the fundamental insight: attention is differentiable retrieval.

The Linear Projections

We start with an input XRB×T×dmodelX \in \mathbb{R}^{B \times T \times d_\text{model}} — a batch of sequences, where each token is a dmodeld_\text{model}-dimensional vector. We learn three separate projection matrices:

Q=XWQ,K=XWK,V=XWVQ = XW^Q, \quad K = XW^K, \quad V = XW^V

where WQ,WKRdmodel×dkW^Q, W^K \in \mathbb{R}^{d_\text{model} \times d_k} and WVRdmodel×dvW^V \in \mathbb{R}^{d_\text{model} \times d_v}.

Why three separate projections? Because what makes a token a good search target (its key) is not the same as what information it should contribute when found (its value), and neither is the same as what the current token is searching for (its query). The model learns to decouple these three roles.

In practice, dk=dv=dmodel/Hd_k = d_v = d_\text{model} / H where HH is the number of attention heads. More on that later.

import torch
import torch.nn as nn

d_model = 512
d_k = 64  # query/key dimension (typically d_model / n_heads)
d_v = 64  # value dimension

W_Q = nn.Linear(d_model, d_k, bias=False)
W_K = nn.Linear(d_model, d_k, bias=False)
W_V = nn.Linear(d_model, d_v, bias=False)

# X: (B, T, d_model)
Q = W_Q(X)  # (B, T, d_model) @ (d_model, d_k) -> (B, T, d_k)
K = W_K(X)  # (B, T, d_model) @ (d_model, d_k) -> (B, T, d_k)
V = W_V(X)  # (B, T, d_model) @ (d_model, d_v) -> (B, T, d_v)

Each token in the sequence now has three representations — one for each role in the retrieval process.

Key Insight: Q, K, V are not three different inputs. They are three learned views of the same input, each optimized for a different role. The model learns what to search for, what to advertise, and what to return — independently.


Scaled Dot-Product Attention

This is the core operation. Five steps, each with a clear mathematical purpose.

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

Let me break this apart.

Step 1: Score Every Pair of Tokens

Compute the dot product between every query and every key:

scores=QKT\text{scores} = QK^T

Shape: (B,Tq,dk)×(B,dk,Tk)(B,Tq,Tk)\text{Shape: } (B, T_q, d_k) \times (B, d_k, T_k) \rightarrow (B, T_q, T_k)

The result is a Tq×TkT_q \times T_k matrix where entry (i,j)(i, j) measures how much token ii‘s query aligns with token jj‘s key. High value means high relevance. For self-attention, Tq=Tk=TT_q = T_k = T, and you get a T×TT \times T attention matrix — every token scored against every other token.

# Q: (B, T_q, d_k), K: (B, T_k, d_k)
scores = torch.matmul(Q, K.transpose(-2, -1))  # (B, T_q, T_k)

This is where the O(n2)O(n^2) cost comes from. For a 4,096-token sequence, this matrix has ~16.7 million entries per batch element. That’s the price of letting every token see every other token directly.

Step 2: Scale by dk\sqrt{d_k}

scores=QKTdk\text{scores} = \frac{QK^T}{\sqrt{d_k}}

This is not cosmetic. Let me derive why it’s necessary.

Assume qq and kk are random vectors in Rdk\mathbb{R}^{d_k} with entries independently drawn from N(0,1)\mathcal{N}(0, 1). Their dot product is:

qk=i=1dkqikiq \cdot k = \sum_{i=1}^{d_k} q_i k_i

Each term qikiq_i k_i is the product of two independent standard normals. The product of two independent N(0,1)\mathcal{N}(0, 1) variables has:

E[qiki]=E[qi]E[ki]=0×0=0\mathbb{E}[q_i k_i] = \mathbb{E}[q_i]\mathbb{E}[k_i] = 0 \times 0 = 0

Var(qiki)=E[qi2ki2](E[qiki])2=E[qi2]E[ki2]0=11=1\text{Var}(q_i k_i) = \mathbb{E}[q_i^2 k_i^2] - (\mathbb{E}[q_i k_i])^2 = \mathbb{E}[q_i^2]\mathbb{E}[k_i^2] - 0 = 1 \cdot 1 = 1

Since the dkd_k terms are independent:

Var(qk)=i=1dkVar(qiki)=dk\text{Var}(q \cdot k) = \sum_{i=1}^{d_k} \text{Var}(q_i k_i) = d_k

So qkq \cdot k has mean 00 and standard deviation dk\sqrt{d_k}.

What this means in practice: When dk=512d_k = 512, dot products have standard deviation 22.6\approx 22.6. Feed values this large into softmax and you get outputs that are essentially one-hot — one position gets weight 1.0\approx 1.0, everything else 0.0\approx 0.0.

Why is this a problem? The gradient of softmax at saturation is near zero. If softmax(z)i1\text{softmax}(z)_i \approx 1, then softmax(z)izj0\frac{\partial \text{softmax}(z)_i}{\partial z_j} \approx 0 for all jj. The model can’t learn which tokens to attend to because the gradients vanish.

Dividing by dk\sqrt{d_k} normalizes the variance back to 11:

Var(qkdk)=Var(qk)dk=dkdk=1\text{Var}\left(\frac{q \cdot k}{\sqrt{d_k}}\right) = \frac{\text{Var}(q \cdot k)}{d_k} = \frac{d_k}{d_k} = 1

Now softmax operates in a regime where it produces meaningful, non-degenerate distributions with healthy gradients.

import math

d_k = Q.size(-1)
scores = scores / math.sqrt(d_k)  # (B, T_q, T_k)

Key Insight: The dk\sqrt{d_k} scaling is not a hyperparameter to tune — it’s derived directly from the variance of dot products. Without it, softmax saturates, gradients vanish, and the model cannot learn attention patterns. This is one of those cases where the math isn’t optional.

Step 3: Apply Mask (Optional)

For autoregressive models, we add a causal mask. For now, I’ll treat this as a simple addition — the next section covers masking in depth.

if mask is not None:
    scores = scores + mask  # additive: 0.0 = allowed, -inf = blocked

Step 4: Softmax

weights=softmax(scores,dim=1)\text{weights} = \text{softmax}(\text{scores}, \text{dim}=-1)

Shape: (B,Tq,Tk)(B,Tq,Tk)\text{Shape: } (B, T_q, T_k) \rightarrow (B, T_q, T_k)

Each row is now a probability distribution over the key positions. Weights are non-negative and sum to 1 along the last dimension. Token ii‘s row tells you exactly how much attention it pays to every other token.

weights = torch.softmax(scores, dim=-1)  # (B, T_q, T_k) — each row sums to 1

Step 5: Weighted Sum of Values

output=weights×V\text{output} = \text{weights} \times V

Shape: (B,Tq,Tk)×(B,Tk,dv)(B,Tq,dv)\text{Shape: } (B, T_q, T_k) \times (B, T_k, d_v) \rightarrow (B, T_q, d_v)

Each token’s output is a weighted combination of all value vectors. If token ii attends strongly to token jj, then token jj‘s value contributes heavily to token ii‘s output.

output = torch.matmul(weights, V)  # (B, T_q, d_v)

The Complete Function

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Scaled dot-product attention.

    Args:
        Q: Query tensor  (B, H, T_q, d_k)
        K: Key tensor    (B, H, T_k, d_k)
        V: Value tensor  (B, H, T_k, d_v)
        mask: Additive mask (B|1, 1|H, T_q, T_k)
              0.0 = allowed, -inf = blocked

    Returns:
        output:  (B, H, T_q, d_v)
        weights: (B, H, T_q, T_k)
    """
    d_k = Q.size(-1)

    # =========================================
    # 1. Score: how much does each query match each key?
    # =========================================
    # (B, H, T_q, d_k) @ (B, H, d_k, T_k) -> (B, H, T_q, T_k)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    # =========================================
    # 2. Mask: block positions that shouldn't be attended to
    # =========================================
    if mask is not None:
        scores = scores + mask

    # =========================================
    # 3. Normalize: convert scores to probabilities
    # =========================================
    # (B, H, T_q, T_k) — each row sums to 1
    weights = torch.softmax(scores, dim=-1)

    # =========================================
    # 4. Aggregate: weighted sum of values
    # =========================================
    # (B, H, T_q, T_k) @ (B, H, T_k, d_v) -> (B, H, T_q, d_v)
    output = torch.matmul(weights, V)

    return output, weights

Note the H dimension — in practice, attention always runs inside multi-head attention, batched over both the batch dimension and heads.

Key Insight: Attention is three matrix multiplies and a softmax. That’s it. QKTQK^T computes relevance, dk\sqrt{d_k} keeps gradients alive, softmax normalizes, and the result retrieves from VV. Everything else — masking, multiple heads, caching — is engineering on top of this core.


The Causal Mask

Why Masking Matters

In autoregressive (decoder) models, token ii must only attend to tokens i\leq i. During generation, future tokens don’t exist yet — looking at them would be cheating.

Without masking, the model sees the answer while trying to predict it. Training would optimize for a trivial copy operation rather than learning to predict.

Additive Masking

Two conventions exist:

ConventionOperationProperties
Booleanscores[mask] = -infIn-place mutation, requires boolean tensor
Additivescores = scores + maskPure addition, composable, broadcastable

We use additive masking. The mask tensor contains 0.0 for allowed positions and -inf for blocked positions.

After softmax, e=0e^{-\infty} = 0 — blocked positions get exactly zero attention weight.

position →  0     1     2     3
token 0  [ 0.0, -inf, -inf, -inf ]   ← can only see itself
token 1  [ 0.0,  0.0, -inf, -inf ]   ← sees token 0 and itself
token 2  [ 0.0,  0.0,  0.0, -inf ]   ← sees 0, 1, and itself
token 3  [ 0.0,  0.0,  0.0,  0.0 ]   ← sees everything up to itself

Why additive over boolean?

1. Pure operation — no in-place mutation, cleaner for autograd

2. Composable — multiple masks can be summed together (e.g., causal + padding)

3. Broadcastable — shape (1, 1, T, T) works across any batch and head count

def causal_mask(T: int, device=None):
    """
    Create additive causal mask.

    Returns:
        mask: (1, 1, T, T) — 0.0 for allowed, -inf for blocked.
               Broadcastable over batch and heads.
    """
    # Upper triangle (above diagonal) = True = blocked
    mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
    return mask.float().masked_fill(mask, float("-inf")).unsqueeze(0).unsqueeze(0)

The unsqueeze(0).unsqueeze(0) adds batch and head dimensions for broadcasting: (T, T)(1, 1, T, T).

Key Insight: The causal mask is not a separate mechanism from attention — it’s just an additive bias on the score matrix. Future positions get -\infty, softmax converts that to 00, and those tokens contribute nothing. Masking and attention are the same computation.


Multi-Head Attention

Why Multiple Heads?

A single attention head computes one set of weights — one notion of “relevance” between tokens. But tokens relate to each other in multiple ways simultaneously.

Consider: “The cat sat on the mat because it was tired.”

  • One head might learn coreference: “it” attends to “cat”
  • Another might learn local context: “it” attends to nearby tokens
  • Another might learn semantic roles: “tired” attends to “sat”

A single head can only learn one of these patterns per layer. Multiple heads learn them in parallel.

The Math

With HH heads and model dimension dmodeld_\text{model}, each head operates on dimension dk=dmodel/Hd_k = d_\text{model} / H:

MultiHead(Q,K,V)=Concat(head1,,headH)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_H) W^O

where each head is:

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

Projection matrices:

  • WiQRdmodel×dkW_i^Q \in \mathbb{R}^{d_\text{model} \times d_k} — projects into query space for head ii
  • WiKRdmodel×dkW_i^K \in \mathbb{R}^{d_\text{model} \times d_k} — projects into key space for head ii
  • WiVRdmodel×dkW_i^V \in \mathbb{R}^{d_\text{model} \times d_k} — projects into value space for head ii
  • WORdmodel×dmodelW^O \in \mathbb{R}^{d_\text{model} \times d_\text{model}} — projects concatenated heads back

The total parameter count is the same as single-head attention at dimension dmodeld_\text{model}. You’re partitioning the same capacity into parallel subspaces.

Implementation: Reshape, Don’t Loop

The naive approach loops over heads. The efficient approach reshapes.

%%{init: { 'theme':'base', 'themeVariables': { 'primaryColor':'#0b1220', 'primaryTextColor':'#e5e7eb', 'primaryBorderColor':'#10b981', 'lineColor':'#06b6d4', 'secondaryColor':'#0f172a', 'tertiaryColor':'#1e293b', 'fontSize':'11px', 'fontFamily':'monospace' } }}%% graph TB subgraph MHA["<b>Multi-Head Attention: Reshape Path</b>"] direction TB Input["<b>Input</b><br/>(B, T, d_model)"] Input --> ProjQ["<b>W_Q</b><br/>(B, T, d_model)"] Input --> ProjK["<b>W_K</b><br/>(B, T, d_model)"] Input --> ProjV["<b>W_V</b><br/>(B, T, d_model)"] ProjQ --> Split["<b>Reshape + Transpose</b><br/>(B, T, d_model) → (B, H, T, d_k)<br/>━━━━━━━━━━<br/>No data copied — just a view"] ProjK --> Split ProjV --> Split Split --> Attn["<b>Scaled Dot-Product Attention</b><br/>Batched over B and H<br/>(B, H, T_q, d_k)"] Attn --> Merge["<b>Transpose + Reshape</b><br/>(B, H, T, d_k) → (B, T, d_model)<br/>━━━━━━━━━━<br/>.contiguous() required"] Merge --> WO["<b>W_O</b><br/>(B, T, d_model)"] end style Input fill:#1e293b,stroke:#06b6d4,color:#cffafe,stroke-width:2px style ProjQ fill:#1e293b,stroke:#10b981,color:#d1fae5 style ProjK fill:#1e293b,stroke:#10b981,color:#d1fae5 style ProjV fill:#1e293b,stroke:#10b981,color:#d1fae5 style Split fill:#1e293b,stroke:#f59e0b,color:#fde68a,stroke-width:2px style Attn fill:#1e293b,stroke:#8b5cf6,color:#e9d5ff,stroke-width:2.5px style Merge fill:#1e293b,stroke:#f59e0b,color:#fde68a,stroke-width:2px style WO fill:#1e293b,stroke:#10b981,color:#d1fae5 style MHA fill:none,stroke:#334155,color:#94a3b8,stroke-width:1px

The key insight: the “split into heads” is just a reshape. No data is copied. Each head sees a different dkd_k-dimensional slice of the projected representation.

Let me trace the shapes explicitly:

# Concrete example: B=2, T=10, d_model=512, H=8, d_k=64
query = torch.randn(2, 10, 512)      # (B, T, d_model)

# =========================================
# 1. Project to full d_model dimension
# =========================================
Q = W_Q(query)                        # (2, 10, 512) — one Linear layer

# =========================================
# 2. Split into heads: view + transpose
# =========================================
Q = Q.view(2, 10, 8, 64)             # (B, T, H, d_k)
Q = Q.transpose(1, 2)                # (B, H, T, d_k) = (2, 8, 10, 64)

# =========================================
# 3. Attention (batched over B=2 and H=8)
# =========================================
attn_output, weights = scaled_dot_product_attention(Q, K, V, mask)
# attn_output: (2, 8, 10, 64)
# weights:     (2, 8, 10, 10)

# =========================================
# 4. Merge heads: transpose + contiguous + view
# =========================================
attn_output = attn_output.transpose(1, 2)     # (2, 10, 8, 64)
attn_output = attn_output.contiguous()        # required for view()
attn_output = attn_output.view(2, 10, 512)    # (2, 10, 512) = (B, T, d_model)

# =========================================
# 5. Output projection
# =========================================
output = W_O(attn_output)                      # (2, 10, 512)

Why .contiguous()? After transpose(), the tensor’s memory layout is non-contiguous — the strides don’t match what .view() expects. Without .contiguous(), you get a runtime error. This is the kind of thing that costs you 30 minutes of debugging exactly once.

The Full Module

class MultiHeadAttention(nn.Module):
    """
    Multi-head attention with explicit projections.

    No torch.nn.MultiheadAttention — every operation visible.

    Args:
        d_model: Model dimension.
        n_heads: Number of attention heads. Must divide d_model.
        bias: Whether to use bias in projections.
    """

    def __init__(self, d_model: int, n_heads: int, bias: bool = False):
        super().__init__()
        assert d_model % n_heads == 0, (
            f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
        )

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        # =========================================
        # Four learned projections
        # =========================================
        self.W_Q = nn.Linear(d_model, d_model, bias=bias)
        self.W_K = nn.Linear(d_model, d_model, bias=bias)
        self.W_V = nn.Linear(d_model, d_model, bias=bias)
        self.W_O = nn.Linear(d_model, d_model, bias=bias)

    def forward(self, query, key, value, mask=None, kv_cache=None):
        """
        Args:
            query: (B, T_q, d_model)
            key:   (B, T_k, d_model)
            value: (B, T_k, d_model)
            mask:  Additive mask (B|1, 1|H, T_q, T_k)
            kv_cache: Optional (K, V) from previous steps,
                      each (B, H, T_prev, d_k)

        Returns:
            output:       (B, T_q, d_model)
            weights:      (B, H, T_q, T_k)
            new_kv_cache: Updated (K, V) or None
        """
        B, T_q, _ = query.shape

        # =========================================
        # 1. Project
        # =========================================
        Q = self.W_Q(query)   # (B, T_q, d_model)
        K = self.W_K(key)     # (B, T_k, d_model)
        V = self.W_V(value)   # (B, T_k, d_model)

        # =========================================
        # 2. Split heads
        # =========================================
        Q = self._split_heads(Q)   # (B, H, T_q, d_k)
        K = self._split_heads(K)   # (B, H, T_k, d_k)
        V = self._split_heads(V)   # (B, H, T_k, d_k)

        # =========================================
        # 3. KV-cache (for incremental decoding)
        # =========================================
        new_kv_cache = None
        if kv_cache is not None:
            K_prev, V_prev = kv_cache
            K = torch.cat([K_prev, K], dim=2)  # (B, H, T_prev+T_k, d_k)
            V = torch.cat([V_prev, V], dim=2)
            new_kv_cache = (K, V)

        # =========================================
        # 4. Attention
        # =========================================
        attn_output, weights = scaled_dot_product_attention(Q, K, V, mask)

        # =========================================
        # 5. Merge heads + output projection
        # =========================================
        attn_output = self._merge_heads(attn_output)  # (B, T_q, d_model)
        output = self.W_O(attn_output)                # (B, T_q, d_model)

        return output, weights, new_kv_cache

    def _split_heads(self, x):
        """(B, T, d_model) -> (B, H, T, d_k)"""
        B, T, _ = x.shape
        return x.view(B, T, self.n_heads, self.d_k).transpose(1, 2)

    def _merge_heads(self, x):
        """(B, H, T, d_k) -> (B, T, d_model)"""
        B, _, T, _ = x.shape
        return x.transpose(1, 2).contiguous().view(B, T, self.d_model)

Parameter Count

Multi-head attention has exactly four weight matrices:

ParameterShapeCount
WQW^Q(dmodel,dmodel)(d_\text{model}, d_\text{model})dmodel2d_\text{model}^2
WKW^K(dmodel,dmodel)(d_\text{model}, d_\text{model})dmodel2d_\text{model}^2
WVW^V(dmodel,dmodel)(d_\text{model}, d_\text{model})dmodel2d_\text{model}^2
WOW^O(dmodel,dmodel)(d_\text{model}, d_\text{model})dmodel2d_\text{model}^2
Total4dmodel24 \cdot d_\text{model}^2

For dmodel=512d_\text{model} = 512: 4×5122=1,048,5764 \times 512^2 = 1,048,576 parameters. The number of heads doesn’t change this — you’re partitioning the same total dimension.

Key Insight: Multi-head attention doesn’t add parameters compared to single-head attention at the same dimension. It partitions the same capacity into parallel subspaces. The model learns to use each head for a different type of relationship — syntax, semantics, position — without any explicit supervision telling it to do so.


Self-Attention vs Cross-Attention

Same mechanism, two modes:

%%{init: { 'theme':'base', 'themeVariables': { 'primaryColor':'#0b1220', 'primaryTextColor':'#e5e7eb', 'primaryBorderColor':'#10b981', 'lineColor':'#06b6d4', 'secondaryColor':'#0f172a', 'tertiaryColor':'#1e293b', 'fontSize':'11px', 'fontFamily':'monospace' } }}%% graph LR subgraph Self["<b>Self-Attention</b>"] direction TB SX["Input X"] --> SQ["Q = X·W_Q"] SX --> SK["K = X·W_K"] SX --> SV["V = X·W_V"] SQ --> SA["Attention"] SK --> SA SV --> SA end subgraph Cross["<b>Cross-Attention</b>"] direction TB CX["Decoder state"] --> CQ["Q = Dec·W_Q"] CE["Encoder output"] --> CK["K = Enc·W_K"] CE --> CV["V = Enc·W_V"] CQ --> CA["Attention"] CK --> CA CV --> CA end style SX fill:#1e293b,stroke:#10b981,color:#d1fae5 style CE fill:#1e293b,stroke:#8b5cf6,color:#e9d5ff style CX fill:#1e293b,stroke:#06b6d4,color:#cffafe style Self fill:none,stroke:#10b981,color:#d1fae5,stroke-width:1px style Cross fill:none,stroke:#8b5cf6,color:#e9d5ff,stroke-width:1px

Self-attention: Q, K, V all come from the same sequence. Each token attends to every other token in the same input. Used in both encoders and decoders.

# Self-attention: same input for all three
output, weights, _ = mha(x, x, x, mask=causal_mask(T))

Cross-attention: Q comes from one sequence (decoder), K and V come from another (encoder output). The decoder queries the encoder’s representation. Used in encoder-decoder models for translation, summarization, and similar tasks.

# Cross-attention: decoder queries, encoder keys/values
output, weights, _ = mha(decoder_state, encoder_output, encoder_output)

Our MultiHeadAttention handles both — the query, key, value arguments are deliberately separate. For self-attention, pass the same tensor for all three. For cross-attention, pass different tensors.


KV-Cache: Making Generation Fast

The Problem

During training, we process the full sequence at once — one forward pass, all positions in parallel. During generation, we decode one token at a time.

At step tt, the new token needs to attend to all t1t-1 previous tokens plus itself. Without caching, this means recomputing K and V projections for all previous tokens at every step. For a sequence of length TT, total projection computation scales as:

t=1Tt=T(T+1)2=O(T2)\sum_{t=1}^{T} t = \frac{T(T+1)}{2} = O(T^2)

That’s O(T2)O(T^2) work just for the linear projections — before we even get to attention.

The Solution

Cache the K and V tensors. At each step:

  1. Compute K and V for only the new tokenO(1)O(1)
  2. Concatenate with the cached K, V from all previous steps
  3. Compute attention using the full cache but only the new Q
# =========================================
# Step t: process one new token
# =========================================
new_Q = W_Q(new_token)  # (B, H, 1, d_k) — one token
new_K = W_K(new_token)  # (B, H, 1, d_k)
new_V = W_V(new_token)  # (B, H, 1, d_k)

# Append to cache
K_cache = torch.cat([K_cache, new_K], dim=2)  # (B, H, t, d_k)
V_cache = torch.cat([V_cache, new_V], dim=2)  # (B, H, t, d_k)

# Attention: (B, H, 1, d_k) against (B, H, t, d_k)
output, _ = scaled_dot_product_attention(new_Q, K_cache, V_cache)
# output: (B, H, 1, d_k) — one token's representation

Per-step projection cost drops from O(Td2)O(T \cdot d^2) to O(d2)O(d^2). The attention computation itself is still O(Td)O(T \cdot d) per step — you can’t avoid looking at all previous tokens.

The Correctness Invariant

This is the most important property of a KV-cache implementation:

Incremental decoding with KV-cache must produce the exact same output as a full forward pass with a causal mask.

If it doesn’t, your cache is wrong. Token tt‘s output should be identical whether you compute it as part of a full batch or incrementally with cached K, V from steps 00 through t1t-1.

We test this explicitly:

def test_kv_cache_matches_full_pass(mha):
    """Cached incremental decoding must match full-sequence result."""
    mha.eval()
    seq = torch.randn(B, T, D_MODEL)
    mask = causal_mask(T)

    # Full pass (ground truth)
    with torch.no_grad():
        full_output, _, _ = mha(seq, seq, seq, mask=mask)

    # Incremental pass: token by token with KV-cache
    cache = (torch.empty(B, H, 0, D_K), torch.empty(B, H, 0, D_K))
    incremental_outputs = []
    with torch.no_grad():
        for t in range(T):
            token = seq[:, t:t+1, :]  # (B, 1, d_model)
            out, _, cache = mha(token, token, token, kv_cache=cache)
            incremental_outputs.append(out)

    incremental_output = torch.cat(incremental_outputs, dim=1)

    # These must match
    torch.testing.assert_close(full_output, incremental_output, atol=1e-5, rtol=1e-5)

Key Insight: KV-cache trades memory for time. You store all previous keys and values (memory grows linearly with TT) but avoid recomputing them (projection cost per step drops from O(T)O(T) to O(1)O(1)). For long sequences, this is the difference between practical and impractical generation speeds.


Full Implementation

The complete, tested implementation lives at src/rlvr_from_scratch/model/attention.py.

What’s in the module

ComponentWhat it doesParameters
scaled_dot_product_attentionCore: score, scale, mask, softmax, aggregateNone (pure function)
causal_maskPrevents attending to future tokensNone (deterministic)
MultiHeadAttentionProjections, head splitting, attention, mergingWQ,WK,WV,WOW^Q, W^K, W^V, W^O
KV-cache supportIncremental decoding without recomputationNone (caches K, V tensors)

Test Coverage

The test suite at tests/model/test_attention.py covers 24 tests across three categories:

Correctness:

  • Output shapes for all configurations (H=1, 2, 4, 8)
  • Attention weights sum to 1
  • Causal mask blocks all future positions
  • Causal mask allows all past positions and self
  • Cross-attention with different Q/K lengths
  • KV-cache matches full forward pass

Robustness:

  • Numerical stability with large dkd_k (1024)
  • Batch independence (each element processed identically)
  • Determinism (same input → same output)

Training:

  • Gradient flow through Q, K, V
  • Gradient flow through all MHA parameters
  • Invalid configuration raises ValueError

Key Takeaways

The Core Operation

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

Three matrix multiplies and a softmax. Everything else is engineering.

Design Choices

  • Masking convention: Additive (0.0 / -inf) — composable, pure, broadcastable
  • Head splitting: Reshape, not loop — same computation, GPU-friendly
  • Bias in projections: Off by default — modern standard (GPT-2+, LLaMA)
  • KV-cache: Concatenation-based — simple, correct, testable

What’s Next

Attention is permutation equivariant — shuffle the input tokens and you get the same output (modulo the shuffling). The model has no sense of order. Token 0 and token 99 are treated identically.

In Part 2: Positional Encoding, I build sinusoidal, learned, and rotary position embeddings from scratch, derive the rotation matrix formulation of RoPE, and show why RoPE won.

After that, Part 3 assembles the full transformer block (attention + FFN + normalization + residuals), and Part 4 builds the training loop with AdamW and cosine warmup from scratch.


Further Reading

Original Papers:

Pedagogical Resources:

Implementation:

Cite this reference

Sousa, V. (2026). Attention Is All You Need to Implement. vitorsousa.com (Foundation Reference). https://www.vitorsousa.com/foundations//

@article{sousa2026,
  title={Attention Is All You Need to Implement},
  author={Sousa, Vitor},
  year={2026},
  note={Foundation Reference},
  url={https://www.vitorsousa.com/foundations//}
}

Discussion

Found something useful, spotted an error, or want to add context? Comments are powered by GitHub Discussions.