Implementing ViTGAN (A novel realistic image generation model) in PyTorch
2021/09/28
Table of Contents
Dataset
The dataset, used in my implementation of ViTGAN is CIFAR-10.
This is a dataset, which consists of 60,000 32x32 images of the following 10 classes:
- airplane
- automobile
- bird
- cat
- deer
- dog
- frog
- horse
- ship
- truck
Results of training ViTGAN on CIFAR-10 are also reported in the original paper.
ViTGAN Implementation Details
Just like any other Generative Adversarial Network (GAN), ViTGAN consists of a generator and a discriminator.
A generator is trained to generate images similar to a given dataset, such that the discriminator thinks they are a part of this dataset. And at the same time, the discriminator is trained to distinguish between images from the dataset and the images, generated by the generator network.
Since ViTGAN uses Vision Transformers, it treats images as a sequence of patches of a predefined size (In this example: pixels).
Generator
The goal of the generator is to create an image, similar to the images from a given dataset
The Generator follows the following architecture:
Equalized learning rate
The equalized learning rate is similar to carefully initializing the weights in that it is calculated according to the number of input features. It multiplies the weights of a linear layer by a weight gain, which is reversely proportional to the square root of the number of input features and is multiplied by a lr_multiplier
:
class FullyConnectedLayer(nn.Module):
def __init__(self, lr_multiplier, in_features):
self.weight_gain = lr_multiplier / np.sqrt(in_features)
def forward(self, x):
w = self.weight.to(x.dtype) * self.weight_gain
This technique is used in every fully-connected layer in both the generator and the discriminator. However, the lr_multiplier
parameter is set to a value different than 1
only in the latent vector mapping network. The mapping network is explained in the following section.
Mapping of the Latent vector
The implementation of the mapping network is taken from StyleGAN2.
The latent vector is generated at random, following a Gaussian distribution:
np.random.normal(0, 1, (batch_size, latent_dim))
It is then mapped to a new latent space by a multi-layer perceptron, which uses an equalized learning rate with lr_multiplier=0.01
. This allows the mapping network to learn at a slower rate, compared to the rest of the network. It is a standard learning rate for the mapping network, leveraged from StyleGAN2. Apparently, the mapping network requires a much slower learning rate in order to have balanced training.
Which is implemented as follows:
def __init__(self, layer_features, latent_dim, w_dim, num_layers, ...)
features_list = [latent_dim] + [layer_features] * (num_layers - 1) + [w_dim]
for idx in range(num_layers):
in_features = features_list[idx]
out_features = features_list[idx + 1]
layer = FullyConnectedLayer(in_features, out_features, lr_multiplier=lr_multiplier)
setattr(self, f'fc{idx}', layer)
...
def forward(self, z):
...
for idx in range(self.num_layers):
layer = getattr(self, f'fc{idx}')
x = layer(x)
...
It is not guaranteed that sampling from a Gaussian distribution is appropriate for the task of generating realistic images with the structure of ViTGAN.
The MLP helps to map the latent vector to a space, which is more suitable for the task. The parameters of the MLP are trainable and it can learn a mapping to a more suitable latent space.
Generator - Transformer Encoder
The input to the generator is only the trainable position encoding, without including any outside information about the concrete image.
And since the position embedding is the same for all images, we repeat it by the batch size as follows:
def __init__(self, hidden_size, ...):
self.pos_emb = nn.Parameter(torch.randn(num_patches, hidden_size))
...
def forward(self, z):
pos = repeat(torch.sin(self.pos_emb), 'n e -> b n e', b=z.shape[0])
...
Notice that the patches are created to be of size . We create a different position embedding for each of the patches for the image.
Since standard Vision Transfomer flattens patches to 1-dimensional representations, we create the position embeddings to match this dimensionality.
As stated in the paper, the position embeddings are passed through a activation function.
Self-Modulated LayerNorm (SLN)
The Self-Modulated LayerNorm is the only place, where the input latent vector influences the network.
Standard layer normalization follows the following function:
Whereas, SLN follows the following function:
class SLN(nn.Module):
def __init__(self, input_size, parameter_size=None):
super().__init__()
self.ln = nn.LayerNorm(input_size)
self.gamma = nn.Linear(input_size, parameter_size)
self.beta = nn.Linear(input_size, parameter_size)
def forward(self, hidden, w):
gamma = self.gamma(w).unsqueeze(1)
beta = self.beta(w).unsqueeze(1)
ln = self.ln(hidden)
return gamma * ln + beta
parameter_size
is equal to input_size
, i.e. .
Transformer Encoder Block
The implementation of the Vision Transformer blocks, including Multi-Head Self-Attention (MSA) is heavily based on [Blog Post].
After creating the position embeddings, we define the Transfomer encoder blocks:
class GeneratorTransformerEncoderBlock(nn.Module):
def __init__(self,
hidden_size=384,
sln_paremeter_size=384,
forward_expansion=4,
...
**kwargs):
super().__init__()
self.sln = SLN(hidden_size, parameter_size=sln_paremeter_size)
self.msa = MultiHeadAttention(hidden_size, **kwargs)
self.feed_forward = FeedForwardBlock(hidden_size, expansion=forward_expansion, ...)
def forward(self, hidden, w):
res = hidden
hidden = self.sln(hidden, w)
hidden = self.msa(hidden)
...
hidden += res
res = hidden
hidden = self.sln(hidden, w)
self.feed_forward(hidden)
...
hidden += res
return hidden
The Transformer block of a ViTGAN generator differs from that of a standard Transformer block only in that it uses an SLN instead of a regular LayerNorm. The input to a block is stored (res
), before modifying, in order to make a residual connection later. The input is passed through an SLN and a Multi-Head Self-Attention, after which the residual connection is made by adding the stored values to the result (hidden += res
).
Then the result is stored, again in res
, and the result is passed through the same SLN and a feed-forward network of 2 layers, implemented as follows:
class FeedForwardBlock(nn.Sequential):
def __init__(self, emb_size, expansion=4, ...):
super().__init__(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
...
nn.Linear(expansion * emb_size, emb_size),
)
The FeedForwardBlock is implemented as a sequential module, which takes a layer of size , connects to a layer of size , where =expansion
, passes through a GELU activation and connects to a layer of size .
SIREN
A description of SIREN: [Blog Post] [Paper] [Colab Notebook]
The code for implementing SIREN is taken from the official implementation: [Colab Notebook].
In contrast to regular SIREN, the desired output is not a single image, but a wide range of different images. For this purpose, the patch embedding is combined with a position embedding.
Fourier Position Encoding
The SIREN network in ViTGAN uses a positional encoding, similar to the position encoding in a Transformer. For a detailed explanation, please visit Merging Vision Transformers (ViT) with SIRENs to form a ViTGAN. A novel approach to generate realistic images. - Fusic Tech Blog.
The position encoding, used in ViTGAN is the Fourier Position Encoding, the code for which was taken from here: CIPS.
The Fourier position encoding for a single position is implemented, according to the following function:
The implementation in PyTorch is as follows:
class LFF(nn.Module):
def __init__(self, hidden_size, **kwargs):
super(LFF, self).__init__()
self.ffm = nn.Linear(2, hidden_size, bias=bias)
nn.init.uniform_(self.conv.weight, -np.sqrt(9 / ch_in), np.sqrt(9 / ch_in))
def forward(self, x):
x = x
x = self.ffm(x)
x = torch.sin(x)
return x
Next, we use the function above to generate position embeddings for each position of the image:
def fourier_pos_embedding(self):
coords = np.linspace(-1, 1, self.out_patch_size, endpoint=True)
pos = np.stack(np.meshgrid(coords, coords), -1)
pos = torch.tensor(pos, dtype=torch.float)
result = self.lff(pos).reshape([self.out_patch_size**2, self.siren_in_features])
return result
Finally, the position and patch embeddings have to be combined. This is implemented with weight modulation.
Weight Modulation
Weight modulation is a technique, where the weights are multiplied by an input embedding. The multiplication is performed element-wise in the following form:
There is an optional term, used to normalize the weights, called "demodulation", defined as follows:
After testing the network, I concluded that demodulation is not used in ViTGAN.
My implementation of the weight modulation is heavily based on CIPS. I have adjusted it to work for a fully-connected network, using a 1D convolution. The reason for using 1D convolution, instead of a linear layer is the groups
term, which optimizes the performance by a factor of batch_size
.
Each SIREN layer consists of a activation, applied to a weight modulation layer. The size of the input, the hidden and the output layers in a SIREN network could vary. Thus, in case the input size differs from the size of the patch embedding, I define an additional fully-connected layer, which converts the patch embedding to the appropriate size.
def __init__(self, demodulation=False, ...):
...
self.scale = 1 / np.sqrt(in_channels)
self.weight = nn.Parameter(
torch.randn(1, out_channels, in_channels, 1)
)
if self.style_size != self.in_channels:
self.style_modulation = FullyConnectedLayer(style_size, in_channels, bias=False)
self.demodulation = demodulation
def forward(self, input, style):
if self.style_size != self.in_channels:
style = self.style_modulation(style)
style = style.view(batch_size, 1, self.in_channels, 1)
weight = self.scale * self.weight * style
if self.demodulation:
demod = torch.rsqrt(weight.pow(2).sum([2]) + 1e-8)
weight = weight * demod.view(batch_size, self.out_channels, 1, 1)
weight = weight.view(
batch_size * self.out_channels, self.in_channels, 1
)
input = input.reshape(1, batch_size * self.in_channels, img_size)
out = F.conv1d(input, weight, groups=batch_size)
out = out.view(batch_size, -1, self.out_channels)
Convert to an image
The results are in the form [batch_size * num_patches, patch_size^2, out_features]
and we have to convert them to [batch_size, image_size^2, out_features]
, where patch_size
* sqrt(num_patches)
= image_size
.
This is done by first converting to a tensor of shape [batch_size, sqrt(num_patches), sqrt(num_patches), patch_size, patch_size, out_features]
.
Next, we reorder them, so that all pixels from a row are ordered consequitively, i.e. [batch_size, sqrt(num_patches), patch_size, sqrt(num_patches), patch_size, out_features]
.
And finally, we convert to the desired shape.
img = img.view([-1, num_patches_x, num_patches_x, patch_size, patch_size, out_features])
img = img.permute([0, 1, 3, 2, 4, 5])
img = img.reshape([-1, image_size**2, out_features])
Discriminator
The Discriminator follows the following architecture:
The ViTGAN Discriminator is mostly a standard Vision Transformer network, with the following modifications:
- DiffAugment;
- Overlapping Image Patches;
- Use vectorized L2 distance in attention;
- Improved Spectral Normalization (ISN);
- Balanced Consistency Regularization (bCR).
DiffAugment
For implementing DiffAugment, I used the code below:
[GitHub] [Paper]
I have implemented it as a part of the discriminator module. It is used to randomly alter the color of the image, translate it and cut out a small portion of it. Differential augmentation is applied to both the "real" images and the "fake" images.
Overlapping Image Patches
Image | Overlapping Patches (16x16) |
---|---|
An image of a lizard | Overlapping Patches of an image of a lizard |
The Creation of the overlapping image patches is implemented with the use of a convolution layer.
nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=stride_size)
As stated in the ViTGAN paper, the overlap is equal to patch_size/2
on each edge, resulting in patches of size , where is patch_size
.
Use vectorized L2 distance in attention for Discriminator
[Paper]
In calculating the self-attention for the discriminator, instead of multiplying the queries and the keys vectors, the L2 distance is calculated:
torch.cdist(queries, keys, p=2)
Improved Spectral Normalization (ISN)
The ISN implementation is based on the following implementation of Spectral Normalization:
[GitHub]
[Paper]
Improved spectral normalization is used in all layers of the discriminator.
Spectral Normalization (SN):
Improved Spectral Normalization (ISN):
We store the initial of the weights in self.w_init_sigma
:
class spectral_norm(nn.Module):
def __init__(self, module, name='weight', ...):
...
self.w_init_sigma = None
def _update_u_v(self):
if self.w_init_sigma == None:
self.w_init_sigma = np.array(sigma.expand_as(w).detach().cpu())
It is necessary to detach the initial from the gradient. Otherwise, PyTorch will attempt to back-propagate through it multiple times, resulting in an error.
At each point of updating the weights, we do so with:
setattr(self.module, self.name, torch.tensor(self.w_init_sigma).to(device) * w / sigma.expand_as(w))
Balanced Consistency Regularization (bCR)
Zhengli Zhao, Sameer Singh, Honglak Lee, Zizhao Zhang, Augustus Odena, Han Zhang; Improved Consistency Regularization for GANs; AAAI 2021 [Paper]
In addition to the regular GAN losses for the discriminator, consistency regularization losses are used.
They ensure that the output of the discriminator does not change when some small augmentation is applied to the input image.
For the implementation of the augmentations, I have used DiffAugment
.
A loss is calculated for both "fake" and "real" images, in the form:
lossD_bCR_fake = F.mse_loss(
discriminator(f_img, do_augment=True),
discriminator(f_img, do_augment=False))
lossD_bCR_real = F.mse_loss(
discriminator(r_img, do_augment=True),
discriminator(r_img, do_augment=False))
My thoughts on implementing ViTGAN
ViTGAN is a novel approach, which combines Vision Transformer with SIREN. It outperforms popular benchmarks and produces fantastic results, which in my opinion is a marvelous achievement.
Implementing ViTGAN took me around one month, mostly due to bugs in my code, which used to lead to inefficient training.
Recently, I spoke to Kwonjoon Lee, the main author of ViTGAN. Their code seems to be based on StyleGAN2, which has a fantastic implementation. This is why they leverage "equalized learning rate", which allows them to use a learning rate of .
References
SIREN: Implicit Neural Representations with Periodic Activation Functions
Vision Transformer: [Blog Post]
L2 distance attention: The Lipschitz Constant of Self-Attention
Spectral Normalization reference code: [GitHub] [Paper]
Diff Augment: [GitHub] [Paper]
Fourier Position Embedding: [Jupyter Notebook]
Exponential Moving Average: [GitHub]
Balanced Concictancy Regularization (bCR): [Paper]
SyleGAN2 Discriminator: [GitHub]
Teodor TOSHKOV
I am an intern at Fusic, a company in Fukuoka, Japan. From 2022, I will be joining the Machine Learning team. I develop mostly deep learning models, using PyTorch and Tensorflow.
Related Posts
Teodor TOSHKOV
2022/06/13
Teodor TOSHKOV
2021/08/24