← Writing
mltransformersmath

Understanding Self-Attention: A Visual Walkthrough

A visual and mathematical deep-dive into the self-attention mechanism powering modern transformers.

·2 min read

Self-attention is the operation that gives transformers their ability to relate every token in a sequence to every other token. Let's build intuition from scratch.

The Core Idea

Given an input sequence of tokens, self-attention computes for each token a weighted sum over all other tokens, where the weights reflect relevance. This is fundamentally different from recurrence — all positions are processed in parallel.

Math

Each input vector xiRd\mathbf{x}_i \in \mathbb{R}^d is projected into three spaces:

Q=XWQ,K=XWK,V=XWVQ = XW^Q, \quad K = XW^K, \quad V = XW^V

The attention weights are:

Attention(Q,K,V)=softmax ⁣(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V

The dk\sqrt{d_k} scaling prevents the dot products from growing large and pushing softmax into saturation.

Architecture Diagram

Complexity

A key cost of self-attention is its O(n2)O(n^2) time and memory in sequence length nn. For a sequence of length 1024 with d=512d = 512:

  • Attention matrix: 1024×1024=1M1024 \times 1024 = 1M entries per head
  • With 8 heads: 8M8M entries

This is why long-context models (100k+ tokens) invest heavily in sparse attention variants.

Code Example

Here's a minimal NumPy implementation of scaled dot-product attention:

import numpy as np
 
def softmax(x, axis=-1):
    e = np.exp(x - x.max(axis=axis, keepdims=True))
    return e / e.sum(axis=axis, keepdims=True)
 
def attention(Q, K, V):
    d_k = Q.shape[-1]
    scores = Q @ K.T / np.sqrt(d_k)  # (seq, seq)
    weights = softmax(scores)          # row-wise
    return weights @ V                 # (seq, d_v)

The - x.max() shift in softmax is a numerical stability trick — it doesn't change the output but prevents exp overflow.

Multi-Head Attention

Running hh attention heads in parallel and concatenating their outputs lets the model attend to information from different representation subspaces:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O

where headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW^Q_i, KW^K_i, VW^V_i).

Each head operates on a d/hd/h-dimensional subspace, so total computation stays proportional to a single full-dimensional attention pass.

Further Reading

  • Attention Is All You Need — the original transformer paper
  • Flash Attention: making attention IO-optimal by fusing the softmax and matmul kernels