[Paper Review] StyleGAN2-ADA #02: Training Generative Adversarial Networks with Limited Data 코드 리뷰
업데이트:
✍🏻 최근에는 이미 있는 모델(
pretrained model
)을 잘fine tuning
하여 의미있는 결과를 내는 연구가 대세이다. (FreezeD, GANSpace, StyleCLIP 등등)StyleGAN2-ADA도 이러한 흐름에서 나온 연구로, loss function이나 network의 architecture를 건들이지 않고 이미 학습이 된 GAN을
finetuning
하거나 training과정에서scratch
를 내는 식으로 학습을 한다. 또한, 적은 데이터로 학습을 해도 discriminator가 overfitting 되지 않도록Adaptive Discriminator Augmentation Mechanism
을 제안하였다.⭐ 이번 포스팅에서는 StyleGAN2-ADA의 Official Code를 살펴본다.
-
Paper : Training Generative Adversarial Networks with Limited Data (NeurlPS 2020 /Tero Karras, Miika Aittala, Janne Hellsten, Samuli Laine, Jaakko Lehtinen, Timo Aila)
- 😎 StyleGAN Posting
- GAN-Zoos! (GAN 포스팅 모음집)
Generation Images
generator.py
-
Generate curated MetFaces images without truncation (Fig.10 left)
$ python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
truncation 없이 이미지를 생성하였다. truncation metric은 immediate latent space $W$에서 중요한 부분만을 catch하는 metric이다. 이 metric을 사용하면 다양성은 떨어지지만, quality가 높은 이미지를 생성할 수 있다.
seed 85 seed 297 seed 849 -
Generate uncurated MetFaces images with truncation (Fig.12 upper left)
$ python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
-
Generate class conditional CIFAR-10 images (Fig.17 left, Car)
- stylegan2-ada에서는 conditional generation도 가능하다. (stylegan2는 x)
- class label을 또다른 mapping network에 넣어 embedding 한 후에 그 embedding을 $w$와 concate하는 방식으로 진행된다.
$ python generate.py --outdir=out --seeds=0-35 --class=1 \ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl
Generator.py
Code- pickle file은 세가지의 network,
G
,D
,G_ema
를 포함하고 있다.G_ema
는 Generator weight의 평균을 exponential moving한 것으로, EMA(Exponential Moving Average)의 방식을 활용하면 GAN을 안정적으로 학습할 수 있다.
- pickle file은 세가지의 network,
style_mixing.py
-
Style mixing example
$ python style_mixing.py --outdir=out --rows=85,100,75,458,1500 --cols=55,821,1789,293 \ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
-
Result
style_mixing.py
Code
Projecting images to latent space
Projector.py
projector.py
함수를 사용하여 원하는 이미지의 latent vector를 구할 수 있다.
$ python projector.py --outdir=out --target=~/mytargetimg.png \
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
projector.py
Code
위에서 생성한 latent vector를 바탕으로 이미지를 생성할 수도 있다.
$ python generate.py --outdir=out --projected-w=out/projected_w.npz \
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
target image | Projection image |
---|---|
Training new networks
# Train with custom dataset using 1 GPU.
# dry-run : 중간에 error가 없는지 확인하기 위한 용도
$ python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1 --dry-run
$ python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1
# Train class-conditional CIFAR-10 using 2 GPUs.
$ python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \\
--gpus=2 --cfg=cifar --cond=1
# Reproduce original StyleGAN2 config F.
$ python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \\
--gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug
# Transfer learn MetFaces from FFHQ using 4 GPUs.
$ python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \\
--gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10
Base configs (--cfg):
auto Automatically select reasonable defaults based on resolution
and GPU count. Good starting point for new datasets.
stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024.
paper256 Reproduce results for FFHQ and LSUN Cat at 256x256.
paper512 Reproduce results for BreCaHAD and AFHQ at 512x512.
paper1024 Reproduce results for MetFaces at 1024x1024.
cifar Reproduce results for CIFAR-10 at 32x32.
Transfer learning source networks (--resume):
ffhq256 FFHQ trained at 256x256 resolution.
ffhq512 FFHQ trained at 512x512 resolution.
ffhq1024 FFHQ trained at 1024x1024 resolution.
celebahq256 CelebA-HQ trained at 256x256 resolution.
lsundog256 LSUN Dog trained at 256x256 resolution.
<PATH or URL> Custom network pickle.
train.py
train.py
의main
function
setup_training_loop_kwargs
:train.py
에서 여러 parameter들을 정의하는 파트
main 함수에서는 subprocess_fn
함수를 호출한다. subprocess_fn
함수에서 gpu의 개수에 맞게 세팅을 조정한 후, training_loop
함수를 호출하여 학습을 본격적으로 시작 !
댓글남기기