Skip to content

A beginner-friendly exploration of Vision Transformers (ViT), implemented in PyTorch.

License

Notifications You must be signed in to change notification settings

Bengal1/Simple-ViT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

139 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Simple ViT

This repository presents a showcase of the Vision Tranformer (ViT), highlighting its core architectural ideas and training dynamics. ViTs mark a turning point in computer vision, leveraging self-attention to capture global relationships in images without relying on convolutional filters. The project provides an accessible yet thorough presentation of how ViTs operate, their advantages and limitations, and how they achieve state-of-the-art results on large-scale image recognition tasks. In addition, we include a comparison with Convolutional Neural Networks (CNNs) to illustrate their differences and relative strengths.

For more information about Transformer Model I recommend Simple Transformer.

Requirements

  • Python
  • PyTorch

Vision Transformer

ViT Architecture

The Vision Transformer (ViT) is a deep learning architecture that adapts the Transformer, originally developed for natural language processing, to image recognition tasks. Introduced by Dosovitskiy et al. in “An Image is Worth 16x16 Words” (2020), ViT replaces traditional convolutional feature extractors with a sequence of image patches processed by self-attention. This approach demonstrated that, with sufficient data and compute, Transformers can outperform convolutional neural networks (CNNs) in computer vision benchmarks, paving the way for a broad family of vision transformer models.

In practice, ViT transforms an image into a sequence of smaller patches, which are then processed using the same self-attention mechanism that made Transformers successful in language tasks. By modeling relationships between patches directly, ViT captures both local details and long-range dependencies within an image, offering a flexible alternative to the strictly hierarchical representations of CNNs. Positional information is incorporated to maintain awareness of spatial structure, and a dedicated representation is used for classification. This design shifts the focus from handcrafted inductive biases toward a more data-driven approach, where the model learns to interpret visual structure primarily from large-scale training data.

Patch Embedding

patch_embedding_data

A key step in the Vision Transformer (ViT) is the patch embedding stage, which transforms an image into a sequence suitable for a Transformer. Instead of processing pixels individually or relying on convolutional filters, the input image is divided into fixed-size patches (for example, 16×16 pixels). Each patch is then flattened into a vector and projected through a linear layer to a chosen embedding space. The result is a sequence of patch embeddings that can be treated similarly to word tokens in natural language processing, allowing the Transformer to apply self-attention mechanisms across the entire image.

CLS Token

The [CLS] token is a learnable embedding prepended to the sequence of patch embeddings in a Vision Transformer. Its primary purpose is to serve as a global representation of the entire image. During the forward pass, the Transformer encoder processes the sequence of patch embeddings along with the [CLS] token, allowing the self-attention mechanism to integrate information from all patches into this special token. After the final encoder block, the [CLS] token contains a summary of the image’s content and is typically fed into the classification head to produce the output logits. By using the [CLS] token in this way, ViT can perform classification based on a single learned representation rather than aggregating information from all patch embeddings.

Positional Encoding

Since Transformers process input sequences without any inherent notion of order, it is necessary to provide information about the position of each patch in the image. In the Vision Transformer, this is achieved through positional encoding, which adds a vector to each patch embedding to indicate its location within the image. Unlike the fixed sinusoidal encodings used in the original Transformer for NLP, ViT often uses learnable positional embeddings, which are initialized randomly and updated during training. These learnable embeddings allow the model to adaptively encode spatial relationships between patches, helping the self-attention mechanism capture both local and global structure in the image.

Transformer Encoder

Encoder

The Transformer encoder is a fundamental component of the Vision Transformer (ViT), responsible for processing the sequence of patch embeddings and capturing relationships between them. Each encoder block contains a multi-head self-attention layer, which allows the model to weigh the importance of each patch relative to all others, followed by a feed-forward network (MLP) that transforms the representations. Residual connections and layer normalization are applied throughout to stabilize training and improve gradient flow. By stacking multiple encoder blocks, the Transformer encoder can build complex, high-level representations of the image, integrating both local and global information for downstream tasks such as classification.

Attention

Attention is a core mechanism in transformers that allows the model to selectively focus on the most relevant parts of an input sequence when making predictions. Instead of processing information uniformly, attention assigns weights to different elements, enabling the network to capture both local and long-range dependencies. In the context of Vision Transformers (ViTs), self-attention is applied directly to image patches, treating them as a sequence of tokens similar to words in natural language processing. This mechanism allows each patch to attend to every other patch, capturing global spatial relationships across the image. Unlike convolutional operations, which have a fixed receptive field, self-attention provides a flexible and adaptive way of modeling dependencies, making it particularly powerful for understanding complex visual structures. In Vision Transformer we apply Multi-Head Self Attention Given an input sequence of tokens (patch embeddings) $X∈ℝ^{N×D}$ where $N$ is the number of patches and $D$ is the embedding dimension, self-attention computes interactions between all tokens as follows:

  1. Linear projections for queries, keys, and values:

$$ X·W_{Q} = Q   ;   X·W_{K} = K   ;   X·W_{V} = V $$

where $W_{Q}, W_{K}, W_{V} ∈ℝ^{D×d}$ are learnable weight matrices, and $d$ is the attention head dimension.

  1. Scaled dot-product attention:
$$Attention(Q,K,V) = Softmax \Bigg(\frac{Q K^{T}}{\sqrt{d}} \Bigg)·V$$
  • $Q K^{T}∈ℝ^{N×N}$ computes similarity between every pair of tokens.
  • $\sqrt{d}$ is a scaling factor to stabilize gradients.
  • The softmax converts similarities into attention weights.
  1. Multi-head attention (concatenation of the heads):
$$MultiHead-Attention = Concat(head_1,...,head_h)·W_{O}$$
  • Multiple attention heads allow the model to capture different types of interactions.
  • $W_{O} ∈ℝ^{hd×D}$ projects concatenated outputs back to the embedding dimension.

For more details information about Attention Mechanism see Simple Transformer.

Feed-Forward Network

feedforward_vit

The Feed-Forward Network (FFN) in the Vision Transformer (ViT) is a crucial component of each encoder block. It consists of two fully connected layers with a non-linear activation function, often GELU (Gaussian Error Linear Unit), applied between them. Unlike self-attention, which enables tokens to exchange information globally, the FFN operates on each token independently, refining and transforming its representation in a higher-dimensional space. This allows the model to capture more complex, non-linear relationships within the data. In ViT, the FFN complements self-attention by enhancing the expressive power of the patch embeddings, ensuring that both global context and token-wise transformations contribute to the learned image representation.

$$y = f(W_{1}·x+b_{1})·W_{2} + b_{2}$$

Where:

  • $x$ is the input vector.
  • $W_i$ is the weight matrix of layer i.
  • $b_i$ is the bias vector of layer i.
  • $f$ is the activation function - GELU.

Layer Normalization

Layer Normalization is used to stabilize and accelerate training by normalizing the inputs to each layer.
For each input vector (for each token in a sequence), subtract the mean and divide by the standard deviation of the vector's values. This centers the data around 0 with unit variance:

$$\hat{x} = \frac{(x - μ)}{\sqrt{σ^{2} + ε}}$$

where μ is the mean and σ is the standard deviation of the input vector.

Then apply scaling (gamma) and shifting (beta) parameters (trainable):

  • γ (scale): A parameter to scale the normalized output.
  • β (shift): A parameter to shift the normalized output.
$$⇨ y = γ·\hat{x} + β$$

ViT vs CNN

Convolutional Neural Networks (CNNs), first demonstrated in LeNet-5 (LeCun et al., 1998) and popularized by AlexNet (2012), dominated computer vision for decades. They rely on convolutional filters applied to local receptive fields, pooling for downsampling, and fully connected layers for classification. This design encodes strong inductive biases: locality (features are learned from neighboring pixels) and translation equivariance (patterns can be recognized regardless of position). Variants like VGG, ResNet, and DenseNet advanced CNNs by increasing depth and introducing innovations such as residual connections.
To learn more about Convolutional Neural Networks (CNNs), I recommend Simple CNN Guide.

The Vision Transformer (ViT), introduced by Dosovitskiy et al. (2020), replaces convolutions with a pure Transformer encoder. An image is split into fixed-size patches (e.g., 16×16), flattened, linearly projected into embeddings, and combined with positional encodings. These are processed by Multi-Head Self-Attention (MHSA), which models global dependencies between all patches in parallel, something CNNs only capture gradually via deeper layers. A special [CLS] token aggregates global features for classification.

CNN vs ViT - Receptive Field In architecture, CNNs build hierarchical representations through stacked convolutions, gradually expanding their receptive fields and excelling at capturing local patterns such as edges and textures. ViTs, on the other hand, operate directly in patch-embedding space, where self-attention provides a global receptive field from the very first layer. This fundamental difference means CNNs have built-in biases for vision, making them data-efficient and effective on smaller datasets, while ViTs rely more heavily on large-scale data to learn spatial relationships that CNNs encode naturally. Consequently, CNNs tend to be more efficient and perform strongly when training data is limited, benefiting from their inductive biases, whereas ViTs scale more effectively with increased model size and dataset availability, often surpassing CNNs in accuracy and adaptability. Additionally, CNNs naturally form structured feature hierarchies that emphasize local detail and are relatively easy to interpret, making them well-suited for tasks like object detection or medical imaging. ViTs, by modeling long-range dependencies from the start, offer greater flexibility in capturing global context, which can lead to stronger performance in complex recognition challenges and transfer learning scenarios.

Training and Optimization

Adam Optimizer

The Adam optimization algorithm is an extension to stochastic gradient descent (SGD). Unlike SGD, The method computes individual adaptive learning rates for different parameters from estimates of first and second moments of the gradients Adam combines the benefits of two other methods: momentum and RMSProp.

Adam Algorithm:

  1. Compute gradients:

    $$g_t = \nabla_\theta J(\theta_t)$$
  2. Update moment estimates:

    $$m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t \quad;\quad v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2$$
  3. Bias correction:

    $$\hat{m}_t = \frac{m_t}{1 - \beta_1^t} \quad;\quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t}$$
  4. Parameter update:

    $$\theta_{t+1} = \theta_t - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$$
  5. Decoupled weight decay:

    $$\theta_{t+1} \leftarrow \theta_{t+1} - \alpha \cdot \lambda \cdot \theta_t$$

Cross Entropy Loss Function

The Cross Entropy Loss Function is widely used for classification tasks, as it measures the difference between the predicted probability distribution and the true distribution. Given a predicted probability vector $\hat{y}$ and a one-hot encoded target vector $y$, the loss for a single example is defined as:

$$ \mathcal{L}_{CE} = - \sum_{i} y_i \log \hat{y}_i $$

This loss penalizes confident incorrect predictions more heavily than less certain ones, encouraging the model to assign higher probabilities to the correct classes. Minimizing cross-entropy effectively maximizes the likelihood of the correct labels under the model’s predicted distribution.

ViT Forward-Pass (Pseudo-Code)

def vit_forward_pass(x):
    """
    x : Tensor of shape (B, C, H, W)
    returns: Tensor of shape (B, num_classes)
    """

    # Patch embedding
    X = patch_embedding(x)                  # (B, N, D)

    # Prepend CLS token
    X = concat([CLS], X, dim="patch_dim")   # (B, N + 1, D)

    # Positional encoding
    X = X + positional_encoding             # (B, N + 1, D)

    # Transformer encoder
    for encoder_layer in encoder_layers:
        X = encoder_layer(X)

    # inal normalization
    X = layer_norm(X)

    # Classification head (CLS token only)
    logits = head(X[:, 0])                  # (B, num_classes)

    return logits

Data

Evaluation

Reference

BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding

An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

Decoupled Weight Decay Regularization

About

A beginner-friendly exploration of Vision Transformers (ViT), implemented in PyTorch.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages