Efficient Image Segmentation Using PyTorch: Part 4

A Vision Transformer-based model

In this 4-part series, we’ll implement image segmentation step by step from scratch using deep learning techniques in PyTorch. This part will focus on implementing a Vision Transformer based model for image segmentation.

Co-authored with Naresh Singh

Figure 1: Result of running image segmentation using a vision transformer model architecture. From top to bottom, input images, ground truth segmentation masks, and predicted segmentation masks. Source: Author(s)

Article outline

In this article, we will visit the Transformer architecture which has taken the world of deep learning by storm. The transformer is a multimodal architecture which can model different modalities such as language, vision, and audio.

In this article, we will

Learn about the transformer architecture and the key concepts involvedUnderstand the vision transformer architectureIntroduce a vision transformer model that is written from scratch so that you can appreciate all the building blocks and moving partsFollow an input tensor fed into this model and inspect how it changes shapeUse this model to perform image segmentation on the Oxford IIIT Pet datasetObserve the results of this segmentation taskBriefly introduce the SegFormer, a state of the art vision transformer for semantic segmentation

Throughout this article, we will reference code and results from this notebook for model training. If you wish to reproduce the results, you’ll need a GPU to ensure that the first notebook completes running in a reasonable amount of time.

Articles in this series

This series is for readers at all experience levels with deep learning. If you want to learn about the practice of deep learning and vision AI along with some solid theory and hands-on experience, you’ve come to the right place! This is expected to be a 4-part series with the following articles:

Concepts and IdeasA CNN-based modelDepthwise separable convolutionsA Vision Transformer-based model (this article)

Let’s start our journey into vision transformers with an introduction and intuitive understanding of the transformer architecture.

The Transformer Architecture

We can think of the transformer architecture as a composition of interleaving layers of communication and computation. This idea is depicted visually in figure 2. The transformer has N processing units (N is 3 in Figure 2), each of which is responsible for processing a 1/N fraction of the input. For those processing units to produce meaningful results, each of them need to have a global view of the input. Hence, the system repeatedly communicates information about the data in every processing unit to every other processing unit; shown using the red, green, and blue arrows going from every processing unit to every other processing unit. This is followed by some computation based on this information. After sufficient repetitions of this process, the model is able to produce desired results.

Figure 2: Interleaved communication and computation in transformers. The image shows just 2 layers of communication and computation. In practice, there are many more such layers. Source: Author(s).

It’s worth noting most online resources typically discuss both the encoder and the decoder of the transformer as presented in the paper titled “Attention is all you need.” However, in this article, we will describe just the encoder part of the transformer.

Let’s take a closer look at what constitutes communication and computation in transformers.

Communication in transformers: Attention

In transformers, communication is implemented by a layer known as the attention layer. In PyTorch, this is called MultiHeadAttention. We’ll get to the reason for that name in a bit.

The documentation says:

“Allows the model to jointly attend to information from different representation subspaces as described in the paper: Attention is all you need.”

The attention mechanism consumes an input tensor x of shape (Batch, Length, Features), and it produces a similarly shaped tensor y such that the features for each input are updated based on which other inputs in the same instance the tensor is paying attention to. Hence, the features of each tensor of length “Features” in the instance of size “Length” are updated based on every other tensor. This is where the quadratic cost of the attention mechanism comes in.

Figure 3: Attention of the word “it ” shown relative to the other words in the sentence. We can see that “it “ is paying attention to the words “animal “, “too “, and “tire(d) ” in the same sentence. Source: Generated using this colab.

In the context of a vision transformer, the input to the transformer is an image. Let’s assume this to be a 128 x 128 (width, height) image. We chunk it into multiple smaller patches of size (16 x 16). For a 128 x 128 image, we get 64 patches (Length), 8 patches in each row, and 8 rows of patches.

Each one of these 64 patches of size 16 x 16 pixels is considered to be a separate input to the transformer model. Without getting too deep into the details, it should be sufficient to think of this process as being driven by 64 different processing units, each of which is processing a single 16×16 image patch.

In each round, the attention mechanism in each processing unit is responsible for looking at the image patch it is responsible for and querying each one of the other remaining 63 processing units to ask them for any information that may be relevant and useful to help it effectively process its own image patch.

The communication step via attention is followed by computation, which we will look at next.

Computation in transformers: Multi Layer Perceptron

Computation in transformers is nothing but a MultiLayerPerceptron (MLP) unit. This unit is composed of 2 Linear layers, with a GeLU non-linearity in between. One can consider using other non-linearities as well. This unit first projects the input to 4x the size and reprojects it back to 1x, which is the same as the input size.

In the code we’ll see in our notebook, this class is called MultiLayerPerceptron. The code is shown below.

class MultiLayerPerceptron(nn.Sequential):
def __init__(self, embed_size, dropout):
nn.Linear(embed_size, embed_size * 4),
nn.Linear(embed_size * 4, embed_size),
# end def
# end class

Now that we understand the high level working of the transformer architecture, let’s focus our attention on the vision transformer since we’re going to be performing image segmentation.

The Vision Transformer

The vision transformer was first introduced by the paper titled “An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale”. The paper discusses how the authors apply the vanilla transformer architecture to the problem of image classification. This is done by splitting the image into patches of size 16×16, and treating each patch as an input token to the model. The transformer encoder model is fed these input tokens, and is asked to predict a class for the input image.

Figure 4: Source: Transformers for image recognition at scale.

In our case, we are interested in image segmentation. We can consider it to be a pixel-level classification task because we intend to predict a target class per pixel..

We make a small but important change to the vanilla vision transformer and replace the MLP head for classification by an MLP head for pixel level classification. We have a single linear layer in the output that is shared by every patch whose segmentation mask is predicted by the vision transformer. This shared linear layer predicts a segmentation mask for every patch that was sent as input to the model.

In the case of the vision transformer, a patch of size 16×16 is considered to be equivalent to a single input token at a specific time step.

Figure 5: The end to end working of the vision transformer for image segmentation. Image generated using this notebook. Source: Author(s).

Building an intuition for tensor dimensions in vision transformers

When working with deep CNNs, the tensor dimensions we used for the most part was (N, C H, W), where the letters stand for the following:

N: Batch sizeC: Number of channelsH: HeightW: Width

You can see that this format is geared toward 2d image processing, since it smells of features that are very specific to images.

With transformers on the other hand, things become a lot more generic and domain agnostic. What we’ll see below applies to vision, text, NLP, audio or other problems where input data can be represented as a sequence. It is worth noting that there’s little vision specific bias in the representation of tensors as they flow through our vision transformer.

When working with transformers and attention in general, we expect the tensors to have the following shape: (B, T, C), where the letters stand for the following:

B: Batch size (same as that for CNNs)T: Time dimension or sequence length. This dimension is also sometimes called L. In the case of vision transformers, each image patch corresponds to this dimension. If we have 16 image patches, then the value of the T dimension will be 16C: The channel or embedding size dimension. This dimension is also sometimes called E. When processing images, each patch of size 3x16x16 (Channel, Width, Height) is mapped via a patch embedding layer to an embedding of size C. We’ll see how this is done later.

Let’s dive into how the input image tensor gets mutated and processed along its way to predicting the segmentation mask.

The journey of a tensor in a vision transformer

In deep CNNs, the journey of a tensor looks something like this (in a UNet, SegNet, or other CNN based architecture).

The input tensor is typically of shape (1, 3, 128, 128). This tensor goes through a series of convolution and max-pooling operations where its spatial dimensions are reduced and channel dimensions are increased, typically by a factor of 2 each. This is called the feature encoder. After this, we do the reverse operation where we increase the spatial dimensions and reduce the channel dimensions. This is called the feature decoder. After the decoding process, we get a tensor of shape (1, 64, 128, 128). This is then projected into the number of output channels C that we desire as (1, C, 128, 128) using a 1×1 pointwise convolution without bias.

Figure 6: Typical progression of tensor shapes through a deep CNN used for image segmentation. Source: Author(s).

With vision transformers, the flow is much more complex. Let’s take a look at an image below and then try to understand how the tensor transforms shapes at every step along the way.

Figure 7: Typical progression of tensor shapes through a vision transformer for image segmentation. Source: Author(s).

Let’s look at each step in more detail and see how it updates the shape of the tensor flowing through the vision transformer. To understand this better, let’s take concrete values for our tensor dimensions.

Batch Normalization: The input and output tensors have shape (1, 3, 128, 128). The shape is unchanged, but the values are normalized to zero mean and unit variance.Image to patches: The input tensor of shape (1, 3, 128, 128) is converted into a stacked patch of 16×16 images. The output tensor has shape (1, 64, 768).Patch embedding: The patch embedding layer maps the 768 input channels to 512 embedding channels (for this example). The output tensor is of shape (1, 64, 512). The patch embedding layer is basically just an nn.Linear layer in PyTorch.Position embedding: The position embedding layer doesn’t have an input tensor, but effectively contributes a learnable parameter (trainable tensor in PyTorch) o f the same shape as the patch embedding. This is of shape (1, 64, 512).Add: The patch and position embeddings are added together piecewise to produce the input to our vision transformer encoder. This tensor is of shape (1, 64, 512). You’ll notice that the main workhorse of the vision transformer, i.e. the encoder basically leaves this tensor shape unchanged.Transformer encoder: The input tensor of shape (1, 64, 512) flows through multiple transformer encoder blocks, each of which have multiple attention heads (communication) followed by an MLP layer (computation). The tensor shape remains unchanged as (1, 64, 512).Linear output projection: If we assume that we want to segment each image into 10 classes, then we will need each patch of size 16×16 to have 10 channels. The nn.Linear layer for output projection will now convert the 512 embedding channels to 16x16x10 = 2560 output channels, and this tensor will look like (1, 64, 2560). In the diagram above C’ = 10. Ideally, this would be a multi-layer perceptron, since MLPs are universal function approximators, but we use a single linear layer since this is an educational exercisePatch to image: This layer converts the 64 patches encoded as a (1, 64, 2560) tensor back into something that looks like a segmentation mask. This can be 10 single channel images, or in this case a single 10 channel image, with each channel being the segmentation mask for one of the 10 classes. The output tensor is of shape (1, 10, 128, 128).

That’s it — we’ve successfully segmented an input image using a vision transformer! Next, let’s take a look at an experiment along with some results.

Vision transformers in action

This notebook contains all the code for this section.

As far as the code and class structure is concerned, it closely mimics the block diagram above. Most of the concepts mentioned above have a 1:1 correspondence to class names in this notebook.

There are some concepts related to the attention layers that are critical hyperparameters for our model. We didn’t mention anything about the details of the multi-head attention earlier since we mentioned that it’s out of scope for the purposes of this article. We highly recommend reading the reference material mentioned above before proceeding if you don’t have a basic understanding of the attention mechanism in transformers.

We used the following model parameters for the vision transformer for segmentation.

768 embedding dimensions for the PatchEmbedding layer12 Transformer encoder blocks8 attention heads in each transformer encoder block20% dropout in multi-head attention and MLP

This configuration can be seen in the VisionTransformerArgs Python dataclass.

class VisionTransformerArgs:
“””Arguments to the VisionTransformerForSegmentation.”””
image_size: int = 128
patch_size: int = 16
in_channels: int = 3
out_channels: int = 3
embed_size: int = 768
num_blocks: int = 12
num_heads: int = 8
dropout: float = 0.2
# end class

A similar configuration as before was used during model training and validation. The configuration is specified below.

The random horizontal flip and colour jitter data augmentations are applied to the training set to prevent overfittingThe images are resized to 128×128 pixels in a non-aspect preserving resize operationNo input normalization is applied to the images — instead a batch normalization layer is used as the first layer of the modelThe model is trained for 50 epochs using the Adam optimizer with a LR of 0.0004 and a StepLR scheduler that decays the learning rate by 0.8x every 12 epochsThe cross-entropy loss function is used to classify a pixel as belonging to a pet, the background, or a pet border

The model has 86.28M parameters and achieved a validation accuracy of 85.89% after 50 training epochs. This is less than the 88.28% accuracy achieved by deep CNN model after 20 training epochs. This could be due to a few factors that need to be validated experimentally.

The last output projection layer is a single nn.Linear and not a multi-layer perceptronThe 16×16 patch size is too large to capture more fine grained detailNot enough training epochsNot enough training data — it’s known that transformer models need a lot more data to train effectively compared to deep CNN modelsThe learning rate is too low

We plotted a gif showing how the model is learning to predict the segmentation masks for 21 images in the validation set.

Figure 8: A gif showing the progression of segmentation masks predicted by the vision transformer for image segmentation model. Source: Author(s).

We notice something interesting in the early training epochs. The predicted segmentation masks have some strange blocking artifacts. The only reason we could think of for this is because we’re breaking down the image into patches of size 16×16 and after very few training epochs, the model hasn’t learned anything useful beyond some very coarse grained information regarding whether this 16×16 patch is generally covered by a pet or by background pixels.

Figure 9: The blocking artifacts seen in the predicted segmentation masks when using the vision transformer for image segmentation. Source: Author(s).

Now that we have seen a basic vision transformer in action, let’s turn our attention to a state of the art vision transformer for segmentation tasks.

SegFormer: Semantic segmentation with transformers

The SegFormer architecture was proposed in this paper in 2021. The transformer we saw above is a simpler version of the SegFormer architecture.

Figure 10: The SegFormer architecture. Source: SegFormer paper (2021).

Most notably, the SegFormer:

Generates 4 sets of images with patches of size 4×4, 8×8, 16×16, and 32×32 instead of a single patched image with patches of size 16x16Uses 4 transformer encoder blocks instead of just 1. This feels like a model ensembleUses convolutions in the pre and post phases of self-attentionDoesn’t use positional embeddingsEach transformer block processes images at spatial resolution H/4 x W/4, H/8 x W/8, H/16 x W/16, and H/32, W/32Similarly, the channels increase when the spatial dimensions reduce. This feels similar to deep CNNsPredictions at multiple spatial dimensions are upsampled and then merged together in the decoderAn MLP combines all these predictions to provide a final predictionThe final prediction is at spatial dimension H/4, W/4 and not at H, W.


In part-4 of this series, we were introduced to the transformer architecture and vision transformers in particular. We developed an intuitive understanding of how vision transformers work, and the basic building block involved in the communication and computation phases of vision transformers. We saw the unique patch based approach adopted by vision transformers for predicting segmentation masks and then combining the predictions together.

We reviewed an experiment that shows vision transformers in action, and were able to compare results with deep CNN approaches. While our vision transformer is not state of the art, it was able to achieve pretty decent results. We provided a glimpse into state-of-the-art approaches such as SegFormer.

It should be clear by now that transformers have a lot more moving parts and are more complex compared to deep CNN-based approaches. From a raw FLOPs point of view, transformers hold the promise of being more efficient. In transformers, the only real layer that is computationally heavy is nn.Linear. This is implemented using optimized matrix multiplication on most architectures. Due to this architectural simplicity, transformers hold the promise of being easier to optimize and speed up compared to deep CNN-based approaches.

Congratulations on making it this far! We’re glad you enjoyed reading this series on efficient image segmentation in PyTorch. If you have questions or comments, please feel free to leave them in the comments section.

Further reading

The details of the attention mechanism are out of scope for this article. Besides, there are numerous high quality resources that you can refer to to understand the attention mechanism in great detail. Here are some that we highly recommend.

The Illustrated TransformerNanoGPT from scratch using PyTorch

We’ll provide links to articles that provide more details on vision transformers below.

Implementing Vision Transformer (ViT) in PyTorch: This article details the implementation of a vision transformer for image classification in PyTorch. Notably, their implementation uses einops, which we avoid, since this is an education-focused exercise (we recommend learning and using einops for code readability though). We instead use native PyTorch operators for permuting and rearranging tensor dimensions. Additionally, there are a few places where the author uses Conv2d instead of Linear layers. We wanted to build an implementation of vision transformers without the use of convolutional layers entirely.Vision Transformer: AI SummerImplementing SegFormer in PyTorch

Efficient Image Segmentation Using PyTorch: Part 4 was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.


Oh hi there 👋
It’s nice to meet you.

Sign up to receive awesome content in your inbox, every month.

We don’t spam!

Leave a Comment

Scroll to Top