[Paper Review] Generative Adversarial Networks(GAN) 논문 설명 및 pytorch 코드 구현
업데이트:
이번 포스팅에서는 대표적인 생성모델 GAN의 original 논문인 Generative Adversarial Networks(GAN)에 대해 알아본다.
GAN은 training data의 distribution을 따라하려는 generator model(G)과 이미지가 진짜 인지 가짜인지 판별하려는 discriminative model(D)의 minimax game이다. G와 D는 multilayer perceptron으로 구성이 되어 있으며, backpropgagtion으로 학습이 된다.
1. Introduction
GAN - minimax two player game
- Generative Adversarial Networks (GAN)은 adversarial process를 적용한 생성모델이다. 주어진 입력 데이터와 유사한 데이터를 생성하는 것을 목표로 하며, Generator model과 Discriminative model이 경쟁하며 서로의 성능을 높여가는 과정으로 학습이 진행된다.
Generator(G)
모델은 위조 지폐를 만드는 사람들과 유사하며,Discriminator (D)
모델은 위조지폐를 발견하는 경찰과 유사하다. 생성자G
는 최대한 기존의 데이터(실제 지폐)와 유사한 지폐를 만들려고 노력하고, 판별자D
는 데이터 샘플이 모델 분포에서 왔는지(위조지폐), 실제 데이터 분포에서 왔는지(실제 지폐) 판별한다.- G는 가짜 Data를 잘 만들어서 D가 진짜와 가짜를 구별 못하게 하는 것이 목적이고, D는 진짜 Data와 G가 만들어낸 가짜 Data를 잘 구별하는 것이 목적
- 이렇게 D와 G가 서로 경쟁적으로(Adversarially) 학습을 하다보면, 실제로 서로에게 학습의 방향성을 제시해주게 되어
Unsupervised Learning
이 가능해짐
2. Adversarial Nets
Generator
Goal
: data $x$에 대한 generator의 distribution $p_g$ 학습-
이를 위해 (1) 우선 Gaussian과 같은 정규 분포에서 random noise
\[\boldsymbol{z} \sim p_{\boldsymbol{z}}(z)\]z
를 추출한 후 (2) Neural Network $G(z; \theta_g)$ 를 거쳐 Fake image인 $G(z)$를 생성 - $x = G(z)$ 는 $P_g(x)$라는 확률 분포에서 추출된 $x$라고 생각해도 무방함
- z vector가 존재하던 공간 :
latent space
Discriminator
- G가 fake image를 생성하고 나면, D는 Fake Image와 Real Image를 input으로 받은 후 Neural Network $D(x;\theta_d)$를 거쳐 0과 1사이의 값을 출력
- D가 가짜 이미지라고 판별을 하면 0과 가까운 숫자를 출력하고, 진짜 이미지라고 판별을 하면 1과 가까운 숫자를 출력
Adversarial nets
G
는 그럴싸한 생성 이미지를 만들어서D
를 속이려고 하고,D
는 진짜 이미지를 찾아내도록 적대적인 network를 구성한다.
G와 D의 minmax game
\[\min _{G} \max _{D} V(D, G)=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}(\boldsymbol{x})}[\log D(\boldsymbol{x})]+\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z})))\]✔ GAN 논문에서는 실제 구현한 알고리즘과 이론 상의 괴리가 있다. 이론적으로는 $\log (1-D(G(\boldsymbol{z}))$를 최소화하는 방향으로 증명을 하지만, 실제 알고리즘에서는 $\log (D(G(\boldsymbol{z}))$를 최대화하는 방향으로 학습을 시킨다.
3. Theoretical Results
- Discrimitive distribution (blue)
- Data generating distribution (black)
- Generative distribution (green)
- (그림) latent space에서 sampling된
z
가 생성 모델을 거쳐 $x = G(z)$로 mapping되는 과정을 그린 그림
만약에 D가 real/fake image를 잘 판별하도록 학습을 한다면, optimal D는 $D^*(x) = \frac{p_{\text {data }}(\boldsymbol{x})}{p_{\text {data }}(\boldsymbol{x})+p_{g}(\boldsymbol{x})}$로 근사되고,
G가 실제 data distribution을 잘 따라가도록 학습이 된다면, optimal G는 $p_g = p_{data}$가 된다.
⭐ D와 G가 모두 optimal 하다면 $D(\boldsymbol{x})=\frac{1}{2}$
훈련 알고리즘
훈련의 내부 loop에서 D를 학습하도록 최적화하는 것은 overfitting이 될 수 있기 때문에 이렇게 하지는 않고, 대신 D를 k번, G를 1번 optimize하는 식으로 번갈아 학습을 시킨다. 이런 과정으로 학습을 시키면 G가 충분히 느리게 변화하기 때문에 D는 최적의 solution을 가질 수 있게 된다. (최근의 GAN 연구에서는 k를 1로 두기도 함)
3.1 Global Optimality of $p_g = p_{data}$
\[D_{G}^{*}(\boldsymbol{x})=\frac{p_{\text {data }}(\boldsymbol{x})}{p_{\text {data }}(\boldsymbol{x})+p_{g}(\boldsymbol{x})}\]Proposition 1. For $G$ fixed, the optimal discriminator $D$ is
\[\begin{aligned} C(G) &=\max _{D} V(G, D) \\ &=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}}\left[\log D_{G}^{*}(\boldsymbol{x})\right]+\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}\left[\log \left(1-D_{G}^{*}(G(\boldsymbol{z}))\right)\right] \\ &=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}}\left[\log D_{G}^{*}(\boldsymbol{x})\right]+\mathbb{E}_{\boldsymbol{x} \sim p_{g}}\left[\log \left(1-D_{G}^{*}(\boldsymbol{x})\right)\right] \\ &=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}}\left[\log \frac{p_{\text {data }}(\boldsymbol{x})}{P_{\text {data }}(\boldsymbol{x})+p_{g}(\boldsymbol{x})}\right]+\mathbb{E}_{\boldsymbol{x} \sim p_{g}}\left[\log \frac{p_{g}(\boldsymbol{x})}{p_{\text {data }}(\boldsymbol{x})+p_{g}(\boldsymbol{x})}\right]\\ &=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}}\left[\log \frac{ p_{\text {data }}(\boldsymbol{x})} {\frac{P_{\text {data }}(\boldsymbol{x})+p_{g}(\boldsymbol{x})}{2}}\right]+\mathbb{E}_{\boldsymbol{x} \sim p_{g}}\left[\log \frac{p_{g}(\boldsymbol{x})}{\frac{p_{\text {data }}(\boldsymbol{x})+p_{g}(\boldsymbol{x})}{2}}\right]-\log (4)\\ &=D_{K L}\left(p_{\text {data }} \| \frac{p_{\text {data }}+p_{g}}{2}\right)+D_{K L}\left(p_{g} \| \frac{p_{\text {data }}+p_{g}}{2}\right)-\log (4)\\ &=2 \cdot J S D\left(p_{\text {data }} \| p_{g}\right)-\log (4) \end{aligned}\]Theorem 1. The global minimum of the virtual training criterion $C(G)$ is achieved if and only if $p_g = p_{data}$. At that point, $C(G)$ achieves the value $− log 4$.
Jensen-Shannon Divergence는 늘 0 이상이다. 따라서 $C(G)$ 값이 최소가 될 때는 $p_g = p_{data}$ !
⭐ (정리)
minimax game
의 global optimum에 도달하면, D는 $D_G(x)=\frac{p_{data}(x)}{p_{data}(x)+p_{g}(x)}$, G는 $p_{data}(x)=p_{g}(x)$
4. Experiments
최근 연구들에서 generator와 discriminator의 모델들은 이 논문의 G와 D model을 따르지 않는다. (모델이 보다 좋은 quality의 이미지를 생성할 수 있도록 architecture를 발전시킴. ex. BigGAN, StyleGAN)
이 논문은 G와 D의 적대적 신경망을 통해 성능을 발전시켰다는 점에서 의의가 있는 논문이기 때문에 G와 D의 network에 대해서는 간단히만 서술하였다.
- Dataset : MNIST, CIFAR-10, TFD
- Generator net : linear activation과 sigmoid activation
- generator에서는 dropout을 사용하지 않았고, intermediate layer에 noise를 주지도 않았음
- Discriminator net : max out activation, dropout
5. Advantages and disadvantages
5.1 Advantages
- Markov chain을 피할 수 있고, backprop을 통해 gradient를 계산할 수 있음
- 학습 중에 추론이 필요 없음
- 모델에 다양한 function들이 추가될 수 있음
5.2 Disdvantages
- $p_g$의 명시적인 표현이 없음
- D와 G는 잘 동기화 되어야함
- 만약, G가 D를 업데이트 하지 않고 너무 많이 훈련을 한다면
Helvetica senario
(G가 z의 값을 너무나도 많이 축소해서 $p_{data}$의 다양성이 사라지는 것)에 빠질 수 있음)
- 만약, G가 D를 업데이트 하지 않고 너무 많이 훈련을 한다면
6. Conclusions and Future work
- A conditional generative model $p(x / c)$ can be obtained by adding $c$ as input to both G and D.
- Learned approximate inference can be performed by training an auxiliary network to predict $z$ given $x$. This is similar to the inference net trained by the wake-sleep algorithm but with the advantage that the inference net may be trained for a fixed generator net after the generator net has finished training.
- One can approximately model all conditionals $p(x_S / x_{S/})$ where S is a subset of the indices of x by training a family of conditional models that share parameters. Essentially, one can use adversarial nets to implement a stochastic extension of the deterministic MP-DB.
- Semi-supervised learning: features from the discriminator or inference net could improve performance of classifiers when limited labeled data is available.
- Efficiency improvements: training could be accelerated greatly by divising better methods for coordinating G and D or determining better distributions to sample z from during training.
7. Code Practice
import torch
import torch.nn as nn
import numpy as np
# for MNIST data
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import matplotlib.pyplot as plt
Preparing data
Loading MNIST Data
- 이번 예제에서는 실제 MNIST training images를 활용하여 MNIST 숫자를 생성하는 GAN model을 만들 예정입니다.
- How to Build a Streaming DataLoader with PyTorch
# download the MINST data
batch_size = 64
transforms_train = transforms.Compose([
transforms.Resize(28),
transforms.ToTensor(), # data를 pytorch의 tensor형식으로 바꿉니다
transforms.Normalize([0.5], [0.5]) # 픽셀값을 0 ~ 1에서 -1 ~ 1 로 바꿔줍니다.
])
train_dataset = datasets.MNIST(root="./dataset", train=True, download=True, transform=transforms_train)
# data를 batch size만큼만 가져오는 dataloader를 만듭니다.
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle=True, num_workers=4)
- 하나의 batch에 들어있는 mnist data를 출력해보았습니다.
images, labels = next(iter(dataloader))
img = torchvision.utils.make_grid(images)
img = img.numpy().transpose(1,2,0)
std = [0.5,0.5,0.5]
mean = [0.5,0.5,0.5]
img = img*std+mean
print([labels[i] for i in range(64)])
plt.imshow(img)
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:477: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
cpuset_checked))
[tensor(5), tensor(6), tensor(2), tensor(9), tensor(0), tensor(9), tensor(4), tensor(6), tensor(9), tensor(6), tensor(8), tensor(9), tensor(1), tensor(7), tensor(3), tensor(9), tensor(2), tensor(8), tensor(9), tensor(3), tensor(4), tensor(6), tensor(2), tensor(8), tensor(8), tensor(4), tensor(8), tensor(4), tensor(6), tensor(2), tensor(3), tensor(0), tensor(2), tensor(3), tensor(8), tensor(2), tensor(4), tensor(9), tensor(2), tensor(6), tensor(7), tensor(0), tensor(3), tensor(1), tensor(2), tensor(5), tensor(0), tensor(5), tensor(3), tensor(2), tensor(0), tensor(4), tensor(6), tensor(6), tensor(8), tensor(2), tensor(7), tensor(5), tensor(4), tensor(9), tensor(4), tensor(5), tensor(0), tensor(7)]
<matplotlib.image.AxesImage at 0x7fa5802841d0>
# image
channels = 1
img_size = 28
img_shape = (channels, img_size, img_size)
Build Model
Generator
- 생성자는 random vector
z
를 입력받아 가짜 이미지를 출력하는 함수입니다. 여기서 z는 정규분포(Normal Distribution)에서 무작위로 추출한 값으로, z vector가 존재하는 공간을 잠재공간(latent space)라고 부릅니다.- 이 튜토리얼에서는 잠재공간의 크기를 100으로 뒀으며, 잠재공간의 크기에는 제한이 없으나 나타내려고 하는 대상의 정보를 충분히 담을 수 있을 만큼 커야합니다.
-
즉, 생성자는 단순한 분포에서 사람 얼굴 이미지와 같은 복잡한 분포로 mapping하는 함수라고 볼 수 있습니다.
- 생성자에 충분히 많은 매개변수를 확보하기 위해 여러개의 layer를 쌓아서 생성자를 만들었습니다.
- 참고
# dimensionality of the latent space
# latent vector를 추출하기 위한 noise 분포의 dimension (정규분포를 따름)
latent_dim = 100
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
def block(input_dim, output_dim, normalize=True):
layers = [nn.Linear(input_dim, output_dim)]
if normalize:
layers.append(nn.BatchNorm1d(output_dim, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
# generater의 model은 여러개의 block을 쌓아서 만들어짐
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, z):
# z : input noise vector
img = self.model(z)
img = img.view(img.size(0), *img_shape)
return img
Discriminator
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
# 이미지에 대한 판별 결과를 반환
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
Loss Function & Optimizer
- 손실 함수로는 Binary Cross Entropy를, 최적화 함수로는 Adam을 사용합니다.
''' Hyper parameter '''
# learning rate
lr = 0.0002
# decay of first order momentum of gradient
b1 = 0.5
b2 = 0.999
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
# Loss function
adversarial_loss = nn.BCELoss()
# Adam Optimizer
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
# GPU
cuda = True if torch.cuda.is_available() else False
if cuda :
generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
Training
- GAN model에서는 근사적인 추론이나 Markov chains을 사용하지 않고, back-propagation만을 이용하여 gradient를 업데이트합니다.
import time
# number of epochs of training
n_epochs = 200
# interval between image samples
sample_interval = 2000
start_time = time.time()
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
for epoch in range(n_epochs):
for i, (imgs, _) in enumerate(dataloader):
# Adversarial ground truths
## 실제 이미지는 1로, 가짜 이미지는 0으로 label됩니다.
real = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)
# Configure input
real_imgs = Variable(imgs.type(Tensor))
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
# Generate a batch of images
## random sampling한 값인 z를 생성자에 넣어 이미지를 생성합니다.
generated_imgs = generator(z)
# Loss measures generator's ability to fool the discriminator
## 생성된 이미지를 discriminator가 판별하게 한 후, loss값을 계산합니다.
g_loss = adversarial_loss(discriminator(generated_imgs), real)
# 생성자(generator) 업데이트
g_loss.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Measure discriminator's ability to classify real from generated samples
## 실제 이미지는 real(1)로, 가짜 이미지는 fake(0)으로 판별하도록 계산합니다.
real_loss = adversarial_loss(discriminator(real_imgs), real)
fake_loss = adversarial_loss(discriminator(generated_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
# 판별자(discriminator) 업데이트
d_loss.backward()
optimizer_D.step()
done = epoch * len(dataloader) + i
if done % sample_interval == 0:
# 생성된 이미지 중에서 25개만 선택하여 5 X 5 격자 이미지에 출력
save_image(generated_imgs.data[:25], f"data{epoch}.png", nrow=5, normalize=True)
# 하나의 epoch이 끝날 때마다 로그(log) 출력
print(f"[Epoch {epoch}/{n_epochs}] [D loss: {d_loss.item():.6f}] [G loss: {g_loss.item():.6f}] [Elapsed time: {time.time() - start_time:.2f}s]")
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:477: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
cpuset_checked))
[Epoch 0/200] [D loss: 0.458743] [G loss: 1.942712] [Elapsed time: 14.48s]
[Epoch 1/200] [D loss: 0.390978] [G loss: 1.275545] [Elapsed time: 28.62s]
[Epoch 2/200] [D loss: 0.265740] [G loss: 1.080804] [Elapsed time: 43.18s]
[Epoch 3/200] [D loss: 0.891665] [G loss: 4.901868] [Elapsed time: 57.52s]
[Epoch 4/200] [D loss: 0.476425] [G loss: 0.863494] [Elapsed time: 71.60s]
[Epoch 5/200] [D loss: 0.454903] [G loss: 0.576118] [Elapsed time: 85.74s]
[Epoch 6/200] [D loss: 0.133300] [G loss: 2.433179] [Elapsed time: 100.24s]
[Epoch 7/200] [D loss: 0.094781] [G loss: 2.981544] [Elapsed time: 114.45s]
[Epoch 8/200] [D loss: 0.181251] [G loss: 3.181638] [Elapsed time: 129.06s]
[Epoch 9/200] [D loss: 0.268062] [G loss: 1.650167] [Elapsed time: 143.39s]
[Epoch 10/200] [D loss: 0.148337] [G loss: 1.861697] [Elapsed time: 157.57s]
[Epoch 11/200] [D loss: 0.336009] [G loss: 6.562683] [Elapsed time: 171.61s]
[Epoch 12/200] [D loss: 0.116051] [G loss: 3.702429] [Elapsed time: 186.11s]
[Epoch 13/200] [D loss: 0.330651] [G loss: 2.613504] [Elapsed time: 200.35s]
[Epoch 14/200] [D loss: 0.194057] [G loss: 2.378803] [Elapsed time: 214.38s]
[Epoch 15/200] [D loss: 0.627167] [G loss: 6.705003] [Elapsed time: 229.09s]
[Epoch 16/200] [D loss: 0.317059] [G loss: 1.656479] [Elapsed time: 243.51s]
[Epoch 17/200] [D loss: 0.201614] [G loss: 1.418605] [Elapsed time: 257.77s]
[Epoch 18/200] [D loss: 0.116628] [G loss: 3.220607] [Elapsed time: 271.63s]
[Epoch 19/200] [D loss: 0.227662] [G loss: 1.625189] [Elapsed time: 286.03s]
[Epoch 20/200] [D loss: 0.215400] [G loss: 2.033691] [Elapsed time: 300.15s]
[Epoch 21/200] [D loss: 0.107763] [G loss: 3.114661] [Elapsed time: 315.20s]
[Epoch 22/200] [D loss: 0.042492] [G loss: 3.386240] [Elapsed time: 329.06s]
[Epoch 23/200] [D loss: 0.058856] [G loss: 5.231410] [Elapsed time: 343.38s]
[Epoch 24/200] [D loss: 0.044925] [G loss: 4.694878] [Elapsed time: 357.78s]
[Epoch 25/200] [D loss: 0.147997] [G loss: 2.623037] [Elapsed time: 371.85s]
[Epoch 26/200] [D loss: 0.097325] [G loss: 2.946077] [Elapsed time: 386.36s]
[Epoch 27/200] [D loss: 0.205090] [G loss: 4.296311] [Elapsed time: 400.57s]
[Epoch 28/200] [D loss: 0.049005] [G loss: 5.919545] [Elapsed time: 414.36s]
[Epoch 29/200] [D loss: 3.039774] [G loss: 13.492342] [Elapsed time: 428.89s]
[Epoch 30/200] [D loss: 0.067113] [G loss: 2.853446] [Elapsed time: 443.79s]
[Epoch 31/200] [D loss: 0.191761] [G loss: 3.105765] [Elapsed time: 457.95s]
[Epoch 32/200] [D loss: 0.141423] [G loss: 2.469535] [Elapsed time: 471.96s]
[Epoch 33/200] [D loss: 0.097303] [G loss: 2.226081] [Elapsed time: 485.90s]
[Epoch 34/200] [D loss: 0.063614] [G loss: 3.540272] [Elapsed time: 500.08s]
[Epoch 35/200] [D loss: 0.158653] [G loss: 1.822363] [Elapsed time: 514.32s]
[Epoch 36/200] [D loss: 0.075912] [G loss: 2.991791] [Elapsed time: 528.57s]
[Epoch 37/200] [D loss: 0.030585] [G loss: 4.666013] [Elapsed time: 542.33s]
[Epoch 38/200] [D loss: 0.211738] [G loss: 2.987039] [Elapsed time: 556.73s]
[Epoch 39/200] [D loss: 0.090581] [G loss: 4.426782] [Elapsed time: 570.63s]
[Epoch 40/200] [D loss: 0.049283] [G loss: 3.451766] [Elapsed time: 585.06s]
[Epoch 41/200] [D loss: 0.168959] [G loss: 3.365350] [Elapsed time: 599.09s]
[Epoch 42/200] [D loss: 0.055633] [G loss: 3.116106] [Elapsed time: 613.63s]
[Epoch 43/200] [D loss: 0.081555] [G loss: 4.972229] [Elapsed time: 627.88s]
[Epoch 44/200] [D loss: 0.107015] [G loss: 4.608972] [Elapsed time: 641.69s]
[Epoch 45/200] [D loss: 0.093549] [G loss: 3.470685] [Elapsed time: 656.27s]
[Epoch 46/200] [D loss: 0.040654] [G loss: 4.167437] [Elapsed time: 670.29s]
[Epoch 47/200] [D loss: 0.067234] [G loss: 4.105668] [Elapsed time: 684.37s]
[Epoch 48/200] [D loss: 0.193317] [G loss: 4.751004] [Elapsed time: 698.28s]
[Epoch 49/200] [D loss: 0.104301] [G loss: 3.192440] [Elapsed time: 712.30s]
[Epoch 50/200] [D loss: 0.138632] [G loss: 2.988059] [Elapsed time: 726.25s]
[Epoch 51/200] [D loss: 0.163484] [G loss: 4.701645] [Elapsed time: 740.59s]
[Epoch 52/200] [D loss: 0.211467] [G loss: 1.915079] [Elapsed time: 754.83s]
[Epoch 53/200] [D loss: 0.032364] [G loss: 6.920775] [Elapsed time: 769.31s]
[Epoch 54/200] [D loss: 0.135302] [G loss: 3.067870] [Elapsed time: 783.61s]
[Epoch 55/200] [D loss: 0.130398] [G loss: 5.209129] [Elapsed time: 797.33s]
[Epoch 56/200] [D loss: 0.235790] [G loss: 5.278724] [Elapsed time: 811.48s]
[Epoch 57/200] [D loss: 0.045859] [G loss: 3.403255] [Elapsed time: 825.54s]
[Epoch 58/200] [D loss: 0.063515] [G loss: 4.804600] [Elapsed time: 839.19s]
[Epoch 59/200] [D loss: 0.033467] [G loss: 3.569349] [Elapsed time: 853.37s]
[Epoch 60/200] [D loss: 0.144997] [G loss: 3.207227] [Elapsed time: 867.60s]
[Epoch 61/200] [D loss: 0.036432] [G loss: 3.991818] [Elapsed time: 881.62s]
[Epoch 62/200] [D loss: 0.072093] [G loss: 3.450863] [Elapsed time: 895.77s]
[Epoch 63/200] [D loss: 0.042299] [G loss: 4.151020] [Elapsed time: 909.89s]
[Epoch 64/200] [D loss: 0.191555] [G loss: 4.481242] [Elapsed time: 923.89s]
[Epoch 65/200] [D loss: 0.046976] [G loss: 3.281461] [Elapsed time: 938.14s]
[Epoch 66/200] [D loss: 0.157537] [G loss: 4.302243] [Elapsed time: 952.55s]
[Epoch 67/200] [D loss: 0.105616] [G loss: 3.489889] [Elapsed time: 966.54s]
[Epoch 68/200] [D loss: 0.088697] [G loss: 4.909249] [Elapsed time: 980.85s]
[Epoch 69/200] [D loss: 0.084548] [G loss: 2.747079] [Elapsed time: 994.79s]
[Epoch 70/200] [D loss: 0.105164] [G loss: 4.200035] [Elapsed time: 1008.93s]
[Epoch 71/200] [D loss: 0.169672] [G loss: 2.961955] [Elapsed time: 1022.98s]
[Epoch 72/200] [D loss: 0.017534] [G loss: 4.261449] [Elapsed time: 1037.11s]
[Epoch 73/200] [D loss: 0.205933] [G loss: 1.466237] [Elapsed time: 1050.83s]
[Epoch 74/200] [D loss: 0.181587] [G loss: 2.550884] [Elapsed time: 1065.09s]
[Epoch 75/200] [D loss: 0.189081] [G loss: 1.525171] [Elapsed time: 1079.47s]
[Epoch 76/200] [D loss: 0.170224] [G loss: 4.031328] [Elapsed time: 1093.65s]
[Epoch 77/200] [D loss: 0.210215] [G loss: 1.816445] [Elapsed time: 1107.68s]
[Epoch 78/200] [D loss: 0.263621] [G loss: 1.674155] [Elapsed time: 1122.18s]
[Epoch 79/200] [D loss: 0.085889] [G loss: 3.241511] [Elapsed time: 1136.09s]
[Epoch 80/200] [D loss: 0.301190] [G loss: 2.394245] [Elapsed time: 1150.36s]
[Epoch 81/200] [D loss: 0.128341] [G loss: 3.396279] [Elapsed time: 1164.71s]
[Epoch 82/200] [D loss: 0.191707] [G loss: 2.869485] [Elapsed time: 1178.81s]
[Epoch 83/200] [D loss: 0.123201] [G loss: 2.764148] [Elapsed time: 1192.81s]
[Epoch 84/200] [D loss: 0.044546] [G loss: 6.066619] [Elapsed time: 1207.00s]
[Epoch 85/200] [D loss: 0.224089] [G loss: 1.769315] [Elapsed time: 1221.24s]
[Epoch 86/200] [D loss: 0.125998] [G loss: 2.424690] [Elapsed time: 1235.79s]
[Epoch 87/200] [D loss: 0.257117] [G loss: 1.532984] [Elapsed time: 1250.00s]
[Epoch 88/200] [D loss: 0.090517] [G loss: 3.309340] [Elapsed time: 1264.04s]
[Epoch 89/200] [D loss: 0.171175] [G loss: 3.907306] [Elapsed time: 1278.22s]
[Epoch 90/200] [D loss: 0.168205] [G loss: 2.541603] [Elapsed time: 1292.23s]
[Epoch 91/200] [D loss: 0.415478] [G loss: 0.807221] [Elapsed time: 1306.15s]
[Epoch 92/200] [D loss: 0.222719] [G loss: 5.097611] [Elapsed time: 1320.32s]
[Epoch 93/200] [D loss: 0.065872] [G loss: 2.917621] [Elapsed time: 1334.55s]
[Epoch 94/200] [D loss: 0.200377] [G loss: 3.596398] [Elapsed time: 1349.03s]
[Epoch 95/200] [D loss: 0.413539] [G loss: 1.967691] [Elapsed time: 1363.17s]
[Epoch 96/200] [D loss: 0.229722] [G loss: 1.654135] [Elapsed time: 1377.75s]
[Epoch 97/200] [D loss: 0.187748] [G loss: 3.326653] [Elapsed time: 1391.98s]
[Epoch 98/200] [D loss: 0.250217] [G loss: 3.498106] [Elapsed time: 1406.14s]
[Epoch 99/200] [D loss: 0.071938] [G loss: 3.358513] [Elapsed time: 1420.34s]
[Epoch 100/200] [D loss: 0.315330] [G loss: 3.291533] [Elapsed time: 1433.96s]
[Epoch 101/200] [D loss: 0.274061] [G loss: 1.418579] [Elapsed time: 1448.00s]
[Epoch 102/200] [D loss: 0.106627] [G loss: 3.502706] [Elapsed time: 1462.21s]
[Epoch 103/200] [D loss: 0.086707] [G loss: 2.944184] [Elapsed time: 1476.08s]
[Epoch 104/200] [D loss: 0.286308] [G loss: 1.633663] [Elapsed time: 1490.62s]
[Epoch 105/200] [D loss: 0.383339] [G loss: 1.785617] [Elapsed time: 1504.78s]
[Epoch 106/200] [D loss: 0.153502] [G loss: 2.274015] [Elapsed time: 1518.63s]
[Epoch 107/200] [D loss: 0.071326] [G loss: 3.263655] [Elapsed time: 1532.88s]
[Epoch 108/200] [D loss: 0.066765] [G loss: 4.009148] [Elapsed time: 1546.92s]
[Epoch 109/200] [D loss: 0.174222] [G loss: 5.633485] [Elapsed time: 1560.80s]
[Epoch 110/200] [D loss: 0.048926] [G loss: 4.407144] [Elapsed time: 1574.86s]
[Epoch 111/200] [D loss: 0.145895] [G loss: 2.280872] [Elapsed time: 1589.18s]
[Epoch 112/200] [D loss: 0.053878] [G loss: 3.833207] [Elapsed time: 1602.85s]
[Epoch 113/200] [D loss: 0.089428] [G loss: 2.974336] [Elapsed time: 1616.87s]
[Epoch 114/200] [D loss: 0.060831] [G loss: 4.322665] [Elapsed time: 1630.73s]
[Epoch 115/200] [D loss: 0.116165] [G loss: 2.789042] [Elapsed time: 1644.71s]
[Epoch 116/200] [D loss: 0.079149] [G loss: 4.856327] [Elapsed time: 1659.04s]
[Epoch 117/200] [D loss: 0.150054] [G loss: 3.824501] [Elapsed time: 1673.54s]
[Epoch 118/200] [D loss: 0.223480] [G loss: 3.520049] [Elapsed time: 1687.67s]
[Epoch 119/200] [D loss: 0.112249] [G loss: 3.149460] [Elapsed time: 1702.19s]
[Epoch 120/200] [D loss: 0.042854] [G loss: 3.803446] [Elapsed time: 1716.72s]
[Epoch 121/200] [D loss: 0.149308] [G loss: 3.718749] [Elapsed time: 1730.79s]
[Epoch 122/200] [D loss: 0.183937] [G loss: 4.126133] [Elapsed time: 1744.93s]
[Epoch 123/200] [D loss: 0.235792] [G loss: 3.088817] [Elapsed time: 1758.81s]
[Epoch 124/200] [D loss: 0.068913] [G loss: 3.624157] [Elapsed time: 1772.78s]
[Epoch 125/200] [D loss: 0.135954] [G loss: 3.560077] [Elapsed time: 1786.69s]
[Epoch 126/200] [D loss: 0.379730] [G loss: 2.913143] [Elapsed time: 1800.80s]
[Epoch 127/200] [D loss: 0.133501] [G loss: 5.789097] [Elapsed time: 1815.09s]
[Epoch 128/200] [D loss: 0.309985] [G loss: 4.953413] [Elapsed time: 1828.91s]
[Epoch 129/200] [D loss: 0.261843] [G loss: 6.617579] [Elapsed time: 1842.97s]
[Epoch 130/200] [D loss: 0.255948] [G loss: 3.763939] [Elapsed time: 1856.66s]
[Epoch 131/200] [D loss: 0.185810] [G loss: 2.438904] [Elapsed time: 1871.22s]
[Epoch 132/200] [D loss: 0.093917] [G loss: 3.691961] [Elapsed time: 1885.59s]
[Epoch 133/200] [D loss: 0.628793] [G loss: 7.800147] [Elapsed time: 1899.49s]
[Epoch 134/200] [D loss: 0.182042] [G loss: 3.037481] [Elapsed time: 1913.25s]
[Epoch 135/200] [D loss: 0.202837] [G loss: 2.250508] [Elapsed time: 1927.17s]
[Epoch 136/200] [D loss: 0.162933] [G loss: 3.666778] [Elapsed time: 1941.18s]
[Epoch 137/200] [D loss: 0.200897] [G loss: 3.234807] [Elapsed time: 1955.14s]
[Epoch 138/200] [D loss: 0.191355] [G loss: 3.124365] [Elapsed time: 1969.11s]
[Epoch 139/200] [D loss: 0.346639] [G loss: 2.832288] [Elapsed time: 1983.40s]
[Epoch 140/200] [D loss: 0.175804] [G loss: 2.732238] [Elapsed time: 1997.45s]
[Epoch 141/200] [D loss: 0.102703] [G loss: 2.937870] [Elapsed time: 2011.73s]
[Epoch 142/200] [D loss: 0.125220] [G loss: 3.047290] [Elapsed time: 2026.15s]
[Epoch 143/200] [D loss: 0.254521] [G loss: 3.729692] [Elapsed time: 2040.00s]
[Epoch 144/200] [D loss: 0.270499] [G loss: 2.384753] [Elapsed time: 2054.47s]
[Epoch 145/200] [D loss: 0.294633] [G loss: 2.181049] [Elapsed time: 2068.42s]
[Epoch 146/200] [D loss: 0.177180] [G loss: 2.044919] [Elapsed time: 2082.32s]
[Epoch 147/200] [D loss: 0.056352] [G loss: 3.474565] [Elapsed time: 2096.29s]
[Epoch 148/200] [D loss: 0.141344] [G loss: 1.978526] [Elapsed time: 2110.62s]
[Epoch 149/200] [D loss: 0.120316] [G loss: 2.453007] [Elapsed time: 2124.73s]
[Epoch 150/200] [D loss: 0.138787] [G loss: 2.114994] [Elapsed time: 2138.86s]
[Epoch 151/200] [D loss: 0.286204] [G loss: 3.877319] [Elapsed time: 2152.94s]
[Epoch 152/200] [D loss: 0.203801] [G loss: 3.630583] [Elapsed time: 2166.96s]
[Epoch 153/200] [D loss: 0.074083] [G loss: 3.694077] [Elapsed time: 2181.11s]
[Epoch 154/200] [D loss: 0.140657] [G loss: 2.875687] [Elapsed time: 2195.48s]
[Epoch 155/200] [D loss: 0.157358] [G loss: 2.242649] [Elapsed time: 2209.58s]
[Epoch 156/200] [D loss: 0.200039] [G loss: 2.878366] [Elapsed time: 2224.22s]
[Epoch 157/200] [D loss: 0.325206] [G loss: 2.763648] [Elapsed time: 2238.28s]
[Epoch 158/200] [D loss: 0.113586] [G loss: 4.866838] [Elapsed time: 2252.13s]
[Epoch 159/200] [D loss: 0.252594] [G loss: 3.457329] [Elapsed time: 2266.20s]
[Epoch 160/200] [D loss: 0.073561] [G loss: 2.482955] [Elapsed time: 2280.43s]
[Epoch 161/200] [D loss: 0.245648] [G loss: 2.080934] [Elapsed time: 2294.87s]
[Epoch 162/200] [D loss: 0.095687] [G loss: 2.422149] [Elapsed time: 2309.49s]
[Epoch 163/200] [D loss: 0.187337] [G loss: 2.427534] [Elapsed time: 2323.50s]
[Epoch 164/200] [D loss: 0.257954] [G loss: 3.781885] [Elapsed time: 2337.57s]
[Epoch 165/200] [D loss: 0.432916] [G loss: 7.190818] [Elapsed time: 2351.79s]
[Epoch 166/200] [D loss: 0.221108] [G loss: 2.897677] [Elapsed time: 2366.26s]
[Epoch 167/200] [D loss: 0.092894] [G loss: 4.951985] [Elapsed time: 2380.21s]
[Epoch 168/200] [D loss: 0.113605] [G loss: 3.088411] [Elapsed time: 2394.95s]
[Epoch 169/200] [D loss: 0.224174] [G loss: 3.478052] [Elapsed time: 2409.36s]
[Epoch 170/200] [D loss: 0.381373] [G loss: 4.609982] [Elapsed time: 2423.26s]
[Epoch 171/200] [D loss: 0.271674] [G loss: 3.405561] [Elapsed time: 2437.54s]
[Epoch 172/200] [D loss: 0.256950] [G loss: 2.007202] [Elapsed time: 2451.84s]
[Epoch 173/200] [D loss: 0.407985] [G loss: 4.499944] [Elapsed time: 2465.98s]
[Epoch 174/200] [D loss: 0.190228] [G loss: 4.008648] [Elapsed time: 2480.12s]
[Epoch 175/200] [D loss: 0.128769] [G loss: 4.746367] [Elapsed time: 2494.61s]
[Epoch 176/200] [D loss: 0.130729] [G loss: 2.741333] [Elapsed time: 2509.04s]
[Epoch 177/200] [D loss: 0.147498] [G loss: 2.339512] [Elapsed time: 2523.38s]
[Epoch 178/200] [D loss: 0.286283] [G loss: 3.510903] [Elapsed time: 2537.71s]
[Epoch 179/200] [D loss: 0.162441] [G loss: 2.513192] [Elapsed time: 2551.83s]
[Epoch 180/200] [D loss: 0.101355] [G loss: 6.710031] [Elapsed time: 2565.97s]
[Epoch 181/200] [D loss: 0.160138] [G loss: 1.960122] [Elapsed time: 2580.97s]
[Epoch 182/200] [D loss: 0.021202] [G loss: 4.239496] [Elapsed time: 2595.23s]
[Epoch 183/200] [D loss: 0.150069] [G loss: 2.767830] [Elapsed time: 2609.29s]
[Epoch 184/200] [D loss: 0.305992] [G loss: 3.022749] [Elapsed time: 2623.83s]
[Epoch 185/200] [D loss: 0.154886] [G loss: 2.489990] [Elapsed time: 2637.73s]
[Epoch 186/200] [D loss: 0.189738] [G loss: 2.579984] [Elapsed time: 2651.86s]
[Epoch 187/200] [D loss: 0.092674] [G loss: 2.880884] [Elapsed time: 2666.22s]
[Epoch 188/200] [D loss: 0.240747] [G loss: 2.788605] [Elapsed time: 2680.37s]
[Epoch 189/200] [D loss: 0.202244] [G loss: 2.370766] [Elapsed time: 2694.83s]
[Epoch 190/200] [D loss: 0.073636] [G loss: 3.369381] [Elapsed time: 2709.10s]
[Epoch 191/200] [D loss: 0.060109] [G loss: 3.445685] [Elapsed time: 2723.15s]
[Epoch 192/200] [D loss: 0.138614] [G loss: 3.958589] [Elapsed time: 2737.41s]
[Epoch 193/200] [D loss: 0.044754] [G loss: 4.216989] [Elapsed time: 2751.42s]
[Epoch 194/200] [D loss: 0.183014] [G loss: 3.495627] [Elapsed time: 2765.50s]
[Epoch 195/200] [D loss: 0.072254] [G loss: 2.941065] [Elapsed time: 2779.85s]
[Epoch 196/200] [D loss: 0.113577] [G loss: 4.764092] [Elapsed time: 2794.83s]
[Epoch 197/200] [D loss: 0.088866] [G loss: 2.900314] [Elapsed time: 2808.97s]
[Epoch 198/200] [D loss: 0.179538] [G loss: 2.701117] [Elapsed time: 2823.20s]
[Epoch 199/200] [D loss: 0.041067] [G loss: 3.463143] [Elapsed time: 2837.76s]
Image
- generate model이 학습됨에 따라 발전해나가는 모습을 출력해보았습니다.
epoch 0 | epoch 51 | epoch 100 | epoch 151 | epoch 198 |
---|---|---|---|---|
Reference
- https://hyeongminlee.github.io/post/gan001_gan/
- https://www.slideshare.net/ssuser5ac863/gan-77392547
- eriklindernoren/PyTorch-GAN
댓글남기기