[論文メモ] Multimodal Conditional Image Synthesis with Product-of-Experts GANs

arxiv.org
deepimagination.cc

NVIDIA

スケッチやテキストなどのマルチモーダルを条件としたProduct-of-Experts Generative Adversarial Networks (PoE-GAN) の提案。

f:id:Ninhydrin:20220217090514p:plain



既存のConditional GANは条件としてスケッチやテキストなど1種類の入力を条件としていた。しかし、それだとスケッチのが説明しやすいときや逆にテキストのが説明しやすいときなどの場合に対応ができないよねというお気持ち。

手法

 M個モダリティ (y_1,...y_M)と画像がペアになったデータセット xが与えられたとき、これらのモダリティの部分集合を入力として画像を生成する生成モデルを学習する( p(x|\mathcal{Y}), \forall\mathcal{Y} \subseteq \{y_1,...,y_M\})。
条件が M個あるので2^M通りの条件での生成が可能で、当然モダリティが1つの場合( p(x|y_i), \forall i \in \{1,...,M\})や空の場合( p(x|\varnothing)=p(x))でも生成できる必要がある。
モダリティとしてはテキスト、スケッチ、セグメンテーションマップ、画像(スタイルのリファレンス)を採用するが、フレームワーク的には他のモダリティも簡単に導入できる。

Product-of-experts modeling

条件として入力するモダリティは画像生成時に満たすべき制約であり、与えられた条件を満たす画像集合はそれらの条件を1つずつ満たす画像集合の積集合の部分になる(図2)。
f:id:Ninhydrin:20220217092548p:plain

ということで同時分布 p(x|y_i, y_j)がある1つのモダリティ条件とした分布 p(x|y_i) p(x|y_j)の積として表現できると仮定する。

このように個々モダリティから生成する”エキスパート”を"掛ける"のでproduct-of-experts(PoE)と呼ばれる。

Generatorはlatent code  zを入力して画像 xを生成するが、 x zに対して一意に決定するので p(z|\mathcal{Y})を求めるのは p(x|\mathcal{Y})を求めるのに等しい。
 p'(z)をpriorとして p(z|\mathcal{Y})を構成するのにPoEを下記のようにモデリングする。
f:id:Ninhydrin:20220217093631p:plain

 q(z|y_i)は一つのモダリティを条件としたエンコーダ。

Multiscale and hierarchical latent space

テキストとスケッチなどでモダリティの解像度が異なるのでlatent spaceもそれに合わせて階層化する。
 z = (z^0,...z^N)とし、 z^0 \in \mathbb{R}^{c_0}は特徴ベクトル、 z^k \in \mathbb{R}^{c_k \times r_k \times r_k}, 1 \leq k \leq Nを解像度ごとの特徴マップ( r_{k+1}=2r_k, r_1=4,  r_Nは画像と同じ解像度)とする。
したがって先程の式(1)の各要素が分解でき、下記の式(2)のようにできる。
f:id:Ninhydrin:20220217100011p:plain
より解像度の低い、抽象度の高い z^{\lt k}を条件として、その積の形式になる。

ここで p'(z^k|z^{\lt k})=\mathcal{N}(\mu^k_0, \sigma^k_0) q'(z^k|z^{\lt k}, y_i)=\mathcal{N}(\mu^k_i, \sigma^k_i)は平均と分散をニューラルネットでパラメータ化された独立したガウシアン。
なお p(z^k|z^{\lt k}, \mathcal{Y})はガウシアンの積なのでガウシアンになる(下記の式3)。
f:id:Ninhydrin:20220221091334p:plain

. Generator architecture

アーキテクチャは図3を参照。各モダリティをエンコード後、Global PoE-Netで集約する。
f:id:Ninhydrin:20220217101034p:plain

セグメンテーションマップとスケッチは入力をskip connectしたCNNで、スタイルはResNet、テキストはCLIPを使ってそれぞれエンコードする(図11)。
f:id:Ninhydrin:20220221095748p:plain


Global PoE-Netは図4参照。
MLPでガウシアン q(z^0|y_i) = \mathcal{N}(\mu^0_i, \sigma^-0_i)を予測し、 z^0をサンプリング。これがdecoderのメインの入力になる。
f:id:Ninhydrin:20220221091214p:plain

decoderはResBlockからなる(図5を参照)。
f:id:Ninhydrin:20220221091945p:plain

 z^0を入力としてそれに対してconvolutionをするのがメインだが、途中で解像度毎のセグメンテーションマップとスケッチを式(3)のPoEで処理した z^kを使ったSPADEと z^0MLPに入れて得た特徴ベクトル wを使ったAdaINを行うlocal-global adaptive instance normalization(LG-AdaIN) レイヤーを挟む。
f:id:Ninhydrin:20220221093417p:plain

テキストやスタイル情報は画像全体の大まかな情報を持っているのに対して、セグメンテーションマップやスケッチは画像の詳細も含んでいることを考えると、セグメンテーションマップとスケッチのみ途中で注入するのは理解できる。

Multiscale multimodal projection discriminator

Discriminatorは画像 xと条件 \mathcal{Y}を受け取り本物かどうかのスコア D(x, \mathcal Y) = sigmoid(f(x, \mathcal Y))を出力する。

各モダリティが独立と仮定すると下記の式。
f:id:Ninhydrin:20220221094139p:plain

Projection Discriminator(PD)をマルチモーダル用に一般化したMultimodal PD(MPD)を提案(図6)。
f:id:Ninhydrin:20220221094547p:plain

画像、条件それぞれを特徴空間に埋め込み、条件なし項に関しては画像埋め込みをLinearで出力、条件ありの項に関しては条件の埋め込みと内積を取り、それらの和として表現するのがPDだが、MPDでは条件ありの部分を条件の数だけ増やす(式6)。
f:id:Ninhydrin:20220221094657p:plain

セグメンテーションマップとスケッチに関しては空間的なモダリティなのでそれぞれの解像度毎でMPDを行う(Multiscale MPD)
f:id:Ninhydrin:20220221095540p:plain

Losses and training procedure

Latent regularization

PoEの仮定(式1)の元、条件について周辺化したpriorと条件なしのpriorは一致する。
f:id:Ninhydrin:20220221100446p:plain

なので各解像度でKLを最小化。 \omega_iはモダリティに対して、 \omega^kは解像度に対しての重み。
f:id:Ninhydrin:20220221095923p:plain

Contrastive losses

contrastive lossはペアデータのバッチ (\textbf{u}, \textbf{v})=\{(u_i, v_i), i=1,...,N \}が与えられたとき、式9のように、ペアでないものに関しては離したままペア同士は似るようにするloss。
f:id:Ninhydrin:20220221101440p:plain

ここでは画像に関してと条件に関しての2種類のcontrastive lossを利用。
画像に関しては学習済みのVGG encoderを用いて本物の画像 xと、その条件から生成した画像 \tilde xについて類似度を最大化する。
f:id:Ninhydrin:20220221101649p:plain
perceptual lossに似てるが、それよりパフォーマンスが良いらしい。

条件に関しては2つあり、1つ目は本物の画像 \textbf xと条件 \textbf y_iそれぞれの埋め込みに対してで下記。
f:id:Ninhydrin:20220224084555p:plain
 D_x, D_{y_i}はDiscriminatorの中間表現で式6と図6(b)を参照。
このlossはあくまでDiscriminatorのupdate用で、Generatorには本物の画像ではなく生成画像 \tilde xに関してlossをとる。
f:id:Ninhydrin:20220224084955p:plain

最終的なlossは以下。
f:id:Ninhydrin:20220224085050p:plain

 \lambda_Xは重み、 \mathcal{L}_{GP} R_1のgradient penalty。

実験・結果

f:id:Ninhydrin:20220224085237p:plain
f:id:Ninhydrin:20220224085259p:plain
f:id:Ninhydrin:20220224085316p:plain
f:id:Ninhydrin:20220224085352p:plain
f:id:Ninhydrin:20220224085409p:plain

所感

複数のモダリティを条件としたときのcGANの提案。
思っていたよりシンプルだった(全部埋め込みにして生成)。
確かに、テキストやスタイルのが表現しやすいこともあるし、スケッチで表現したいときもある。
参考になった点は
1つ目はモダリティによって入力する場所を分割すること。確かにスケッチやセグメンテーションマップは生成結果の詳細部分への制約になるのでそれなりの解像度を保ったまま入力したい。逆にテキストやスタイルは画像全体に関わることなので入力で十分。StyleGANの生成結果制御でも似た議論があった。
2つ目はKL loss。確かに周辺化すれば一致するよね。
3つ目はLG-AdaIN(AdaINとSPADEの融合)。取り扱う情報の解像度に合わせて2種類を使い分けるのはなるほどとなる。

自分でcGANを使うにあたって参考になることが多かった。