[Paper Review] StyleGAN2-ADA #02: Training Generative Adversarial Networks with Limited Data 코드 리뷰

업데이트:

View On GitHub

✍🏻 최근에는 이미 있는 모델(pretrained model)을 잘 fine tuning하여 의미있는 결과를 내는 연구가 대세이다. (FreezeD, GANSpace, StyleCLIP 등등)

StyleGAN2-ADA도 이러한 흐름에서 나온 연구로, loss function이나 network의 architecture를 건들이지 않고 이미 학습이 된 GAN을 finetuning하거나 training과정에서 scratch를 내는 식으로 학습을 한다. 또한, 적은 데이터로 학습을 해도 discriminator가 overfitting 되지 않도록 Adaptive Discriminator Augmentation Mechanism을 제안하였다.

⭐ 이번 포스팅에서는 StyleGAN2-ADA의 Official Code를 살펴본다.


Generation Images

generator.py

  • Generate curated MetFaces images without truncation (Fig.10 left)

      $ python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \
          --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
    

    truncation 없이 이미지를 생성하였다. truncation metric은 immediate latent space $W$에서 중요한 부분만을 catch하는 metric이다. 이 metric을 사용하면 다양성은 떨어지지만, quality가 높은 이미지를 생성할 수 있다.

    seed 85 seed 297 seed 849
  • Generate uncurated MetFaces images with truncation (Fig.12 upper left)

      $ python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \
          --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
    
  • Generate class conditional CIFAR-10 images (Fig.17 left, Car)

    • stylegan2-ada에서는 conditional generation도 가능하다. (stylegan2는 x)
    • class label을 또다른 mapping network에 넣어 embedding 한 후에 그 embedding을 $w$와 concate하는 방식으로 진행된다.
    $ python generate.py --outdir=out --seeds=0-35 --class=1 \
        --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl
    

  • Generator.py Code
    • pickle file은 세가지의 network, G, D, G_ema를 포함하고 있다.
      • G_ema는 Generator weight의 평균을 exponential moving한 것으로, EMA(Exponential Moving Average)의 방식을 활용하면 GAN을 안정적으로 학습할 수 있다.

style_mixing.py

  • Style mixing example

      $ python style_mixing.py --outdir=out --rows=85,100,75,458,1500 --cols=55,821,1789,293 \
          --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
    
  • Result

  • style_mixing.py Code

Projecting images to latent space

Projector.py

projector.py 함수를 사용하여 원하는 이미지의 latent vector를 구할 수 있다.

$ python projector.py --outdir=out --target=~/mytargetimg.png \
    --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
  • projector.py Code

위에서 생성한 latent vector를 바탕으로 이미지를 생성할 수도 있다.

$ python generate.py --outdir=out --projected-w=out/projected_w.npz \
    --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
target image Projection image

Training new networks

# Train with custom dataset using 1 GPU.
# dry-run : 중간에 error가 없는지 확인하기 위한 용도
$ python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1 --dry-run
$ python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1

# Train class-conditional CIFAR-10 using 2 GPUs.
$ python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \\
    --gpus=2 --cfg=cifar --cond=1

# Reproduce original StyleGAN2 config F.
$ python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \\
    --gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug

# Transfer learn MetFaces from FFHQ using 4 GPUs.
$ python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \\
    --gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10
Base configs (--cfg):
    auto       Automatically select reasonable defaults based on resolution
                and GPU count. Good starting point for new datasets.
    stylegan2  Reproduce results for StyleGAN2 config F at 1024x1024.
    paper256   Reproduce results for FFHQ and LSUN Cat at 256x256.
    paper512   Reproduce results for BreCaHAD and AFHQ at 512x512.
    paper1024  Reproduce results for MetFaces at 1024x1024.
    cifar      Reproduce results for CIFAR-10 at 32x32.


Transfer learning source networks (--resume):
    ffhq256        FFHQ trained at 256x256 resolution.
    ffhq512        FFHQ trained at 512x512 resolution.
    ffhq1024       FFHQ trained at 1024x1024 resolution.
    celebahq256    CelebA-HQ trained at 256x256 resolution.
    lsundog256     LSUN Dog trained at 256x256 resolution.
    <PATH or URL>  Custom network pickle.

train.py

  • train.pymain function
  • setup_training_loop_kwargs : train.py에서 여러 parameter들을 정의하는 파트

main 함수에서는 subprocess_fn함수를 호출한다. subprocess_fn 함수에서 gpu의 개수에 맞게 세팅을 조정한 후, training_loop함수를 호출하여 학습을 본격적으로 시작 !

training_loop.py

network.py

댓글남기기