Elijah Cole

Where are the parameters in a vision transformer?

A vision transformer like ViT-B/16 has around 86 million parameters. To understand how those parameters are being used, we can take a look at the ViT-B/16 implementation in torchvision. (Spoiler: Far more parameters come from MLP layers than from attention layers!)

How many parameters are there in total?

This blog post from Weights and Biases provides a nice snippet to count the parameters in a PyTorch model:

total_params = sum(param.numel() for param in model.parameters())

We can grab the ViT-B/16 model from torchvision and use this snippet to find the total number of parameters:

import torchvision
model = torchvision.models.vit_b_16()
total_params = sum(param.numel() for param in model.parameters())
86567656

Now let’s try to work through the model to find all of those parameters.

How is the model structured?

To figure out where to start, let’s look at the basic structure of the model by calling print() on the model object.

import torchvision
model = torchvision.models.vit_b_16()
print(model)
VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      ...
      (encoder_layer_11): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  )
  (heads): Sequential(
    (head): Linear(in_features=768, out_features=1000, bias=True)
  )
)

This shows us all of the torch.nn.Module objects that make up the model. In particular, we have a torch.nn.Module object called VisionTransformer which has three children:

However, there are some “hidden” parameters that are not obvious from model summary produced by print(model). It turns out that print(model) only shows instances of torch.nn.Module (i.e. a module and its submodules). Any module attributes of type torch.nn.Parameter will not show up because this type subclasses torch.Tensor, not torch.nn.Module. In the case of ViT-B/16, we need to watch out for this in two places:

With these “hidden” parameters accounted for, we just need to tally up the parameters in the modules from the summary displayed by print(model).

How many parameters are in conv_proj?

We need 16 * 16 * 3 parameters for each convolutional filter, and we have 768 of those. We also have a bias term for each filter.

num_params(conv_proj) = 16 * 16 * 3 * 768 + 768 = 590592

How many parameters are in heads?

This is just a linear layer with 768 inputs and 1000 outputs, with a bias for each output dimension.

num_params(heads) = 768 * 1000 + 1000 = 769000

How many parameters are in encoder?

Most of this model’s parameters are in the encoder, which consists of 12 identical transformer blocks (EncoderBlock) followed by a final normalization layer (LayerNorm).

Let’s take a look at one encoder block:

EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )

LayerNorm takes a 768-dimensional embedding and performs some rescaling followed by a linear layer with biases.

num_params(LayerNorm) = 768 + 768 = 1536

MLPBlock has two Linear submodules that contribute trainable parameters.

num_params(MLPBlock) = (768 * 3072 + 3072) + (3072 * 768 + 768) = 4722432

Finally, let’s take a look at the MultiheadAttention block. We will need a few facts about ViT-B/16:

Recall that one attention head is often written as

head(Q, K, V) = Attention(Q W^Q, K W^K, V W^V).

Since we have a per-patch embedding dimension of 768 and a per-head embedding dimension of 64, the matrices W^Q, W^K, and W^V have size 768 x 64. However, the PyTorch implementation of MultiheadAttention has bias=True by default. One attention head therefore has 3 * (768 * 64 + 64) parameters. To get the final parameter count for MultiheadAttention, we just multiply by 12 and add in the 768 * 768 + 768 parameters for the output projection (i.e. the projection that is applied to the concatenated outputs of the individual attention heads).

num_params(MultiheadAttention) = 12 * (3 * (768 * 64 + 64)) + (768 * 768 + 768) = 2362368

We just need to tally up these modules to get the number of parameters for one encoder block.

num_params(EncoderBlock) = 2 * num_params(LayerNorm) + num_params(MultiheadAttention) + num_params(MLPBlock) = 2 * 1536 + 2362368 + 4722432 = 7087872

The encoder module consists of 12 EncoderBlock modules followed by one LayerNorm module. We also have the “hidden” parameters from the positional embeddings.

num_params(encoder) = 12 * num_params(EncoderBlock) + num_params(LayerNorm) + num_params(pos_embedding) = 12 * 7087872 + 1536 + (14 * 14 * 768) + 768 = 85207296

Conclusion

The total number of parameters for ViT-B/16 can be found by summing up the parameters of its children modules, and adding in the “hidden” parameters from the CLS token.

num_params(VisionTransformer) = num_params(conv_proj) + num_params(encoder) + num_params(heads) + num_params(class_token) = 590592 + 85207296 + 769000 + 768 = 86567656

This is the number we were shooting for based on the code snippet from the beginning!

Unsurprisingly, 98% of the model’s parameters are in the encoder module, whose size is dominated by EncoderBlock modules. A little more surprisingly, only 33% of the parameters in an EncoderBlock are involved in MultiheadAttention. Almost all of the other 67% of the parameters are found in the MLPBlock. This means that MLPs account for 2/3 of the parameters in ViT-B/16 overall!