Notebook content¶

  • Equation (1): Patch embedding + class token + positional embeddings → z₀
  • Equations (2)–(3): Transformer encoder blocks (MSA + MLP) with pre-LN and residuals
  • Equation (4): Final LayerNorm on the class token and a linear classification head
  • Printed shapes at each stage to match the math

Equation components breakdown

Equation Component This notebook
(1) Patch embedding (image → patches → embeddings) Unfold + Linear, class token, positional embeddings
(2) Multi-Head Self-Attention (MSA) MultiHeadSelfAttention
(3) Feed-forward network (MLP) MLP

| (2)+(3) | Transformer encoder layer | ViTEncoderLayer
| (4) | Final LayerNorm and classification head | LayerNorm + Linear |

Dimension symbol meanings and typical values in ViT-Base

Symbol Meaning Typical Value (ViT-Base)
B Batch size — number of images processed together 1, 8, 32, …
H, W Image height and width (input) 224 × 224
P Patch size (each side length in pixels) 16
C Number of input channels 3 (RGB)
N Number of patches per image (H / P) × (W / P) = 14 × 14 = 196
D Embedding dimension (hidden size) 768
h or H_attn Number of attention heads 12
Dₕ Dimension per head D / h = 768 / 12 = 64

Mini-implementation¶

In [1]:
%pip install torch --quiet
DEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 24.1 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063
Note: you may need to restart the kernel to use updated packages.
In [2]:
import torch
import torch.nn as nn
import math

Eq. (2) — Multi-Head Self-Attention (MSA) sublayer¶

In [3]:
class MultiHeadSelfAttention(nn.Module):
    """Multi-head self-attention (Eq. 2)"""
    def __init__(self, dim, num_heads=12, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = 1.0 / math.sqrt(self.head_dim)

        self.qkv = nn.Linear(dim, dim * 3)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):  # x: [B, N, D]
        B, N, D = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, B, H, N, Dh]
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale          # [B, H, N, N]
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        out = attn @ v                                         # [B, H, N, Dh]
        out = out.transpose(1, 2).reshape(B, N, D)             # [B, N, D]
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

Eq. (3) — MLP sublayer¶

In [4]:
class MLP(nn.Module):
    """Two-layer feedforward block (Eq. 3 part).
    """
    def __init__(self, dim, mlp_ratio=4.0, act_layer=nn.GELU, drop=0.0):
        super().__init__()
        hidden = int(dim * mlp_ratio)
        self.fc1 = nn.Linear(dim, hidden)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden, dim)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

Eqs. (2) & (3) — Transformer encoder layer (pre-LN + residual)¶

In [5]:
class ViTEncoderLayer(nn.Module):
    """One transformer encoder block: LN→MSA→residual; LN→MLP→residual."""
    def __init__(self, dim, num_heads=12, mlp_ratio=4.0, drop=0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)   # LN in Eq. (2)
        self.msa = MultiHeadSelfAttention(dim, num_heads, proj_drop=drop)
        self.ln2 = nn.LayerNorm(dim)   # LN in Eq. (3)
        self.mlp = MLP(dim, mlp_ratio, drop=drop)

    def forward(self, x):               # x: [B, N+1, D]
        x = x + self.msa(self.ln1(x))   # Eq. (2): z'_ℓ
        x = x + self.mlp(self.ln2(x))   # Eq. (3): z_ℓ
        return x

Full MiniViT — mapping Equations (1)–(4)¶

In [6]:
class MiniViT(nn.Module):
    """Minimal ViT mapping directly to Equations (1)–(4)."""
    def __init__(self, image_size=224, patch_size=16, in_chans=3,
                 embed_dim=192, depth=2, num_heads=3, num_classes=10):
        super().__init__()
        assert image_size % patch_size == 0
        self.H = self.W = image_size // patch_size
        self.num_patches = self.H * self.W

        # Eq. (1): patch projection (x_p^i E), plus class token & positional embeddings
        patch_dim = patch_size * patch_size * in_chans  # P^2 * C
        self.proj = nn.Linear(patch_dim, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.num_patches, embed_dim))
        self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)

        # Eqs. (2)–(3): stack of encoder layers
        self.blocks = nn.ModuleList([
            ViTEncoderLayer(embed_dim, num_heads=num_heads) for _ in range(depth)
        ])

        # Eq. (4): final LN on class token, then a linear classifier
        self.ln = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # lightweight init
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        B = x.size(0)

        # ---- Eq. (1): construct z0 ----
        patches = self.unfold(x).transpose(1, 2)         # [B, N, P^2*C]
        tokens = self.proj(patches)                      # [B, N, D]
        cls = self.cls_token.expand(B, -1, -1)          # [B, 1, D]
        z0 = torch.cat([cls, tokens], dim=1) + self.pos_embed
        print("Eq.(1) z0:", z0.shape)

        # ---- Eqs. (2) & (3): transformer encoder ----
        z = z0
        for i, blk in enumerate(self.blocks, 1):
            z = blk(z)
            print(f"After layer {i}: {z.shape}")

        # ---- Eq. (4): y = LN(z_L^0) ----
        cls_final = z[:, 0]                # [B, D]
        y = self.ln(cls_final)             # [B, D]
        print("Eq.(4) y (pre-head):", y.shape)
        logits = self.head(y)              # [B, num_classes]
        return logits

Run a quick forward pass and print the shapes¶

In [7]:
# Create dummy input and run the model
x = torch.randn(1, 3, 224, 224) # (B, C, H, W)
model = MiniViT(image_size=224, patch_size=16, embed_dim=192, depth=2, num_heads=3, num_classes=10)
logits = model(x)
print("Logits:", logits.shape)
Eq.(1) z0: torch.Size([1, 197, 192])
After layer 1: torch.Size([1, 197, 192])
After layer 2: torch.Size([1, 197, 192])
Eq.(4) y (pre-head): torch.Size([1, 192])
Logits: torch.Size([1, 10])