[Paper Review] GAN-CLS : Generative Adversarial Text to Image Synthesis 논문 분석

업데이트:

✍🏻 이번 포스팅에서는 text로 image를 생성하는 Generative Adversarial Text to Image Synthesis model에 대해 살펴본다.

이 논문이 나왔을 때만 해도 text에서 image를 합성하는 연구가 활발하지 않았었다. 이 논문에서는 당시에 자주 사용하던 architecture인 GAN과 deep architectrue를 사용하여 text to image synthesis 모델을 만들었다.

1. Introduction

본 논문은 text에서 image를 합성하는 모델을 학습한다. 즉, word나 charactor를 image pixel로 mapping 해야 한다.

⭐ Challenging Problem

  • text에서 중요한 visual details을 추출
  • 추출된 정보로부터 진짜같은 이미지를 합성
  • Multimodal Learning
  • Deep convolution decoder network
  • GAN(Generative Adversarial Network)

3. Background

3.1 Generative Adversarial Networks

GAN은 Generator와 Discriminator가 경쟁을 하면서 학습을 하는 minimax game model이다. (GAN에 대한 설명은 이 글을 참고)

\[\begin{aligned} \min _{G} \max _{D} V(D, G)=& \mathbb{E}_{x \sim p_{\text {data }}(x)}[\log D(x)]+\\ & \mathbb{E}_{x \sim p_{z}(z)}[\log (1-D(G(z)))] \end{aligned}\]

3.2 Deep symmetric structured joint embedding

  • 맨 아래의 사진이 text와 가장 비슷하므로 loss가 0에 가까움

  • Classfier
    • $f_{t}$ : text classifier

      \[\left.f_{v}(v)=\underset{y \in \mathcal{Y}}{\arg \max } \mathbb{E}_{t \sim \mathcal{T}(y)}\left[\phi(v)^{T} \varphi(t)\right)\right]\]
    • $f_{v}$ : visual classifier

      \[\left.f_{t}(t)=\underset{y \in \mathcal{Y}}{\arg \max } \mathbb{E}_{v \sim \mathcal{V}(y)}\left[\phi(v)^{T} \varphi(t)\right)\right]\]
    • $v_n$ : image / $\phi$ : image encoder(ex CNN)
    • $t_n$ : corresponding text description / $\varphi$ : text encoder (ex LSTM)
    • $y_n$ : cass label
    • 특정 이미지가 들어갔을 때 가장 비슷한 text를 이끌어내는 분류기
  • Structure Loss

    \[\frac{1}{N} \sum_{n=1}^{N} \Delta\left(y_{n}, f_{v}\left(v_{n}\right)\right)+\Delta\left(y_{n}, f_{t}\left(t_{n}\right)\right)\]

    이미지 분류기로부터 나온 값($f_{v}\left(v_{n}\right)$, text값이 나옴)과 실제 값($y_{n}$)간의 loss를 계산하고,

    text 분류기로 나온 값(image값이 나옴)과 실제 값과의 비교하여 loss를 계산한다.

4. Method

⭐ 본 논문에서는 Convolutional RNN으로 encoding한 text features와 DCGAN을 활용해서 이미지를 합성한다.


4.1 Network Architecture

  • Generator : $\mathbb{R}^{Z} \times \mathbb{R}^{T} \rightarrow \mathbb{R}^{D}$
  • Discriminator : $\mathbb{R}^{D} \times \mathbb{R}^{T} \rightarrow{0,1}$
    • T : dim of text description embedding
    • D : dim of text image embedding
    • Z : dim of noise ($z \in \mathbb{R}^{Z} \sim \mathcal{N}(0,1)$)

이 논문의 architecture는 DCGAN을 변형하여 만들었다.

Generator

  1. text encoder $\varphi$를 사용해서 text query $t$를 encoding 한다.

  2. encoding의 결과값($\varphi(t)$)을 FC layer에 넣어 compress한 후 Leaky-ReLU를 사용해서 128-dim의 작은 차원으로 compression한다.

  3. 이 값을 noise vector z와 concate한후 deconvolutional network를 통해 generate image를 얻는다.

Discriminator

  1. stride-2 convolution layer와 BN기법, leaky ReLU function을 이용해서 학습을 한다.

  2. 1번의 과정을 4x4 conv layer가 될때까지 반복한다. (여러개의 layer를 쌓음)

  3. 4x4 conv layer가 되면 compressiong된 embedding vector $\varphi$를 여러개 복사해서 conv layer 뒤에 이어붙인다.(depth concatenation)

  4. 1x1 conv layer가 되도록 연산을 한 후 final score를 얻는다.

  • 이때 모든 conv layer에 대해서 BN을 해준다.

4.2 Matching-aware discriminator (GAN-CLS)

conditional GAN은 discriminator가 (text, images) pair가 진짜인지 가짜인지 판단하도록 학습한다. 이때 문제점은 discriminator는 real training image가 어떤 text embedding context와 match되는지 모른다는 점이다.

즉, real image가 자기를 설명하지 않은 text와 match될 수도 있다.(mismatch)

따라서 (real image, mismatched text term)을 추가하도록 GAN training algorithm을 수정하여 discriminator가 fake에 대해서도 학습을 할 수 있게 한다.

기존의 GAN은 위의 algorithm에서 7,9의 score만 있었다면 이제는 line 8에 있는 score도 추가된다.


4.3 Learning with manifold interpolation (GAN-INT)

우리가 네트워크를 학습시킬 때에는 주어진 text만을 가지고 image를 만들지만, 실제로 이를 사용할 때에는 training dataset에 없는 text를 줘도 image를 생성할 수 있어야한다.

즉, training dataset에 있는 text와 비슷한 text를 입력했을 때에도 image를 생성해야하기 때문에 interpolation의 방식을 사용하면 조금 더 효과적으로 학습을 할 수 있게 된다.

  • 딥러닝 네트워크는 embedding pair사이의 interpolation에서 representation을 학습한다.
    • interpolation(보간)은 data manifold 근처에서 생김
    • (text1, image1), (text2, image2)가 있을 때 text1과 text2가 space내에서 가까이에 있다면 image1을 학습할 때 text1의 feature외에도 text2의 feature가 사용될 수 있다. text1과 text2의 보간값을 가지고 학습이 가능하다.
  • 따라서 interpolated text embedding($\beta t_{1}+(1-\beta) t_{2}$)을 사용하기 위해 아래의 식을 사용한다.
    • 이때 interpolated text embedding은 사람이 쓴 text는 아니다. text1과 text2의 보간된 embedding값이 text3라면 text3를 사용하게 되는 것!

      \[\mathbb{E}_{t_{1}, t_{2} \sim p_{\text {data }}}\left[\log \left(1-D\left(G\left(z, \beta t_{1}+(1-\beta) t_{2}\right)\right)\right)\right]\]
    • generator는 t1과 t2의 보간값으로 학습을 한다.
    • 보통 $\beta$로는 0.5를 사용한다.

4.3 Inverting the generator for style transfer

text encoding $\varphi(t)$는 image content(ex. flower shape, colors)를 찾아내는 역할을 한다. (text에서 이미지로 표현할 만한 feature들을 추출) 이미지를 생성하려면 사물의 특징들을 잘 추출하여 생성하는 것도 중요하지만, 그 외의 정보들(ex.배경이나 자세)들을 생성하는 것도 중요하다.

우리는 이 정보들을(object외의 정보)를 style factor라고 부르는데, noise sample z에서 style factor를 만드려면 style transfer를 학습시키는 과정이 필요하다.

style transfer는 기존의 GAN과 다르다. 기존에는 $\hat{x} \leftarrow G(z, \varphi(t))$를 학습했다면, style transfer에서는 $\hat{x}$에서 $z$로 거꾸로 훈련을 한다.

정리하자면, Style transfer는 query image를 text(style factor)로 바꿔주는 기법이다.

loss함수로는 다음의 simple squared loss를 사용한다. (S는 style encoder network)

\[\mathcal{L}_{\text {style }}=\mathbb{E}_{t, z \sim \mathcal{N}(0,1)}\|z-S(G(z, \varphi(t)))\|_{2}^{2}\]

5. Experiments

Dataset

  • CUB dataset (bird)
  • Oxford-102 dataset (flower)

Structure

  • Text encoder : pre-train deep convolutional-recurrent
    • pre-train된 model을 사용한 건, 연산 속도를 빠르게 하려고 한거지 모델의 성능에 영향을 끼치지는 않는다.
  • Image encoder : 1024-dim GoogleLeNet image embedding

Hyper-parameter는 논문 참고

5.1 Qualitative Results

GAN-CLS와 GAN-INT를 동시에 적용하면 좋은 결과를 냄을 확인할 수 있다.


5.2 Disentangling style and content / Pose and back ground style transfer

text embedding/text caption는 image content(object의 특징들)에 대한 정보는 갖고 있지만, style information에 대한 정보들은 없다. 5.2 실험은 style information을 찾기 위한 실험이다.

이미지의 similar/dissimilar를 만든 다음, style encoder를 통해 style vector를 예측한다.(4.3처럼)

이를 위해 본 실험에서는 K-mean을 사용해서 image를 100개의 cluster로 grouping했다. 학습을 시킨다면, 비슷한 style을 가진 이미지들의 similarity는 다른 스타일을 가진 similarity보다 클 것이다.

  • Background color : RGB channel로 cluster
  • Bird Pose : 6 keypoint(beak, belly, breast, crown, forehead, and tail)로 cluster

위의 그림을 보면 이미지 생성시에 style을 잘 적용됨을 확인할 수 있다.


5.3 Sentence interpolation

아래의 그림은 Generator가 학습을 할 때 Interpolation(4.2) text maniford에서 학습을 함을 증명한다.

Text Interpolation(Left)

  • noise distribution은 유지하면서 text embedding만을 바꾸면 왼쪽 그림처럼 이미지가 변한다.

Noise Interpolation(Right)

  • text encoding은 고정하고 두 가지의 noise vector만을 바꿔봤을 때 이 결과 역시 smooth하게 변함을 확인할 수 있다.

5.4 Beyond birds and flowers

MS-COCO dataset을 이용하여 GAN-CLS를 학습시켜봤다.

text encoder/GAN architecture, hyperparameter는 CUB, Oxford-102를 학습시킬 때와 동일한 것을 사용했다. 다만, MS-COCO dataset은 class당 object가 여러개이기 때문에 text encoding을 할 때 category level이 아닌 instance level에서 image와 matching을 했다.

일부는 학습이 잘되지만, 일부는 학습이 잘 안된다. 객체가 많은(text가 많은) dataset을 학습에 대해서는 추가적인 연구가 필요⭐

6. Conclusion

  • visual description으로 부터 image를 생성하는 간단하고 효율적인 model을 제안
  • visual interpretation을 사용하면 조금 더 좋은 성능을 낼 수 있다.
  • Style과 content를 분리하는 모델을 제안
  • 이미지 생성이 어려운 MS-COCO dataset에서도 학습이 가능하다.

7. Opinion

아이디어도 신선하고, 실험도 다양한 좋은 논문인 것 같다. text to image 관련 후속 논문들도 읽어보고 싶다!


Reference

  • https://www.youtube.com/watch?v=M6E6ne4PSi0
  • Code : https://github.com/aelnouby/Text-to-Image-Synthesis

댓글남기기