Vision Transformer (ViT) Notes
This post focuses on understanding the core idea and model architecture of the Vision Transformer (ViT).
Before ViT, convolutional architectures remained dominant in computer vision. Classic Convolutional Neural Networks (CNNs) such as VGG and especially ResNet, with its residual skip connections, were state of the art. In the original ViT paper (Dosovitskiy et al., 2021, An image is worth 16 x 16 words: transformers for image recognition at scale), the authors applied transformers to image classification tasks and found that, when trained on large datasets, ViT can match or even surpass the performance of top convolutional networks.
Core idea
From words to patches
ViT processes sequences of image patches just as transformers process sequences of word tokens in natural language processing (NLP). In NLP, a transformer takes a sequence of word embeddings as input and learns contextual relationships among them through self-attention. Similarly, ViT splits an image into fixed-size patches, flattens each patch, linearly embeds the patches into vectors, and feeds the resulting sequence of patch embeddings into a standard transformer encoder.
Besides this words vs patches analogy, ViT also mirrors how transformers are trained in NLP. Dosovitskiy et al. (2021) note that:
The dominant approach [in NLP] is to pre-train on a large text corpus and then fine-tune on a smaller task-specific dataset.
Large Transformer-based models are often pre-trained on large corpora and then fine-tuned for the task at hand: BERT uses a denoising self-supervised pre-training task, while the GPT line of work uses language modeling as its pre-training task.
ViT follows a similar pattern: it excels when first trained on large image datasets and then fine-tuned for smaller, domain-specific tasks. We will get to this point later in the post when discussing why ViT benefits so much from large-scale pretraining. For a deeper look at pretraining in large language models, see my previous post.
CNNs vs. ViT
In machine learning, inductive bias refers to the built-in assumptions a model makes about data before seeing any examples. Dosovitskiy et al. (2021) noted that ViT has much less image-specific inductive bias than CNNs, quote:
… In CNNs, locality, two-dimensional neighborhood structure, and translation equivariance are baked into each layer throughout the whole model. In ViT, only MLP layers are local and translationally equivariant, while the self-attention layers are global.
The table below summarizes how CNNs and ViT differ in their built-in assumptions about images.
| Inductive biases | CNNs | ViTs |
|---|---|---|
| Locality | Convolution filters look at local pixel neighborhoods (small receptive fields). | No built-in locality; all patches can attend to each other globally. |
| 2D neighborhood | Operates directly on 2D image grids; spatial adjacency is hard-coded. | Flattened into 1D sequence of patches; 2D structure not explicit. |
| Translation equivariance | A shifted input produces shifted feature maps (invariance to object position). | No inherent equivariance; positional embeddings must be learned. |
Why ViT excels with large datasets
Because ViT does not come with strong image-specific inductive biases, it treats all patches equally and relies on self-attention to decide which regions of an image are important. This explains why ViT performs worse on small datasets but excels on large ones.
When the training dataset is small, CNNs often outperform ViTs because their built-in assumptions help them learn meaningful patterns even with limited data. In contrast, ViT has to learn locality, edges, textures, and other visual structures from scratch, making it more data-hungry and less efficient at small scales. With large datasets, however, ViT can learn spatial and structural patterns directly from data rather than relying on hard-coded assumptions. Its global self-attention becomes a major advantage, enabling the model to capture long-range dependencies across the entire image and ultimately surpass CNNs.
Architecture
In this section, we walk through the ViT model architecture using Figure 1 and Equations (1)-(4) from Dosovitskiy et al. (2021). Figure 1 (below) is an annotated version of the original figure from the paper, showing how an image is divided into patches, embedded, and processed through the transformer encoder. Google Research blog also provides an animated visualization that illustrates the ViT workflow.
Here is how ViT is defined mathematically:
\[\begin{aligned} \mathbf{z}_0 &= \big[\, x_{\text{class}};\; x_{p}^{1}\,\mathbf{E};\; x_{p}^{2}\,\mathbf{E};\; \dots;\; x_{p}^{N}\,\mathbf{E} \,\big] + E_{\text{pos}}, \quad \mathbf{E}\in \mathbb{R}^{(P^{2}\!\cdot C)\times D},\; E_{\text{pos}}\in \mathbb{R}^{(N+1)\times D} \quad &(1) \\ \mathbf{z}'_{\ell} &= \mathrm{MSA}\!\left(\mathrm{LN}(\mathbf{z}_{\ell-1})\right) + \mathbf{z}_{\ell-1}, \quad \ell=1,\dots,L \quad &(2) \\ \mathbf{z}_{\ell} &= \mathrm{MLP}\!\left(\mathrm{LN}(\mathbf{z}'_{\ell})\right) + \mathbf{z}'_{\ell}, \quad \ell=1,\dots,L \quad &(3) \\ \mathbf{y} &= \mathrm{LN}\!\left(\mathbf{z}^{\,0}_{L}\right) \quad &(4) \end{aligned}\]Equation (1): embedding
Equation (1) corresponds to the embedding stage in Figure 1. The input image is represented as \(x \in \mathbb{R}^{H \times W \times C}\), where \(H\), \(W\), and \(C\) denote height, width, and number of channels. The image is divided into \(N\) patches, \(x_p \in \mathbb{R}^{N \times (P^2 \cdot C)}\), where \((P, P)\) represents the resolution of each image patch. Each patch vector is multiplied by a learnable projection matrix \(E\) to get a \(D\)-dimensional embedding. A special classification token \(x{\text{class}}\) is prepended - its embedding will later represent the whole image. A positional embedding \(E_{\text{pos}}\) is also added to retain spatial information. These form the input sequence \(\mathbf{z}_0\) that will be feeded to the transformer.
Understanding constant \(D\)
Transformers are designed to process sequences where every token has the same dimensionality across all layers. This uniformity makes stacking layers simple: each layer takes a sequence of vectors of shape \([N \times D]\) and outputs another sequence \([N \times D]\). Here, \(D\) indicates the embedding dimension (sometimes called the hidden size). Every patch embedding and every intermediate vector inside the transformer has size \(D\).
\(D\) is a hyperparameter. With enough data, a larger \(D\) usually means higher model capacity and better performance — but it also requires more compute and memory. In a nutshell, \(D\) is a design choice controlling how much information each token can represent and how big the transformer is.
Understanding the trainable linear layer \(\mathbf{E}\)
A raw image patch of size \(P \times P \times C\) is flattened into a vector, but the transformer expects tokens of dimension \(D\). So a trainable linear layer (denoted as \(\mathbf{E}\)) maps flattened patches into the \(D\)-dimensional embedding space. After this step, all tokens (patches + [CLS]) live in the same latent space.
Let’s walk through the logic using the ViT-Base configuration. Suppose the image size is \(224 \times 224 \times 3\), patch size is \(16 \times 16\), and embedding dimension \(D = 768\).
-
Split into patches.
Each patch = \(16 \times 16 \times 3 = 768\) raw pixel values.
Total number of patches = \((224 / 16)^2 = 196\). -
Flatten each patch.
Flatten \(16 \times 16 \times 3\) into a vector of length 768.
Before projection, each patch is a 768-dimensional vector of raw pixel intensities. -
Project to the embedding dimension \(D\).
Note: although the patch length (768) and \(D\) (768) are identical in this example (ViT-Base), they are not the same thing.
The flattened patch values are raw pixel values, while \(D\) represents the dimension of learned embeddings that the transformer requires in a consistent latent space.
Even when \(P \times P \times C = D\), we still need a linear projection — it’s not about matching dimensions, it’s about learning a mapping from raw pixels → latent embedding space.ViT applies a trainable linear layer \(\mathbf{E}\):
\[\mathbf{e} = \mathbf{W} \mathbf{x} + \mathbf{b}\]where:
- \(\mathbf{x}\) = flattened patch (size \(P^2 C\), here 768)
- \(\mathbf{W}\) = learnable weight matrix (size \((P^2 C) \times D\))
- \(\mathbf{b}\) = bias term (size \(D\))
- \(\mathbf{e}\) = patch embedding (size \(D\))
-
Build the sequence.
Each of the 196 patches is now projected to a 768-dimensional embedding vector.
Add one learnable [CLS] token (also 768-dimensional). The sequence length becomes 197, each of size 768 — so the input to the transformer is a \([197 \times 768]\) matrix. -
Transformer layers.
Each transformer layer takes a \([197 \times 768]\) input and outputs another \([197 \times 768]\) matrix.
Equation (2): multi-head self-attention (MSA)
Equation (2) represents the self-attention sub-layer in each transformer block. Inside each transformer block, the input passes through layer normalization (LN) and then multi-head self-attention (MSA): \(\mathrm{MSA}\!\left(\mathrm{LN}(\mathbf{z}_{\ell-1})\right)\). Each token (patch embedding) attends to all others globally, capturing relationships across the entire image. The residual connection \(\mathbf{z}_{\ell-1}\) (skip arrow in Figure 1) helps stabilize gradient flow during training.
Equation (3): feed-forward (MLP)
Equation (3) describes the feed-forward network (or MLP) that follows self-attention. After the attention step, the output is normalized again and passed through a two-layer MLP (with a GELU activation in between): \(\mathrm{MLP}(x) = W_2 \, \mathrm{GELU}(W_1 x + b_1) + b_2\). Another residual connection adds the input back to the output.
Together, Equations (2) and (3) make up a transformer block - stacked \(L\) times as shown in Figure 1.
Equation (4): classification head
Equation (4) corresponds to the output head of the model. After L transformer blocks, the output (i.e., the class token embedding \(\mathbf{z}^{0}_{L}\) that represents the entire image) is extracted. \(\mathbf{z}^{0}_{L}\) is layer-normalized and used for classification through a linear classifier. The linear classifier in ViT can be viewed as a standard logistic regression layer that takes the final class token embedding as input and outputs logits for each class.
Implementation
In this section, we demonstrate a jupyter notebook that maps the core ViT equations to a minimal PyTorch implementation. It follows the paper’s equations (1) to (4) and mirrors the structure of open-source ViT implementations (Google Research, timm, Hugging Face).
▶️ Notebook: ViT — From Equations to Code (click to expand)
References
-
Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., & Houlsby, N. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. Paper
-
Google Research ViT: Repo
-
timm: PyTorch Image Models (Ross Wightman) Vision Transformer source: Code
-
Hugging Face Transformers — Vision Transformer (ViT): Documentation, code