Controllable Medical Image Generation with ControlNets | by Walter Hugo Lopez Pinaya | Jun, 2023


ControlNet Architecture

The ControlNet architecture comprises two main components: a trainable version of the encoder from the U-Net model, including the middle blocks, and a pre-trained “locked” version of the diffusion model. Here, the locked copy preserves the generative capability, while the trainable copy is trained on specific image-to-image datasets to learn conditional control. These two components are interconnected using a “zero convolution” layer — a 1×1 convolution layer with initialized weights and biases set to zeros. The convolution weights gradually transition from zeros to optimized parameters, ensuring that during the initial training steps, the outputs of both the trainable and locked copies remain consistent with what they would be if the ControlNet were absent. In other words, when a ControlNet is applied to certain neural network blocks prior to any optimization, it does not introduce any additional influence or noise to the deep neural features.

By integrating these two components, the ControlNet enables us to govern the behaviour of each level in the Diffusion Model’s U-Net.

In our example, we instantiate the ControlNet in this script, using the following equivalent snippet.

import torch
from generative.networks.nets import ControlNet, DiffusionModelUNet

# Load pre-trained diffusion model
diffusion_model = DiffusionModelUNet(
spatial_dims=2,
in_channels=3,
out_channels=3,
num_res_blocks=2,
num_channels=[256, 512, 768],
attention_levels=[False, True, True],
with_conditioning=True,
cross_attention_dim=1024,
num_head_channels=[0, 512, 768],
)
diffusion_model.load_state_dict(torch.load("diffusion_model.pt"))

# Create ControlNet
controlnet = ControlNet(
spatial_dims=2,
in_channels=3,
num_res_blocks=2,
num_channels=[256, 512, 768],
attention_levels=[False, True, True],
with_conditioning=True,
cross_attention_dim=1024,
num_head_channels=[0, 512, 768],
conditioning_embedding_in_channels=1,
conditioning_embedding_num_channels=[64, 128, 128, 256],
)

# Create trainable copy of the diffusion model
controlnet.load_state_dict(diffusion_model.state_dict(), strict=False)

# Lock the weighht of the diffusion model
for p in diffusion_model.parameters():
p.requires_grad = False

Since we are using a Latent Diffusion Model, this requires ControlNets to convert image-based conditions to the same latent space to match the convolution size. For that, we use a convolutional network trained jointly with the full model. In our case, we have three downsampling levels (similar to the autoencoder KL) defined in “conditioning_embedding_num_channels=[64, 128, 128, 256]”. Since our conditional image is a FLAIR image with one channel, we also need to specify its input number of channels in “conditioning_embedding_in_channels=1”.

After initialising our network, we train it similarly to a diffusion model. In the following snippet (and in this part of the code), we can see that first we pass our conditional FLAIR image to the trainable network and obtain the outputs from its skip connections. Then, these values are inputted into the diffusion model when computing the predicted noise. Internally, the diffusion model sums the skip connection from the ControlNets with its own ones before feeding the decoder part (code).


# Training Loop
...

images = batch["t1w"].to(device)
cond = batch["flair"].to(device)

...

noise = torch.randn_like(latent_representation).to(device)
noisy_z = scheduler.add_noise(
original_samples=latent_representation, noise=noise, timesteps=timesteps
)

# Compute trainable part
down_block_res_samples, mid_block_res_sample = controlnet(
x=noisy_z, timesteps=timesteps, context=prompt_embeds, controlnet_cond=cond
)

# Using controlnet outputs to control diffusion model behaviour
noise_pred = diffusion_model(
x=noisy_z,
timesteps=timesteps,
context=prompt_embeds,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
)

# Then compute diffusion model loss as usual
...



Source link

Leave a Comment