Notebook content¶
- Equation (1): Patch embedding + class token + positional embeddings →
z₀ - Equations (2)–(3): Transformer encoder blocks (MSA + MLP) with pre-LN and residuals
- Equation (4): Final LayerNorm on the class token and a linear classification head
- Printed shapes at each stage to match the math
| Equation | Component | This notebook |
|---|---|---|
| (1) | Patch embedding (image → patches → embeddings) | Unfold + Linear, class token, positional embeddings |
| (2) | Multi-Head Self-Attention (MSA) | MultiHeadSelfAttention |
| (3) | Feed-forward network (MLP) | MLP |
| (2)+(3) | Transformer encoder layer | ViTEncoderLayer
| (4) | Final LayerNorm and classification head | LayerNorm + Linear |
| Symbol | Meaning | Typical Value (ViT-Base) |
|---|---|---|
| B | Batch size — number of images processed together | 1, 8, 32, … |
| H, W | Image height and width (input) | 224 × 224 |
| P | Patch size (each side length in pixels) | 16 |
| C | Number of input channels | 3 (RGB) |
| N | Number of patches per image | (H / P) × (W / P) = 14 × 14 = 196 |
| D | Embedding dimension (hidden size) | 768 |
| h or H_attn | Number of attention heads | 12 |
| Dₕ | Dimension per head | D / h = 768 / 12 = 64 |
Mini-implementation¶
In [1]:
%pip install torch --quiet
DEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 24.1 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063 Note: you may need to restart the kernel to use updated packages.
In [2]:
import torch
import torch.nn as nn
import math
Eq. (2) — Multi-Head Self-Attention (MSA) sublayer¶
In [3]:
class MultiHeadSelfAttention(nn.Module):
"""Multi-head self-attention (Eq. 2)"""
def __init__(self, dim, num_heads=12, attn_drop=0.0, proj_drop=0.0):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = 1.0 / math.sqrt(self.head_dim)
self.qkv = nn.Linear(dim, dim * 3)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x): # x: [B, N, D]
B, N, D = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, H, N, Dh]
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale # [B, H, N, N]
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
out = attn @ v # [B, H, N, Dh]
out = out.transpose(1, 2).reshape(B, N, D) # [B, N, D]
out = self.proj(out)
out = self.proj_drop(out)
return out
Eq. (3) — MLP sublayer¶
In [4]:
class MLP(nn.Module):
"""Two-layer feedforward block (Eq. 3 part).
"""
def __init__(self, dim, mlp_ratio=4.0, act_layer=nn.GELU, drop=0.0):
super().__init__()
hidden = int(dim * mlp_ratio)
self.fc1 = nn.Linear(dim, hidden)
self.act = act_layer()
self.fc2 = nn.Linear(hidden, dim)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
Eqs. (2) & (3) — Transformer encoder layer (pre-LN + residual)¶
In [5]:
class ViTEncoderLayer(nn.Module):
"""One transformer encoder block: LN→MSA→residual; LN→MLP→residual."""
def __init__(self, dim, num_heads=12, mlp_ratio=4.0, drop=0.0):
super().__init__()
self.ln1 = nn.LayerNorm(dim) # LN in Eq. (2)
self.msa = MultiHeadSelfAttention(dim, num_heads, proj_drop=drop)
self.ln2 = nn.LayerNorm(dim) # LN in Eq. (3)
self.mlp = MLP(dim, mlp_ratio, drop=drop)
def forward(self, x): # x: [B, N+1, D]
x = x + self.msa(self.ln1(x)) # Eq. (2): z'_ℓ
x = x + self.mlp(self.ln2(x)) # Eq. (3): z_ℓ
return x
Full MiniViT — mapping Equations (1)–(4)¶
In [6]:
class MiniViT(nn.Module):
"""Minimal ViT mapping directly to Equations (1)–(4)."""
def __init__(self, image_size=224, patch_size=16, in_chans=3,
embed_dim=192, depth=2, num_heads=3, num_classes=10):
super().__init__()
assert image_size % patch_size == 0
self.H = self.W = image_size // patch_size
self.num_patches = self.H * self.W
# Eq. (1): patch projection (x_p^i E), plus class token & positional embeddings
patch_dim = patch_size * patch_size * in_chans # P^2 * C
self.proj = nn.Linear(patch_dim, embed_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.num_patches, embed_dim))
self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
# Eqs. (2)–(3): stack of encoder layers
self.blocks = nn.ModuleList([
ViTEncoderLayer(embed_dim, num_heads=num_heads) for _ in range(depth)
])
# Eq. (4): final LN on class token, then a linear classifier
self.ln = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
# lightweight init
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
def forward(self, x):
B = x.size(0)
# ---- Eq. (1): construct z0 ----
patches = self.unfold(x).transpose(1, 2) # [B, N, P^2*C]
tokens = self.proj(patches) # [B, N, D]
cls = self.cls_token.expand(B, -1, -1) # [B, 1, D]
z0 = torch.cat([cls, tokens], dim=1) + self.pos_embed
print("Eq.(1) z0:", z0.shape)
# ---- Eqs. (2) & (3): transformer encoder ----
z = z0
for i, blk in enumerate(self.blocks, 1):
z = blk(z)
print(f"After layer {i}: {z.shape}")
# ---- Eq. (4): y = LN(z_L^0) ----
cls_final = z[:, 0] # [B, D]
y = self.ln(cls_final) # [B, D]
print("Eq.(4) y (pre-head):", y.shape)
logits = self.head(y) # [B, num_classes]
return logits
Run a quick forward pass and print the shapes¶
In [7]:
# Create dummy input and run the model
x = torch.randn(1, 3, 224, 224) # (B, C, H, W)
model = MiniViT(image_size=224, patch_size=16, embed_dim=192, depth=2, num_heads=3, num_classes=10)
logits = model(x)
print("Logits:", logits.shape)
Eq.(1) z0: torch.Size([1, 197, 192]) After layer 1: torch.Size([1, 197, 192]) After layer 2: torch.Size([1, 197, 192]) Eq.(4) y (pre-head): torch.Size([1, 192]) Logits: torch.Size([1, 10])