Generating Images Using VAEs, GANs, and Diffusion Models | by Justin Cheigh | May, 2023


Learn how to generate images using VAEs, DCGANs, and DDPMs

Image by CatBird AI, Image Prompt by ChatGPT, ChatGPT Prompted by Justin Cheigh

Introduction:

We’re currently in the midst of a generative AI boom. In November 2022, Open AI’s generative language model ChatGPT shook up the world, and in March 2023 we even got GPT-4!

Even though the future of these LLMs is extremely exciting, today we will be focusing on image generation. With the rise of diffusion models, image generation took a giant leap forward. Now we’re surrounded by models like DALL-E 2, Stable Diffusion, and Midjourney. For example, see the following image below. Just to show the power of these LLMs, I gave ChatGPT a very simple prompt, which I then fed into the free CatbirdAI. CatbirdAI uses different models, including Openjourney, Dreamlike Diffusion, and more:

In this article Daisuke Yamada (my co-author) and I will work towards diffusion. We’ll use 3 different models and generate images in the style of MNIST handwritten digits using each of them. The first model will be a traditional Variational Autoencoder (VAE). We’ll then discuss GANs and implement a Deep Convolution GAN (DCGAN). Finally, we’ll turn to diffusion models and implement the model described in the paper Denoising Diffusion Probabilistic Models. For each model we’ll go through the theory working behind the scenes before implementing in Tensorflow/Keras.

A quick note on notation. We will use try to use subscripts like x₀, but there may be times where instead we will have to use x_T to denote subscript.

Let’s briefly discuss prerequisites. It’s important to be familiar with deep learning and comfortable with Tensorflow/Keras. Further, you should be familiar with VAEs and GANs; we will go over the main theory but prior experience will be helpful. If you’ve never seen these models, check out these helpful sources: MIT S6.191 Lecture, Stanford Generative Model Lecture, VAE Blog. Finally, there’s no need to be familiar with DCGANs or diffusion. Great! Let’s get started.

Generative Model Trilemma:

As an unsupervised process, generative AI often lacks well defined metrics to track progress. But before we approach any methods to evaluate generative models, we need to understand what generative AI is actually trying to accomplish! The goal of generative AI is to take training samples from some unknown, complex data distribution (e.g., the distribution of human faces) and learn a model that can “capture this distribution”. So, what factors are relevant in evaluating such a model?

We certainly want high quality samples, i.e. the generated data should be realistic and accurate compared to the actual data distribution. Intuitively we can just subjectively evaluate this by looking at the outputs. This is formalized and standardized in a benchmark known as HYPE (Human eYe Perceptual Evaluation). Although there are other quantitative methods, today we will just rely on our own subjective evaluation.

It’s also important to have fast sampling (i.e., the speed of generation, or scalability). One particular aspect we will look at is the number of network passes required to generate a new sample. For example, we will see that GANs will require just one pass of the generator network to turn noise into a (hopefully) realistic data sample, while DDPMs require sequential generation, which ends up making it much slower.

A final important quality is known as mode coverage. We don’t just want to learn a specific part of the unknown distribution, but rather we want to capture the entire distribution to ensure sample diversity. For example, we don’t want a model that just outputs images of 0s and 1s, but rather all possible digit classes.

Each of these three important factors (quality of samples, speed of sampling, and mode coverage), are covered in the Generative Model Trilemma”.

Image Created by Daisuke Yamada, Inspired by Figure 1 in DDGANs Paper

Now that we understand how we will compare and contrast these models, let’s dive into VAEs!

Variational Autoencoder:

One of the first generative models that you will encounter is the Variational Autoencoder (VAE). Since VAEs are just traditional autoencoders with a probabilistic spin, let’s remind ourselves of autoencoders.

Autoencoders are dimensionality reduction models that learn to compress data into some latent representation:

Image Created by Justin Cheigh

The encoder compresses the input into a latent representation called the bottleneck, and then the decoder reconstructs the input. The decoder reconstructing the input means we can train with L2 loss between input/output.

Autoencoders cannot be used for image generation since they overfit, which leads to a sparse latent space that is discontinuous and disconnected (non-regularizable). VAEs fix this by encoding the input x as a distribution over the latent space:

The input x gets fed into the encoder E. The output E(x) is a vector of means and vector of standard deviations which parameterize a distribution P(z | x). The common choice is a multivariate standard Gaussian. From here we sample z ~ P(z | x), and finally the decoder attempts to reconstruct x from z (just like with the autoencoder).

Notice this sampling process is non-differentiable, so we need to change something to allow backpropagation to be possible. To do so we use the reparameterization trick, where we move sampling to an input layer by first sampling ϵ ~ N(0,1). Then we can perform a fixed sampling step: z = μ + σ ϵ. Notice we get the same sampling, but now we have a clear path to backpropagate error since the only stochastic node is an input!

Recall training for autoencoders is L2 loss, which constitutes a reconstruction term. For VAEs, we also add a regularization term, which is used to make the latent space “well-behaved”:

Here, the first term is a reconstruction term, whereas the second term is a regularization term. Specifically here we are using the Kullback-Leibler (KL) divergence between the learned distribution over the latent space and a prior distribution. This measures the similarity between 2 distributions and helps prevent overfitting.

Great! We’ve recapped the theory and intuition behind VAEs, and we will now discuss implementation details. After importing relevant libraries, we define a few hyperparameter values:

latent_dim = 2 # dimension of latent space

epochs = 200
batch_size = 32
learning_rate = 1e-4

After downloading the MNIST dataset and doing some basic preprocessing, we define our loss function and network:

def get_default_loss(model, x):
with tf.device(device):
mean, logvar, z = model.encoder(x)
xhat = model.decoder(z)
rl = tf.reduce_mean(keras.losses.binary_crossentropy(x, xhat))*28*28
kl = tf.reduce_mean(1+logvar-tf.square(mean)-tf.exp(logvar)) * -0.5
return rl + kl

'''Sampling layer'''
class Sampling(Layer):

def call(self, prob):
# uses reparameterization trick
mean, logvar = tf.split(prob, num_or_size_splits=2, axis=1)
e = random.normal(shape=(tf.shape(mean)[0], tf.shape(mean)[1]))
z = mean + e * tf.exp(logvar * 0.5)
return mean, logvar, z

'''Basic Convolutional VAE'''
class VAE(Model):

def __init__(self, latent_dim, **kwargs):
super(VAE, self).__init__(**kwargs)
self.latent_dim = latent_dim
self.encoder = self.get_encoder()
self.decoder = self.get_decoder()

'''encoder + reparametrization (i.e., sampling) layer'''
def get_encoder(self):
# encoder
input_x = Input(shape=(28,28,1))
x = Conv2D(filters=64, kernel_size=3, strides=(2,2), activation='relu')(input_x)
x = Conv2D(filters=64, kernel_size=3, strides=(2,2), activation='relu')(x)
x = Flatten()(x)
x = Dense(self.latent_dim * 2)(x)
# sampling
(mean, logvar, z) = Sampling()(x)
return Model(input_x, [mean, logvar, z], name="encoder")

'''decoder'''
def get_decoder(self):
input_z = Input(shape=(self.latent_dim,))
z = Dense(7*7*64, activation="relu")(input_z)
z = Reshape((7, 7, 64))(z)
z = Conv2DTranspose(filters=64, kernel_size=3, strides=2, padding='same', activation='relu')(z)
z = Conv2DTranspose(filters=64, kernel_size=3, strides=2, padding='same', activation='relu')(z)
xhat = Conv2DTranspose(filters=1, kernel_size=3, strides=1, padding='same', activation='sigmoid')(z)
return Model(input_z, xhat, name="decoder")

'''train'''
def train_step(self, x):
with tf.device(device):
x = x[0] if isinstance(x, tuple) else x
with tf.GradientTape() as tape:
loss = get_default_loss(self, x)
gradient = tape.gradient(loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(gradient, self.trainable_weights))
return {"loss": loss}

def call(self, inputs):
pass

vae = VAE(latent_dim=latent_dim)
vae.compile(optimizer=Adam(learning_rate=learning_rate))
vae.build(input_shape=(28,28,1))
vae.summary()

Ok, let’s break this down. The get_default_loss(model, x) function takes in a VAE model and some input x and returns the VAE loss we defined before (with C = 1). We defined a convolutional VAE, where the encoder uses Conv2D layers to downsample, and the decoder uses Conv2DTranspose layers (deconvolution layers) to upsample. We optimized with Adam.

Since the other two generative models begin with random noise, rather than use some input image we simply sampled from the latent space and used the decoder to generate new images. We tested latent_dim = 2 and latent_dim = 100 and obtained the following results:

Generated Images for VAE. Image Created by Daisuke Yamada

Since we just do one forward pass for generating new samples, our sampling is fast. Further, this is comparatively a simple model, so our training is fast. The results in dimension 2 (meaning the dimension of our bottleneck latent representation is 2) are good but a bit blurry. However, our results in dimension 100 are not that good. We think either we lacked computing power or maybe the posterior began to spread over non-existent modes. In other words, we begin to learn unmeaningful latent features.

So, how could one theoretically choose the “optimal” latent dimension? Clearly 100 is not good, but perhaps there’s something in between 2 and 100 that is ideal. There’s a tradeoff here between sample quality and computational efficiency. So, you could determine how important each of these factors is for you and do something like a grid search to correctly choose this hyperparameter.

We also plotted the latent space in dimension 2. Basically, the following tells us what the decoder outputs based on where in the latent space we begin with.

Our VAE Latent Space

As you can see the latent space is decently diverse and is pretty complete and continuous! So, reflecting on the Generative Model Trilemma we get the following:

Trilemma for VAE (Red is Good, Blue is Bad). Image Created by Daisuke Yamada

We’ll now shift gears and discuss DCGANs, and we’ll begin with an accelerated explanation of GANs.

Deep Convolutional GANs:

In GANs, there is a generator G and a discriminator D. The generator creates new data, and the discriminator differentiates (or discriminates) between real and fake data. The two are trained against each other in a mini-max game fashion, hence the term adversarial.

We are given some training data, and we begin by sampling random noise z using either a standard normal or uniform distribution. This noise is the latent representation of the data to be generated. We start with noise to allow for more diverse data samples and to avoid overfitting.

The noise is fed into the generator, which outputs the generated data x = G(z). The discriminator then takes x and outputs P[x = real] = D(x), i.e. the probability that the generated image x is a real image. Additionally, we feed the discriminator a real image from the training set.

We typically define the loss as a mini-max game:

GANs Loss

Notice for the discriminator this looks like binary cross entropy, which makes sense since it’s a binary classifier. The fancy looking expected value over points sampled from each distribution really corresponds to what you would expect to get if you take a data point from (a) the data distribution (E_{x ~ p(data)}) and (b) random noise (E_{z ~ p(z)). The first term expresses that the discriminator wants to maximize the likelihood of classifying real data as 1, whereas the second term expresses that the discriminator wants to maximize the likelihood of classifying fake data as 0. The discriminator also acts under the mini-max assumption that the generator will act optimally.

Ok, we’ll now transition to DCGANs. DCGANs are like GANs, with a few notable changes in architecture; the main one is that DCGANs don’t use any multilayer perceptrons and instead utilizes convolutions/deconvolutions. Below are the architectural guidelines for stable DCGANs (from the original paper):

  • Replace any pooling layers with strided convolutions (discriminator) and fractional-strided convolutions (i.e. deconvolutions) (generator)
  • Use batchnorm in both the generator and the discriminator
  • Remove fully connected hidden layers for deeper architectures
  • Use ReLU activation in generator for all layers except for the output, which uses Tanh
  • Use LeakyReLU activation in the discriminator for all layers.
DCCGAN Architecture- Image Created by Justin Cheigh

We often use GANs for image generation, so intuitively it makes sense to use convolutional layers. We use standard convolutional layers in the discriminator, as we want to down-sample the image into hierarchical features, while for generators we use deconvolutional layers to up-sample the image from noise (latent representation) to the generated image.

Batch normalization is used to stabilize the training process, improve convergence, and enable faster learning. Leaky ReLU prevents the zero learning problem of ReLu. Finally, Tanh is used to prevent saturation of smaller inputs and avoid the vanishing gradient problem (since it’s symmetric around the origin).

Great! Now let’s see how to implement DCGANs. After importing libraries we set hyperparameter values:

latent_dim = 100
epochs = 100
batch_size = 32
learning_rate = 1e-4

After some data preprocessing and splitting into batches for computational efficiency, we are ready to define the generator and discriminator:

generator = Sequential([
# input
Dense(units=7*7*128, input_shape=(latent_dim,)),
Reshape((7,7,128)),
# conv 1
Conv2DTranspose(filters=64, kernel_size=3, strides=2, padding='same'),
BatchNormalization(),
ReLU(max_value=0.2),
# conv 2
Conv2DTranspose(filters=64, kernel_size=3, strides=1, padding='same'),
BatchNormalization(),
ReLU(max_value=0.2),
# final tanh
Conv2DTranspose(filters=1, kernel_size=3, strides=2, padding='same', activation='tanh')
])

discriminator = Sequential([
# conv 1
Conv2D(filters=64, kernel_size=3, strides=2, padding='same', input_shape=(28,28,1)),
LeakyReLU(0.2),
# conv 2
Conv2D(filters=64, kernel_size=3, strides=2, padding='same'),
# BatchNormalization(),
LeakyReLU(0.2),
# output
Flatten(),
Dense(1, activation='sigmoid')
])

discriminator.compile(optimizer=Adam(learning_rate=learning_rate),
loss=BinaryCrossentropy(),
metrics=[BinaryAccuracy()])
discriminator.trainable = False
gan = Sequential([
generator,
discriminator
])
gan.compile(optimizer=Adam(learning_rate=learning_rate), loss=BinaryCrossentropy(), metrics=[BinaryAccuracy()])

Remember all of the architectural guidelines for DCGANs! We are using Conv2DTranspose for the generator, and regular Conv2D with stride for the discriminator. Notice we compile the discriminator with the loss Binary cross entropy, yet specify the discriminator as trainable = False. This is because we will implement the training loop ourselves:

def generate():
# generate image based on random noise z
with tf.device(device):
random_vector_z = np.random.normal(loc=0, scale=1, size=(16, latent_dim))
generated = generator(random_vector_z)
return generate

# labels
real = np.ones(shape=(batch_size, 1))
fake = np.zeros(shape=(batch_size, 1))

# generator and discriminator losses
g_losses, d_losses = [], []

with tf.device(device_name=device):
for epoch in range(epochs):
for real_x in x_train_digits:
'''discriminator'''
# train on real data
d_loss_real = discriminator.train_on_batch(x=real_x, y=real)

# train on fake data
z = np.random.normal(loc=0, scale=1, size=(batch_size, latent_dim))
fake_x = generator.predict_on_batch(x=z)
d_loss_fake = discriminator.train_on_batch(x=fake_x, y=fake)

# total loss
d_loss = np.mean(d_loss_real + d_loss_fake)

'''generator'''
g_loss = gan.train_on_batch(x=z, y=real)

g_losses.append(g_loss[-1])
d_losses.append(d_loss)

First we train the discriminator with both real data and fake data created from the generator, and then we train the generator as well. Great! Let’s see what the results look like:

Image Created by Daisuke Yamada

Our sample quality is better (less blurry)! We still have fast sampling, as inference is just inputting random noise to the generator. Below is our latent space:

Our DCGAN Latent Space

Unfortunately, our latent space is not very diverse (this is especially evident from the samples in Epoch 1 with latent dimension 2). We are likely experiencing a common issue of mode collapse. Formally, this means the generator only learns to create a subset specialized in fooling discriminator. In other words, if the discriminator doesn’t do well when the generator creates images of 1s, there’s no reason for the generator to do anything else. So, for the generative model trilemma, we get the following:

Trilemma for DCGANs (Red is Good, Blue is Bad). Image Created by Daisuke Yamada

Now that we explored GANs and DCGANs, it’s time to transition to diffusion models!

Diffusion Models:

Diffusion probabilistic models (or just diffusion models) are currently a part of every top image generation model. We will be discussing Denoising Diffusion Probabilistic Models (DDPMs).

For both VAEs and GANs, sample generation involves going from noise to a generated image in one step. GANs perform inference by taking noise and doing a forward pass through the generator, and VAEs perform inference by sampling noise and passing it through the decoder. The main idea of diffusion is to generate a sequence of images, where every subsequent image is slight less noisy, with the final image ideally being realistic! There are two aspects of DDPMs. In the forward process we take real images and iteratively add noise. In the reverse process we learn how to undo the noise added in the forward process:

Let’s explain the forward and reverse process at an intuitive level. We are given a set of training data X, where each data point x₀X₀ is sampled from a data distribution x₀ ~ q(x₀). Recall q(x₀) is the unknown distribution we want to represent.

From right to left is the hardcoded forward process where we take some training sample x₀ ~ q(x₀) and iteratively add Gaussian noise. Namely we will generate a sequence of images x₁, x₂, …, x_T, with each subsequent image in the sequence being more and more noisy. In the end, we will end up with something that can be thought of as pure noise! From left to right we have the reverse process, where we learn how to denoise, i.e. predict how to get from x_{t+1} → xₜ. Great! Now that we understand the basics of the forward and reverse process, let’s dive into the theory!

Formally, the forward process is described by a Markov chain that iteratively adds Gaussian noise to the data according to a pre-determined variance schedule β₁, …, β_T. The term Markov chain just means that x_{t+1} only depends on xₜ. So, x_{t+1} is conditionally independent of x₁, …, x_ₜ₋₁ given xₜ. which means q(xₜ | x₀, …, xₜ₋₁) = q(xₜ | xₜ₋₁). The other important concept is a variance schedule. We define some values β₁, …, β_T which are used to parameterize the Gaussian noise we add at each time step. Typically 0 ≤ β ≤ 1 with β₁ small and β_T large. All of this is put in our definition of q(xₜ | xₜ₋₁):

Forward Diffusion Process

So, we start with x₀. Then, for T timesteps, we follow the above equation to get to the next image in the sequence: xₜ ~ q(xₜ | xₜ₋₁). One can prove in the limit T → ∞ that x_T is equivalent to an isotropic Gaussian distribution.

Except we’re not done yet. We intuitively should be able to get from x₀ to any x_t in one step by expanding recursively. We first use a reparameterization trick (like with VAEs):

Reparameterization Trick

This allows us to do the following:

Sampling Arbitrary Timestep t

By following the above equation we can get from x₀ to any xₜ in one step! For those curious the derivation involves expanding and using the addition property of Gaussians. Let’s move on to the reverse process.

In the reverse process our goal is to know q(xₜ₋₁| xₜ) since we can just take random noise and iteratively sample from q(xₜ₋₁ | xₜ) to generate a realistic image. One may think we can easily obtain q(xₜ₋₁ | xₜ) using Bayes rule, but it turns out this is computationally intractable. This intuitively makes sense; to reverse the forward step we need to look at xₜ and consider all the possible ways we could have gotten there.

So, rather than directly computing q(xₜ₋₁ | xₜ), we will learn a model p with weights θ that approximates these conditional probabilities. Luckily, we can successfully estimate q(xₜ₋₁| xₜ) as a Gaussian if βₜ is sufficiently small. This insight is due to some incredibly difficult theory involving stochastic differential equations. So we can define p as the following:

Model Definition

So, what is our loss? Well, if we want to undo the noise added, intuitively it should suffice to just predict the added noise. To see a more complete derivation please check out this great blog by Lilian Weng. But it turns out our intuition is true, and rather than have a model p, instead we can just have a network ϵ_θ that predicts noise added. With this we can train using MSE between the actual and predicted noise:

Final Loss

Here, ϵ is the actual error, whereas the other term is the predicted error. You may notice the expectation is taken over x_0; this is because usually the error is written in terms of the reparameterization trick (from above), which allows you to obtain x_t directly from x_0. Thus our network inputs are the time t and the current image xₜ.

Let’s do a full recap. We train a network ϵ_θ using MSE to learn how to predict noise added. Once trained, we can use our neural network ϵ_θ to predict the noise added at any timestep. Using this noise and some of the above equations we complete the reverse process and effectively can “denoise”. We therefore can perform inference by taking noise and continuously denoise. Both this sampling process and train process are described by the following pieces of pseudocode:

Pseudocode for Training/Sampling

In training we take a real image, sample t ~ Uniform({1,2,…,T}) (we do this since it’s computationally inefficient to do every step), then take a gradient descent step on the MSE of target/predicted noise. In sampling we take random noise then continuously sample using our predicted noise and our derived equations, until we get to some generated image x₀.

Great! We can now move on to implementation details. For our underlying architecture we will use a U-Net:

Image Created by Justin Cheigh; Inspiration

From the architecture it’s pretty clear why this is called a U-Net! U-Nets were initially used in biomedical image segmentation, but they also work very well with diffusion models! Intuitively, this is true because (a) the input and output shape are the same, which is exactly what we need, and (b) we will see that U-Nets (due to the encoder-decoder structure paired with the skip connections) are good at preserving both local/global information, which helps retain our image but still add noise effectively.

The U-Net has a similar encoder-decoder structure as past generative models. Specifically if you look at the image following the shape of the “U”, you can see on the way down we have a sequence of downsampling layers, each of which are part of the encoder structure. On the way up, we have a sequence of upsampling layers, which are part of the decoder structure. The input and output have the same shape, which is ideal given our input is a noisy image xₜ and our output is some predicted noise.

However, you may notice there is one important difference between a U-Net and a standard autoencoder, which are the skip connections. At each level we have a downsampling block, which connects to another downsampling block (following the shape of the “U”), and a skip connection to an upsampling block. Remember these downsampling blocks basically are looking at the image at different resolutions (learning different levels of hierarchical features). By having these skip connections we ensure that we account for each of these features at each resolution! Another way to think of a U-Net is as a sequence of stacked autoencoders.

Ok, now let’s look at our specific implementation. First of all, I lied… I said that our input is just the noisy image xₜ. However, we also input the actual timestep t in order to give us a notion of time. The way we do so is using a time step embedding, where we take the time t and use a sinusoidal positional embedding:

Sinusoidal Position Embedding

For those unfamiliar, a high level overview of sinusoidal position embedding is that we encode elements in some sequence (here just the timesteps) using sinusoidal functions, with the intuition being the smooth structure of these functions will be easier for neural networks to learn from. So, our actual input is our noisy image xₜ and our time step t, which initially goes through this time embedding layer.

We then have our downsampling/upsampling blocks: each downsampling (upsampling) block contains 2 ResNet blocks, 1 Attention layer, and 1 Convolution (deconvolution) layer. Let’s quickly go over each of these.

Residual Networks, or ResNet, are basically a sequence of convolutional layers with large skip connections, which allow information flow across very deep neural networks. Attention, a revolutionary idea crucial in understanding fundamental architectures like the Transformer, tells the neural network what to focus on. For example, here we have 2 ResNet blocks. After these blocks we will have the input image as a vector of latent features, and the attention layer will tell the neural network which of these features are most important to focus on. Finally, the standard convolution/deconvolution allows for down/upsampling, respectively.

We use 4 of these stacked autoencoders in our implementation:

       
def call(self, x, time, training=True, **kwargs):
with tf.device(device):
# front conv
x = self.init_conv(x)

# time embedding
t = self.time_mlp(time)

# move down the encoder
h = []
for down_block1, down_block2, attention, downsample in self.downs:
x = down_block1(x, t)
x = down_block2(x, t)
x = attention(x)
h.append(x) # keep for skip connection!
x = downsample(x)

# bottleneck consists of
x = self.mid_block1(x, t) # ResNet block
x = self.mid_attn(x) # Attention layer
x = self.mid_block2(x, t) # ResNet block

# move up the decoder
for up_block1, up_block2, attention, upsample in self.ups:
x = tf.concat([x, h.pop()], axis=-1)
x = up_block1(x, t)
x = up_block2(x, t)
x = attention(x)
x = upsample(x)
x = tf.concat([x, h.pop()], axis=-1)

# back conv
x = self.final_conv(x)
return x

Great! Now that we have defined our U-Net class, we can move on to using the U-Net for our specific problem. We first define relevant hyperparameters:

image_size = (32, 32)
num_channel = 1
batch_size = 64
timesteps = 200
learning_rate = 1e-4
epochs = 10

Due to lack of computing power, we use timesteps = T = 200, even though the original paper used T = 1000. After data preprocessing, we define the forward process

# define forward pass
beta = np.linspace(0.0001, 0.02, timesteps) # variance schedule
alpha = 1 - beta
a = np.concatenate((np.array([1.]), np.cumprod(alpha, 0)[:-1]), axis=0) # alpha bar

def forward(x_0, t):
# uses trick to sample from arbitrary timestep!
with tf.device(device):
noise_t = np.random.normal(size=x_0.shape)
sqrt_a_t = np.reshape(np.take(np.sqrt(a), t), (-1, 1, 1, 1))
sqrt_one_minus_a_t = np.reshape(np.take(np.sqrt(1-a), t), (-1, 1, 1, 1))
x_t = sqrt_a_t * x_0 + sqrt_one_minus_a_t * noise_t
return noise_t, x_t

So, here we define our variance schedule in a pretty standard way. In the forward function we use the reparameterization trick that allows us to sample arbitrary xₜ from x₀. Below is a visualization of the forward process:

Visualization of Our Forward Process

We then instantiate our U-Net, define our loss function, and define the training process:


unet = Unet() # instantiate model

def loss(noise, predicted):
# remember we just use MSE!
with tf.device(device):
return tf.math.reduce_mean((noise-predicted)**2)

optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
def train_step(train_images):
with tf.device(device):
# create a "batch" number of random timesteps (in our case 64)
timestep_values = tf.random.uniform(shape=[train_images.shape[0]], minval=0, maxval=timesteps, dtype=tf.int32)

# forward
noised_images, noise = forward(x_0=train_images, t=timestep_values)

# set the gradient and get the prediction
with tf.GradientTape() as tape:
predicted = unet(x=noised_images, time=timestep_values)
loss_value = loss(noise, predicted)

# optimize U-Net using ADAM
gradients = tape.gradient(loss_value, unet.trainable_variables)
optimizer.apply_gradients(zip(gradients, unet.trainable_variables))

return loss_value

def train(epochs):
with tf.device(device):
for epoch in range(epochs):
losses = []
for i, batch_images in enumerate(iter(dataset)):
loss = train_step(batch_images)
losses.append(loss)

Remember, our loss (after lots of work) is just MSE! The rest is a fairly standard training loop. After training, we can think about inference. Recalling our Sampling Algorithm 2, we implement as follows:

def denoise_x(x_t, pred_noise, t):
with tf.device(device):
# obtain variables
alpha_t = np.take(alpha, t)
a_t = np.take(a, t)

# calculate denoised_x (i.e., x_{t-1})
beta_t = np.take(beta, t)
z = np.random.normal(size=x_t.shape)
denoised_x = (1/np.sqrt(alpha_t)) * (x_t - ((1-alpha_t)/np.sqrt(1-a_t))*pred_noise) + np.sqrt(beta_t) * z
return denoised_x

def backward(x, i):
with tf.device(device):
t = np.expand_dims(np.array(timesteps-i-1, np.int32), 0)
pred_noise = unet(x, t)
return denoise_x(x, pred_noise, t)

Here we define how to take an image at a certain time step and denoise it. With this we can fully define our inference process:

def get_sample(x=None):
# generate noise
if x is None:
x = tf.random.normal(shape=(1,32,32,1))

# array to store images
imgs = [np.squeeze(np.squeeze(x, 0),-1)]

# backward process
for i in tqdm(range(timesteps-1)):
x = backward(x, i)
if i in [0,25,50,75,100,125,150,175,198]:
imgs.append(np.squeeze(np.squeeze(x, 0),-1))

return imgs if show_progress else imgs[-1]

We take random noise, then continuously use our backward function to denoise, until we get to a realistic looking image! And here are some of our results:

Examples of Our DDPMs Generated Samples

The samples are decently high quality. Further, we were able to get a diverse range of samples. Presumably our sample quality would improve with more computing power; diffusion is very computationally expensive, which impacted our ability to train this model. One can also “reverse engineer”. We take a training image, noise it, and then denoise it to see our ability to reconstruct the image. We get the following:

Reverse Engineering. Image Created by Daisuke Yamada

It is important to note that the reverse process is probabilistic, meaning we don’t always end up with even a similar image as our input image.

Great! Let’s go back to the Generative Model Trilemma. We have high quality samples, a diverse range of samples, and more stable training (this happens as a byproduct of doing these iterative steps). However, we have slow training and slow sampling, as we need to sample over and over again during inference. So, we get the following:

Trilemma for DDPMs (Red is Good, Blue is Bad). Image Created by Daisuke Yamada

Conclusion:

Wow! We covered 3 image generation models, going all the way from standard VAEs to DDPMs. For each we looked at the Generative Model Trilemma and obtained the following results:

Comparing Models. Image Created by Daisuke Yamada

The natural question is: can we get all 3 parts of the Generative Model Trilemma? Well it seems like diffusion is almost there, as we just need to figure out a way to increase the speed of sampling. Intuitively this is difficult because we relied on the assumption that we can model the reverse process as a Gaussian, which only works if we do the reverse process at nearly all timesteps.

However, it turns out getting all 3 factors of the Trilemma is possible! Models like DDIMs or DDGANs build on top of DDPMs, but they have figured out ways to increase the speed of sampling (one way is to use a strided sampling schedule). With this and different other optimizations, we can obtain all 3 facets of the Generative Model Trilemma!

So, what’s next? One particular interesting avenue is conditional generation. Conditional generation allows you to generate new samples conditioned on some class labels or descriptive text. For example, in all of the image generation models initially listed you can input something like “Penguin bench pressing 1000 pounds” and get a reasonable output. Although we didn’t have time to explore this avenue of conditional generation, it seems like a very interesting next step!

Well, that’s all from us. Thank you for reading!

Unless otherwise stated, all images are created by the author(s).



Source link

Leave a Comment