加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
gan_cifar10.py 9.23 KB
一键复制 编辑 原始数据 按行查看 历史
caogang440 提交于 2017-11-28 14:24 . Fix the bug in gan_cifar10.py
import os, sys
sys.path.append(os.getcwd())
import time
import tflib as lib
import tflib.save_images
import tflib.mnist
import tflib.cifar10
import tflib.plot
import tflib.inception_score
import numpy as np
import torch
import torchvision
from torch import nn
from torch import autograd
from torch import optim
# Download CIFAR-10 (Python version) at
# https://www.cs.toronto.edu/~kriz/cifar.html and fill in the path to the
# extracted files here!
DATA_DIR = 'cifar-10-batches-py/'
if len(DATA_DIR) == 0:
raise Exception('Please specify path to data directory in gan_cifar.py!')
MODE = 'wgan-gp' # Valid options are dcgan, wgan, or wgan-gp
DIM = 128 # This overfits substantially; you're probably better off with 64
LAMBDA = 10 # Gradient penalty lambda hyperparameter
CRITIC_ITERS = 5 # How many critic iterations per generator iteration
BATCH_SIZE = 64 # Batch size
ITERS = 200000 # How many generator iterations to train for
OUTPUT_DIM = 3072 # Number of pixels in CIFAR10 (3*32*32)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
preprocess = nn.Sequential(
nn.Linear(128, 4 * 4 * 4 * DIM),
nn.BatchNorm2d(4 * 4 * 4 * DIM),
nn.ReLU(True),
)
block1 = nn.Sequential(
nn.ConvTranspose2d(4 * DIM, 2 * DIM, 2, stride=2),
nn.BatchNorm2d(2 * DIM),
nn.ReLU(True),
)
block2 = nn.Sequential(
nn.ConvTranspose2d(2 * DIM, DIM, 2, stride=2),
nn.BatchNorm2d(DIM),
nn.ReLU(True),
)
deconv_out = nn.ConvTranspose2d(DIM, 3, 2, stride=2)
self.preprocess = preprocess
self.block1 = block1
self.block2 = block2
self.deconv_out = deconv_out
self.tanh = nn.Tanh()
def forward(self, input):
output = self.preprocess(input)
output = output.view(-1, 4 * DIM, 4, 4)
output = self.block1(output)
output = self.block2(output)
output = self.deconv_out(output)
output = self.tanh(output)
return output.view(-1, 3, 32, 32)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
main = nn.Sequential(
nn.Conv2d(3, DIM, 3, 2, padding=1),
nn.LeakyReLU(),
nn.Conv2d(DIM, 2 * DIM, 3, 2, padding=1),
nn.LeakyReLU(),
nn.Conv2d(2 * DIM, 4 * DIM, 3, 2, padding=1),
nn.LeakyReLU(),
)
self.main = main
self.linear = nn.Linear(4*4*4*DIM, 1)
def forward(self, input):
output = self.main(input)
output = output.view(-1, 4*4*4*DIM)
output = self.linear(output)
return output
netG = Generator()
netD = Discriminator()
print netG
print netD
use_cuda = torch.cuda.is_available()
if use_cuda:
gpu = 0
if use_cuda:
netD = netD.cuda(gpu)
netG = netG.cuda(gpu)
one = torch.FloatTensor([1])
mone = one * -1
if use_cuda:
one = one.cuda(gpu)
mone = mone.cuda(gpu)
optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))
optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))
def calc_gradient_penalty(netD, real_data, fake_data):
# print "real_data: ", real_data.size(), fake_data.size()
alpha = torch.rand(BATCH_SIZE, 1)
alpha = alpha.expand(BATCH_SIZE, real_data.nelement()/BATCH_SIZE).contiguous().view(BATCH_SIZE, 3, 32, 32)
alpha = alpha.cuda(gpu) if use_cuda else alpha
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
if use_cuda:
interpolates = interpolates.cuda(gpu)
interpolates = autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = netD(interpolates)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).cuda(gpu) if use_cuda else torch.ones(
disc_interpolates.size()),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
return gradient_penalty
# For generating samples
def generate_image(frame, netG):
fixed_noise_128 = torch.randn(128, 128)
if use_cuda:
fixed_noise_128 = fixed_noise_128.cuda(gpu)
noisev = autograd.Variable(fixed_noise_128, volatile=True)
samples = netG(noisev)
samples = samples.view(-1, 3, 32, 32)
samples = samples.mul(0.5).add(0.5)
samples = samples.cpu().data.numpy()
lib.save_images.save_images(samples, './tmp/cifar10/samples_{}.jpg'.format(frame))
# For calculating inception score
def get_inception_score(G, ):
all_samples = []
for i in xrange(10):
samples_100 = torch.randn(100, 128)
if use_cuda:
samples_100 = samples_100.cuda(gpu)
samples_100 = autograd.Variable(samples_100, volatile=True)
all_samples.append(G(samples_100).cpu().data.numpy())
all_samples = np.concatenate(all_samples, axis=0)
all_samples = np.multiply(np.add(np.multiply(all_samples, 0.5), 0.5), 255).astype('int32')
all_samples = all_samples.reshape((-1, 3, 32, 32)).transpose(0, 2, 3, 1)
return lib.inception_score.get_inception_score(list(all_samples))
# Dataset iterator
train_gen, dev_gen = lib.cifar10.load(BATCH_SIZE, data_dir=DATA_DIR)
def inf_train_gen():
while True:
for images, target in train_gen():
# yield images.astype('float32').reshape(BATCH_SIZE, 3, 32, 32).transpose(0, 2, 3, 1)
yield images
gen = inf_train_gen()
preprocess = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
for iteration in xrange(ITERS):
start_time = time.time()
############################
# (1) Update D network
###########################
for p in netD.parameters(): # reset requires_grad
p.requires_grad = True # they are set to False below in netG update
for i in xrange(CRITIC_ITERS):
_data = gen.next()
netD.zero_grad()
# train with real
_data = _data.reshape(BATCH_SIZE, 3, 32, 32).transpose(0, 2, 3, 1)
real_data = torch.stack([preprocess(item) for item in _data])
if use_cuda:
real_data = real_data.cuda(gpu)
real_data_v = autograd.Variable(real_data)
# import torchvision
# filename = os.path.join("test_train_data", str(iteration) + str(i) + ".jpg")
# torchvision.utils.save_image(real_data, filename)
D_real = netD(real_data_v)
D_real = D_real.mean()
D_real.backward(mone)
# train with fake
noise = torch.randn(BATCH_SIZE, 128)
if use_cuda:
noise = noise.cuda(gpu)
noisev = autograd.Variable(noise, volatile=True) # totally freeze netG
fake = autograd.Variable(netG(noisev).data)
inputv = fake
D_fake = netD(inputv)
D_fake = D_fake.mean()
D_fake.backward(one)
# train with gradient penalty
gradient_penalty = calc_gradient_penalty(netD, real_data_v.data, fake.data)
gradient_penalty.backward()
# print "gradien_penalty: ", gradient_penalty
D_cost = D_fake - D_real + gradient_penalty
Wasserstein_D = D_real - D_fake
optimizerD.step()
############################
# (2) Update G network
###########################
for p in netD.parameters():
p.requires_grad = False # to avoid computation
netG.zero_grad()
noise = torch.randn(BATCH_SIZE, 128)
if use_cuda:
noise = noise.cuda(gpu)
noisev = autograd.Variable(noise)
fake = netG(noisev)
G = netD(fake)
G = G.mean()
G.backward(mone)
G_cost = -G
optimizerG.step()
# Write logs and save samples
lib.plot.plot('./tmp/cifar10/train disc cost', D_cost.cpu().data.numpy())
lib.plot.plot('./tmp/cifar10/time', time.time() - start_time)
lib.plot.plot('./tmp/cifar10/train gen cost', G_cost.cpu().data.numpy())
lib.plot.plot('./tmp/cifar10/wasserstein distance', Wasserstein_D.cpu().data.numpy())
# Calculate inception score every 1K iters
if False and iteration % 1000 == 999:
inception_score = get_inception_score(netG)
lib.plot.plot('./tmp/cifar10/inception score', inception_score[0])
# Calculate dev loss and generate samples every 100 iters
if iteration % 100 == 99:
dev_disc_costs = []
for images, _ in dev_gen():
images = images.reshape(BATCH_SIZE, 3, 32, 32).transpose(0, 2, 3, 1)
imgs = torch.stack([preprocess(item) for item in images])
# imgs = preprocess(images)
if use_cuda:
imgs = imgs.cuda(gpu)
imgs_v = autograd.Variable(imgs, volatile=True)
D = netD(imgs_v)
_dev_disc_cost = -D.mean().cpu().data.numpy()
dev_disc_costs.append(_dev_disc_cost)
lib.plot.plot('./tmp/cifar10/dev disc cost', np.mean(dev_disc_costs))
generate_image(iteration, netG)
# Save logs every 100 iters
if (iteration < 5) or (iteration % 100 == 99):
lib.plot.flush()
lib.plot.tick()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化