A vision transformer like ViT-B/16 has around 86 million parameters. To understand how those parameters are being used, we can take a look at the ViT-B/16 implementation in torchvision
. (Spoiler: Far more parameters come from MLP layers than from attention layers!)
This blog post from Weights and Biases provides a nice snippet to count the parameters in a PyTorch model:
total_params = sum(param.numel() for param in model.parameters())
We can grab the ViT-B/16 model from torchvision
and use this snippet to find the total number of parameters:
import torchvision
model = torchvision.models.vit_b_16()
total_params = sum(param.numel() for param in model.parameters())
86567656
Now let’s try to work through the model to find all of those parameters.
To figure out where to start, let’s look at the basic structure of the model by calling print()
on the model object.
import torchvision
model = torchvision.models.vit_b_16()
print(model)
VisionTransformer(
(conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
(encoder): Encoder(
(dropout): Dropout(p=0.0, inplace=False)
(layers): Sequential(
(encoder_layer_0): EncoderBlock(
(ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(self_attention): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(dropout): Dropout(p=0.0, inplace=False)
(ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): MLPBlock(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
...
(encoder_layer_11): EncoderBlock(
(ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(self_attention): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(dropout): Dropout(p=0.0, inplace=False)
(ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): MLPBlock(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
(ln): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
)
(heads): Sequential(
(head): Linear(in_features=768, out_features=1000, bias=True)
)
)
This shows us all of the torch.nn.Module
objects that make up the model. In particular, we have a torch.nn.Module
object called VisionTransformer
which has three children:
conv_proj
: A convolutional layer that implements patch extraction (breaking a 224 x 224 x 3 image into 16 x 16 x 3 blocks) and projection (learning a linear transformation of the 16 * 16 * 3 = 768 values corresponding to each block).encoder
: A stack of 12 transformer blocks, followed by a final normalization layer.heads
: A linear layer that maps the 768-dimensional CLS
token embedding to the 1000 ImageNet classes.However, there are some “hidden” parameters that are not obvious from model summary produced by print(model)
. It turns out that print(model)
only shows instances of torch.nn.Module
(i.e. a module and its submodules). Any module attributes of type torch.nn.Parameter
will not show up because this type subclasses torch.Tensor
, not torch.nn.Module
. In the case of ViT-B/16, we need to watch out for this in two places:
torch.nn.Parameter
attribute of encoder
called pos_embedding
.
CLS
token is implemented as a a torch.nn.Parameter
attribute of VisionTransformer
called class_token
.
With these “hidden” parameters accounted for, we just need to tally up the parameters in the modules from the summary displayed by print(model)
.
conv_proj
?We need 16 * 16 * 3 parameters for each convolutional filter, and we have 768 of those. We also have a bias term for each filter.
num_params(
conv_proj
) = 16 * 16 * 3 * 768 + 768 = 590592
heads
?This is just a linear layer with 768 inputs and 1000 outputs, with a bias for each output dimension.
num_params(
heads
) = 768 * 1000 + 1000 = 769000
encoder
?Most of this model’s parameters are in the encoder, which consists of 12 identical transformer blocks (EncoderBlock
) followed by a final normalization layer (LayerNorm
).
Let’s take a look at one encoder block:
EncoderBlock(
(ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(self_attention): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(dropout): Dropout(p=0.0, inplace=False)
(ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): MLPBlock(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
LayerNorm
takes a 768-dimensional embedding and performs some rescaling followed by a linear layer with biases.
num_params(
LayerNorm
) = 768 + 768 = 1536
MLPBlock
has two Linear
submodules that contribute trainable parameters.
num_params(
MLPBlock
) = (768 * 3072 + 3072) + (3072 * 768 + 768) = 4722432
Finally, let’s take a look at the MultiheadAttention
block. We will need a few facts about ViT-B/16:
Recall that one attention head is often written as
head(Q, K, V) = Attention(Q W^Q, K W^K, V W^V).
Since we have a per-patch embedding dimension of 768 and a per-head embedding dimension of 64, the matrices W^Q, W^K, and W^V have size 768 x 64. However, the PyTorch implementation of MultiheadAttention has bias=True
by default. One attention head therefore has 3 * (768 * 64 + 64) parameters. To get the final parameter count for MultiheadAttention
, we just multiply by 12 and add in the 768 * 768 + 768 parameters for the output projection (i.e. the projection that is applied to the concatenated outputs of the individual attention heads).
num_params(
MultiheadAttention
) = 12 * (3 * (768 * 64 + 64)) + (768 * 768 + 768) = 2362368
We just need to tally up these modules to get the number of parameters for one encoder block.
num_params(
EncoderBlock
) = 2 * num_params(LayerNorm
) + num_params(MultiheadAttention
) + num_params(MLPBlock
) = 2 * 1536 + 2362368 + 4722432 = 7087872
The encoder
module consists of 12 EncoderBlock
modules followed by one LayerNorm
module. We also have the “hidden” parameters from the positional embeddings.
num_params(
encoder
) = 12 * num_params(EncoderBlock
) + num_params(LayerNorm
) + num_params(pos_embedding
) = 12 * 7087872 + 1536 + (14 * 14 * 768) + 768 = 85207296
The total number of parameters for ViT-B/16 can be found by summing up the parameters of its children modules, and adding in the “hidden” parameters from the CLS
token.
num_params(
VisionTransformer
) = num_params(conv_proj
) + num_params(encoder
) + num_params(heads
) + num_params(class_token
) = 590592 + 85207296 + 769000 + 768 = 86567656
This is the number we were shooting for based on the code snippet from the beginning!
Unsurprisingly, 98% of the model’s parameters are in the encoder
module, whose size is dominated by EncoderBlock
modules. A little more surprisingly, only 33% of the parameters in an EncoderBlock
are involved in MultiheadAttention
. Almost all of the other 67% of the parameters are found in the MLPBlock
. This means that MLPs account for 2/3 of the parameters in ViT-B/16 overall!