diff --git a/application/homework_20220821/group7/imgs/.keep b/application/homework_20220821/group7/imgs/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/application/homework_20220821/group7/imgs/CUFS.png b/application/homework_20220821/group7/imgs/CUFS.png new file mode 100644 index 0000000000000000000000000000000000000000..3b63eb2c065cd465dd873c8506837b23abf03b02 Binary files /dev/null and b/application/homework_20220821/group7/imgs/CUFS.png differ diff --git a/application/homework_20220821/group7/imgs/aerial2ground.png b/application/homework_20220821/group7/imgs/aerial2ground.png new file mode 100644 index 0000000000000000000000000000000000000000..654f92547e4d4df59543cc13c9fede3e2ec9f3dd Binary files /dev/null and b/application/homework_20220821/group7/imgs/aerial2ground.png differ diff --git a/application/homework_20220821/group7/imgs/compare.png b/application/homework_20220821/group7/imgs/compare.png new file mode 100644 index 0000000000000000000000000000000000000000..31cf0164023da32f33909bfa02b72c7f5a6cc041 Binary files /dev/null and b/application/homework_20220821/group7/imgs/compare.png differ diff --git a/application/homework_20220821/group7/imgs/concated_img.jpg b/application/homework_20220821/group7/imgs/concated_img.jpg new file mode 100644 index 0000000000000000000000000000000000000000..61d94881165dd5077b93d51a69d9388f0197d231 Binary files /dev/null and b/application/homework_20220821/group7/imgs/concated_img.jpg differ diff --git a/application/homework_20220821/group7/imgs/dayton.jpg b/application/homework_20220821/group7/imgs/dayton.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b2b9a6cde26e9edfc90c5d89376f980d2e2b4094 Binary files /dev/null and b/application/homework_20220821/group7/imgs/dayton.jpg differ diff --git a/application/homework_20220821/group7/imgs/default_pred.jpg b/application/homework_20220821/group7/imgs/default_pred.jpg new file mode 100644 index 0000000000000000000000000000000000000000..91a8890701a99e248ad04de7b89d4f8328e73165 Binary files /dev/null and b/application/homework_20220821/group7/imgs/default_pred.jpg differ diff --git a/application/homework_20220821/group7/imgs/epoch.png b/application/homework_20220821/group7/imgs/epoch.png new file mode 100644 index 0000000000000000000000000000000000000000..4ab218994b7ecfef85bfdfc1d02f2542339d0202 Binary files /dev/null and b/application/homework_20220821/group7/imgs/epoch.png differ diff --git a/application/homework_20220821/group7/imgs/fk_img1.jpg b/application/homework_20220821/group7/imgs/fk_img1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..868b65ff0d87ef20143bc91b133c6f95d52730ad Binary files /dev/null and b/application/homework_20220821/group7/imgs/fk_img1.jpg differ diff --git a/application/homework_20220821/group7/imgs/ground2aerial.png b/application/homework_20220821/group7/imgs/ground2aerial.png new file mode 100644 index 0000000000000000000000000000000000000000..526dc64155fd5d98211c03dd7891fa7d8b02631d Binary files /dev/null and b/application/homework_20220821/group7/imgs/ground2aerial.png differ diff --git a/application/homework_20220821/group7/imgs/params.png b/application/homework_20220821/group7/imgs/params.png new file mode 100644 index 0000000000000000000000000000000000000000..0db4c3c16bb6661ce1ba6bacf481944255d8fe45 Binary files /dev/null and b/application/homework_20220821/group7/imgs/params.png differ diff --git a/application/homework_20220821/group7/imgs/res1.png b/application/homework_20220821/group7/imgs/res1.png new file mode 100644 index 0000000000000000000000000000000000000000..ef08fb226fb30983a8af2de0184152c91db266dd Binary files /dev/null and b/application/homework_20220821/group7/imgs/res1.png differ diff --git a/application/homework_20220821/group7/imgs/res2.png b/application/homework_20220821/group7/imgs/res2.png new file mode 100644 index 0000000000000000000000000000000000000000..1ede41956c8c7eb80d475f83b0c71171fb899ef2 Binary files /dev/null and b/application/homework_20220821/group7/imgs/res2.png differ diff --git a/application/homework_20220821/group7/imgs/test.jpg b/application/homework_20220821/group7/imgs/test.jpg new file mode 100644 index 0000000000000000000000000000000000000000..47e627687b19d24371fab677572ef62e796e630f Binary files /dev/null and b/application/homework_20220821/group7/imgs/test.jpg differ diff --git a/application/homework_20220821/group7/imgs/thesis.png b/application/homework_20220821/group7/imgs/thesis.png new file mode 100644 index 0000000000000000000000000000000000000000..8db300d7858e7e3aab3d20044cd593be7db1ff01 Binary files /dev/null and b/application/homework_20220821/group7/imgs/thesis.png differ diff --git "a/application/homework_20220821/group7/\345\237\272\344\272\216 Pix2Pix \346\250\241\345\236\213\347\232\204\350\267\250\350\247\206\350\247\222\350\210\252\347\251\272\345\233\276\345\203\217\345\234\260\351\235\242\345\233\276\345\203\217\347\224\237\346\210\220.ipynb" "b/application/homework_20220821/group7/\345\237\272\344\272\216 Pix2Pix \346\250\241\345\236\213\347\232\204\350\267\250\350\247\206\350\247\222\350\210\252\347\251\272\345\233\276\345\203\217\345\234\260\351\235\242\345\233\276\345\203\217\347\224\237\346\210\220.ipynb" new file mode 100644 index 0000000000000000000000000000000000000000..5cb20218a9675240d9d94a842d415f9376a69d41 --- /dev/null +++ "b/application/homework_20220821/group7/\345\237\272\344\272\216 Pix2Pix \346\250\241\345\236\213\347\232\204\350\267\250\350\247\206\350\247\222\350\210\252\347\251\272\345\233\276\345\203\217\345\234\260\351\235\242\345\233\276\345\203\217\347\224\237\346\210\220.ipynb" @@ -0,0 +1,550 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 一、模型介绍\n", + "### 1.1 pix2pix模型原理介绍\n", + "在机器学习与深度学习中,往往需要对一些图片进行处理,其中,图像到图像的翻译越来越广泛地应用于我们的生活中。这一翻译过程的本质为像素到像素的一一映射关系的体现,pix2pix网络模型通过给定的数据集进行一定的训练过程,就可以得到像素之间的映射关系,从而对新的图片进行预测,生成期望样式与类型的图片。\n", + "我们所用的pix2pix网络模型不针对某一特定的应用程序,通过对不同的数据集进行训练,即可得到不同的训练效果,应用于不同的实际场景中。\n", + "### 1.2 图像模型的结构化损失计算\n", + "模型基于CGAN网络计算其结构化损失,与其他的计算方法相比,该方案的优势在于,损失是可以学习的,理论上,它可以惩罚输出和目标之间存在差异的任何可能结构。\n", + "### 1.3 生成式对抗网络(GAN)架构\n", + "生成式对抗网络(Generative Adversarial Networks,GAN)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。该模型由两个模块组成——**生成器G**和**判别器D**\n", + "- 生成器:基于U-Net架构,生成与训练图像相似的“假图像”\n", + "- 判别器:使用卷积“PatchGAN”分类器,仅在图像块的尺度上惩罚结构,判断生成器生成的图片的真伪 \n", + "\n", + "值得一提的是,生成器和判别器都使用卷积BatchNorm ReLu形式的模块,与无条件GAN不同的是,在该模型中,生成器和判别器都观察输入边缘图。\n", + "\n", + "### 1.4 目标函数表示\n", + "Pix2Pix的目标函数由两部分组成,分别是CGAN的目标函数和$L_1$损失函数。\n", + "\n", + "CGAN目标函数:\n", + "$$L_{CGAN}(G,D)=\\mathop {\\min }\\limits_G \\mathop {\\max }\\limits_D V(D,G) = {\\rm E}_{x\\sim{p_{data}(x)}}[\\log D(x,y)] + {\\rm E}_{z\\sim{p_z}(z)}[\\log (1 - D(x,G(z,x)))]$$\n", + "上式中$D(x,G(x,z))$,$D(x,G(x,z))$表示真是配对数据输入图像$x$与输出图像$y$对于判别器$D$的结果,而 $D(x,G(x,z))$,$D(x,G(x,z))$则是$x$经过生成器产生的图像 $G(x,z)$,$G(x,z)$对于判别器判断的结果。\n", + "\n", + "L1损失函数:\n", + "$$L_{L_1}(G) = {\\rm E}_{x,y,z} [||y - G(x,z)||_{1}]$$\n", + "\n", + "Pix2Pix最终的目标函数为:\n", + "$$G^{*} = arg\\mathop {\\min }\\limits_G \\mathop {\\max }\\limits_D V(D,G) + \\lambda L_{L_1}(G)$$\n", + "\n", + "其中 $\\lambda$为超参数,可以根据情况调节,当 $\\lambda=0 $时表示不采用$L_1$损失函数。\n", + "\n", + "### 1.5 预测结果的评估\n", + "对于生成图片的质量,往往难以进行量化的展示,因此我们针对图片的合理性与真实性两个方面对模型生成的图片质量进行评估。\n", + "- AMT感知研究,对模型的合理性进行量化评估\n", + "- 根据合成照片的标签对合成照片进行分类准确度评分,测量城市景观等图像是否够逼真\n", + "\n", + "### 1.6 模型评价\n", + "pix2pix模型针对不同的训练数据集可产生不同的训练效果,可广泛应用于生活实际中;针对数据的训练过程,使用少量的数据集即可得到较为良好的训练效果,训练时间较短,利于学习与应用推广。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 二、案例一——人像素描图转换" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. CUFS 数据集介绍\n", + "\n", + "CUFS 数据集由香港中文大学制作。其中包括188张来自香港中文大学(中大)学生数据库的人脸,123张来自AR数据库,295张来自XM2VTS数据库的人脸。总共有606个面。每张脸都有一幅素描,是由一位艺术家根据一张在正常光照条件下拍摄的正面姿势和中性表情的照片绘制的。可惜的是,除了香港中文大学的人脸数据还能找到下载渠道,AR 数据库和 XM2VTS 数据库中的人脸图片暂时没能找到下载链接。\n", + "\n", + "所以本次的数据集中,只有 188 张人脸和对应的素描图片,尺寸均为:200 * 250。\n", + "\n", + "我们首先做的尝试是:将 188 张图片中的前 40 张作为测试集,后 148 张作为训练集进行模型训练。\n", + "\n", + "![](imgs/cufs.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. 数据处理" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# 引入相关库\n", + "import numpy as np\n", + "import cv2\n", + "import os\n", + "from PIL import Image\n", + "import mindspore.dataset.vision.c_transforms as C" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.1 文件重命名\n", + "\n", + "mindspore 上下载的 pix2pix 模型要求训练集和测试集图片的命名为 “数字+后缀名”,所以这里需要对下载到的数据集图片进行重命名。" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "image_path = 'data/cufs/origin/photos'\n", + "sketch_path = 'data/cufs/origin/sketches'\n", + "\n", + "images = os.listdir(image_path)\n", + "sketches = os.listdir(sketch_path)\n", + "\n", + "images.sort()\n", + "sketches.sort()\n", + "\n", + "\n", + "def rename_sketches():\n", + " for i in range(1, len(sketches)+1):\n", + " if \"M2\" in sketches[i-1]:\n", + " os.rename(os.path.join(sketch_path, sketches[i-1]), os.path.join(sketch_path, sketches[i-1].replace(\"M2\", \"m\")))\n", + " if \"F2\" in sketches[i-1]:\n", + " os.rename(os.path.join(sketch_path, sketches[i-1]), os.path.join(sketch_path, sketches[i-1].replace(\"F2\", \"f\")))\n", + "\n", + "def rename_images():\n", + " for i in range(1, len(images)+1):\n", + " os.rename(os.path.join(\n", + " image_path, images[i-1]), os.path.join(image_path, \"%d.jpg\" % i))\n", + " os.rename(os.path.join(\n", + " sketch_path, sketches[i-1]), os.path.join(sketch_path, \"%d.jpg\" % i))\n", + " \n", + " \n", + "rename_sketches()\n", + "rename_images()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.2 resize 图像" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "def resize_image(input_path, dst_path, size=(256, 256)):\n", + " resize = C.Resize(size)\n", + "\n", + " for file in os.listdir(input_path):\n", + " image = Image.open(os.path.join(input_path, file))\n", + " image = resize(image)\n", + " image = Image.fromarray(image)\n", + " image.save(os.path.join(dst_path, file))\n", + "\n", + "resize_image(\"data/cufs/origin/photos\", \"data/cufs/256_256/photos\")\n", + "resize_image(\"data/cufs/origin/sketches\", \"data/cufs/256_256/sketches\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.3 图像拼接\n", + "\n", + "pix2pix模型的输入图像是原图和标签图的拼接图,为了让数据可以直接被模型处理,需要将数据集中的图片 photo 和对应的素描图 sketch 拼接为一张图片,效果如下:
\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'os' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32mC:\\Windows\\TEMP/ipykernel_4800/3968237396.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 22\u001b[0m \u001b[0mdst_path\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;34m\"data/cufs/256_256/train\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 23\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 24\u001b[1;33m \u001b[0mfiles\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mos\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlistdir\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimage_path\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 25\u001b[0m \u001b[0mfiles\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msort\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 26\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mNameError\u001b[0m: name 'os' is not defined" + ] + } + ], + "source": [ + "def concate_photo_sketch(sketch, image, dst):\n", + " '''\n", + " sketch: [string] 素描/ground truth 的路径\n", + " image: [string] 原照片 image 的路径\n", + " dst: [string] 拼接后图片的保存路径\n", + " '''\n", + " image = np.array(image)\n", + " sketch = np.array(sketch)\n", + " if sketch.ndim == 2:\n", + " sketch = np.expand_dims(sketch, 2)\n", + " sketch = np.concatenate([sketch, sketch, sketch], 2)\n", + "\n", + " sketch_photo = np.hstack([sketch, image])\n", + " sketch_photo = Image.fromarray(sketch_photo)\n", + " sketch_photo.save(dst)\n", + " print(\"%s SAVED !!!! \" % dst)\n", + "\n", + "\n", + "\n", + "image_path = \"data/cufs/256_256/photos\"\n", + "sketch_path = \"data/cufs/256_256/sketches\"\n", + "dst_path = \"data/cufs/256_256/train\"\n", + "\n", + "files = os.listdir(image_path)\n", + "files.sort()\n", + "\n", + "for file in files:\n", + " photo = Image.open(os.path.join(image_path, file))\n", + " sketch = Image.open(os.path.join(sketch_path, file)) \n", + " dst = os.path.join(dst_path, file)\n", + "\n", + " concate_photo_sketch(sketch, photo, dst)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.1 photo2sketch\n", + "\n", + "将数据集中的图片尺寸 resize 到 256 * 256 并且拼接 sketch 和 photo 后,以前 40 张图片作为测试集、后 148 张图片作为训练集对 pix2pix 模型进行训练。\n", + "\n", + "首先使用默认超参进行训练 200 个 epoch,训练过程中模型生成的 fake image,如下所示:\n", + "![](imgs/fk_img.jpg)\n", + "\n", + "对测试集的预测情况为:\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "由上可以看到,默认参数的效果还是可以继续改进的;所以我们接下来尝试了调节各种参数,并且考虑到数据集中的图片数量较少,还添加了 randomVerticalFlip、randomRotate、randomSharpness 三种数据增强方法。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "'''\n", + "仿照源代码的格式,编写了随机垂直翻转、随机旋转和随机锐化处理 3 中数据增强方法\n", + "'''\n", + "def sync_random_rotate(input_images, target_images):\n", + " '''\n", + " Randomly flip the input images and the target images.\n", + " '''\n", + " seed = np.random.randint(0, 2000000000)\n", + "\n", + " mindspore.set_seed(seed)\n", + " op = C.RandomRotation(30)\n", + " out_input = op(input_images)\n", + "\n", + " mindspore.set_seed(seed)\n", + " op = C.RandomRotation(30)\n", + " out_target = op(target_images)\n", + " \n", + " return out_input, out_target\n", + "\n", + "def sync_random_Vertical_Flip(input_images, target_images):\n", + " '''\n", + " Randomly flip the input images and the target images.\n", + " '''\n", + " seed = np.random.randint(0, 2000000000)\n", + "\n", + " mindspore.set_seed(seed)\n", + " op = C.RandomVerticalFlip(prob=0.5)\n", + " out_input = op(input_images)\n", + "\n", + " mindspore.set_seed(seed)\n", + " op = C.RandomVerticalFlip(prob=0.5)\n", + " out_target = op(target_images)\n", + " \n", + " return out_input, out_target\n", + "\n", + "\n", + "def sync_random_Sharpness(input_images, target_images):\n", + " '''\n", + " Randomly flip the input images and the target images.\n", + " '''\n", + " seed = np.random.randint(0, 2000000000)\n", + "\n", + " mindspore.set_seed(seed)\n", + " op = C.RandomSharpness((0.9, 1.5))\n", + " out_input = op(input_images)\n", + "\n", + " mindspore.set_seed(seed)\n", + " op = C.RandomSharpness((0.9, 1.5))\n", + " out_target = op(target_images)\n", + " \n", + " return out_input, out_target\n", + "\n", + "\n", + "\n", + "def create_train_dataset(dataset):\n", + " '''\n", + " Create train dataset.\n", + " '''\n", + "\n", + " mean = [0.5 * 255] * 3\n", + " std = [0.5 * 255] * 3\n", + "\n", + " trans = [\n", + " C.Normalize(mean=mean, std=std),\n", + " C.HWC2CHW()\n", + " ]\n", + "\n", + " train_ds = de.GeneratorDataset(dataset, column_names=[\"input_images\", \"target_images\"], shuffle=False)\n", + " train_ds = train_ds.map(operations=[sync_random_Horizontal_Flip, sync_random_rotate, sync_random_Sharpness, sync_random_Vertical_Flip], input_columns=[\"input_images\", \"target_images\"])\n", + "# train_ds = train_ds.map(operations=[sync_random_Horizontal_Flip], input_columns=[\"input_images\", \"target_images\"])\n", + "\n", + " train_ds = train_ds.map(operations=trans, input_columns=[\"input_images\"])\n", + " train_ds = train_ds.map(operations=trans, input_columns=[\"target_images\"])\n", + "\n", + " train_ds = train_ds.batch(1, drop_remainder=True)\n", + "\n", + " return train_ds" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "进行上述调整后,模型的预测输出如下:可以看到,这时的效果已经比较让人满意了。\n", + "\n", + "![](imgs/res1.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.4 sketch2photo 尝试\n", + "\n", + "和 photo2sketch 相比,这里的数据处理方式唯一的不同是:在拼接 photo 和 sketch 时,photo 在左, sketch 在右。出于时间的限制,这里只使用了数据增强来优化模型的输出效果,同时,由sketch转photo要比photo转sketch难度大,所以这个模型的输出效果没有前一个那么逼真。之后需尝试更多的调参,甚至调整网络模型结构来进一步优化模型的预测结果。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def concate_photo_sketch(left, right, dst):\n", + " '''\n", + " left: [string] 素描/ground truth 的路径,合并后位于图片左边\n", + " right: [string] 原照片 image 的路径,合并后位于图片右边\n", + " dst: [string] 拼接后图片的保存路径\n", + " '''\n", + " left = np.array(left)\n", + " right = np.array(right)\n", + " if left.ndim == 2:\n", + " left = np.expand_dims(left, 2)\n", + " left = np.concatenate([left, left, left], 2)\n", + " if right.ndim == 2:\n", + " right = np.expand_dims(right, 2)\n", + " right = np.concatenate([right, right, right], 2)\n", + " \n", + " print(\"image.shape = \", left.shape)\n", + " print(\"sketch.shape = \", right.shape)\n", + "\n", + " sketch_photo = np.hstack([left, right])\n", + " sketch_photo = Image.fromarray(sketch_photo)\n", + " sketch_photo.save(dst)\n", + " print(\"%s SAVED !!!! \" % dst)\n", + "\n", + "\n", + "\n", + "image_path = \"data/cufs/256_256/photos\"\n", + "sketch_path = \"data/cufs/256_256/sketches\"\n", + "dst_path = \"data/cufs/256_256/train2\"\n", + "\n", + "files = os.listdir(image_path)\n", + "files.sort()\n", + "\n", + "for file in files:\n", + " photo = Image.open(os.path.join(image_path, file))\n", + " sketch = Image.open(os.path.join(sketch_path, file)) \n", + " dst = os.path.join(dst_path, file)\n", + "\n", + " concate_photo_sketch(photo, sketch, dst)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "模型预测效果:\n", + "\n", + "![](imgs/res2.png)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 三、案例2——跨视角航空图像地面图像生成\n", + "\n", + "\n", + "### 1. 数据集 Dayton 介绍\n", + "\n", + "Dayton数据集是用于地对地(或航空对地)图像转换或交叉视图图像合成的数据集。它包含道路视图和道路鸟瞰图的图像。总共有76048张图像,训练/测试分割为55000 / 21048。原始数据集中的图像具有 354×354 分辨率。\n", + "\n", + "这里考虑到算力和时间的限制,只使用了 1000 张图片作为训练集,1000 张图片作为测试集\n", + "\n", + "![](imgs/dayton.jpg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. 参数设置\n", + "\n", + "经过和上一个案例类似的预处理之后,将数据集输入到模型中进行训练。通过调参的方式,尽可能达到模型的最优效果。\n", + "\n", + "![](imgs/params.png)\n", + "\n", + "实验发现,当参数如上设置时,预测的效果比较理想且趋于稳定。\n", + "\n", + "![](imgs/epoch.png)\n", + "
可以看到,当 epoch 逐渐增大,模型的预测效果越来越好且趋于稳定
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. 预测结果\n", + "\n", + "![](imgs/ground2aerial.png)\n", + "\n", + "
Ground to Aerial
\n", + "\n", + "![](imgs/aerial2ground.png)\n", + "\n", + "
Aerial to Ground
" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4. Cross-View 模型分析\n", + "\n", + "\n", + "### 4.1 Cross-View 模型效果分析\n", + "总结:实现地面与航空图像的相互翻译生成功能\n", + "\n", + "\t生成的地面照片和航空照片,道路方向颜色,天空颜色面积,路边建筑高度颜色基本一致\n", + " \n", + "\t随着训练epoch的增加,生成器生成的图像更接近于真实图像,并趋于稳定\n", + "\n", + "### 4.2 Cross-View 模型问题分析\n", + "\n", + "\n", + "1.生成图像仍存在模糊失真的问题\n", + "2.部分图像生成与真值差距较大\n", + "3.没有引入量化评估标准:Inception Score,Accuracy,KL(model data),SSIM, PSNR and Sharpness Difference,FID Score\n", + "\n", + "**原因分析**:\n", + "1.相对于该任务的难度,训练集数量较少,仅1000张\n", + "2.pix2pix baseline对于cross-view任务,源域和目标域图像像素几乎没有重合,仅采用一对生成判别器,效果较差\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "通过以pix2pix为baseline发表的论文对比,我们发现训练模型与发表论文展示的pix2pix baseline效果基本一致。所以推测部分图片效果较差不是因为超参的限制,而是网络模型本身的限制。\n", + "\n", + "![](imgs/compare.png)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**提升空间**\n", + "\n", + "1. 添加多个生成器,级联语义信息如分割label,引导生成器生成更逼真的图像\n", + "\n", + "2. 多通道自注意力选择模型 SelectionGAN\n", + "\n", + "3. 建立cross-view,地面航空图像中,对应坐标匹配的单应性矩阵(homograpy)\n", + "\n", + "4. 增加训练集数量和训练epoch\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 四、Mindspore Pix2Pix 代码贡献\n", + "\n", + "1. 添加多个数据集处理脚本\n", + "\n", + "2. 实现auto-resume 功能,保存生成器判别器权重模型,中断训练加载模型重启训练,目前设置了每 20 个 epoch 自动保存一次生成器和判别器的ckpt文件\n", + "\n", + "3. 生成多个数据集config;实验过程中,我们尝试了多种不同的数据集,编写了相对应的 config.yaml 文件\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "vscode": { + "interpreter": { + "hash": "199b987b6994316a3c6b8731a85f82b0addcbcf4ad2e6bf53475c7085cd06d2d" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}