Efficient Image Segmentation Using PyTorch: Part 4 | by Dhruv Matani | Jun, 2023


The vision transformer was first introduced by the paper titled “An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale”. The paper discusses how the authors apply the vanilla transformer architecture to the problem of image classification. This is done by splitting the image into patches of size 16×16, and treating each patch as an input token to the model. The transformer encoder model is fed these input tokens, and is asked to predict a class for the input image.

Figure 4: Source: Transformers for image recognition at scale.

In our case, we are interested in image segmentation. We can consider it to be a pixel-level classification task because we intend to predict a target class per pixel..

We make a small but important change to the vanilla vision transformer and replace the MLP head for classification by an MLP head for pixel level classification. We have a single linear layer in the output that is shared by every patch whose segmentation mask is predicted by the vision transformer. This shared linear layer predicts a segmentation mask for every patch that was sent as input to the model.

In the case of the vision transformer, a patch of size 16×16 is considered to be equivalent to a single input token at a specific time step.

Figure 5: The end to end working of the vision transformer for image segmentation. Image generated using this notebook. Source: Author(s).

Building an intuition for tensor dimensions in vision transformers

When working with deep CNNs, the tensor dimensions we used for the most part was (N, C H, W), where the letters stand for the following:

  • N: Batch size
  • C: Number of channels
  • H: Height
  • W: Width

You can see that this format is geared toward 2d image processing, since it smells of features that are very specific to images.

With transformers on the other hand, things become a lot more generic and domain agnostic. What we’ll see below applies to vision, text, NLP, audio or other problems where input data can be represented as a sequence. It is worth noting that there’s little vision specific bias in the representation of tensors as they flow through our vision transformer.

When working with transformers and attention in general, we expect the tensors to have the following shape: (B, T, C), where the letters stand for the following:

  • B: Batch size (same as that for CNNs)
  • T: Time dimension or sequence length. This dimension is also sometimes called L. In the case of vision transformers, each image patch corresponds to this dimension. If we have 16 image patches, then the value of the T dimension will be 16
  • C: The channel or embedding size dimension. This dimension is also sometimes called E. When processing images, each patch of size 3x16x16 (Channel, Width, Height) is mapped via a patch embedding layer to an embedding of size C. We’ll see how this is done later.

Let’s dive into how the input image tensor gets mutated and processed along its way to predicting the segmentation mask.

The journey of a tensor in a vision transformer

In deep CNNs, the journey of a tensor looks something like this (in a UNet, SegNet, or other CNN based architecture).

The input tensor is typically of shape (1, 3, 128, 128). This tensor goes through a series of convolution and max-pooling operations where its spatial dimensions are reduced and channel dimensions are increased, typically by a factor of 2 each. This is called the feature encoder. After this, we do the reverse operation where we increase the spatial dimensions and reduce the channel dimensions. This is called the feature decoder. After the decoding process, we get a tensor of shape (1, 64, 128, 128). This is then projected into the number of output channels C that we desire as (1, C, 128, 128) using a 1×1 pointwise convolution without bias.

Figure 6: Typical progression of tensor shapes through a deep CNN used for image segmentation. Source: Author(s).

With vision transformers, the flow is much more complex. Let’s take a look at an image below and then try to understand how the tensor transforms shapes at every step along the way.

Figure 7: Typical progression of tensor shapes through a vision transformer for image segmentation. Source: Author(s).

Let’s look at each step in more detail and see how it updates the shape of the tensor flowing through the vision transformer. To understand this better, let’s take concrete values for our tensor dimensions.

  1. Batch Normalization: The input and output tensors have shape (1, 3, 128, 128). The shape is unchanged, but the values are normalized to zero mean and unit variance.
  2. Image to patches: The input tensor of shape (1, 3, 128, 128) is converted into a stacked patch of 16×16 images. The output tensor has shape (1, 64, 768).
  3. Patch embedding: The patch embedding layer maps the 768 input channels to 512 embedding channels (for this example). The output tensor is of shape (1, 64, 512). The patch embedding layer is basically just an nn.Linear layer in PyTorch.
  4. Position embedding: The position embedding layer doesn’t have an input tensor, but effectively contributes a learnable parameter (trainable tensor in PyTorch) o f the same shape as the patch embedding. This is of shape (1, 64, 512).
  5. Add: The patch and position embeddings are added together piecewise to produce the input to our vision transformer encoder. This tensor is of shape (1, 64, 512). You’ll notice that the main workhorse of the vision transformer, i.e. the encoder basically leaves this tensor shape unchanged.
  6. Transformer encoder: The input tensor of shape (1, 64, 512) flows through multiple transformer encoder blocks, each of which have multiple attention heads (communication) followed by an MLP layer (computation). The tensor shape remains unchanged as (1, 64, 512).
  7. Linear output projection: If we assume that we want to segment each image into 10 classes, then we will need each patch of size 16×16 to have 10 channels. The nn.Linear layer for output projection will now convert the 512 embedding channels to 16x16x10 = 2560 output channels, and this tensor will look like (1, 64, 2560). In the diagram above C’ = 10. Ideally, this would be a multi-layer perceptron, since MLPs are universal function approximators, but we use a single linear layer since this is an educational exercise
  8. Patch to image: This layer converts the 64 patches encoded as a (1, 64, 2560) tensor back into something that looks like a segmentation mask. This can be 10 single channel images, or in this case a single 10 channel image, with each channel being the segmentation mask for one of the 10 classes. The output tensor is of shape (1, 10, 128, 128).

That’s it — we’ve successfully segmented an input image using a vision transformer! Next, let’s take a look at an experiment along with some results.

Vision transformers in action

This notebook contains all the code for this section.

As far as the code and class structure is concerned, it closely mimics the block diagram above. Most of the concepts mentioned above have a 1:1 correspondence to class names in this notebook.

There are some concepts related to the attention layers that are critical hyperparameters for our model. We didn’t mention anything about the details of the multi-head attention earlier since we mentioned that it’s out of scope for the purposes of this article. We highly recommend reading the reference material mentioned above before proceeding if you don’t have a basic understanding of the attention mechanism in transformers.

We used the following model parameters for the vision transformer for segmentation.

  1. 768 embedding dimensions for the PatchEmbedding layer
  2. 12 Transformer encoder blocks
  3. 8 attention heads in each transformer encoder block
  4. 20% dropout in multi-head attention and MLP

This configuration can be seen in the VisionTransformerArgs Python dataclass.

@dataclass
class VisionTransformerArgs:
"""Arguments to the VisionTransformerForSegmentation."""
image_size: int = 128
patch_size: int = 16
in_channels: int = 3
out_channels: int = 3
embed_size: int = 768
num_blocks: int = 12
num_heads: int = 8
dropout: float = 0.2
# end class

A similar configuration as before was used during model training and validation. The configuration is specified below.

  1. The random horizontal flip and colour jitter data augmentations are applied to the training set to prevent overfitting
  2. The images are resized to 128×128 pixels in a non-aspect preserving resize operation
  3. No input normalization is applied to the images — instead a batch normalization layer is used as the first layer of the model
  4. The model is trained for 50 epochs using the Adam optimizer with a LR of 0.0004 and a StepLR scheduler that decays the learning rate by 0.8x every 12 epochs
  5. The cross-entropy loss function is used to classify a pixel as belonging to a pet, the background, or a pet border

The model has 86.28M parameters and achieved a validation accuracy of 85.89% after 50 training epochs. This is less than the 88.28% accuracy achieved by deep CNN model after 20 training epochs. This could be due to a few factors that need to be validated experimentally.

  1. The last output projection layer is a single nn.Linear and not a multi-layer perceptron
  2. The 16×16 patch size is too large to capture more fine grained detail
  3. Not enough training epochs
  4. Not enough training data — it’s known that transformer models need a lot more data to train effectively compared to deep CNN models
  5. The learning rate is too low

We plotted a gif showing how the model is learning to predict the segmentation masks for 21 images in the validation set.

Figure 8: A gif showing the progression of segmentation masks predicted by the vision transformer for image segmentation model. Source: Author(s).

We notice something interesting in the early training epochs. The predicted segmentation masks have some strange blocking artifacts. The only reason we could think of for this is because we’re breaking down the image into patches of size 16×16 and after very few training epochs, the model hasn’t learned anything useful beyond some very coarse grained information regarding whether this 16×16 patch is generally covered by a pet or by background pixels.

Figure 9: The blocking artifacts seen in the predicted segmentation masks when using the vision transformer for image segmentation. Source: Author(s).

Now that we have seen a basic vision transformer in action, let’s turn our attention to a state of the art vision transformer for segmentation tasks.

SegFormer: Semantic segmentation with transformers

The SegFormer architecture was proposed in this paper in 2021. The transformer we saw above is a simpler version of the SegFormer architecture.

Figure 10: The SegFormer architecture. Source: SegFormer paper (2021).

Most notably, the SegFormer:

  1. Generates 4 sets of images with patches of size 4×4, 8×8, 16×16, and 32×32 instead of a single patched image with patches of size 16×16
  2. Uses 4 transformer encoder blocks instead of just 1. This feels like a model ensemble
  3. Uses convolutions in the pre and post phases of self-attention
  4. Doesn’t use positional embeddings
  5. Each transformer block processes images at spatial resolution H/4 x W/4, H/8 x W/8, H/16 x W/16, and H/32, W/32
  6. Similarly, the channels increase when the spatial dimensions reduce. This feels similar to deep CNNs
  7. Predictions at multiple spatial dimensions are upsampled and then merged together in the decoder
  8. An MLP combines all these predictions to provide a final prediction
  9. The final prediction is at spatial dimension H/4, W/4 and not at H, W.



Source link

Leave a Comment