diff --git a/application/homework_20220821/group8/yunhe8/CycleGAN.ipynb.zip b/application/homework_20220821/group8/yunhe8/CycleGAN.ipynb.zip new file mode 100644 index 0000000000000000000000000000000000000000..19e979dbc4607f39be53e427bfa737b9bacc26ed Binary files /dev/null and b/application/homework_20220821/group8/yunhe8/CycleGAN.ipynb.zip differ diff --git a/application/homework_20220821/group8/yunhe8/CycleGAN.md b/application/homework_20220821/group8/yunhe8/CycleGAN.md new file mode 100644 index 0000000000000000000000000000000000000000..8824303f12b9b38fc6b0da654cd69b7723111c15 --- /dev/null +++ b/application/homework_20220821/group8/yunhe8/CycleGAN.md @@ -0,0 +1,759 @@ +# CycleGAN 实现Summer2Winter数据集的风格迁移 + +## CycleGAN产生背景 +更早提出的域迁移Domain Adaptation(画风迁移)模型有Pix2Pix,但Pix2Pix要求训练的数据必须是成对的,而现实生活中,要找到两个域(画风)中成对出现的图片是相当困难的。由此,CycleGAN诞生了,它只需要两种域的数据,而不需要他们有严格对应关系,从而获得了更广泛地应用。 + +![jpg](figure/p1.jpg) + +## 原理 +CycleGAN通过判别器和生成器的对抗训练,学习数据集图片的像素概率分布来生成图片。完成X域到Y域的图片风格迁移,既要拟合Y域图片的风格分布分布,又要保持X域图片对应的内容特征。打个比方,用草图风格的猫图片生成照片风格的猫图片时,要求生成的猫咪“即要活灵活现,又要姿势不变”。 + +![jpg](figure/p2.jpg) + +因为Pix2Pix是一个CGAN,所以,我们通过用X域图片当约束条件来限制Pix2Pix的输出Y域风格图片时保有X域图片的特征。而送入CycleGAN的两组(X域Y域)图片没有一一对应关系,即使我们将X域图片当成限制条件输入到一个CGAN中,也起不到限制模型输出保有X域图片特征的作用。因为,送入的两组图片完全是随机配在一起,CGAN学不到任何联系。因此,CycleGAN采取了一个绝妙的设计:通过添加“循环生成”并优化一致性损失来代替CGAN中使用的约束条件来限制生成器保有原域图片特征。 + +## 模型结构 + +![jpg](figure/p3.jpg) + +上半部份是生成器G和判别器Dy进行x2y的训练过程,下半部份是生成器F和判别器Dx进行y2x的训练过程。 +> 生成器由编码器、转换器和解码器构成 +- **编码**:第一步利用卷积神经网络从输入图象中提取特征。将图像压缩成256个64*64的特征向量。 +- **转换**:通过组合图像的不相近特征,将图像在DA域中的特征向量转换为DB域中的特征向量。作者使用了6层Reset模块,每个Reset模块是一个由两个卷积层构成的神经网络层,能够达到在转换时同时保留原始图像特征的目标。 +- **解码**:利用反卷积层(decovolution)完成从特征向量中还原出低级特征的工作,最后得到生成图像。 +>鉴别器将一张图像作为输入,并尝试预测其为原始图像或是生成器的输出图像。鉴别器本身属于卷积网络,需要从图像中提取特征,再通过添加产生一维输出的卷积层来确定提取的特征是否属于特定类别。 + +![jpg](figure/p4.jpg) + +> [1] J.-Y. Zhu, T. Park, P. Isola, and A. A. Efros, ‘Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks’. arXiv, Aug. 24, 2020. Accessed: Aug. 24, 2022. [Online]. Available: http://arxiv.org/abs/1703.10593 +> +> [2] K. He, X. Zhang, S. Ren, and J. Sun, ‘Deep Residual Learning for Image Recognition’, in *2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR)*, Las Vegas, NV, USA, Jun. 2016, pp. 770–778. doi: [10.1109/CVPR.2016.90](https://doi.org/10.1109/CVPR.2016.90). +> +> [3] https://gitee.com/mindspore/models/tree/master/research/cv/CycleGAN + +# 模型实现 + +## 1.下载summer2winter_yosemite数据集 + + +```python +import os +from mindvision import dataset + +if os.path.exists("./data/summer2winter_yosemite"): + print("Dataset summer2winter_yosemite already exists!") +else: + dl_path = "./data" + dl_url = "https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/summer2winter_yosemite.zip" + dl = dataset.DownLoad() + dl.download_and_extract_archive(url=dl_url, download_path=dl_path) +``` + + Dataset summer2winter_yosemite already exists! + + +## 2.部分数据集展示 + + +```python +import os +import matplotlib.pyplot as plt +import matplotlib.image as mpimg + +# trainA 数据集夏天图片展示 +path = './data/summer2winter_yosemite/trainA' +plt.figure(figsize=(10, 3), dpi=140) +for _, _, files in os.walk(path): + for i in range(1, 31): + plt.subplot(3, 10, i) + plt.axis("off") + img = mpimg.imread(os.path.join(path, files[i])) + plt.imshow(img) +plt.show() + +``` + + +![png](figure/output_5_0.png) + + + + +```python +# trainB 数据集冬天图片展示 +path = './data/summer2winter_yosemite/trainB' +plt.figure(figsize=(10, 3), dpi=140) +for _, _, files in os.walk(path): + for i in range(1, 31): + plt.subplot(3, 10, i) + plt.axis("off") + img = mpimg.imread(os.path.join(path, files[i])) + plt.imshow(img) +plt.show() +``` + + +![png](figure/output_6_0.png) + + + +## 3.CycleGAN部分代码 + +由于代码繁多,我们只展示模型训练中的部分重要函数代码: + +### 3.1 数据处理 + +框架代码中对训练数据核心预处理过程如下: + +```python +if args.use_random: + trans = [ + C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)), + C.RandomHorizontalFlip(prob=0.5), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + else: + trans = [ + C.Resize((image_size, image_size)), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] +``` + +### 3.2 判别器和生成器 + +![Cycle](figure/Cycle.png) + +生成器G将x、x^生成为y^,生成器F将y、y^生成为x^;判别器Dy与生成器(x2y)G进行对抗训练,判别器Dx与生成器(y2x)F进行对抗训练。 + + +```python +import mindspore.nn as nn + +# 判别器 +class Discriminator(nn.Cell): + # 初始化: + def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'): + super(Discriminator, self).__init__() + kernel_size = 4 # 卷积核的大小 + layers = [ + nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1), # 二维卷积 + nn.LeakyReLU(alpha) # 激活函数 + ] + nf_mult = ndf + for i in range(1, n_layers): + nf_mult_prev = nf_mult + nf_mult = min(2 ** i, 8) * ndf + layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1)) + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) * ndf + layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1)) + layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1)) + self.features = nn.SequentialCell(layers) + + # 构造函数: + def construct(self, x): + output = self.features(x) + return output + +# 生成器 +class Generator(nn.Cell): + # 初始化: + def __init__(self, G_A, G_B, use_identity=True): + super(Generator, self).__init__() + self.G_A = G_A + self.G_B = G_B + self.ones = ops.OnesLike() + self.use_identity = use_identity + + # 构造函数: + def construct(self, img_A, img_B): + """If use_identity, identity loss will be used.""" + fake_A = self.G_B(img_B) + fake_B = self.G_A(img_A) + rec_A = self.G_B(fake_B) + rec_B = self.G_A(fake_A) + if self.use_identity: + identity_A = self.G_B(img_A) + identity_B = self.G_A(img_B) + else: + identity_A = self.ones(img_A) + identity_B = self.ones(img_B) + return fake_A, fake_B, rec_A, rec_B, identity_A, identity_B +``` + +### 3.3 网络结构(以ResNet为例) +模型中的网络结构以GAN为主,也有一部分的卷积神经网络,为防止卷积网络退化,还加入了深度残差网络,原理如下图: + +![ResNet](figure/ResNet.png) + +残差网络给卷积网络添加了x的恒等映射,能比较加入卷积层之后的训练效果与加入卷积层之前的训练效果,如果模型在加入卷积层之前的训练效果已经足够好了,而加入卷积层后反而使模型的效果变差,那么我们便会使用原模型的输出结果作为整个模型的输出结果。 + + +```python +# ResNet网络结构 +class ResNetGenerator(nn.Cell): + def __init__(self, in_planes=3, ngf=64, n_layers=9, alpha=0.2, norm_mode='batch', dropout=False, + pad_mode="CONSTANT"): + super(ResNetGenerator, self).__init__() + self.conv_in = ConvNormReLU(in_planes, ngf, 7, 1, alpha, norm_mode, pad_mode=pad_mode) + self.down_1 = ConvNormReLU(ngf, ngf * 2, 3, 2, alpha, norm_mode) + self.down_2 = ConvNormReLU(ngf * 2, ngf * 4, 3, 2, alpha, norm_mode) + layers = [ResidualBlock(ngf * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * n_layers + self.residuals = nn.SequentialCell(layers) + self.up_2 = ConvTransposeNormReLU(ngf * 4, ngf * 2, 3, 2, alpha, norm_mode) + self.up_1 = ConvTransposeNormReLU(ngf * 2, ngf, 3, 2, alpha, norm_mode) + if pad_mode == "CONSTANT": + self.conv_out = nn.Conv2d(ngf, 3, kernel_size=7, stride=1, pad_mode='pad', padding=3) + else: + pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode) + conv = nn.Conv2d(ngf, 3, kernel_size=7, stride=1, pad_mode='pad') + self.conv_out = nn.SequentialCell([pad, conv]) + self.activate = ops.Tanh() + + def construct(self, x): + x = self.conv_in(x) + x = self.down_1(x) + x = self.down_2(x) + x = self.residuals(x) + x = self.up_2(x) + x = self.up_1(x) + output = self.conv_out(x) + return self.activate(output) + +# RisidualBlock网络 +class ResidualBlock(nn.Cell): + def __init__(self, dim, norm_mode='batch', dropout=False, pad_mode="CONSTANT"): + super(ResidualBlock, self).__init__() + self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode) + self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode, use_relu=False) + self.dropout = dropout + if dropout: + self.dropout = nn.Dropout(0.5) + + def construct(self, x): + out = self.conv1(x) + if self.dropout: + out = self.dropout(out) + out = self.conv2(out) + return x + out +``` + +### 3.4 损失函数 + +#### 3.4.1 连接网络 +因为CycleGAN结构上的特殊性,其损失是判别器和生成器的多输出形式,这就导致它和一般的分类网络不同。所以我们需要自定义WithLossCell类,将网络和Loss连接起来,代码如下: + + +```python +import mindspore.nn as nn + +class WithLossCell(nn.Cell): + def __init__(self, network): + super(WithLossCell, self).__init__(auto_prefix=False) + self.network = network + + def construct(self, img_A, img_B): + _, _, lg, _, _, _, _, _, _ = self.network(img_A, img_B) + return lg +``` + +#### 3.4.2 LOSS的总体结构 + +![LOSS](figure/LOSS.jpg) + +##### Adversarial(GAN) Loss + +$$\mathcal{L}_{GAN}(G,D_{Y},X,Y)=E_{y\sim_Pdata(y)}[log(D_{Y}(y))]+E_{x\sim_Pdata(x)}[log(1-D_{Y}G((x)))]$$ + +作用: +- 让生成的图像更加接近真实分布的图像,但不能保证生成风格迁移但内容不变的图像 + +生成器无法改变判别器对真实数据的看法,但通过训练,生成器尽可能生成假图片欺骗判别器 +>枯叶蝶不能改变鸟对真实叶子的看法,但能改变鸟对自己的看法 + + +```python +class GANLoss(nn.Cell): + def __init__(self, mode="lsgan", reduction='mean'): + super(GANLoss, self).__init__() + self.loss = None + self.ones = ops.OnesLike() + if mode == "lsgan": + self.loss = nn.MSELoss(reduction) + elif mode == "vanilla": + self.loss = BCEWithLogits(reduction) + else: + raise NotImplementedError(f'GANLoss {mode} not recognized, we support lsgan and vanilla.') + + def construct(self, predict, target): + target = ops.cast(target, ops.dtype(predict)) + target = self.ones(predict) * target + loss = self.loss(predict, target) + return loss +``` + +##### Generator Loss + +生成器的LOSS函数除了生成cycle loss外,还有额外添加了一项identity loss + +$$\mathcal{L}_{identity}(G,F)=E_{y\sim_Pdata(y)}[||G_{y}(y)-y||]+E_{x\sim_Pdata(x)}[||F_{x}(x)-x||]$$ +identity loss是指:**B输入A2B得到的数据应该还是尽可能的像B而不是变成别的东西** +其作用: +- 保证生成图片色调不变 +- 避免迁移过度,使得G生成的图像与原图像完全无关,而F又过度矫正G生成的图像 + + +在此次训练中,cycle loss 通过`lambda_A`与`lambda_B`控制,identity loss 除通过`lambda_A`与`lambda_B`控制外,还通过`lambda_idt`控制。 + + +```python +class GeneratorLoss(nn.Cell): + def __init__(self, args, generator, D_A, D_B): + super(GeneratorLoss, self).__init__() + self.lambda_A = args.lambda_A + self.lambda_B = args.lambda_B + self.lambda_idt = args.lambda_idt + self.use_identity = args.lambda_idt > 0 + self.dis_loss = GANLoss(args.gan_mode) + self.rec_loss = nn.L1Loss("mean") + self.generator = generator + self.D_A = D_A + self.D_B = D_B + self.true = Tensor(True, mstype.bool_) + + def construct(self, img_A, img_B): + fake_A, fake_B, rec_A, rec_B, identity_A, identity_B = self.generator(img_A, img_B) + loss_G_A = self.dis_loss(self.D_B(fake_B), self.true) + loss_G_B = self.dis_loss(self.D_A(fake_A), self.true) + loss_C_A = self.rec_loss(rec_A, img_A) * self.lambda_A + loss_C_B = self.rec_loss(rec_B, img_B) * self.lambda_B + if self.use_identity: + loss_idt_A = self.rec_loss(identity_A, img_A) * self.lambda_A * self.lambda_idt + loss_idt_B = self.rec_loss(identity_B, img_B) * self.lambda_B * self.lambda_idt + else: + loss_idt_A = 0 + loss_idt_B = 0 + loss_G = loss_G_A + loss_G_B + loss_C_A + loss_C_B + loss_idt_A + loss_idt_B + return (loss_G, fake_A, fake_B, loss_G_A, loss_G_B, loss_C_A, loss_C_B, loss_idt_A, loss_idt_B) + +class DiscriminatorLoss(nn.Cell): + def __init__(self, args, D_A, D_B): + super(DiscriminatorLoss, self).__init__() + self.D_A = D_A + self.D_B = D_B + self.false = Tensor(False, mstype.bool_) + self.true = Tensor(True, mstype.bool_) + self.dis_loss = GANLoss(args.gan_mode) + self.rec_loss = nn.L1Loss("mean") + + def construct(self, img_A, img_B, fake_A, fake_B): + D_fake_A = self.D_A(fake_A) + D_img_A = self.D_A(img_A) + D_fake_B = self.D_B(fake_B) + D_img_B = self.D_B(img_B) + loss_D_A = self.dis_loss(D_fake_A, self.false) + self.dis_loss(D_img_A, self.true) + loss_D_B = self.dis_loss(D_fake_B, self.false) + self.dis_loss(D_img_B, self.true) + loss_D = (loss_D_A + loss_D_B) * 0.5 + return loss_D +``` + + +## 4. 模型训练 + +### 4.1 train函数 + + +```python +import argparse +import mindspore as ms +import mindspore.nn as nn +from src.utils.args import get_args +from src.utils.reporter import Reporter +from src.utils.tools import get_lr, ImagePool, load_ckpt +from src.dataset.cyclegan_dataset import create_dataset +from src.models.losses import DiscriminatorLoss, GeneratorLoss +from src.models.cycle_gan import get_generator, get_discriminator, Generator, TrainOneStepG, TrainOneStepD + +ms.set_seed(1) + +def train(): + """Train function.""" + args = get_args("train") + if args.need_profiler: + from mindspore.profiler.profiling import Profiler + profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True) + ds = create_dataset(args) + G_A = get_generator(args) + G_B = get_generator(args) + D_A = get_discriminator(args) + D_B = get_discriminator(args) + if args.load_ckpt: + load_ckpt(args, G_A, G_B, D_A, D_B) + imgae_pool_A = ImagePool(args.pool_size) + imgae_pool_B = ImagePool(args.pool_size) + generator = Generator(G_A, G_B, args.lambda_idt > 0) + + loss_D = DiscriminatorLoss(args, D_A, D_B) + loss_G = GeneratorLoss(args, generator, D_A, D_B) + optimizer_G = nn.Adam(generator.trainable_params(), get_lr(args), beta1=args.beta1) + optimizer_D = nn.Adam(loss_D.trainable_params(), get_lr(args), beta1=args.beta1) + + net_G = TrainOneStepG(loss_G, generator, optimizer_G) + net_D = TrainOneStepD(loss_D, optimizer_D) + + data_loader = ds.create_dict_iterator() + if args.rank == 0: + reporter = Reporter(args) + reporter.info('==========start training===============') + for _ in range(args.max_epoch): + if args.rank == 0: + reporter.epoch_start() + for data in data_loader: + img_A = data["image_A"] + img_B = data["image_B"] + res_G = net_G(img_A, img_B) + fake_A = res_G[0] + fake_B = res_G[1] + res_D = net_D(img_A, img_B, imgae_pool_A.query(fake_A), imgae_pool_B.query(fake_B)) + if args.rank == 0: + reporter.step_end(res_G, res_D) + reporter.visualizer(img_A, img_B, fake_A, fake_B) + if args.rank == 0: + reporter.epoch_end(net_G) + if args.need_profiler: + profiler.analyse() + break + if args.rank == 0: + reporter.info('==========end training===============') + +``` + +### 4.2 sh命令训练 + +框架代码中数据预处理以及模型训练部分可以使用脚本命令执行。 + +我们可以使用如下命令进行默认参数的模型训练: + +```shell +sh ./scripts/run_train_standalone_gpu.sh +``` + +同时,可以修改shell文件指令中的args参数编辑不同训练的参数,因为我们使用的数据集是summer2winter_yosemite数据集,所以需要在shell文件指令中修改dataroot参数。另外,我们主要针对训练中model, batch_size和lambda_idt三个参数的训练效果进行研究,下面以model为ResNet,batch_size为1,lambda_idt为5的训练指令为例。 + +```shell +python train.py --platform GPU --device_id 0 --model ResNet --max_epoch 200 --batch_size 1 --lambda_idt 5 --dataroot ./data/summer2winter_yosemite/ +``` + +我们一共进行了8个模型的训练,具体信息如下表所示: + +| 后缀(_fake_A/B) | model | batch_size | lambda_idt | +| --------------- | ----------- | ---------- | ---------- | +| (Null) | ResNet | 1 | 0.5 | +| _b2 | ResNet | 2 | 0.5 | +| _b4 | ResNet | 4 | 0.5 | +| 005 | ResNet | 1 | 0.05 | +| 5 | ResNet | 1 | 5 | +| _DepthResNet | DepthResNet | 1 | 0.5 | +| _DepthResNet005 | DepthResNet | 1 | 0.05 | +| _DepthResNet5 | DepthResNet | 1 | 5 | + +## 5. 模型评估 + +### 5.1 eval函数 + + +```python +import os +from mindspore import Tensor +from src.models.cycle_gan import get_generator +from src.utils.args import get_args +from src.dataset.cyclegan_dataset import create_dataset +from src.utils.reporter import Reporter +from src.utils.tools import save_image, load_ckpt + + +def predict(): + """Predict function.""" + args = get_args("predict") + G_A = get_generator(args) + G_B = get_generator(args) + G_A.set_train(True) + G_B.set_train(True) + load_ckpt(args, G_A, G_B) + imgs_out = os.path.join(args.outputs_dir, "predict") + if not os.path.exists(imgs_out): + os.makedirs(imgs_out) + if not os.path.exists(os.path.join(imgs_out, "fake_A")): + os.makedirs(os.path.join(imgs_out, "fake_A")) + if not os.path.exists(os.path.join(imgs_out, "fake_B")): + os.makedirs(os.path.join(imgs_out, "fake_B")) + args.data_dir = 'testA' + ds = create_dataset(args) + reporter = Reporter(args) + reporter.start_predict("A to B") + for data in ds.create_dict_iterator(output_numpy=True): + img_A = Tensor(data["image"]) + path_A = str(data["image_name"][0], encoding="utf-8") + path_B = path_A[0:-4] + "_fake_B.jpg" + fake_B = G_A(img_A) + save_image(fake_B, os.path.join(imgs_out, "fake_B", path_B)) + save_image(img_A, os.path.join(imgs_out, "fake_B", path_A)) + reporter.info('save fake_B at %s', os.path.join(imgs_out, "fake_B", path_A)) + reporter.end_predict() + args.data_dir = 'testB' + ds = create_dataset(args) + reporter.dataset_size = args.dataset_size + reporter.start_predict("B to A") + for data in ds.create_dict_iterator(output_numpy=True): + img_B = Tensor(data["image"]) + path_B = str(data["image_name"][0], encoding="utf-8") + path_A = path_B[0:-4] + "_fake_A.jpg" + fake_A = G_B(img_B) + save_image(fake_A, os.path.join(imgs_out, "fake_A", path_A)) + save_image(img_B, os.path.join(imgs_out, "fake_A", path_B)) + reporter.info('save fake_A at %s', os.path.join(imgs_out, "fake_A", path_B)) + reporter.end_predict() + +``` + +### 5.2 sh命令评估 + +相似地,框架代码中结果评估可以使用脚本命令执行。 + +我们可以使用如下命令进行结果的评估: + +```shell +sh ./scripts/run_eval_gpu.sh +``` + +## 6. 结果展示与分析 + + +```python +import os +import matplotlib.pyplot as plt +from PIL import Image + +plt.figure("ResNet", figsize=(12, 6), dpi=200) # 图像窗口名称 +img1 = Image.open(os.path.join('figure/fakeB', '2011-06-03 21:27:20' + '.jpg')) +img2 = Image.open(os.path.join('figure/fakeB', '2011-06-03 21:27:20_fake_B005' + '.jpg')) +img3 = Image.open(os.path.join('figure/fakeB', '2011-06-03 21:27:20_fake_B' + '.jpg')) +img4 = Image.open(os.path.join('figure/fakeB', '2011-06-03 21:27:20_fake_B5' + '.jpg')) +img5 = Image.open(os.path.join('figure/fakeA', '2014-03-26 08:00:41' + '.jpg')) +img6 = Image.open(os.path.join('figure/fakeA', '2014-03-26 08:00:41_fake_A005' + '.jpg')) +img7 = Image.open(os.path.join('figure/fakeA', '2014-03-26 08:00:41_fake_A' + '.jpg')) +img8 = Image.open(os.path.join('figure/fakeA', '2014-03-26 08:00:41_fake_A5' + '.jpg')) +plt.subplot(2,4,1) +plt.title('Original') +plt.xticks([]),plt.yticks([]) +plt.imshow(img1) +plt.subplot(2,4,2) +plt.title('lambda_idt=0.05') +plt.xticks([]),plt.yticks([]) +plt.imshow(img2) +plt.subplot(2,4,3) +plt.title('lambda_idt=0.5') +plt.xticks([]),plt.yticks([]) +plt.imshow(img3) +plt.subplot(2,4,4) +plt.title('lambda_idt=5') +plt.xticks([]),plt.yticks([]) +plt.imshow(img4) +plt.subplot(2,4,5) +plt.title('Original') +plt.xticks([]),plt.yticks([]) +plt.imshow(img5) +plt.subplot(2,4,6) +plt.title('lambda_idt=0.05') +plt.xticks([]),plt.yticks([]) +plt.imshow(img6) +plt.subplot(2,4,7) +plt.title('lambda_idt=0.5') +plt.xticks([]),plt.yticks([]) +plt.imshow(img7) +plt.subplot(2,4,8) +plt.title('lambda_idt=5') +plt.xticks([]),plt.yticks([]) +plt.imshow(img8) +plt.suptitle("model: ResNet (batch_size=1)") +plt.show() +``` + + +![png](figure/output_27_0.png) + + + +由上图可知,改变`lambda_idt`的值,生成的图片整体效果相同,不同`lambda_idt`生成的图片的色调不同,而由于CycleGAN的loss不能直接衡量训练结果的好坏,训练结果的好坏由我们主观判断,故不同的训练集最优的`lambda_idt`也不同。 + + +```python +plt.figure("DepthResNet", figsize=(12, 6), dpi=200) # 图像窗口名称 +img1 = Image.open(os.path.join('figure/fakeB', '2011-06-03 21:27:20' + '.jpg')) +img2 = Image.open(os.path.join('figure/fakeB', '2011-06-03 21:27:20_fake_B_DepthResNet005' + '.jpg')) +img3 = Image.open(os.path.join('figure/fakeB', '2011-06-03 21:27:20_fake_B_DepthResNet' + '.jpg')) +img4 = Image.open(os.path.join('figure/fakeB', '2011-06-03 21:27:20_fake_B_DepthResNet5' + '.jpg')) +img5 = Image.open(os.path.join('figure/fakeA', '2014-03-26 08:00:41' + '.jpg')) +img6 = Image.open(os.path.join('figure/fakeA', '2014-03-26 08:00:41_fake_A_DepthResNet005' + '.jpg')) +img7 = Image.open(os.path.join('figure/fakeA', '2014-03-26 08:00:41_fake_A_DepthResNet' + '.jpg')) +img8 = Image.open(os.path.join('figure/fakeA', '2014-03-26 08:00:41_fake_A_DepthResNet5' + '.jpg')) +plt.subplot(2,4,1) +plt.title('Original') +plt.xticks([]),plt.yticks([]) +plt.imshow(img1) +plt.subplot(2,4,2) +plt.title('lambda_idt=0.05') +plt.xticks([]),plt.yticks([]) +plt.imshow(img2) +plt.subplot(2,4,3) +plt.title('lambda_idt=0.5') +plt.xticks([]),plt.yticks([]) +plt.imshow(img3) +plt.subplot(2,4,4) +plt.title('lambda_idt=5') +plt.xticks([]),plt.yticks([]) +plt.imshow(img4) +plt.subplot(2,4,5) +plt.title('Original') +plt.xticks([]),plt.yticks([]) +plt.imshow(img5) +plt.subplot(2,4,6) +plt.title('lambda_idt=0.05') +plt.xticks([]),plt.yticks([]) +plt.imshow(img6) +plt.subplot(2,4,7) +plt.title('lambda_idt=0.5') +plt.xticks([]),plt.yticks([]) +plt.imshow(img7) +plt.subplot(2,4,8) +plt.title('lambda_idt=5') +plt.xticks([]),plt.yticks([]) +plt.imshow(img8) +plt.suptitle("model: DepthResNet (batch_size=1)") +plt.show() +``` + + +![png](figure/output_29_0.png) + + + +通过改变batch_size进行训练,可以发现batch_size=1,2时可以得到较好的效果,batch_size=4时容易出现其他异常的特征。同时,观察下表给出的训练过程gpu的memory usage,batch_size过大有可能出现out of memory的报错。 + +| batch_size | GPU memory usage(MiB) | +| :--------: | :-------------------: | +| 1 | 5705 | +| 2 | 7753 | +| 4 | 11849 | + + +```python +plt.figure("batch_size", figsize=(12, 6), dpi=200) # 图像窗口名称 +img1 = Image.open(os.path.join('figure/fakeB', '2011-06-03 21:27:20' + '.jpg')) +img2 = Image.open(os.path.join('figure/fakeB', '2011-06-03 21:27:20_fake_B' + '.jpg')) +img3 = Image.open(os.path.join('figure/fakeB', '2011-06-03 21:27:20_fake_B_b2' + '.jpg')) +img4 = Image.open(os.path.join('figure/fakeB', '2011-06-03 21:27:20_fake_B_b4' + '.jpg')) +img5 = Image.open(os.path.join('figure/fakeA', '2014-03-26 08:00:41' + '.jpg')) +img6 = Image.open(os.path.join('figure/fakeA', '2014-03-26 08:00:41_fake_A' + '.jpg')) +img7 = Image.open(os.path.join('figure/fakeA', '2014-03-26 08:00:41_fake_A_b2' + '.jpg')) +img8 = Image.open(os.path.join('figure/fakeA', '2014-03-26 08:00:41_fake_A_b4' + '.jpg')) +plt.subplot(2,4,1) +plt.title('Original') +plt.xticks([]),plt.yticks([]) +plt.imshow(img1) +plt.subplot(2,4,2) +plt.title('batch_size=1') +plt.xticks([]),plt.yticks([]) +plt.imshow(img2) +plt.subplot(2,4,3) +plt.title('batch_size=2') +plt.xticks([]),plt.yticks([]) +plt.imshow(img3) +plt.subplot(2,4,4) +plt.title('batch_size=4') +plt.xticks([]),plt.yticks([]) +plt.imshow(img4) +plt.subplot(2,4,5) +plt.title('Original') +plt.xticks([]),plt.yticks([]) +plt.imshow(img5) +plt.subplot(2,4,6) +plt.title('batch_size=1') +plt.xticks([]),plt.yticks([]) +plt.imshow(img6) +plt.subplot(2,4,7) +plt.title('batch_size=2') +plt.xticks([]),plt.yticks([]) +plt.imshow(img7) +plt.subplot(2,4,8) +plt.title('batch_size=4') +plt.xticks([]),plt.yticks([]) +plt.imshow(img8) +plt.suptitle("Model: ResNet (lambda_idt=0.5)") +plt.show() +``` + + +![png](figure/output_31_0.png) + + + +通过改变`batch_size`的值,可以看到当`batch_size`为4时,图片出现了较大的影响,分析可能是由于当`batch_size`非1时,网络结构学到了batch中其他图片的信息,从而产生错误,故而在CycleGAN中,batchsize为1。 + +最后,我们选取了一张全新的风景图片,将这张图片分别放入summer和winter的真实图片中,利用训练好的生成器生成winter和summer的图片,观察生成对应的效果。我们很难判定原图属于夏天还是冬天,但某些生成器生成的图片能够展现出较明显的冬天和夏天的特征。说明对于CycleGAN来说,选择的图片也会对结果产生一定的影响。 + +但是,如果训练的数据集中存在较多这样特征不明显的图片,模型就很难学习冬天和夏天对应的特征。在summer2winter_yosemite数据集中,存在一定数量的图片未能展现出明显的夏天或冬天的特征,这也就在一定程度上影响了训练出的生成器生成图片的效果。 + + +```python +plt.figure("secret", figsize=(12, 6), dpi=200) # 图像窗口名称 +img1 = Image.open(os.path.join('figure/secret', 'secret1' + '.jpg')) +img2 = Image.open(os.path.join('figure/secret/fakeB', 'secret1_fake_B005' + '.jpg')) +img3 = Image.open(os.path.join('figure/secret/fakeB', 'secret1_fake_B' + '.jpg')) +img4 = Image.open(os.path.join('figure/secret/fakeB', 'secret1_fake_B5' + '.jpg')) +img5 = Image.open(os.path.join('figure/secret', 'secret1' + '.jpg')) +img6 = Image.open(os.path.join('figure/secret/fakeA', 'secret1_fake_A005' + '.jpg')) +img7 = Image.open(os.path.join('figure/secret/fakeA', 'secret1_fake_A' + '.jpg')) +img8 = Image.open(os.path.join('figure/secret/fakeA', 'secret1_fake_A5' + '.jpg')) +plt.subplot(2,4,1) +plt.title('Summer') +plt.xticks([]),plt.yticks([]) +plt.imshow(img1) +plt.subplot(2,4,2) +plt.title('lambda_idt=0.05') +plt.xticks([]),plt.yticks([]) +plt.imshow(img2) +plt.subplot(2,4,3) +plt.title('lambda_idt=0.5') +plt.xticks([]),plt.yticks([]) +plt.imshow(img3) +plt.subplot(2,4,4) +plt.title('lambda_idt=5') +plt.xticks([]),plt.yticks([]) +plt.imshow(img4) +plt.subplot(2,4,5) +plt.title('Winter') +plt.xticks([]),plt.yticks([]) +plt.imshow(img5) +plt.subplot(2,4,6) +plt.title('lambda_idt=0.05') +plt.xticks([]),plt.yticks([]) +plt.imshow(img6) +plt.subplot(2,4,7) +plt.title('lambda_idt=0.5') +plt.xticks([]),plt.yticks([]) +plt.imshow(img7) +plt.subplot(2,4,8) +plt.title('lambda_idt=5') +plt.xticks([]),plt.yticks([]) +plt.imshow(img8) +plt.suptitle("Secret Test (model=ResNet, batchsize=1)") +plt.show() +``` + + +![png](figure/output_34_0.png) + + diff --git a/application/homework_20220821/group8/yunhe8/eval.py b/application/homework_20220821/group8/yunhe8/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..62f6886361621ed8e9d3d3b02408cbf7167d1f1e --- /dev/null +++ b/application/homework_20220821/group8/yunhe8/eval.py @@ -0,0 +1,71 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Cycle GAN test.""" + +import os +from mindspore import Tensor +from src.models.cycle_gan import get_generator +from src.utils.args import get_args +from src.dataset.cyclegan_dataset import create_dataset +from src.utils.reporter import Reporter +from src.utils.tools import save_image, load_ckpt + + +def predict(): + """Predict function.""" + args = get_args("predict") + G_A = get_generator(args) + G_B = get_generator(args) + G_A.set_train(True) + G_B.set_train(True) + load_ckpt(args, G_A, G_B) + imgs_out = os.path.join(args.outputs_dir, "predict") + if not os.path.exists(imgs_out): + os.makedirs(imgs_out) + if not os.path.exists(os.path.join(imgs_out, "fake_A")): + os.makedirs(os.path.join(imgs_out, "fake_A")) + if not os.path.exists(os.path.join(imgs_out, "fake_B")): + os.makedirs(os.path.join(imgs_out, "fake_B")) + args.data_dir = 'testA' + ds = create_dataset(args) + reporter = Reporter(args) + reporter.start_predict("A to B") + for data in ds.create_dict_iterator(output_numpy=True): + img_A = Tensor(data["image"]) + path_A = str(data["image_name"][0], encoding="utf-8") + path_B = path_A[0:-4] + "_fake_B.jpg" + fake_B = G_A(img_A) + save_image(fake_B, os.path.join(imgs_out, "fake_B", path_B)) + save_image(img_A, os.path.join(imgs_out, "fake_B", path_A)) + reporter.info('save fake_B at %s', os.path.join(imgs_out, "fake_B", path_A)) + reporter.end_predict() + args.data_dir = 'testB' + ds = create_dataset(args) + reporter.dataset_size = args.dataset_size + reporter.start_predict("B to A") + for data in ds.create_dict_iterator(output_numpy=True): + img_B = Tensor(data["image"]) + path_B = str(data["image_name"][0], encoding="utf-8") + path_A = path_B[0:-4] + "_fake_A.jpg" + fake_A = G_B(img_B) + save_image(fake_A, os.path.join(imgs_out, "fake_A", path_A)) + save_image(img_B, os.path.join(imgs_out, "fake_A", path_B)) + reporter.info('save fake_A at %s', os.path.join(imgs_out, "fake_A", path_B)) + reporter.end_predict() + +if __name__ == "__main__": + predict() + \ No newline at end of file diff --git a/application/homework_20220821/group8/yunhe8/figure/Cycle.png b/application/homework_20220821/group8/yunhe8/figure/Cycle.png new file mode 100644 index 0000000000000000000000000000000000000000..c91dea0a8ce7b192e7dd3b89d41a183f8d9f0c61 Binary files /dev/null and b/application/homework_20220821/group8/yunhe8/figure/Cycle.png differ diff --git a/application/homework_20220821/group8/yunhe8/figure/LOSS.jpg b/application/homework_20220821/group8/yunhe8/figure/LOSS.jpg new file mode 100644 index 0000000000000000000000000000000000000000..274584f4459d3581ca11e62a286e362b9afb072f Binary files /dev/null and b/application/homework_20220821/group8/yunhe8/figure/LOSS.jpg differ diff --git a/application/homework_20220821/group8/yunhe8/figure/ResNet.png b/application/homework_20220821/group8/yunhe8/figure/ResNet.png new file mode 100644 index 0000000000000000000000000000000000000000..f6212a71e499790e305b5c128ce27603939ca129 Binary files /dev/null and b/application/homework_20220821/group8/yunhe8/figure/ResNet.png differ diff --git a/application/homework_20220821/group8/yunhe8/figure/output_27_0.png b/application/homework_20220821/group8/yunhe8/figure/output_27_0.png new file mode 100644 index 0000000000000000000000000000000000000000..c4d2ecae96c1223ed521c90aa2351ec95e436183 Binary files /dev/null and b/application/homework_20220821/group8/yunhe8/figure/output_27_0.png differ diff --git a/application/homework_20220821/group8/yunhe8/figure/output_29_0.png b/application/homework_20220821/group8/yunhe8/figure/output_29_0.png new file mode 100644 index 0000000000000000000000000000000000000000..06869c72e66c99c0d074b57c7a821671b53b7cbe Binary files /dev/null and b/application/homework_20220821/group8/yunhe8/figure/output_29_0.png differ diff --git a/application/homework_20220821/group8/yunhe8/figure/output_31_0.png b/application/homework_20220821/group8/yunhe8/figure/output_31_0.png new file mode 100644 index 0000000000000000000000000000000000000000..e591f6bec80136b60d09a417233a5f4474daf8fa Binary files /dev/null and b/application/homework_20220821/group8/yunhe8/figure/output_31_0.png differ diff --git a/application/homework_20220821/group8/yunhe8/figure/output_34_0.png b/application/homework_20220821/group8/yunhe8/figure/output_34_0.png new file mode 100644 index 0000000000000000000000000000000000000000..c9d9ee077ff3536af49301db7142365f3a7f60ce Binary files /dev/null and b/application/homework_20220821/group8/yunhe8/figure/output_34_0.png differ diff --git a/application/homework_20220821/group8/yunhe8/figure/output_5_0.png b/application/homework_20220821/group8/yunhe8/figure/output_5_0.png new file mode 100644 index 0000000000000000000000000000000000000000..b4ff06f15ea329a2d4ac999fa26acc3572d7cda9 Binary files /dev/null and b/application/homework_20220821/group8/yunhe8/figure/output_5_0.png differ diff --git a/application/homework_20220821/group8/yunhe8/figure/output_6_0.png b/application/homework_20220821/group8/yunhe8/figure/output_6_0.png new file mode 100644 index 0000000000000000000000000000000000000000..927e8f8114267f25fd37905da5d6abf7a5a90b9c Binary files /dev/null and b/application/homework_20220821/group8/yunhe8/figure/output_6_0.png differ diff --git a/application/homework_20220821/group8/yunhe8/figure/p1.jpg b/application/homework_20220821/group8/yunhe8/figure/p1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4b1893316c9181410c110c0363efb367b2208f67 Binary files /dev/null and b/application/homework_20220821/group8/yunhe8/figure/p1.jpg differ diff --git a/application/homework_20220821/group8/yunhe8/figure/p2.jpg b/application/homework_20220821/group8/yunhe8/figure/p2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..53fae344929cae7dc8eecde447c0e0d28bb41626 Binary files /dev/null and b/application/homework_20220821/group8/yunhe8/figure/p2.jpg differ diff --git a/application/homework_20220821/group8/yunhe8/figure/p3.jpg b/application/homework_20220821/group8/yunhe8/figure/p3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..50e7254d3b5b416b1e6a694f974d9e4288d12366 Binary files /dev/null and b/application/homework_20220821/group8/yunhe8/figure/p3.jpg differ diff --git a/application/homework_20220821/group8/yunhe8/figure/p4.jpg b/application/homework_20220821/group8/yunhe8/figure/p4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..420604db90f05cdc04e3b7ca861b01968e2b5052 Binary files /dev/null and b/application/homework_20220821/group8/yunhe8/figure/p4.jpg differ diff --git a/application/homework_20220821/group8/yunhe8/postprocess.py b/application/homework_20220821/group8/yunhe8/postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..faacce1840ccd74d4dd11fc7df398890bb90473e --- /dev/null +++ b/application/homework_20220821/group8/yunhe8/postprocess.py @@ -0,0 +1,60 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" + postprocess +""" +import os +import numpy as np +from PIL import Image +from src.utils.args import get_args +from mindspore import Tensor + +def save_image(img, img_path): + """Save a numpy image to the disk + + Parameters: + img (numpy array / Tensor): image to save. + image_path (str): the path of the image. + """ + if isinstance(img, Tensor): + img = img.asnumpy() + elif not isinstance(img, np.ndarray): + raise ValueError("img should be Tensor or numpy array, but get {}".format(type(img))) + img = decode_image(img) + + img_pil = Image.fromarray(img) + img_pil.save(img_path + ".jpg") + +def decode_image(img): + """Decode a [1, C, H, W] Tensor to image numpy array.""" + mean = 0.5 * 255 + std = 0.5 * 255 + + return (img * std + mean).astype(np.uint8).transpose((1, 2, 0)) + +if __name__ == '__main__': + args = get_args("predict") + + result_dir = args.outputs_dir + object_imageSize = args.image_size + rst_path = args.dataroot + + for i in range(len(os.listdir(rst_path))): + file_name = os.path.join(rst_path, "CycleGAN_data_bs" + str(args.batch_size) + '_' + str(i) + '_0.bin') + output = np.fromfile(file_name, np.float32).reshape(3, object_imageSize, object_imageSize) + print(output.shape) + save_image(output, result_dir + "/" + str(i + 1)) + print("=======image", i + 1, "saved success=======") + print("Generate images success!") diff --git a/application/homework_20220821/group8/yunhe8/preprocess.py b/application/homework_20220821/group8/yunhe8/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..f611e84d22140dba044301462f25ab143d379c6d --- /dev/null +++ b/application/homework_20220821/group8/yunhe8/preprocess.py @@ -0,0 +1,32 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" + preprocess +""" +import os +from src.utils.args import get_args +from src.dataset.cyclegan_dataset import create_dataset + +if __name__ == '__main__': + args = get_args("predict") + result_path = args.outputs_dir + ds_val = create_dataset(args) + img_path = os.path.join(result_path, "00_data") + os.makedirs(img_path) + for i, data in enumerate(ds_val.create_dict_iterator(output_numpy=True)): + file_name = "CycleGAN_data_bs" + str(args.batch_size) + "_" + str(i) + ".bin" + file_path = img_path + "/" + file_name + data['image'].tofile(file_path) + print("=" * 20, "export bin files finished", "=" * 20) diff --git a/application/homework_20220821/group8/yunhe8/scripts/run_eval_gpu.sh b/application/homework_20220821/group8/yunhe8/scripts/run_eval_gpu.sh new file mode 100644 index 0000000000000000000000000000000000000000..4f4a34e3e7dd6836d2d5d653755080e3b1b034c6 --- /dev/null +++ b/application/homework_20220821/group8/yunhe8/scripts/run_eval_gpu.sh @@ -0,0 +1,16 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +python eval.py --platform GPU --device_id 0 --model ResNet --outputs_dir ./outputs --G_A_ckpt ./outputs/ckpt/G_A_200.ckpt --G_B_ckpt ./outputs/ckpt/G_B_200.ckpt > output.eval.log 2>&1 & diff --git a/application/homework_20220821/group8/yunhe8/scripts/run_train_standalone_gpu.sh b/application/homework_20220821/group8/yunhe8/scripts/run_train_standalone_gpu.sh new file mode 100644 index 0000000000000000000000000000000000000000..4f575bd77cf001c3b4d628055668333717663904 --- /dev/null +++ b/application/homework_20220821/group8/yunhe8/scripts/run_train_standalone_gpu.sh @@ -0,0 +1,16 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +python train.py --platform GPU --device_id 0 --model ResNet --max_epoch 200 --dataroot ./data/summer2winter_yosemite/ --outputs_dir ./outputs > output.train.log 2>&1 & diff --git a/application/homework_20220821/group8/yunhe8/train.py b/application/homework_20220821/group8/yunhe8/train.py new file mode 100644 index 0000000000000000000000000000000000000000..d68096973090795edd055474c48349c8a1c645d8 --- /dev/null +++ b/application/homework_20220821/group8/yunhe8/train.py @@ -0,0 +1,86 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""General-purpose training script for image-to-image translation. +You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model'). +Example: + Train a resnet model: + python train.py --dataroot ./data/horse2zebra --model ResNet +""" + +import mindspore as ms +import mindspore.nn as nn +from src.utils.args import get_args +from src.utils.reporter import Reporter +from src.utils.tools import get_lr, ImagePool, load_ckpt +from src.dataset.cyclegan_dataset import create_dataset +from src.models.losses import DiscriminatorLoss, GeneratorLoss +from src.models.cycle_gan import get_generator, get_discriminator, Generator, TrainOneStepG, TrainOneStepD + +ms.set_seed(1) + +def train(): + """Train function.""" + args = get_args("train") + if args.need_profiler: + from mindspore.profiler.profiling import Profiler + profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True) + ds = create_dataset(args) + G_A = get_generator(args) + G_B = get_generator(args) + D_A = get_discriminator(args) + D_B = get_discriminator(args) + if args.load_ckpt: + load_ckpt(args, G_A, G_B, D_A, D_B) + imgae_pool_A = ImagePool(args.pool_size) + imgae_pool_B = ImagePool(args.pool_size) + generator = Generator(G_A, G_B, args.lambda_idt > 0) + + loss_D = DiscriminatorLoss(args, D_A, D_B) + loss_G = GeneratorLoss(args, generator, D_A, D_B) + optimizer_G = nn.Adam(generator.trainable_params(), get_lr(args), beta1=args.beta1) + optimizer_D = nn.Adam(loss_D.trainable_params(), get_lr(args), beta1=args.beta1) + + net_G = TrainOneStepG(loss_G, generator, optimizer_G) + net_D = TrainOneStepD(loss_D, optimizer_D) + + data_loader = ds.create_dict_iterator() + if args.rank == 0: + reporter = Reporter(args) + reporter.info('==========start training===============') + for _ in range(args.max_epoch): + if args.rank == 0: + reporter.epoch_start() + for data in data_loader: + img_A = data["image_A"] + img_B = data["image_B"] + res_G = net_G(img_A, img_B) + fake_A = res_G[0] + fake_B = res_G[1] + res_D = net_D(img_A, img_B, imgae_pool_A.query(fake_A), imgae_pool_B.query(fake_B)) + if args.rank == 0: + reporter.step_end(res_G, res_D) + reporter.visualizer(img_A, img_B, fake_A, fake_B) + if args.rank == 0: + reporter.epoch_end(net_G) + if args.need_profiler: + profiler.analyse() + break + if args.rank == 0: + reporter.info('==========end training===============') + + +if __name__ == "__main__": + train()