Fusic Tech Blog

Fusion of Society, IT and Culture

Tips and Tricks for training GANs
2021/09/28

Tips and Tricks for training GANs

Training GANs can be very difficult and confusing at times. In this post, I introduce some techniques for training more stable GANs.

Note that this is based on my personal experience and the results may vary, depending on the task.

Not G(z1)G(z_1) and D(G(z2))D(G(z_2)), but G(z1)G(z_1) and D(G(z1))D(G(z_1))

When training a GAN, there are two losses, a loss for the generator LGL_G and a loss for the discriminator LDL_D.

The loss for the discriminator is trained to classify real images as "real" and generated images as "fake":

LD=LDfake+LDrealL_D=L_{D_{fake}}+L_{D_{real}}

The losses are calculated as follows:

z1,z2Rkz_1, z_2 \sim \mathbb{R}^k
LDfake=D(G(z1))L_{D_{fake}} = D(G(z_1))
LDreal=1D(image)L_{D_{real}} = 1 - D(image)
LG=1D(G(z2))L_{G} = 1 - D(G(z_2))

When training a non-conditional GAN, the input is a random seed zz, which is converted to an image. Such a noise zz is used 2 times, once when calculating LDfakeL_{D_{fake}} and once again when calculating LGL_{G}.

PyTorch and Tensorflow advise passing backwards through a given network only once. Due to the calculation of the gradient of the generator twice, it becomes necessary to generate an image twice.

However, if z1z_1 and z2z_2 are different, the model does not train well. Instead, using the same seed, i.e. z1=z2z_1 = z_2 yields much better results.

A better solution to this issue is to use a function called .detach(), which detaches the output from the generator and treats it as a static input.

real_label = torch.ones(batch_size)
fake_label = torch.zeros(batch_size)

optim_generator = torch.optim.Adam(lr=lr, params=generator.parameters())
optim_discriminator = torch.optim.Adam(lr=lr, params=discriminator.parameters())

for real_images in dataloader:
    z = torch.randn([batch_size, latent_dim])

    fake_images = generator(z)

    # Train discriminator
    ## with fake images
    fake_logits = discriminator(fake_images.detach())
    dis_loss_fake = nn.BCELoss(fake_logits, fake_label)

    ## with real images
    real_logits = discriminator(real_images.detach())
    dis_loss_real = nn.BCELoss(real_logits, real_label)

    ## combine losses
    dis_loss = dis_loss_fake + dis_loss_real

    ## train
    optim_discriminator.zero_grad()
    dis_loss.backward()
    optim_discriminator.step()

    # Train generator
    fake_logits = discriminator(fake_images.detach())
    gen_loss = nn.BCELoss(fake_logits, real_label)

    ## train
    optim_generator.zero_grad()
    gen_loss.backward()
    optim_generator.step()

Use equalized learning rate

Equalized learning rate is a technique for improving learning stability, as described in "Progressive growing of gans for improved quality, stability, and variation". Instead of careful initialization of the weights, it scales them dynamically during runtime. This allows for alleviating the scale-invariance in common optimizers.

The weights are initialized with N(0,1)\mathcal{N}(0,1) and the weights are scaled according to the number of input features and a given multiplier. The multiplier is fixed at initialization and can be set for each layer. For example, setting lr_multiplier=0.01 for the mapping network in StyleGAN2 and similar networks helps tremendously.

def __init__(self, lr_multiplier, in_features)
    self.weight_gain = lr_multiplier / np.sqrt(in_features)

def forward(self, x):
    w = self.weight.to(x.dtype) * self.weight_gain

A standard alternative approach is to initialize the weights, according to the number of inputs, described in the following section.

Use Gaussian distribution for weight initialization

Due to the GAN's instability in training, it is sensitive to the initialization.

A best practice is to initialize the weights, according to a normal distribution. Its variance is calculated for each layer, according to the number of input features.

An easy way to implement this in PyTorch:

# PyTorch
def init_normal(m):
    if type(m) == nn.Linear:
        y = m.in_features
        m.weight.data.normal_(0.0,1/np.sqrt(y))

model.apply(init_normal)

When using weight modulation, increase the learning rate

Weight modulation is a powerful tool for combining two input streams. It is commonly used in generative models for combining a style embedding to position embeddings.

Instead of standard techniques, such as concatenation or summation or the two inputs, it modulates the weights directly, as follows:

wijk=siwijkw^{'}_{ijk}=s_i\cdot w_{ijk}

An optional term is to normalize the weights, called "demodulation":

wijk=wijki,kwijk2+ϵw^{''}_{ijk}=\frac{w^{'}_{ijk}}{\sqrt{\sum_{i,k}{w^{'}_{ijk}}^2+\epsilon}}

However, introducing demodulation can alter the scale of the position embeddings, which is crucial information in some models.

When integrating weight modulation, please try it both with and without demodulation.

An example implementation for a linear layer, heavily based on CIPS:

    def __init__(self, ...):
        ...
        self.scale = 1 / np.sqrt(in_channels)
        self.weight = nn.Parameter(
            torch.randn(1, out_channels, in_channels, 1)
        )
        if self.style_size != self.in_channels:
            self.modulation = FullyConnectedLayer(style_size, in_channels, bias=False)
        self.demodulation = demodulation

    def forward(self, input, style):
        if self.style_size != self.in_channels:
            style = self.modulation(style)
        style = style.view(batch_size, 1, self.in_channels, 1)
        weight = self.scale * self.weight * style

        if self.demodulation:
            demod = torch.rsqrt(weight.pow(2).sum([2]) + 1e-8)
            weight = weight * demod.view(batch_size, self.out_channels, 1, 1)

        weight = weight.view(
            batch_size * self.out_channels, self.in_channels, 1
        )

        input = input.reshape(1, batch_size * self.in_channels, img_size)

        out = F.conv1d(input, weight, groups=batch_size)

        out = out.view(batch_size, -1, self.out_channels)

For a 2d convolution implementation, refer to CIPS.

Weight modulation is a very powerful technique, but it requires a significantly larger learning rate, compared to traditional approaches. But I noticed that it trains comparatively slower and in most networks, which use it, the learning rate is set to 0.002 for a batch size of 128.

Adjusting the learning rate is extremely important

GANs are notoriously difficult to train and it is essential to find a good balance between the generator and the discriminator. The images, produced by the generator start looking similar to the training dataset, only after a while, and waiting to see whether a given learning rate could eventually produce good results can be tedious.

For this reason, I recommend training the generator to output fixed images from the dataset, based on some fixed noise. If the generator cannot train to map a fixed noise to a fixed set of images, then it will definitely not train in the real training.

ground_truth, _ = next(iter(trainloader))
z = torch.randn([batch_size, latent_dim]).to(device)

for step in range(total_steps):
    model_output = Generator(z)
    loss = ((model_output - ground_truth)**2).mean()
    
    if step % steps_til_summary == 0:
        print("Step %d, Total loss %0.6f" % (step, loss))
        # Plot some of the generated images

Using such a toy experiment for the generator allows for rapid testing of different learning rates. Finding the optimal learning rate, when one experiment takes only 5-10 minutes to train is much easier.

However, the optimal learning rate for the toy experiment does not guarantee that it will lead to a good balance with the discriminator, it only guarantees that the generator can learn steadily, without diverging. The real training might require a little bit lower learning rate and a separate adjustment of the learning rate for the discriminator.

Teodor TOSHKOV

Teodor TOSHKOV

I am an intern at Fusic, a company in Fukuoka, Japan. From 2022, I will be joining the Machine Learning team. I develop mostly deep learning models, using PyTorch and Tensorflow.