Top View


Author Teodor TOSHKOV

Merging Vision Transformers (ViT) with SIRENs to form a ViTGAN. A novel approach to generate realistic images.

2021/08/24

Vision Transformer (ViT)

Vision Transformers are Transformers, which segments an image into patches and treat the sequence of 1-dimensional representations of the patches the same way a standard Transformer treats a sequence of words.

Credit: Google AI blog

A good description of the Vision Transformer can be found on Google AI blog.

ViTGAN architecture

The Generator follows the following architecture:


The Generator differs from a standard Vision Transformer in the following ways:

  • The input to the Transformer consists only of the position embedding;
  • Self-Modulated Layer Norm (SLN) is used in place of LayerNorm. This is the only place, where the seed, defined by zz influences the network;
  • Instead of a classification head, the embedding of each patch is passed to a SIREN\textup{SIREN}, which generates an image for each patch. The patches are then stitched together to form the final image.

The Discriminator follows the following architecture:


The ViTGAN Discriminator is a standard Vision Transformer network, with the following modifications:

  • DiffAugment;
  • Overlapping Image Patches;
  • Use of vectorized L2 distance in self-attention;
  • Improved Spectral Normalization (ISN);
  • Balanced Consistency Regularization (bCR).

Ensuring Lipschitzness in the Discriminator

Lipschitz continuity refers to functions, which have a limit on their gradient, i.e.:

df(x)dxC,C0\bigg | \frac{df(x)}{dx} \bigg | \leq C, \qquad C \geq 0

If the gradient is unbounded, it makes the training very unstable. A single training step could change the weights of a model to infinity. Thus we would like to ensure boundaries on the gradient.

Using vectorized L2 distance in attention for Discriminator

Standard attention self-attention mechanism, using a dot-product has recently been shown to have an unbounded gradient [paper].

ViTGAN's Discriminator uses this recently proposed self-attention mechanism, which uses Euclidean distance, i.e. d(,)d(\cdot,\cdot).

Normal Attention Mechanism

Attentionh(X)=softmax(QKTdhV)Attention_h(X) = softmax \bigg ( \frac{QK^T}{\sqrt{d_h}} V \bigg )

Lipschitz Attention Mechanism

Attentionh(X)=softmax(d(Q,K)dhV)Attention_h(X) = softmax \bigg ( \frac{d(Q,K)}{\sqrt{d_h}} V \bigg )

Improved Spectral Normalization

Using Spectral Normalization of the form:

WˉSN(W):=W/σ(W)\bar{W}_{SN}(W):=W/\sigma(W)

further strengthens the Lipschitz continuity of the Discriminator. It guarantees a Lipschitz constant, equal to 1, i.e. the gradient is always equal to 11 or 1-1.

However, the authors of ViTGAN have noticed that the Transformer networks are very sensitive to the Lipschitz constant and that a small constant leads to slow training and even collapse of information.

This is why they introduce the Improved Spectral Normalization (ISN):

WˉISN(W):=σ(Winit)W/σ(W)\bar{W}_{ISN}(W):=\sigma(W_{init})\cdot W/\sigma(W)

This leads to the gradients of different layers having different Lipschitz constants, equal to their values at the initialization of the model.

Overlapping Image Patches

The discriminator is prone to overfitting.

In standard Vision Transformers, an image is segmented, according to a grid of predefined sizes:

ImagePatches (8x8)
An image of a lizardPatches of an image of a lizard

However, the choice of grid size is extremely important and can lead to very different results.

The problem is that the discriminator memorizes local cues and does not propagate any meaningful information to the generator, halting the learning process.

ViTGAN's discriminator introduces overlapping patches:

ImageOverlapping Patches (16x16)
An image of a lizardOverlapping Patches of an image of a lizard

Using overlapping patches, instead of standard patches improves the transformer in the following 2 ways:

  • Less sensitivity to the patch grid size;
  • Improved sense of locality, thanks to the shared pixels between neighboring patches.

DiffAugment

Traditional image augmentation techniques, such as translation, random cropping, color-shifting, etc. are applied only to the real dataset. They are used to make the discriminator more robust and less sensitive to minor changes in the input, which in turn helps alleviate over-fitting.

Differential Augmentation (DiffAugment) [Paper] is a way of augmenting an image so that the information can be back-propagated through. This allows us to apply augmentations not only to the real dataset but also to the generated "fake" images.

Using DiffAugment instead of regular image augmentation has been proven to improve the results by a very large margin.

SLN ・ Self-modulated LayerNorm

ViTGAN's Generator uses a self-modulated LayerNorm, which is computed by:

SLN(h,w)=γ(w)hμσ+β(w)SLN(h_{\ell},w)=\gamma_{\ell}(w)\odot\frac{h_{\ell}-\mu}{\sigma}+\beta_{\ell}(w)

where γ(w)\gamma_\ell(w) and β(w)\beta_\ell(w) are 1-layer networks, which calculate the bias and the deviation for the layer, based on the latent value zz.

SIREN\textup{SIREN}

SIREN\textup{SIREN}s are networks, which use a sinesine activation function and are used to learn implicit representations of natural continuous signals. For a detailed explanation, please visit Sinusoidal Representation Networks SIREN - Fusic Tech Blog.

Usually, SIREN\textup{SIREN}-s are trained to represent a single image, but the authors of ViTGAN have managed to incorporate a patch embedding to Fourier position embedding. This allows the SIREN\textup{SIREN} network to represent a wide variety of images.

EfouE_{fou} Fourier position embedding

The Fourier position embedding is similar to the position embedding of Transformers in that it follows sinsin waves. It is calculated by the following equation:

Efou(v)=sin(Wv)E_{fou}(\mathbf v)= \sin(\mathbf W \mathbf v)

Where v\mathbf v is the xx and yy coordinates of a given pixel, normalized to lie between 1-1 and 11, and W\mathbf W is a trainable weight matrix, mapping a 2-dimensional vector to the desired hidden size.

Weight modulation

ViTGAN's generator utilizes weight modulation. This is a powerful tool for combining two different inputs. Similar to StyleGAN2 and in CIPS, it is used to combine style embedding and position embeddings:

wijk=siwijkw^{'}_{ijk}=s_i\cdot w_{ijk}

Weight modulation yields much better results, compared to using a simple concatenation or summation of the two embeddings. It even outperforms a skip-connection implementation.

Balanced Consistency Regularization (bCR)

ViTGAN utilizes bCR [Improved Consistency Regularization for GANs].

It trains the discriminator to output the same prediction for a given image and its augmentations:


No matter if the image is moved a little bit, its color is shifted a little bit, or any other small augmentation applied to it, we would want the outputs from the discriminator to be the same, despite the augmentations. Otherwise, even small changes to the input could result in drastically different outputs from the discriminator.

Using bCR stabilizes the training of both the discriminator and the generator.

Summary

In my opinion, ViTGAN is a very well-thought-out model.

The authors have focused a lot on the weaknesses in the training of Transformers. Especially ensuring the Lipschitz continuity has proven to increase the quality of the results significantly.

But the factor, which has contributed the most to improving the results seems to be the use of DiffAugment.

Comparison of different Discriminators

Kwonjoon Lee et. al. have researched the weaknesses in training a Transformer, in particular a Vision Transformer.

Although Vision Transformers have existed since the end of 2020, Kwonjoon Lee et. al. are one of the first teams to apply them to a GAN. Achieving such remarkable results with their first publication on Vision Transformer GANs is incredible. What's more, they have managed to successfully leverage SIREN\textup{SIREN} in a novel approach.

Teodor TOSHKOV

Teodor TOSHKOV

I am an intern at Fusic, a company in Fukuoka, Japan. From 2022, I will be joining the Machine Learning team. I develop mostly deep learning models, using PyTorch and Tensorflow.