Fusic Tech Blog

Fusion of Society, IT and Culture

Implementing ViTGAN (A novel realistic image generation model) in PyTorch

Implementing ViTGAN (A novel realistic image generation model) in PyTorch

ViTGAN is a Generative Adversarial Network, which implements a mixture of a Vision Transformer and a SIREN for a generator. The discriminator is also a Vision Transformer.

For a detailed description of the technical details of the paper by Kwonjoon Lee et. al., please visit VITGAN: Training GANs with Vision Transformers - Fusic Tech Blog.

In this post, I am going to describe my implementation of ViTGAN: Training GANs with Vision Transformers.

Feel free to look at my implementation in Google Colaboratory and on GitHub:

teodorToshkov - ViTGAN-pytorch stars - ViTGAN-pytorch forks - ViTGAN-pytorch

Open In Colab


The dataset, used in my implementation of ViTGAN is CIFAR-10.

Images form CIFAR-10

This is a dataset, which consists of 60,000 32x32 images of the following 10 classes:

  • airplane
  • automobile
  • bird
  • cat
  • deer
  • dog
  • frog
  • horse
  • ship
  • truck

Results of training ViTGAN on CIFAR-10 are also reported in the original paper.

ViTGAN Implementation Details

Just like any other Generative Adversarial Network (GAN), ViTGAN consists of a generator and a discriminator.

A generator is trained to generate images similar to a given dataset, such that the discriminator thinks they are a part of this dataset. And at the same time, the discriminator is trained to distinguish between images from the dataset and the images, generated by the generator network.

Since ViTGAN uses Vision Transformers, it treats images as a sequence of patches of a predefined size (In this example: 4×44\times 4 pixels).


The goal of the generator is to create an image, similar to the images from a given dataset

The Generator follows the following architecture:

ViTGAN Generator architecture

Equalized learning rate

The equalized learning rate is similar to carefully initializing the weights in that it is calculated according to the number of input features. It multiplies the weights of a linear layer by a weight gain, which is reversely proportional to the square root of the number of input features and is multiplied by a lr_multiplier:

class FullyConnectedLayer(nn.Module):
    def __init__(self, lr_multiplier, in_features):
        self.weight_gain = lr_multiplier / np.sqrt(in_features)

    def forward(self, x):
        w = self.weight.to(x.dtype) * self.weight_gain

This technique is used in every fully-connected layer in both the generator and the discriminator. However, the lr_multiplier parameter is set to a value different than 1 only in the latent vector mapping network. The mapping network is explained in the following section.

Mapping of the Latent vector

The implementation of the mapping network is taken from StyleGAN2.

The latent vector zz is generated at random, following a Gaussian distribution:

np.random.normal(0, 1, (batch_size, latent_dim))

It is then mapped to a new latent space by a multi-layer perceptron, which uses an equalized learning rate with lr_multiplier=0.01. This allows the mapping network to learn at a slower rate, compared to the rest of the network. It is a standard learning rate for the mapping network, leveraged from StyleGAN2. Apparently, the mapping network requires a much slower learning rate in order to have balanced training.


Which is implemented as follows:

def __init__(self, layer_features, latent_dim, w_dim, num_layers, ...)
    features_list = [latent_dim] + [layer_features] * (num_layers - 1) + [w_dim]
    for idx in range(num_layers):
        in_features = features_list[idx]
        out_features = features_list[idx + 1]
        layer = FullyConnectedLayer(in_features, out_features, lr_multiplier=lr_multiplier)
        setattr(self, f'fc{idx}', layer)

def forward(self, z):
    for idx in range(self.num_layers):
        layer = getattr(self, f'fc{idx}')
        x = layer(x)

It is not guaranteed that sampling from a Gaussian distribution is appropriate for the task of generating realistic images with the structure of ViTGAN.
The MLP helps to map the latent vector to a space, which is more suitable for the task. The parameters of the MLP are trainable and it can learn a mapping to a more suitable latent space.

Generator - Transformer Encoder

The input to the generator is only the trainable position encoding, without including any outside information about the concrete image.

And since the position embedding is the same for all images, we repeat it by the batch size as follows:

def __init__(self, hidden_size, ...):
    self.pos_emb = nn.Parameter(torch.randn(num_patches, hidden_size))

def forward(self, z):
    pos = repeat(torch.sin(self.pos_emb), 'n e -> b n e', b=z.shape[0])

Notice that the patches are created to be of size number_patches×hidden_size\textup{number\_patches}\times \textup{hidden\_size}. We create a different position embedding for each of the patches for the image.

Since standard Vision Transfomer flattens patches to 1-dimensional representations, we create the position embeddings to match this dimensionality.

As stated in the paper, the position embeddings are passed through a sinesine activation function.

Self-Modulated LayerNorm (SLN)

The Self-Modulated LayerNorm is the only place, where the input latent vector influences the network.

Standard layer normalization follows the following function:


Whereas, SLN follows the following function:

class SLN(nn.Module):
    def __init__(self, input_size, parameter_size=None):
        self.ln = nn.LayerNorm(input_size)
        self.gamma = nn.Linear(input_size, parameter_size)
        self.beta = nn.Linear(input_size, parameter_size)

    def forward(self, hidden, w):
        gamma = self.gamma(w).unsqueeze(1)
        beta = self.beta(w).unsqueeze(1)
        ln = self.ln(hidden)
        return gamma * ln + beta

parameter_size is equal to input_size, i.e. γ,βRD\gamma_{\ell}, \beta_{\ell}\in \mathbb{R}^D.

Transformer Encoder Block

The implementation of the Vision Transformer blocks, including Multi-Head Self-Attention (MSA) is heavily based on [Blog Post].

After creating the position embeddings, we define the Transfomer encoder blocks:

class GeneratorTransformerEncoderBlock(nn.Module):
    def __init__(self,
        self.sln = SLN(hidden_size, parameter_size=sln_paremeter_size)
        self.msa = MultiHeadAttention(hidden_size, **kwargs)
        self.feed_forward = FeedForwardBlock(hidden_size, expansion=forward_expansion, ...)

    def forward(self, hidden, w):
        res = hidden
        hidden = self.sln(hidden, w)
        hidden = self.msa(hidden)
        hidden += res

        res = hidden
        hidden = self.sln(hidden, w)
        hidden += res
        return hidden

The Transformer block of a ViTGAN generator differs from that of a standard Transformer block only in that it uses an SLN instead of a regular LayerNorm. The input to a block is stored (res), before modifying, in order to make a residual connection later. The input is passed through an SLN and a Multi-Head Self-Attention, after which the residual connection is made by adding the stored values to the result (hidden += res).

Then the result is stored, again in res, and the result is passed through the same SLN and a feed-forward network of 2 layers, implemented as follows:

class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion=4, ...):
            nn.Linear(emb_size, expansion * emb_size),
            nn.Linear(expansion * emb_size, emb_size),

The FeedForwardBlock is implemented as a sequential module, which takes a layer of size xx, connects to a layer of size x×ax\times a, where aa=expansion, passes through a GELU activation and connects to a layer of size xx.


A description of SIREN: [Blog Post] [Paper] [Colab Notebook]

The code for implementing SIREN is taken from the official implementation: [Colab Notebook].

In contrast to regular SIREN, the desired output is not a single image, but a wide range of different images. For this purpose, the patch embedding is combined with a position embedding.

Fourier Position Encoding

The SIREN network in ViTGAN uses a positional encoding, similar to the position encoding in a Transformer. For a detailed explanation, please visit Merging Vision Transformers (ViT) with SIRENs to form a ViTGAN. A novel approach to generate realistic images. - Fusic Tech Blog.

The position encoding, used in ViTGAN is the Fourier Position Encoding, the code for which was taken from here: CIPS.

The Fourier position encoding for a single position is implemented, according to the following function:

Efou(v)=sin(Wv)E_{fou}(\mathbf v)= \sin(\mathbf W \mathbf v)

The implementation in PyTorch is as follows:

class LFF(nn.Module):
    def __init__(self, hidden_size, **kwargs):
        super(LFF, self).__init__()
        self.ffm = nn.Linear(2, hidden_size, bias=bias)
        nn.init.uniform_(self.conv.weight, -np.sqrt(9 / ch_in), np.sqrt(9 / ch_in))

    def forward(self, x):
        x = x
        x = self.ffm(x)
        x = torch.sin(x)
        return x

Next, we use the function above to generate position embeddings for each position of the image:

def fourier_pos_embedding(self):
    coords = np.linspace(-1, 1, self.out_patch_size, endpoint=True)
    pos = np.stack(np.meshgrid(coords, coords), -1)
    pos = torch.tensor(pos, dtype=torch.float)
    result = self.lff(pos).reshape([self.out_patch_size**2, self.siren_in_features])
    return result

Finally, the position and patch embeddings have to be combined. This is implemented with weight modulation.

Weight Modulation

Weight modulation is a technique, where the weights are multiplied by an input embedding. The multiplication is performed element-wise in the following form:

wijk=siwijkw^{'}_{ijk}=s_i\cdot w_{ijk}

There is an optional term, used to normalize the weights, called "demodulation", defined as follows:


After testing the network, I concluded that demodulation is not used in ViTGAN.

My implementation of the weight modulation is heavily based on CIPS. I have adjusted it to work for a fully-connected network, using a 1D convolution. The reason for using 1D convolution, instead of a linear layer is the groups term, which optimizes the performance by a factor of batch_size.

Each SIREN layer consists of a sinsin activation, applied to a weight modulation layer. The size of the input, the hidden and the output layers in a SIREN network could vary. Thus, in case the input size differs from the size of the patch embedding, I define an additional fully-connected layer, which converts the patch embedding to the appropriate size.

    def __init__(self, demodulation=False, ...):
        self.scale = 1 / np.sqrt(in_channels)
        self.weight = nn.Parameter(
            torch.randn(1, out_channels, in_channels, 1)
        if self.style_size != self.in_channels:
            self.style_modulation = FullyConnectedLayer(style_size, in_channels, bias=False)
        self.demodulation = demodulation

    def forward(self, input, style):
        if self.style_size != self.in_channels:
            style = self.style_modulation(style)
        style = style.view(batch_size, 1, self.in_channels, 1)
        weight = self.scale * self.weight * style

        if self.demodulation:
            demod = torch.rsqrt(weight.pow(2).sum([2]) + 1e-8)
            weight = weight * demod.view(batch_size, self.out_channels, 1, 1)

        weight = weight.view(
            batch_size * self.out_channels, self.in_channels, 1

        input = input.reshape(1, batch_size * self.in_channels, img_size)

        out = F.conv1d(input, weight, groups=batch_size)

        out = out.view(batch_size, -1, self.out_channels)

Convert to an image

The results are in the form [batch_size * num_patches, patch_size^2, out_features] and we have to convert them to [batch_size, image_size^2, out_features], where patch_size * sqrt(num_patches) = image_size.

This is done by first converting to a tensor of shape [batch_size, sqrt(num_patches), sqrt(num_patches), patch_size, patch_size, out_features].

Next, we reorder them, so that all pixels from a row are ordered consequitively, i.e. [batch_size, sqrt(num_patches), patch_size, sqrt(num_patches), patch_size, out_features].

And finally, we convert to the desired shape.

img = img.view([-1, num_patches_x, num_patches_x, patch_size, patch_size, out_features])
img = img.permute([0, 1, 3, 2, 4, 5])
img = img.reshape([-1, image_size**2, out_features])


The Discriminator follows the following architecture:

ViTGAN Discriminator architecture

The ViTGAN Discriminator is mostly a standard Vision Transformer network, with the following modifications:

  • DiffAugment;
  • Overlapping Image Patches;
  • Use vectorized L2 distance in attention;
  • Improved Spectral Normalization (ISN);
  • Balanced Consistency Regularization (bCR).


For implementing DiffAugment, I used the code below:
[GitHub] [Paper]

I have implemented it as a part of the discriminator module. It is used to randomly alter the color of the image, translate it and cut out a small portion of it. Differential augmentation is applied to both the "real" images and the "fake" images.

Overlapping Image Patches

Image Overlapping Patches (16x16)
An image of a lizard Overlapping Patches of an image of a lizard

The Creation of the overlapping image patches is implemented with the use of a convolution layer.

nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=stride_size)

As stated in the ViTGAN paper, the overlap is equal to patch_size/2 on each edge, resulting in patches of size 2p×2p2p\times 2p, where pp is patch_size.

Use vectorized L2 distance in attention for Discriminator


In calculating the self-attention for the discriminator, instead of multiplying the queries and the keys vectors, the L2 distance is calculated:

torch.cdist(queries, keys, p=2)

Improved Spectral Normalization (ISN)

The ISN implementation is based on the following implementation of Spectral Normalization:
[GitHub] [Paper]

Improved spectral normalization is used in all layers of the discriminator.

Spectral Normalization (SN):


Improved Spectral Normalization (ISN):

WˉISN(W):=σ(Winit)W/σ(W)\bar{W}_{ISN}(W):=\sigma(W_{init})\cdot W/\sigma(W)

We store the initial σ\sigma of the weights in self.w_init_sigma:

class spectral_norm(nn.Module):
	def __init__(self, module, name='weight', ...):
        self.w_init_sigma = None

	def _update_u_v(self):
        if self.w_init_sigma == None:
            self.w_init_sigma = np.array(sigma.expand_as(w).detach().cpu())

It is necessary to detach the initial σ(W)\sigma(W) from the gradient. Otherwise, PyTorch will attempt to back-propagate through it multiple times, resulting in an error.

At each point of updating the weights, we do so with:

		setattr(self.module, self.name, torch.tensor(self.w_init_sigma).to(device) * w / sigma.expand_as(w))

Balanced Consistency Regularization (bCR)

Zhengli Zhao, Sameer Singh, Honglak Lee, Zizhao Zhang, Augustus Odena, Han Zhang; Improved Consistency Regularization for GANs; AAAI 2021 [Paper]

In addition to the regular GAN losses for the discriminator, consistency regularization losses are used.

They ensure that the output of the discriminator does not change when some small augmentation is applied to the input image.

For the implementation of the augmentations, I have used DiffAugment.

A loss is calculated for both "fake" and "real" images, in the form:

lossD_bCR_fake = F.mse_loss(
    discriminator(f_img, do_augment=True),
    discriminator(f_img, do_augment=False))
lossD_bCR_real = F.mse_loss(
    discriminator(r_img, do_augment=True),
    discriminator(r_img, do_augment=False))

My thoughts on implementing ViTGAN

ViTGAN is a novel approach, which combines Vision Transformer with SIREN. It outperforms popular benchmarks and produces fantastic results, which in my opinion is a marvelous achievement.

Implementing ViTGAN took me around one month, mostly due to bugs in my code, which used to lead to inefficient training.

Recently, I spoke to Kwonjoon Lee, the main author of ViTGAN. Their code seems to be based on StyleGAN2, which has a fantastic implementation. This is why they leverage "equalized learning rate", which allows them to use a learning rate of 0.0020.002.


SIREN: Implicit Neural Representations with Periodic Activation Functions
Vision Transformer: [Blog Post]
L2 distance attention: The Lipschitz Constant of Self-Attention
Spectral Normalization reference code: [GitHub] [Paper]
Diff Augment: [GitHub] [Paper]
Fourier Position Embedding: [Jupyter Notebook]
Exponential Moving Average: [GitHub]
Balanced Concictancy Regularization (bCR): [Paper]
SyleGAN2 Discriminator: [GitHub]



I am an intern at Fusic, a company in Fukuoka, Japan. From 2022, I will be joining the Machine Learning team. I develop mostly deep learning models, using PyTorch and Tensorflow.