Fusic Tech Blog

Fusicエンジニアによる技術ブログ

【論文読み】Exploring Simple Siamese Representation Learning
2024/03/25

【論文読み】Exploring Simple Siamese Representation Learning

こんにちは、鷲﨑です。弊社の機械学習チーム勉強会で、Exploring Simple Siamese Representation Learning という画像の表現学習に関する論文を読みました。自然言語処理におけるBERTのように、画像処理においても表現学習(Representation Learning)の手法が重要だと考えられており、この論文も、画像処理にて表現学習を行う手法の一つであるSiamese学習の新しいアーキテクチャを提案していました。この論文では、教師なしで表現学習を行うこれまでの手法の中でも、より簡単に実装可能な手法を提案しており、また、提起している仮説にもとても感動したので、紹介したいと思います。

今までのSiamese Learningとの違い

下図(論文中 Fig.3)は、論文にて提案されているSimSiamと、代表的なSiamese学習を用いた表現獲得の手法を示しています。 各手法は、下図の2つのencorderの出力を、それぞれz1z_1, z2z_2とした時、z1=z2z_1 = z_2に収束してしまうという崩壊問題に対して、様々な方法で対処しています。 崩壊問題がなぜ、問題であるかというと、ここでは類似度を学習したいわけではなく、画像の表現を学習したいためです。

以下、SimSiam以外の各手法を簡単に説明します。

  • SimCLR
    崩壊問題を解決するため、大量のネガティブサンプリング(対象と無関係なデータ)が必要で、様々な負例との比較を行うため、バッチサイズを大きくする必要があります。実際、論文中でもバッチサイズが4096や8192など大きなバッチサイズで、良好な精度を達成していました。しかし、この手法は、各encoderに入力している拡張した画像同士の類似度を近くしているだけで、表現を学習しているわけではありません。
  • BYOL
    momentum encorderの導入と、stop gradientによる勾配伝搬の停止によりパラメータ更新の停止により、ネガティブサンプリングと崩壊の問題を解決しています。しかし、momentum enconderは、勾配停止により学習されないため、encoderのパラメータとmomentum encoderのパラメータを適当な割合で足し合わせて更新する必要があり、実装が複雑になります。
  • SwAV
    この手法は、崩壊問題に対して、クラスタリングを用いて、対処しています。具体的には、各encoderの出力とクラスタベクトルとの類似度を取り、Sinkhorn-Knoppアルゴリズムを用いて最適化したクラスタ割当確率と、Softmaxから得られるクラス割当確率が一貫していることを仮定して問題を解いています。(詳細は、この記事がわかりやすかったです。) この手法も、ネガティブサンプリングの問題を解決していますが、実装が複雑です。

Siamesees arcs

上記のように、これまでの手法は、ネガティブサンプリングを行い必要なデータが増える、または、momentum encoderやクラスタリングを行い実装が複雑なるなどの問題がありました。そこで、提案されたのが、SimSiamです。これは、ネガティブサンプリングが必要なく、stop gradient(勾配停止)と、予測レイヤを追加するのみで、以下の疑似コードのように簡単に実装可能です。ここでは、予測レイヤ(h)を導入しており、また、損失(D)の中では、detachを使用して、勾配計算を止めています。

Psuede code

論文の結果を見ると、既存の手法より良い結果を示していました。この理由を、勾配停止と、予測モデルにあるとしており、実際、勾配停止がない場合と予測レイヤーがない場合の実験で、大きく精度が低下していました。

なぜ、勾配停止と予測モデルが必要なのかの仮説

表現学習の損失を定式化すると、以下のように書くことができます。

L(θ,η)=Ex,T[Fθ(T(x))ηx22]\mathcal{L}(\theta, \eta)=\mathbb{E}_{x, \mathcal{T}}\lbrack||\mathcal{F}_{\theta}(\mathcal{T}(x)) - \eta_x||^2_2\rbrack

ここで、F\mathcal{F}は、パラメータθ\thetaで定義されたニューラルネットワークで、T\mathcal{T}はAugmentation(データの拡張)、xxは画像です。 η\etaは、画像xxの特徴表現で、最終的には、あらゆるデータ拡張の平均で更新されていきます。ここでの目的は、画像の表現を出力する、ネットワークFθ\mathcal{F}_{\theta}を学習することです。しかし、minθ,η(L(θ,η))\rm{min}_{\theta, \eta} (\mathcal{L}(\theta, \eta))のように損失を最小化するためには、ネットワークを学習するだけでなく、特徴表現η\etaも学習する必要があります。

そこで、EMアルゴリズムのように、以下のように更新することで、この損失を最小化しています。

  1. η\etaを固定し、損失が最小となるθ\thetaを計算
    θtargminθL(θ,ηt1)\theta^t \leftarrow \rm{arg min}_{\theta} \mathcal{L}(\theta, \eta^{t-1})
  2. θ\thetaを固定し、損失が最小となるη\etaを計算
    ηtargminηL(θt,η)\eta^t \leftarrow \rm{arg min}_{\eta} \mathcal{L}(\theta^{t}, \eta)

これらの計算を繰り返すことで、L(θ,η)\mathcal{L}(\theta, \eta)を最小化することができます。

まず、2.から考えていきます。ηt\eta^tの更新は表現学習の損失より、ηxtEx,T[Fθt(T(x))]\eta^t_x \leftarrow \mathbb{E}_{x, \mathcal{T}}\lbrack\mathcal{F}_{\theta^t}(\mathcal{T}(x))\rbrackで計算できます。これは、データ拡張に関する分布における画像xxの平均的な表現を割り当てていることになります。

しかし、データ拡張に関する分布は未知であるため、実際にこれを計算することは困難です。

SimSiamにおける近似

そこで、SimSiamでは、期待値計算をせず、以下の式でη\etaを更新するように近似しています。これは、適当な一回のサンプリングによる画像とデータ拡張で更新していることになります。

ηxtFθt(T(x))\eta^t_x \leftarrow \mathcal{F}_{\theta^t}(\mathcal{T}(x))

また、この結果、θ\thetaの更新も以下のようになります。この式は、一般的なSiamese学習の使用する定理ですね。

θt+1argminθEx,T[Fθ(T(x))Fθt(T(x))22]\theta^{t+1} \leftarrow \rm{arg min}_{\theta} \mathbb{E}_{x, \mathcal{T}}\lbrack||\mathcal{F}_{\theta}(\mathcal{T}(x)) - \mathcal{F}_{\theta^t}(\mathcal{T'}(x))||^2_2\rbrack

さて、上のθ\thetaを更新する式では、片方のネットワークのパラメータは定数(θt\theta^t)で固定されています。つまり、SimSiamでは勾配停止を適用することになります。

一方で、SimSiamの近似により、表現学習の損失に対するEMアルゴリズムライクな更新とのギャップが発生しています。それは、η\etaの計算時に、期待値計算を除去したことです。 この期待値計算を補助するために、予測レイヤーhhを導入しています。各ネットワークの出力をz1,z2z_1, z_2とした時、予測レイヤの出力h(z1)h(z_1)は、z2z_2との誤差が最小となることがゴールになります。これは、任意の画像xxに対して、h(z1)=Ez[z2]=ET[f(T(x))]h(z_1)=\mathbb{E}_z\lbrack z_2 \rbrack = \mathbb{E}_{\mathcal{T}} \lbrack f(\mathcal{T}(x)) \rbrackを満たすhhを計算することになり、画像の特徴表現η\etaの更新式になります。このhhは実際に期待値を求めるわけではありませんが、論文では、この予測レイヤを用いることで、近似によるギャップを埋めてくれるのではないかと仮定していました。

まとめ

上記のように、順序だって学習のアーキテクチャを作ることは、憧れます。そして、実際に性能が出ていることも、すごいと思います。 SimSiamは、実装も簡単なので、今後、Siamese学習を行う際には、使用も検討に入れて良いのではと思います。 論文中には、他にも、パラメータに関する考察や比較実験があり、面白いので是非読んでみてください。

Kai Washizaki

Kai Washizaki

Conpany: Fusic Co., Ltd. Program Language: Python, Go, PHP Interest: Machine Learning, MLOps