[Paper Review] CGAN : Conditional Generative Adversarial Nets 논문 분석

업데이트:

이번 포스팅에서는 CGAN(Conditional Generative Adeversarial Networks)에 대해 살펴본다.

이 논문은 GAN이 나오고 나서 얼마지나지 않아 발표된 논문이다. 쉽게 쓰여져있고, 길이도 짧아서 가볍게 읽기 좋은 논문이다.

1. Introduction

기존의 GAN은 Adversaria nets을 사용하기 때문에 Markov cahin을 사용하지 않아도 됐고, 오직 gradient를 얻기위한 backpropagtion만이 필요했다. 또한 추론도 필요없었기 때문에 학습이 쉬웠고 SoTA를 달성했다.

이 논문은 기존의 GAN에 conditional information(ex.class labels, images, text descriptions)를 추가하였다.

2. Conditional Adversarial Nets

CGAN의 모델은 매우 간단하다. 기존의 GAN이 다음의 수식을 만족하는 Mini-Max Game이었다면,

\[\min _{G} \max _{D} V(D, G)=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}(\boldsymbol{x})}[\log D(\boldsymbol{x})]+\mathbb{E}_{\boldsymbol{z} \sim p_{z}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z})))]\]

CGAN은 extra infomation인 $y$가 추가된 아래의 식이다.

\[\min _{G} \max _{D} V(D, G)=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}(\boldsymbol{x})}[\log D(\boldsymbol{x} \mid \boldsymbol{y})]+\mathbb{E}_{\boldsymbol{z} \sim p_{z}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z} \mid \boldsymbol{y})))\]

사실 엄청나게 특별한 건 없고, 단지 Marginal distributionJoint distribution으로 변한게 전부이다.

latent vactor z를 one-hot vector로 embedding한 후, 이를 class label y와 cancatenation 해줬다. 또한, 이 논문은 GAN이 나오고 얼마 안된 후에 나온 논문이기 때문에 MLP구조를 사용한다.


3. Experiment Results

실험으로는 Unimodal과 multimodal, 두가지의 실험을 진행했다.

3.1 Unimodal

  • dataset : MNIST 사용

  • Generator nets uniform distribution에서 100차원의 noise prior z를 뽑은 후, ReLU Layer를 이용해서 z는 200차원으로, y는 1000차원으로 mapping한다. 이후, 이 둘을 concat하여 학습을 진행한다.
  • Discriminator
  • Results label 별로 잘 학습된다.

3.2 Multimodal

  • dataset : MIR Flickr 25,000dataset
    • UGM(User-generated metadata) 사용 (labeled data)
  • Results
    • tag들이 잘 생성된다.

4. Opinion

🤔 엄청난 performance를 냈다기보다는, GAN에 conditional infomation을 추가하자라는 idea가 괜찮아서 유명해진 논문 같다. GAN이 나오고 괜찮아보이니까 잽싸게 낸 느낌 ?!


5. Code

5.1 Generator

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = opt.img_size // 4
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

5.2 Discriminator

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = opt.img_size // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity

5.3 Training

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

Reference

  • Naver AI LAB 최윤제 연구원님 발표자료
  • https://github.com/eriklindernoren/PyTorch-GAN

댓글남기기