Training GANs can be very difficult and confusing at times. In this post, I introduce some techniques for training more stable GANs.
Note that this is based on my personal experience and the results may vary, depending on the task.
and , but andNot
When training a GAN, there are two losses, a loss for the generator and a loss for the discriminator .
The loss for the discriminator is trained to classify real images as "real" and generated images as "fake":
The losses are calculated as follows:
When training a non-conditional GAN, the input is a random seed , which is converted to an image. Such a noise is used 2 times, once when calculating and once again when calculating .
PyTorch and Tensorflow advise passing backwards through a given network only once. Due to the calculation of the gradient of the generator twice, it becomes necessary to generate an image twice.
However, if and are different, the model does not train well. Instead, using the same seed, i.e. yields much better results.
A better solution to this issue is to use a function called
.detach(), which detaches the output from the generator and treats it as a static input.
real_label = torch.ones(batch_size) fake_label = torch.zeros(batch_size) optim_generator = torch.optim.Adam(lr=lr, params=generator.parameters()) optim_discriminator = torch.optim.Adam(lr=lr, params=discriminator.parameters()) for real_images in dataloader: z = torch.randn([batch_size, latent_dim]) fake_images = generator(z) # Train discriminator ## with fake images fake_logits = discriminator(fake_images.detach()) dis_loss_fake = nn.BCELoss(fake_logits, fake_label) ## with real images real_logits = discriminator(real_images.detach()) dis_loss_real = nn.BCELoss(real_logits, real_label) ## combine losses dis_loss = dis_loss_fake + dis_loss_real ## train optim_discriminator.zero_grad() dis_loss.backward() optim_discriminator.step() # Train generator fake_logits = discriminator(fake_images.detach()) gen_loss = nn.BCELoss(fake_logits, real_label) ## train optim_generator.zero_grad() gen_loss.backward() optim_generator.step()
Use equalized learning rate
Equalized learning rate is a technique for improving learning stability, as described in "Progressive growing of gans for improved quality, stability, and variation". Instead of careful initialization of the weights, it scales them dynamically during runtime. This allows for alleviating the scale-invariance in common optimizers.
The weights are initialized with and the weights are scaled according to the number of input features and a given multiplier. The multiplier is fixed at initialization and can be set for each layer. For example, setting
lr_multiplier=0.01 for the mapping network in StyleGAN2 and similar networks helps tremendously.
def __init__(self, lr_multiplier, in_features) self.weight_gain = lr_multiplier / np.sqrt(in_features) def forward(self, x): w = self.weight.to(x.dtype) * self.weight_gain
A standard alternative approach is to initialize the weights, according to the number of inputs, described in the following section.
Use Gaussian distribution for weight initialization
Due to the GAN's instability in training, it is sensitive to the initialization.
A best practice is to initialize the weights, according to a normal distribution. Its variance is calculated for each layer, according to the number of input features.
An easy way to implement this in PyTorch:
# PyTorch def init_normal(m): if type(m) == nn.Linear: y = m.in_features m.weight.data.normal_(0.0,1/np.sqrt(y)) model.apply(init_normal)
When using weight modulation, increase the learning rate
Weight modulation is a powerful tool for combining two input streams. It is commonly used in generative models for combining a style embedding to position embeddings.
Instead of standard techniques, such as concatenation or summation or the two inputs, it modulates the weights directly, as follows:
An optional term is to normalize the weights, called "demodulation":
However, introducing demodulation can alter the scale of the position embeddings, which is crucial information in some models.
When integrating weight modulation, please try it both with and without demodulation.
An example implementation for a linear layer, heavily based on CIPS:
def __init__(self, ...): ... self.scale = 1 / np.sqrt(in_channels) self.weight = nn.Parameter( torch.randn(1, out_channels, in_channels, 1) ) if self.style_size != self.in_channels: self.modulation = FullyConnectedLayer(style_size, in_channels, bias=False) self.demodulation = demodulation def forward(self, input, style): if self.style_size != self.in_channels: style = self.modulation(style) style = style.view(batch_size, 1, self.in_channels, 1) weight = self.scale * self.weight * style if self.demodulation: demod = torch.rsqrt(weight.pow(2).sum() + 1e-8) weight = weight * demod.view(batch_size, self.out_channels, 1, 1) weight = weight.view( batch_size * self.out_channels, self.in_channels, 1 ) input = input.reshape(1, batch_size * self.in_channels, img_size) out = F.conv1d(input, weight, groups=batch_size) out = out.view(batch_size, -1, self.out_channels)
For a 2d convolution implementation, refer to CIPS.
Weight modulation is a very powerful technique, but it requires a significantly larger learning rate, compared to traditional approaches. But I noticed that it trains comparatively slower and in most networks, which use it, the learning rate is set to
0.002 for a batch size of
Adjusting the learning rate is extremely important
GANs are notoriously difficult to train and it is essential to find a good balance between the generator and the discriminator. The images, produced by the generator start looking similar to the training dataset, only after a while, and waiting to see whether a given learning rate could eventually produce good results can be tedious.
For this reason, I recommend training the generator to output fixed images from the dataset, based on some fixed noise. If the generator cannot train to map a fixed noise to a fixed set of images, then it will definitely not train in the real training.
ground_truth, _ = next(iter(trainloader)) z = torch.randn([batch_size, latent_dim]).to(device) for step in range(total_steps): model_output = Generator(z) loss = ((model_output - ground_truth)**2).mean() if step % steps_til_summary == 0: print("Step %d, Total loss %0.6f" % (step, loss)) # Plot some of the generated images
Using such a toy experiment for the generator allows for rapid testing of different learning rates. Finding the optimal learning rate, when one experiment takes only 5-10 minutes to train is much easier.
However, the optimal learning rate for the toy experiment does not guarantee that it will lead to a good balance with the discriminator, it only guarantees that the generator can learn steadily, without diverging. The real training might require a little bit lower learning rate and a separate adjustment of the learning rate for the discriminator.