Table of Contents
データセット
ViTGANの私の実装で使用されているデータセットはCIFAR-10です。
このデータセットは,以下の10クラスの32x32画像60,000枚から構成されています.
- 飛行機
- 自動車
- 鳥
- 猫
- 鹿
- 犬
- カエル
- 馬
- 船
- トラック
CIFAR-10でViTGANを学習した結果は、論文で報告されています。
ViTGAN実装内容
ViTGANは、他のGAN(Generative Adversarial Network)と同様に、ジェネレーターとディスクリミネータで構成されています。
ジェネレーターは、ディスクリミネータがデータセットの一部であると考えるような、与えられたデータセットに類似した画像を生成するように訓練されています。そして同時に、ディスクリミネータはデータセットからの画像とジェネレーターによって生成された画像を区別するように訓練されます。
ViTGANはVision Transformerを使用しているため、画像をあらかじめ定義されたサイズ(この例ではピクセル)のパッチのシーケンスとして扱います。
ジェネレーター
ジェネレーターの目的は、与えられたデータセットからの画像に類似した画像を作成することです。
ジェネレーターは以下のようなアーキテクチャになっています。
均等化された学習率(Equalized learning rate)
重みを慎重に初期化することと同様に、均等化された学習率も入力特徴数に応じて計算されます。線形層の重みに、入力特徴数の平方根に逆比例する重みゲインを乗じ、lr_multiplier
を乗じたものです。
class FullyConnectedLayer(nn.Module):
def __init__(self, lr_multiplier, in_features):
self.weight_gain = lr_multiplier / np.sqrt(in_features)
def forward(self, x):
w = self.weight.to(x.dtype) * self.weight_gain
均等化された学習率はViTGANのジェネレーターとディスクリミネータの全ての層に適応されています。ただし、パラメーターlr_multiplier
が1
以外の数字に設定されている所は、潜在空間ベクトルの射影ネットワークだけです。潜在空間ベクトルの射影ネットが以下のセクションで紹介しています。
潜在空間ベクトルの射影
射影ネットワークの実装はStyleGAN2から取っています。
潜在空間ベクトルは、ガウス分布にしたがってランダムに生成されます。
np.random.normal(0, 1, (batch_size, latent_dim))
をMLPによって新しい潜在空間へ射影します。lr_multiplier=0.01
という均等化された学習率を使用します。射影ネットワークは他のネットワークに比べてゆっくりとした速度で学習することができます。StyleGAN2でも使用されている学習レートです。学習のいいバランスを取れますように、射影ネットワークの学習レートが非常に小さくしなければならないそうです。
以下のように実装されています:
def __init__(self, layer_features, latent_dim, w_dim, num_layers, ...)
features_list = [latent_dim] + [layer_features] * (num_layers - 1) + [w_dim]
for idx in range(num_layers):
in_features = features_list[idx]
out_features = features_list[idx + 1]
layer = FullyConnectedLayer(in_features, out_features, lr_multiplier=lr_multiplier)
setattr(self, f'fc{idx}', layer)
...
def forward(self, z):
...
for idx in range(self.num_layers):
layer = getattr(self, f'fc{idx}')
x = layer(x)
...
ViTGANの構造を持つ自然な画像を生成する課題に、ガウス分布からのサンプリングが適切であることは保証できません。
MLPは、潜在空間ベクトルを、課題に適した空間へ射影するのに役立ちます。MLPのパラメータは学習可能で、より良い潜在空間への射影を学習することができます。
ジェネレーター - Transformerのエンコーダー
ジェネレーターへの入力は、具体的な画像に関する外部情報を一切含まない、学習可能な位置エンコーディングのみである。
位置エンコーディングは全ての画像に対して同じであるため、以下のようにバッチサイズごとに繰り返します。
def __init__(self, hidden_size, ...):
self.pos_emb = nn.Parameter(torch.randn(num_patches, hidden_size))
...
def forward(self, z):
pos = repeat(torch.sin(self.pos_emb), 'n e -> b n e', b=z.shape[0])
...
パッチの大きさは、となっています。画像のパッチごとに異なる位置の埋め込みを行います。
標準的なVision Transformerはパッチを1次元の表現にフラット化するため、この次元に合わせて位置エンベッディングを作成します。
論文で述べられているように、位置エンベッディングは活性化関数を遠いしています。
Self-Modulated LayerNorm (SLN)
Self-Modulated LayerNormは潜在空間ベクトルがネットワークに影響する1つだけの所です。
標準レイヤーの正規化は、以下の関数に従っています。
SLNは以下の関数に従っています。
class SLN(nn.Module):
def __init__(self, input_size, parameter_size=None):
super().__init__()
self.ln = nn.LayerNorm(input_size)
self.gamma = nn.Linear(input_size, parameter_size)
self.beta = nn.Linear(input_size, parameter_size)
def forward(self, hidden, w):
gamma = self.gamma(w).unsqueeze(1)
beta = self.beta(w).unsqueeze(1)
ln = self.ln(hidden)
return gamma * ln + beta
parameter_size
はinput_size
()です。
Transformerエンコーダーブロック
MSA(Multi-Head Self-Attention)をはじめとするVision Transformerブロックの実装は、[ブログの記事]に大きく依存しています。
位置エンベッディングを作成した後、Transformerのエンコーダブロックを実装します。
class GeneratorTransformerEncoderBlock(nn.Module):
def __init__(self,
hidden_size=384,
sln_paremeter_size=384,
forward_expansion=4,
...
**kwargs):
super().__init__()
self.sln = SLN(hidden_size, parameter_size=sln_paremeter_size)
self.msa = MultiHeadAttention(hidden_size, **kwargs)
self.feed_forward = FeedForwardBlock(hidden_size, expansion=forward_expansion, ...)
def forward(self, hidden, w):
res = hidden
hidden = self.sln(hidden, w)
hidden = self.msa(hidden)
...
hidden += res
res = hidden
hidden = self.sln(hidden, w)
self.feed_forward(hidden)
...
hidden += res
return hidden
ViTGANジェネレーターのTransformerブロックは、通常のLayerNormの代わりにSLNを使用している点だけが、標準的なTransformerブロックと異なります。ブロックへの入力は、後で残余の接続を行うために、修正する前に保存されます(res
)。入力はSLNとMulti-Head Self-Attentionを通過した後、格納されている値を結果に追加することで残留接続を行います(hidden += res
)。
その後、結果は再び res
に格納され、その結果は同じSLNと2層のフィードフォワードネットワークを通します。
class FeedForwardBlock(nn.Sequential):
def __init__(self, emb_size, expansion=4, ...):
super().__init__(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
...
nn.Linear(expansion * emb_size, emb_size),
)
FeedForwardBlockは、Sequentialモジュールとして実装しています。サイズの入力を(=expansion
)の層とGELU活性化関数を通し、サイズの層を通します。
SIREN
SIRENの説明:[ブログ記事] [論文] [Colab Notebook]
SIRENを実装するためのコードは、公式の実装を参考にしています。[Colab Notebook]。
一般的なSIRENとは違い、目標の出力は単一の画像ではなく、広範囲の異なる画像です。そのために、パッチエンベディングと位置エンベディングを組み合わせています。
フーリエ位置エンベッディング
ViTGANのSIRENネットワークは、Transformerと似ている位置エンコーディングを使用されています。詳しい説明はMerging Vision Transformers (ViT) with SIRENs to form a ViTGAN. A novel approach to generate realistic images. - Fusic Tech Blogに書いています。
ViTGANで使用されているポジションエンコーディングはフーリエ位置エンベッディングで、そのコードはCIPSから引用したものです。
1つの位置に対するフーリエ位置エンコーディングは、以下の関数に従って実行されます。
PyTorchでの実装は以下の通りです。
class LFF(nn.Module):
def __init__(self, hidden_size, **kwargs):
super(LFF, self).__init__()
self.ffm = nn.Linear(2, hidden_size, bias=bias)
nn.init.uniform_(self.conv.weight, -np.sqrt(9 / ch_in), np.sqrt(9 / ch_in))
def forward(self, x):
x = x
x = self.ffm(x)
x = torch.sin(x)
return x
上記の関数を用いて、画像の各位置に対する位置エンコーディングを生成します。
def fourier_pos_embedding(self):
coords = np.linspace(-1, 1, self.out_patch_size, endpoint=True)
pos = np.stack(np.meshgrid(coords, coords), -1)
pos = torch.tensor(pos, dtype=torch.float)
result = self.lff(pos).reshape([self.out_patch_size**2, self.siren_in_features])
return result
最後は、位置とパッチのエンコーディングを組み合わせる必要があります。そのために重みモジュレーションが使われます。
重みモジュレーション
重みモジュレーションとは、入力のエンベディングと重みを掛け算する技術です。要素ごとの掛け算は以下のように行います。
また、重みを正規化するための適応可能な「ディモジュレーション」が次のように定義されています。
実験を行った後、ViTGANでは「ディモジュレーション」が使用されていないと結論しました。
重みモジュレーションの私の実装は、CIPSをベースにしています。全結合層が必要ですが、1次元の畳み込み層として実装しています。理由は、畳み込み層のgroups
によって、batch_size
を考え、最適化されているからです。
各SIREN層は、重みモジュレーション層に適用される活性化で構成されています。SIRENネットワークでは,入力層,隠れ層,出力層のサイズが異なることがあります。したがって、入力サイズがパッチエンベディングのサイズと異なる場合には、パッチエンベディングを適切なサイズに変換するために、追加の完全連結層を通しています。
def __init__(self, demodulation=False, ...):
...
self.scale = 1 / np.sqrt(in_channels)
self.weight = nn.Parameter(
torch.randn(1, out_channels, in_channels, 1)
)
if self.style_size != self.in_channels:
self.style_modulation = FullyConnectedLayer(style_size, in_channels, bias=False)
self.demodulation = demodulation
def forward(self, input, style):
if self.style_size != self.in_channels:
style = self.style_modulation(style)
style = style.view(batch_size, 1, self.in_channels, 1)
weight = self.scale * self.weight * style
if self.demodulation:
demod = torch.rsqrt(weight.pow(2).sum([2]) + 1e-8)
weight = weight * demod.view(batch_size, self.out_channels, 1, 1)
weight = weight.view(
batch_size * self.out_channels, self.in_channels, 1
)
input = input.reshape(1, batch_size * self.in_channels, img_size)
out = F.conv1d(input, weight, groups=batch_size)
out = out.view(batch_size, -1, self.out_channels)
出力を画像の形へ
SIRENネットワークからの出力が[batch_size * num_patches, patch_size^2, out_features]
の形になっていますが、[batch_size, image_size^2, out_features]
に変更しなければなりません(patch_size
* sqrt(num_patches)
= image_size
)。
最初は結果を[batch_size, sqrt(num_patches), sqrt(num_patches), patch_size, patch_size, out_features]
の形に変更します。
その後、各行のすべてのピクセルが連続して並ぶように、次元を次のように並べ替えます。
[batch_size, sqrt(num_patches), patch_size, sqrt(num_patches), patch_size, out_features]
.
最後は、結果を画像の形に変更します。
img = img.view([-1, num_patches_x, num_patches_x, patch_size, patch_size, out_features])
img = img.permute([0, 1, 3, 2, 4, 5])
img = img.reshape([-1, image_size**2, out_features])
ディスクリミネータ
ディスクリミネータが以下のアーキテクチャーで定義されています。
ViTGANのディスクリミネータは、標準的なVision Transformerに、以下のような変更を加えたものです。
- DiffAugment。
- 重複する画像パッチ。
- ベクトル化されたL2距離によるself-attention。
- 改良型スペクトル正規化(ISN)。
- バランスド・コンシステンシー正則化(bCR)。
DiffAugment
DiffAugmentを適応するために、以下のコードを使用しています。
[GitHub] [論文]
微分可能な増強(DiffAugment)とはランダム敵にカラーシフトや移動、ランダムクロッピングを微分可能な関数です。
ディスクリミネータの一部として実装しています。微分可能な増強が適応される画像は本物"real"の写真だけではなく、ジェネレーターで生成された"fake"の画像でも使用します。
重複する画像パッチ
画像 | 重複する画像パッチ(16x16) |
---|---|
トカゲの画像 | トカゲの画像の重複するパッチ |
重複する画像パッチを、畳み込み層として実装しています。
nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=stride_size)
ViTGANの論文で使用されている各辺が重複する部分はpatch_size/2
です。パッチサイズが( patch_size
)。
ベクトル化されたL2距離によるself-attention
[論文]
self-attentionを計算する時、クエリとキーの行列掛け算の代わりに、クエリとキーのL2距離を計算します。
torch.cdist(queries, keys, p=2)
改良型スペクトル正規化 (ISN)
ISNの私の実装は次のスペクトル正規化の実装をベースにしています。
[GitHub]
[論文]
ISNがディスクリミネータの全ての層で使用されています。
スペクトル正規化 (SN):
改良型スペクトル正規化 (ISN):
重みのの初期値をself.w_init_sigma
で保存します。
class spectral_norm(nn.Module):
def __init__(self, module, name='weight', ...):
...
self.w_init_sigma = None
def _update_u_v(self):
if self.w_init_sigma == None:
self.w_init_sigma = np.array(sigma.expand_as(w).detach().cpu())
の初期値を勾配からdetach
する必要があります。detach
しないとPyTorchがその初期値を複数回通ろうとし、エラーが発生します。
重みを更新する時は以下の通りに更新します。
setattr(self.module, self.name, torch.tensor(self.w_init_sigma).to(device) * w / sigma.expand_as(w))
バランスド・コンシステンシー正則化(bCR)
Zhengli Zhao, Sameer Singh, Honglak Lee, Zizhao Zhang, Augustus Odena, Han Zhang; Improved Consistency Regularization for GANs; AAAI 2021 [Paper]
GANの一般のディスクリミネータの損失に加え、バランスド・コンシステンシー正則化(bCR)も使用されています。
bCRは、入力に小さな増強を加えても、ディスクリミネータが同じ出力を出すための損失です。
bCRを実装するために、DiffAugment
を使用しています。
損失が生成された"fake"の画像と本当の写真の"real"の画像で計算します。
lossD_bCR_fake = F.mse_loss(
discriminator(f_img, do_augment=True),
discriminator(f_img, do_augment=False))
lossD_bCR_real = F.mse_loss(
discriminator(r_img, do_augment=True),
discriminator(r_img, do_augment=False))
ViTGANの実装・コメント
ViTGANは、Vision TransformerとSIRENを組み合わせた、新たなGANのアーキテクチャーです。有名なモデルを超えている結果が得られ、素晴らしいだと思います。
ViTGANを実装するのに、約一ヶ月がかかりました。その理由は、コードのエラーのせいで、学習が不満足になっていました。
最近は、ViTGANの作者であるKwonjoon Leeと話しました。実装が素晴らしいStyleGAN2をベースにしています。その実装の均等化された学習率を使い、学習レートをという高いレートに設定する事が出来ています。
関連リンク
SIREN:Implicit Neural Representations with Periodic Activation Functions
Vision Transformer:[Blog Post]
L2距離self-attention:The Lipschitz Constant of Self-Attention
スペクトル正規化関連コード:[GitHub] [Paper]
DiffAugment:[GitHub] [Paper]
フーリエ位置エンコーディング:[Jupyter Notebook]
Exponential Moving Average:[GitHub]
バランスド・コンシステンシー正則化(bCR):[Paper]
SyleGAN2のディスクリミネータ:[GitHub]
Teodor TOSHKOV
I am an intern at Fusic, a company in Fukuoka, Japan. From 2022, I will be joining the Machine Learning team. I develop mostly deep learning models, using PyTorch and Tensorflow.