In this series, we’ll train multiple models for class segmentation from scratch. There are many considerations to account for when building and training a model from scratch. Below, we will look at some of the key decisions that you need to make when doing so.
Choosing the right model for your task
There are many factors to consider when choosing the right deep-learning model for image segmentation. Some of the most important factors include:
- The type of image segmentation task: There are two main types of image segmentation tasks: class (semantic) segmentation and object (instance) segmentation. Since we’re focusing on the simpler class segmentation problem, we shall consider modeling our problem accordingly.
- Size and complexity of the dataset: The size and complexity of the dataset will affect the complexity of the model that we need to use. For example, if we are working with images with a small spatial dimension, we may use a simpler (or shallower) model, such as a fully convolutional network (FCN). If we are working with a large and complex dataset, we may use a more complex (or deeper) model such as a U-Net.
- Availability of pre-trained models: There are many pre-trained models available for image segmentation. These models can be used as a starting point for our own model or we can use them directly. However, if we use a pre-trained model, we may be constrained by the spatial dimensions of the input image to the model. In this series, we shall focus on training a model from scratch.
- Computational resources available: Deep learning models can be computationally expensive to train. If we have limited computational resources, we may need to choose simpler models or more efficient model architectures.
In this series, we are going to work with the Oxford IIIT Pet dataset since it’s big enough for us to be able to train a medium size model and require the use of a GPU. We would highly recommend creating an account on kaggle.com or use Google Colab’s free GPU for running the notebooks and code referenced in this series.
Here are some of the most popular deep learning model architectures for image segmentation:
- U-Net: The U-Net is a convolutional neural network that is commonly used for image segmentation tasks. It uses skip connections, which can help train the network faster and result in better overall accuracy. If you have to choose, U-Net is always an excellent default choice!
- FCN: The Fully Convolutional Network (FCN) is a fully convolutional network, but it is not as deep as the U-Net. The lack of depth is mainly due to the fact that at higher network depths, the accuracy drops. This makes it faster to train, but it may not be as accurate as the U-Net.
- SegNet: SegNet is a popular model architecture similar to U-Net, and uses lesser activation memory than U-Net. We shall use SegNet in this series.
- Vision Transformer (ViT): Vision Transformers have recently gained popularity due to their simple structure and applicability of the attention mechanism to text, vision, and other domains. Vision Transformers can be more efficient (compared to CNNs) for both training and inference, but historically have needed more data to train compared to convolutional neural networks. We shall also use ViT in this series.
These are just a few of the many deep learning models that can be used for image segmentation. The best model for your specific task will depend on the factors mentioned earlier, on the specific task, and your own experiments.
Choosing the right loss function
The choice of loss function for image segmentation tasks is an important one, as it can have a significant impact on the performance of the model. There are many different loss functions available, each with its own advantages and disadvantages. The most popular loss functions for image segmentation are:
- Cross-entropy loss: Cross-entropy loss is a measure of the difference between the predicted probability distribution and the ground truth probability distribution
- IoU loss: IoU loss measures the amount of overlap between the predicted mask and ground-truth mask per class. IoU loss penalizes cases where either the prediction or recall would suffer. IoU as defined is not differentiable, so we need to slightly tweak it to use it as a loss function
- Dice loss: Dice loss is also a measure of the overlap between the predicted mask and the ground truth mask.
- Tversky loss: Tversky loss is proposed as a robust loss function that can be used to handle imbalanced datasets.
- Focal loss: Focal loss is designed to focus on hard examples, which are examples that are difficult to classify. This can be helpful for improving the performance of the model on challenging datasets.
The best loss function for a particular task will depend on the specific requirements of the task. For example, if accuracy is more important, then IoU loss or Dice loss may be better choices. If the task is imbalanced, then Tversky loss or Focal loss may be good choices. The specific loss function used may impact the rate of convergence of your model when training it.
The loss function is a hyperparameter of your model, and using a different loss based on the results we see can allow us to reduce the loss faster and improve the model’s accuracy.
Default: In this series, we shall use cross entropy loss, since it’s always a good default to choose when the results are not known.
You can use the following resources to learn more about loss functions.
Let’s take a detailed look at the IoU Loss we define below as a robust alternative to the Cross Entropy Loss for segmentation tasks.
The Custom IoU Loss
IoU is defined as intersection over union. For image segmentation tasks, we can compute this by computing (for each class), the intersection of pixels in that class as predicted by the model and in the ground truth segmentation mask.
For example, if we have 2 classes:
Then we can determine which pixels were classified as a person, and compare that with the ground truth pixels for a person, and compute the IoU for the person class. Similarly, we can compute the IoU for the background class.
Once we have these class-specific IoU metrics, we can choose to average them unweighted or weigh them before averaging them to account for any sort of class imbalance as we saw in the example earlier.
The IoU metric as defined requires us to compute hard labels for each metric. This requires the use of the argmax() function, which isn’t differentiable, so we can’t use this metric as a loss function. Hence, instead of using hard labels, we apply softmax() and use the predicted probabilities as soft labels to compute the IoU metric. This results in a differentiable metric that we can then compute the loss from. Hence, sometimes, the IoU metric is also known as the soft-IoU-metric when used in the context of a loss function.
If we have a metric (M) that takes values between 0.0 and 1.0, we can compute the loss (L) as:
L = 1 — M
However, here’s another trick one can use to convert a metric into a loss if your metric has the value between 0.0 and 1.0. Compute:
L = -log(M)
I.e. compute the negative log of the metric. This is meaningfully different from the previous formulation, and you can read about it here and here. Basically, it results in better learning for your model.
Using IoU as our loss also brings the loss function closer to capturing what we really care about. There are pros and cons of using an evaluation metric as the loss function. If you’re interested in exploring this space more, you can start with this discussion on stackexchange.
To train your model efficiently and effectively for good accuracy, one needs to be mindful of the amount and kind of training data used to train the model. The choice of training data used will significantly impact the final model’s accuracy, so if there’s one thing you wish to take away from this article series then this should be it!
Typically, we’d split our data into 3 parts with the parts being roughly in the proportions mentioned below.
- Training (80%)
- Validation (10%)
- Test (10%)
You’d train your model on the training set, evaluate accuracy on the validation set, and repeat the process till you’re happy with the reported metrics. Only then would you evaluate the model on the test set, and then report the numbers. This is done to prevent any sort of bias from creeping into your model’s architecture and hyperparameters used during training and evaluation. In general, the more you tweak your setup based on the outcomes you see with the test data, the less reliable your results will get. Hence, we must limit our decision making to only the results we see on the training and validation datasets.
In this series, we shall not use a test dataset. Instead, we’ll use our test dataset as the validation dataset, and apply data augmentation on the test dataset so that we’re always validating our models on data that’s slightly different. This kind of prevents us from overfitting our decisions on the validation dataset. This is a bit of a hack, and we’re doing this just for expediency and as a short-cut. For production model development, you should try to stick with the standard recipe mentioned above.
The dataset we’re going to use in this series has 3680 images in the training set. While this may seem like a large number of images, we want to make sure that our model doesn’t overfit on these images since we’ll be training the model for numerous epochs.
In a single training epoch, we train the model on the entire training dataset, and we’d typically train models in production for 60 or more epochs. In this series, we shall train the model only for 20 epochs for faster iteration times. To prevent overfitting, we’ll employ a technique called data augmentation that is used to generate new input data from existing input data. The basic idea behind data augmentation for image inputs is that if you change the image slightly, it feels like a new image to the model, but one can reason about whether the expected outputs would be the same. Here are some examples of data augmentations that we’ll apply in this series.
While we’re going to use the Torchvision library for applying the data augmentations above, we’d encourage you to evaluate the Albumentations data augmentation library for vision tasks as well. Both libraries have a rich set of transformations available for use with image data. We personally continue to use Torchvision simply because it’s what we started with. Albumentations supports richer primitives for data augmentation that can make changes to the input image as well as the ground truth labels or masks at the same time. For example, if you were to resize or flip an image, you’d need to make the same change to the ground truth segmentation mask. Albumentations can do this for you out of the box.
Broadly speaking, both libraries support transformations that are applied to the image either at the pixel-level or change the spatial dimensions of the image. The pixel-level transforms are called color transforms by torchvision, and the spatial transforms are called Geometric transforms by torchvision.
Below, we shall see some examples of both pixel-level as well as geometric transforms applied by the Torchvision and Albumentations libraries.
Evaluating your model’s performance
When evaluating your model’s performance, you’d want to know how it performs on a metric that is representative of the quality of the model’s performance on real world data. For example, for the image segmentation task, we’d want to know how accurately a model is able to predict the correct class for a pixel. Hence, we say that Pixel Accuracy is the validation metric for this model.
You could use your evaluation metric as the loss function (why not optimize what you really care about!) except that this may not always be possible.
To read more about various accuracy metrics applicable to image segmentation tasks, please see:
The downside of using pixel accuracy as a performance metric
While the accuracy metric may be a good default choice to measure the performance of image segmentation tasks, it does have its own drawbacks, which may be significant based on your specific situation.
For example, consider an image segmentation task to identify a person’s eyes in a picture, and mark those pixels accordingly. The model will hence classify each pixels as either one of:
Assume that there’s just 1 person in each image, and 98% of the pixels don’t correspond to an eye. In this case, the model can simply learn to predict every pixel as being a background pixel and achieve 98% pixel accuracy on the segmentation task. Wow!
In such cases, using the IoU or Dice metric may be a much better idea, since IoU would capture how much of the prediction was correct, and wouldn’t necessarily be biased by the region that each class or category occupies in the original image. You could even consider using the IoU or Dice coefficient per class as a metric. This may better capture the performance of your model for the task at hand.
When considering pixel accuracy alone, the precision and recall of the object we’re looking to compute the segmentation mask for (eyes in the example above) can capture the details we’re looking for.
Now that we have covered a large part of the theoretical underpinnings of image segmentation, let’s take a detour into considerations related to inference and deployment of image segmentation for real-world workloads.
Model size and inference latency
Last but not least, we’d want to ensure that our model has a reasonable number of parameters but not too many, since we want a small and efficient model. We shall look into this aspect in greater detail in a future post related to reducing model size using efficient model architectures.
As far as inference latency is concerned, what matters is the number of mathematical operations (mult-adds) our model executes. Both the model size and mult-adds can be displayed using the torchinfo package. While mult-adds is a great proxy for determining the model’s latency, there can be a large variation in latency across various backends. The only real way to determine the performance of your model on a specific backend or device is to profile and benchmark it on that specific device with the set of inputs you expect to see in production settings.
from torchinfo import summary
model = nn.Linear(1000, 500)
col_names=["kernel_size", "output_size", "num_params", "mult_adds"],
Layer (type:depth-idx) Kernel Shape Output Shape Param # Mult-Adds
Linear -- [1, 500] 500,500 500,500
Total params: 500,500
Trainable params: 500,500
Non-trainable params: 0
Total mult-adds (M): 0.50
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 2.00
Estimated Total Size (MB): 2.01