[Paper Review] StyleGAN : A Style-Based Generator Architecture for Generative Adversarial Networks 논문 분석

업데이트:

✍🏻 이번 포스팅에서는 style transfer를 PGGAN에 적용한 nvidia research의 StyleGAN model에 대해 살펴본다.


본 model은 discrimintator나 loss function은 건들이지 않고, style을 더 잘 학습시키도록 generator의 architecture를 발전시킨 모델이다. 기존의 PGGAN(ProGAN) 모델을 변형시켜 image합성을 style scale-specific cotrol 할 수 있도록 발전시켰다.

아래의 사진처럼 특정 style을 scaling할 수 있는 것이 style scale-specific control 이다.

또한, 본 논문에서는 interpolation quality와 disentanglement를 측정할 수 있는 2가지 방식에 대해 소개하고 있다.(Section 4)

1. Introduction

GAN의 generator는 아직도 black box같다는 비판을 받고 있다. latent space의 interpolation에 대한 연구는 어느 정도 진행이 되었지만, 아직도 latent space에 관한 이해가 부족하다. (GAN-interpolation은 이 글 참고)

본 논문은 style transfer literature를 기존의 generator에 적용하였다. latent vetor $z$에 style을 넣을 수 있도록 learned constant input $w$(mapping network를 통해 학습된 parmeter)을 조금씩 조정하였다.(style : w1, w2, w3…)

또한, style vector를 넣을 때 noise와 함께 넣어주어서 style간에 correlation이 없도록 했고, 자세나 identity와 같은 unsupervised separation부터 주근깨나 머리같은 stochastic variation까지 자동으로 학습할 수 있도록 하였다.

즉, scale-specific mixing과 interpolation이 가능해졌다.


정리하자면, StyleGAN은 latent space를 disentangle하도록 구조를 약간 수정하였으며, 추가적으로 latent space disentanglement를 측정할 수 있도록 perceptual path length와 linear separability의 2가지 metric을 제안하였다.

Disentanglement

StyleGAN의 architecture에 대해 소개하기 앞서, 이 논문에서 자주 사용되는 표현인 Disentanglement에 대해 설명하고자 한다. (논문의 section4)

다양한 style의 이미지를 생성하려면, variation의 factor를 조정할 수 있어야한다. 즉, style을 조금씩 조정할 때 latent space에서는 linear하게 변하는 것처럼 표현되어야한다.

위의 그림에서 이상적인 경우는 (a)이다. (a)의 latent space에서는 남자에서 여자로 갈때 linear하게 attribute이 변하게 된다.

반면, (b)는 기존의 generator들이 사용하는 방식으로, fixed distribution에서 latent space를 구성한다.(entangle된 상태). 이 경우에는 성별과 머리의 style들이 entangle되어 남자에서 여자로 갈때 머리의 길이도 변하게 된다. 즉 style들이 서로 독립적이지 않기 때문에 이미지를 생성할 때에도 영향을 끼치게 된다.

따라서 목표로 하는 바는 (c)처럼 disentangle하게 latent space를 만드는 것이다. (c)와 같이 latent space를 $\mathcal{W}$로 mapping한다면, 우리는 각 특성의 변화를 linear하게 조절할 수 있게 될 것이다.

2. Style-based Generator

본 논문에서는 style-based generator를 사용한다.

Traditional GAN (그림 (a)) : latent vector가 normalize 되는 구조 를 따른다.

  • PGGAN이나 BigGAN 등이 이러한 구조를 따른다.
    • PGGAN에 대한 설명은 이 글 참고
  • 이러한 방식을 따르게 되면 training data가 latent space의 probability density를 따라야 하기 때문에 entanglement하게 된다. 즉, 주근깨나 머리처럼 stochastic variation을 변경하는데 제한이 생기게 된다. (각 feature들이 correlation을 갖게 되어 주근깨를 넣으려고 했는데, 머리색깔이 바뀌게 되거나 하는 현상이 발생)

StyleGAN은 이러한 문제를 해결하고자 latent vector z를 generator의 input으로 바로 사용하지 않고, latent vector를 non-linear mapping network에 넣어 disentanglement하게 만들고자 하였다. (각 feature들이 상관관계를 갖지 않도록 보정함)


Style-based Generator (그림 (b))

우선 latent space $\mathcal{Z}$에서 non-linear mapping network $f: \mathcal{Z} \rightarrow \mathcal{W}$를 거쳐 style code $\mathbf{w} \in \mathcal{W}$를 생성한다. ($\mathcal{W}$ : intermediate latent space)


이후에는 각 convolutional layer에 있는 adaptive instance normalization(AdaIN)을 사용하여 generator를 조절한다.

  • A-block (learned affine transform) : style code인 $w$를 뽑고나면, affine transformation을 통해 $w$를 style $y = (y_s, y_b)$로 바꿔준다.

    (아핀 변환은 일종의 fc layer이며, 위의 모델에서 A-block은 각기 다른 파라미터를 가진 fc layer)

  • AdaIN : feature map $x_i$를 normalize한 후, style vector $y$로 scaling & biasing

    \[\operatorname{AdaIN}\left(\mathbf{x}_{i}, \mathbf{y}\right)=\mathbf{y}_{s, i} \frac{\mathbf{x}_{i}-\mu\left(\mathbf{x}_{i}\right)}{\sigma\left(\mathbf{x}_{i}\right)}+\mathbf{y}_{b, i}\]
    class AdaptiveInstanceNorm(nn.Module):
        def __init__(self, in_channel, style_dim):
            super().__init__()
    
            self.norm = nn.InstanceNorm2d(in_channel)
            # A-block : 각 conv마다 channel의 개수가 다르기 때문에 A-block들은 서로 다른 fc layer
            self.style = EqualLinear(style_dim, in_channel * 2)
    
            self.style.linear.bias.data[:in_channel] = 1
            self.style.linear.bias.data[in_channel:] = 0
    
        def forward(self, input, style):
            style = self.style(style).unsqueeze(2).unsqueeze(3)
            gamma, beta = style.chunk(2, 1)
    
            # Instance Normalize
            out = self.norm(input)
            # scaling & bias
            out = gamma * out + beta
    
            return out
    
    


또한, 특이한 점은 더이상 첫번째 convolution layer의 input을 latent code에서 feeding하지 않는다는 점이다. 본 논문에서는 traditional input layer를 제거하고 learned constant tensor 4 x 4 x 512에서 이미지 합성을 시작한다.

  • 쉽게 말하자면, latent code를 synthesis network $g$에 넣어 학습을 시키면 synthesis network의 입력으로 주었던 latent code $z$는 의미 없어진다.
  • input 자체를 학습시켰기 때문에 Learned라는 표현을 쓰며, inference시에 learned value들이 고정되어있기 때문에 Constant라는 표현을 사용하였다.
  • 여러 가지 사진을 만들때, 왼쪽의 네트워크만 바꿔주고 오른쪽의 synthesis network는 고정한다.
class ConstantInput(nn.Module):
    def __init__(self, channel, size=4):
        super().__init__()

        self.input = nn.Parameter(torch.randn(1, channel, size, size))

    def forward(self, input):
        batch = input.shape[0]
        out = self.input.repeat(batch, 1, 1, 1)

        return out

B-block (Stochastic Variation) : 마지막으로는 stochastic detail을 생성하기 위해 각 convolution layer후에 per-pixel noise input을 넣어준다.

  • Gaussian에서 noise를 sampling한 후, channel scaling factor를 통해 dimmension을 맞춰준다.
  • 이때 noise값을 multiplication하게 되면 결과값이 이 noise에 너무 dependent하게 되기 때문에 단순히 add만 해준다. (noise를 더해주는 건 머리카락의 움직임과 같은 미세한 변화만을 조정해주기 위해 더해주는 것)
class NoiseInjection(nn.Module):
    def __init__(self, channel):
        super().__init__()

        self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))

    def forward(self, image, noise):
        return image + self.weight * noise
  • 다음 그림을 보면, noise를 변화시킬 때마다 머리카락과 같은 stochastic detail들이 조금씩 변화한다.

  • 또한, Figure5를 보면 각 layer마다 더해지는 noise가 각기 다른 stochastic detail에 관여함을 확인할 수 있다.

  • 정리하자면, noise들은 local한 변화에만 영향을 주고 전체적인 특성을 바꾸지는 못한다.

⭐ A-block은 전체적인 global attribute이나 identity를 결정하고, B-block은 stochastic detail을 결정한다.

  • 우리가 global attribute를 latent space로 modeling하려면 spatial correlation을 고려해야하지만, stochastic detail을 modeling할 때에는 전체적인 특성을 고려하지 않아도 된다. 따라서 detail한 변화를 주고싶을 때에는 spatially independent noise를 사용한다.

3. Properties of the style-based generator

⭐ Our generator architecture makes it possible to control the image synthesis via scale-specific modifications to the styles.

We can view the mapping network and affine transformations as a way to draw samples for each style from a learned distribution, and the synthesis network as a way to generate a novel image based on a collection of styles

3.1 Style Mixing

style mixing은 단어 그대로 style을 섞어주는 것이다. latent code 2개를 sampling ($z_1, z_2$)한 후, mapping network를 거쳐 2개의 style code($w_1, w_2$)를 만드는 것이다. 이후 어떠한 기준점을 정한 후, 두 개의 스타일을 넣어주면 style이 섞이게 된다

이렇게 mixing regularization을 하면 style들이 correlation되지 않게 된다.


4. Disentanglement studies

latent space의 interpolation 정도를 측정하기 위한 2가지의 metric을 소개하낟.

4.1 Perceptual path length(PPL)

The average perceptual path length in latent space $\mathcal{Z}$

  • $\mathbf{z}{1}, \mathbf{z}{2} \sim P(\mathbf{z}), t \sim U(0,1)$
  • $\epsilon=10^{-4}$
  • $\mathcal{Z}$ 상에서 interpolation을 할 때에는 spherical interpolation(slerp)을 사용한다.
  • latent space에서 굉장히 가까이에 있는 이미지를 2개 생성한 후에 이 이미지들간의 거리를 측정한다. 이 값이 작게 나올 수록 이미지들이 비슷한거니까 disentangle되었다고 생각한다.
  • epsilon의 값에 따라 차이가 나지 않도록 epsilon의 값에 따라 normalize를 해주며, 10만개의 sample에 대해 평균값을 계산한다.
\[\begin{array}{r} l_{\mathcal{Z}}=\mathbb{E}\left[\frac { 1 } { \epsilon ^ { 2 } } d \left(G\left(\operatorname{slerp}\left(\mathbf{z}_{1}, \mathbf{z}_{2} ; t\right)\right)\right.\right.,\left.\left.G\left(\operatorname{slerp}\left(\mathbf{z}_{1}, \mathbf{z}_{2} ; t+\epsilon\right)\right)\right)\right] \end{array}\]

The average perceptual path length in latent space $\mathcal{W}$ \(l_{\mathcal{W}}=\mathbb{E}\left[\frac { 1 } { \epsilon ^ { 2 } } d \left(\begin{array}{l} g\left(\operatorname{lerp}\left(f\left(\mathbf{z}_{1}\right), f\left(\mathbf{z}_{2}\right) ; t\right)\right),\left.\left.g\left(\operatorname{lerp}\left(f\left(\mathbf{z}_{1}\right), f\left(\mathbf{z}_{2}\right) ; t+\epsilon\right)\right)\right)\right] \end{array}\right.\right.\)

$\mathcal{Z}$보다는 $\mathcal{W}$를 사용했을 때, 일반 모델보다는 AdaIN을 사용한 style-based model을 사용했을 때 disentanglement하다.

4.2 Linear separability

생략

5. Code

5.1 Generator

class StyledConvBlock(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size=3,
        padding=1,
        style_dim=512,
        initial=False,
        upsample=False,
        fused=False,
    ):
        super().__init__()

        # Learned constant tensor (4x4x512)
        if initial:
            self.conv1 = ConstantInput(in_channel)

        else:
            if upsample:
                if fused:
                    self.conv1 = nn.Sequential(
                        FusedUpsample(
                            in_channel, out_channel, kernel_size, padding=padding
                        ),
                        Blur(out_channel),
                    )

                else:
                    self.conv1 = nn.Sequential(
                        nn.Upsample(scale_factor=2, mode='nearest'),
                        EqualConv2d(
                            in_channel, out_channel, kernel_size, padding=padding
                        ),
                        Blur(out_channel),
                    )

            else:
                self.conv1 = EqualConv2d(
                    in_channel, out_channel, kernel_size, padding=padding
                )

        self.noise1 = equal_lr(NoiseInjection(out_channel))
        self.adain1 = AdaptiveInstanceNorm(out_channel, style_dim)
        self.lrelu1 = nn.LeakyReLU(0.2)

        self.conv2 = EqualConv2d(out_channel, out_channel, kernel_size, padding=padding)
        self.noise2 = equal_lr(NoiseInjection(out_channel))
        self.adain2 = AdaptiveInstanceNorm(out_channel, style_dim)
        self.lrelu2 = nn.LeakyReLU(0.2)

    def forward(self, input, style, noise):
        out = self.conv1(input)
        out = self.noise1(out, noise)
        out = self.lrelu1(out)
        out = self.adain1(out, style)

        out = self.conv2(out)
        out = self.noise2(out, noise)
        out = self.lrelu2(out)
        out = self.adain2(out, style)

        return out
  • EqualConv2d
class EqualConv2d(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

        conv = nn.Conv2d(*args, **kwargs)
        conv.weight.data.normal_()
        conv.bias.data.zero_()
        self.conv = equal_lr(conv)

    def forward(self, input):
        return self.conv(input)
  • Generator

StyleGAN은 PGGAN architecture를 따른다.

class Generator(nn.Module):
    def __init__(self, code_dim, fused=True):
        super().__init__()

        self.progression = nn.ModuleList(
            [
                StyledConvBlock(512, 512, 3, 1, initial=True),  # 4
                StyledConvBlock(512, 512, 3, 1, upsample=True),  # 8
                StyledConvBlock(512, 512, 3, 1, upsample=True),  # 16
                StyledConvBlock(512, 512, 3, 1, upsample=True),  # 32
                StyledConvBlock(512, 256, 3, 1, upsample=True),  # 64
                StyledConvBlock(256, 128, 3, 1, upsample=True, fused=fused),  # 128
                StyledConvBlock(128, 64, 3, 1, upsample=True, fused=fused),  # 256
                StyledConvBlock(64, 32, 3, 1, upsample=True, fused=fused),  # 512
                StyledConvBlock(32, 16, 3, 1, upsample=True, fused=fused),  # 1024
            ]
        )

        self.to_rgb = nn.ModuleList(
            [
                EqualConv2d(512, 3, 1),
                EqualConv2d(512, 3, 1),
                EqualConv2d(512, 3, 1),
                EqualConv2d(512, 3, 1),
                EqualConv2d(256, 3, 1),
                EqualConv2d(128, 3, 1),
                EqualConv2d(64, 3, 1),
                EqualConv2d(32, 3, 1),
                EqualConv2d(16, 3, 1),
            ]
        )

        # self.blur = Blur()

    def forward(self, style, noise, step=0, alpha=-1, mixing_range=(-1, -1)):
        out = noise[0]

        if len(style) < 2:
            inject_index = [len(self.progression) + 1]

        else:
            inject_index = sorted(random.sample(list(range(step)), len(style) - 1))

        crossover = 0

        for i, (conv, to_rgb) in enumerate(zip(self.progression, self.to_rgb)):
            if mixing_range == (-1, -1):
                if crossover < len(inject_index) and i > inject_index[crossover]:
                    crossover = min(crossover + 1, len(style))

                style_step = style[crossover]

            else:
                if mixing_range[0] <= i <= mixing_range[1]:
                    style_step = style[1]

                else:
                    style_step = style[0]

            if i > 0 and step > 0:
                out_prev = out
                
            out = conv(out, style_step, noise[i])

            if i == step:
                out = to_rgb(out)

                if i > 0 and 0 <= alpha < 1:
                    skip_rgb = self.to_rgb[i - 1](out_prev)
                    skip_rgb = F.interpolate(skip_rgb, scale_factor=2, mode='nearest')
                    out = (1 - alpha) * skip_rgb + alpha * out

                break

        return out
  • Styled Generator
class StyledGenerator(nn.Module):
    def __init__(self, code_dim=512, n_mlp=8):
        super().__init__()

        self.generator = Generator(code_dim)

        layers = [PixelNorm()]
        for i in range(n_mlp):
            layers.append(EqualLinear(code_dim, code_dim))
            layers.append(nn.LeakyReLU(0.2))

        self.style = nn.Sequential(*layers)

    def forward(
        self,
        input,
        noise=None,
        step=0,
        alpha=-1,
        mean_style=None,
        style_weight=0,
        mixing_range=(-1, -1),
    ):
        styles = []
        if type(input) not in (list, tuple):
            input = [input]

        for i in input:
            styles.append(self.style(i))

        batch = input[0].shape[0]

        if noise is None:
            noise = []

            for i in range(step + 1):
                size = 4 * 2 ** i
                noise.append(torch.randn(batch, 1, size, size, device=input[0].device))

        if mean_style is not None:
            styles_norm = []

            for style in styles:
                styles_norm.append(mean_style + style_weight * (style - mean_style))

            styles = styles_norm

        return self.generator(styles, noise, step, alpha, mixing_range=mixing_range)

    def mean_style(self, input):
        style = self.style(input).mean(0, keepdim=True)

        return style

5.2 Discriminator

class Discriminator(nn.Module):
    def __init__(self, fused=True, from_rgb_activate=False):
        super().__init__()

        self.progression = nn.ModuleList(
            [
                ConvBlock(16, 32, 3, 1, downsample=True, fused=fused),  # 512
                ConvBlock(32, 64, 3, 1, downsample=True, fused=fused),  # 256
                ConvBlock(64, 128, 3, 1, downsample=True, fused=fused),  # 128
                ConvBlock(128, 256, 3, 1, downsample=True, fused=fused),  # 64
                ConvBlock(256, 512, 3, 1, downsample=True),  # 32
                ConvBlock(512, 512, 3, 1, downsample=True),  # 16
                ConvBlock(512, 512, 3, 1, downsample=True),  # 8
                ConvBlock(512, 512, 3, 1, downsample=True),  # 4
                ConvBlock(513, 512, 3, 1, 4, 0),
            ]
        )

        def make_from_rgb(out_channel):
            if from_rgb_activate:
                return nn.Sequential(EqualConv2d(3, out_channel, 1), nn.LeakyReLU(0.2))

            else:
                return EqualConv2d(3, out_channel, 1)

        self.from_rgb = nn.ModuleList(
            [
                make_from_rgb(16),
                make_from_rgb(32),
                make_from_rgb(64),
                make_from_rgb(128),
                make_from_rgb(256),
                make_from_rgb(512),
                make_from_rgb(512),
                make_from_rgb(512),
                make_from_rgb(512),
            ]
        )

        # self.blur = Blur()

        self.n_layer = len(self.progression)

        self.linear = EqualLinear(512, 1)

    def forward(self, input, step=0, alpha=-1):
        for i in range(step, -1, -1):
            index = self.n_layer - i - 1

            if i == step:
                out = self.from_rgb[index](input)

            if i == 0:
                out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8)
                mean_std = out_std.mean()
                mean_std = mean_std.expand(out.size(0), 1, 4, 4)
                out = torch.cat([out, mean_std], 1)

            out = self.progression[index](out)

            if i > 0:
                if i == step and 0 <= alpha < 1:
                    skip_rgb = F.avg_pool2d(input, 2)
                    skip_rgb = self.from_rgb[index + 1](skip_rgb)

                    out = (1 - alpha) * skip_rgb + alpha * out

        out = out.squeeze(2).squeeze(2)
        # print(input.size(), out.size(), step)
        out = self.linear(out)

        return out

5.3 Training

def train(args, dataset, generator, discriminator):
    step = int(math.log2(args.init_size)) - 2
    resolution = 4 * 2 ** step
    loader = sample_data(
        dataset, args.batch.get(resolution, args.batch_default), resolution
    )
    data_loader = iter(loader)

    adjust_lr(g_optimizer, args.lr.get(resolution, 0.001))
    adjust_lr(d_optimizer, args.lr.get(resolution, 0.001))

    pbar = tqdm(range(3_000_000))

    requires_grad(generator, False)
    requires_grad(discriminator, True)

    disc_loss_val = 0
    gen_loss_val = 0
    grad_loss_val = 0

    alpha = 0
    used_sample = 0

    max_step = int(math.log2(args.max_size)) - 2
    final_progress = False

    for i in pbar:
        discriminator.zero_grad()

        alpha = min(1, 1 / args.phase * (used_sample + 1))

        if (resolution == args.init_size and args.ckpt is None) or final_progress:
            alpha = 1

        if used_sample > args.phase * 2:
            used_sample = 0
            step += 1

            if step > max_step:
                step = max_step
                final_progress = True
                ckpt_step = step + 1

            else:
                alpha = 0
                ckpt_step = step

            resolution = 4 * 2 ** step

            loader = sample_data(
                dataset, args.batch.get(resolution, args.batch_default), resolution
            )
            data_loader = iter(loader)

            torch.save(
                {
                    'generator': generator.module.state_dict(),
                    'discriminator': discriminator.module.state_dict(),
                    'g_optimizer': g_optimizer.state_dict(),
                    'd_optimizer': d_optimizer.state_dict(),
                    'g_running': g_running.state_dict(),
                },
                f'checkpoint/train_step-{ckpt_step}.model',
            )

            adjust_lr(g_optimizer, args.lr.get(resolution, 0.001))
            adjust_lr(d_optimizer, args.lr.get(resolution, 0.001))

        try:
            real_image = next(data_loader)

        except (OSError, StopIteration):
            data_loader = iter(loader)
            real_image = next(data_loader)

        used_sample += real_image.shape[0]

        b_size = real_image.size(0)
        real_image = real_image.cuda()

        if args.loss == 'wgan-gp':
            real_predict = discriminator(real_image, step=step, alpha=alpha)
            real_predict = real_predict.mean() - 0.001 * (real_predict ** 2).mean()
            (-real_predict).backward()

        elif args.loss == 'r1':
            real_image.requires_grad = True
            real_scores = discriminator(real_image, step=step, alpha=alpha)
            real_predict = F.softplus(-real_scores).mean()
            real_predict.backward(retain_graph=True)

            grad_real = grad(
                outputs=real_scores.sum(), inputs=real_image, create_graph=True
            )[0]
            grad_penalty = (
                grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2
            ).mean()
            grad_penalty = 10 / 2 * grad_penalty
            grad_penalty.backward()
            if i%10 == 0:
                grad_loss_val = grad_penalty.item()

        if args.mixing and random.random() < 0.9:
            gen_in11, gen_in12, gen_in21, gen_in22 = torch.randn(
                4, b_size, code_size, device='cuda'
            ).chunk(4, 0)
            gen_in1 = [gen_in11.squeeze(0), gen_in12.squeeze(0)]
            gen_in2 = [gen_in21.squeeze(0), gen_in22.squeeze(0)]

        else:
            gen_in1, gen_in2 = torch.randn(2, b_size, code_size, device='cuda').chunk(
                2, 0
            )
            gen_in1 = gen_in1.squeeze(0)
            gen_in2 = gen_in2.squeeze(0)

        fake_image = generator(gen_in1, step=step, alpha=alpha)
        fake_predict = discriminator(fake_image, step=step, alpha=alpha)

        if args.loss == 'wgan-gp':
            fake_predict = fake_predict.mean()
            fake_predict.backward()

            eps = torch.rand(b_size, 1, 1, 1).cuda()
            x_hat = eps * real_image.data + (1 - eps) * fake_image.data
            x_hat.requires_grad = True
            hat_predict = discriminator(x_hat, step=step, alpha=alpha)
            grad_x_hat = grad(
                outputs=hat_predict.sum(), inputs=x_hat, create_graph=True
            )[0]
            grad_penalty = (
                (grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1) - 1) ** 2
            ).mean()
            grad_penalty = 10 * grad_penalty
            grad_penalty.backward()
            if i%10 == 0:
                grad_loss_val = grad_penalty.item()
                disc_loss_val = (-real_predict + fake_predict).item()

        elif args.loss == 'r1':
            fake_predict = F.softplus(fake_predict).mean()
            fake_predict.backward()
            if i%10 == 0:
                disc_loss_val = (real_predict + fake_predict).item()

        d_optimizer.step()

        if (i + 1) % n_critic == 0:
            generator.zero_grad()

            requires_grad(generator, True)
            requires_grad(discriminator, False)

            fake_image = generator(gen_in2, step=step, alpha=alpha)

            predict = discriminator(fake_image, step=step, alpha=alpha)

            if args.loss == 'wgan-gp':
                loss = -predict.mean()

            elif args.loss == 'r1':
                loss = F.softplus(-predict).mean()

            if i%10 == 0:
                gen_loss_val = loss.item()

            loss.backward()
            g_optimizer.step()
            accumulate(g_running, generator.module)

            requires_grad(generator, False)
            requires_grad(discriminator, True)

        if (i + 1) % 100 == 0:
            images = []

            gen_i, gen_j = args.gen_sample.get(resolution, (10, 5))

            with torch.no_grad():
                for _ in range(gen_i):
                    images.append(
                        g_running(
                            torch.randn(gen_j, code_size).cuda(), step=step, alpha=alpha
                        ).data.cpu()
                    )

            utils.save_image(
                torch.cat(images, 0),
                f'sample/{str(i + 1).zfill(6)}.png',
                nrow=gen_i,
                normalize=True,
                range=(-1, 1),
            )

        if (i + 1) % 10000 == 0:
            torch.save(
                g_running.state_dict(), f'checkpoint/{str(i + 1).zfill(6)}.model'
            )

        state_msg = (
            f'Size: {4 * 2 ** step}; G: {gen_loss_val:.3f}; D: {disc_loss_val:.3f};'
            f' Grad: {grad_loss_val:.3f}; Alpha: {alpha:.5f}'
        )

        pbar.set_description(state_msg)


if __name__ == '__main__':
    code_size = 512
    batch_size = 16
    n_critic = 1

    parser = argparse.ArgumentParser(description='Progressive Growing of GANs')

    parser.add_argument('path', type=str, help='path of specified dataset')
    parser.add_argument(
        '--phase',
        type=int,
        default=600_000,
        help='number of samples used for each training phases',
    )
    parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
    parser.add_argument('--sched', action='store_true', help='use lr scheduling')
    parser.add_argument('--init_size', default=8, type=int, help='initial image size')
    parser.add_argument('--max_size', default=1024, type=int, help='max image size')
    parser.add_argument(
        '--ckpt', default=None, type=str, help='load from previous checkpoints'
    )
    parser.add_argument(
        '--no_from_rgb_activate',
        action='store_true',
        help='use activate in from_rgb (original implementation)',
    )
    parser.add_argument(
        '--mixing', action='store_true', help='use mixing regularization'
    )
    parser.add_argument(
        '--loss',
        type=str,
        default='wgan-gp',
        choices=['wgan-gp', 'r1'],
        help='class of gan loss',
    )

    args = parser.parse_args()

    generator = nn.DataParallel(StyledGenerator(code_size)).cuda()
    discriminator = nn.DataParallel(
        Discriminator(from_rgb_activate=not args.no_from_rgb_activate)
    ).cuda()
    g_running = StyledGenerator(code_size).cuda()
    g_running.train(False)

    g_optimizer = optim.Adam(
        generator.module.generator.parameters(), lr=args.lr, betas=(0.0, 0.99)
    )
    g_optimizer.add_param_group(
        {
            'params': generator.module.style.parameters(),
            'lr': args.lr * 0.01,
            'mult': 0.01,
        }
    )
    d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.0, 0.99))

    accumulate(g_running, generator.module, 0)

    if args.ckpt is not None:
        ckpt = torch.load(args.ckpt)

        generator.module.load_state_dict(ckpt['generator'])
        discriminator.module.load_state_dict(ckpt['discriminator'])
        g_running.load_state_dict(ckpt['g_running'])
        g_optimizer.load_state_dict(ckpt['g_optimizer'])
        d_optimizer.load_state_dict(ckpt['d_optimizer'])

    transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ]
    )

    dataset = MultiResolutionDataset(args.path, transform)

    if args.sched:
        args.lr = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
        args.batch = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 32, 256: 32}

    else:
        args.lr = {}
        args.batch = {}

    args.gen_sample = {512: (8, 4), 1024: (4, 2)}

    args.batch_default = 32

    train(args, dataset, generator, discriminator)

6. Result

6.1 Quality of generated images

CELEBA-HG와 이 논문을 통해 Nvidia가 새롭게 공개한 데이터셋 FFHQ을 사용하여 기존의 baseline과 비교를 했다.

baseline인 PGGAN 모델과 비교했을 때 stylegan이 더 좋은 성능을 냄을 확인할 수 있다.

6.2 Prior art

  • discriminator를 향상시키려던 연구들
    • multiple discriminators
    • multi-resolution discrimination
    • self-attention
  • generator를 향상시키려는 연구들
    • the exact distribution in the input latent space [5] or
    • shaping the input latent space via Gaussian mixture models
    • clustering
    • encouraging convexity

7. Conclusion

StyleGAN은 high-level attributes와 stochastic effects를 잘 나눠서 학습시키고, intermediate latent space를 linear하게 만듦으로써 고해상도의 이미지를 생성할 수 있도록 하였다.

8. Opinion

✍🏻 그동안은 잘 생성하지 못했던 고해상도의 이미지를 생성하는 방법을 제시한 놀라운 모델인 것 같다.(style들의 종류를 attribute과 stochastic detail로 나눈 후 각각의 생성방식을 전개한 것도 흥미롭다.) 새로운 dataset을 공개하고, architecture 뿐만 아니라 disentanglement까지 측정하는 방법에 대해 제안을 한 것을 보면 nvidia에서 이를 갈고 낸 논문이라고 생각이 든다.

PGGAN을 발전시킨 방식도 굉장히 흥미로웠다. 이 논문의 후속 연구인 stylegan2도 궁금하다!


Reference

댓글남기기