加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
taught.py 2.59 KB
一键复制 编辑 原始数据 按行查看 历史
tacom 提交于 2022-07-25 15:23 . [init] code for board init
from StillGAN.models import create_model
from StillGAN.data import create_dataset
from StillGAN.options.test_options import TestOptions
from StillGAN.util.visualizer import save_images
from StillGAN.util import html
from StillGAN.util.util import *
from StillGAN.models.networks import ResUNet
from torch import nn
import functools
from torch.autograd import Variable
from torchvision import transforms
import cv2
#某工具
class Identity(nn.Module):
def forward(self, x):
return x
def get_norm_layer(norm_type='instance'):
"""Return a normalization layer
Parameters:
norm_type (str) -- the name of the normalization layer: batch | instance | none
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
"""
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == 'none':
def norm_layer(x):
return Identity()
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
#增强预处理
def pre_still(img):
transform_list = []
res = img
osize = [512, 512]
transform_list.append(transforms.Resize(osize, Image.BICUBIC))
transform_list.append(transforms.RandomCrop(512))
transform_list += [transforms.ToTensor()]
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
trans = transforms.Compose(transform_list)
res = trans(res)
res = Variable(torch.unsqueeze(res, dim=0).float(), requires_grad=False)
return res
#马博增强
def stillgan(img_path):
current = Image.open(img_path)
height, width = current.size
pic = pre_still(current)
model = ResUNet(3, 3, 64, norm_layer=get_norm_layer())
net = torch.load('StillGAN/checkpoints/isee_csigan/120_net_G_A.pth', map_location=torch.device('cpu'))
model.load_state_dict(net)
runned = model(pic)
image = tensor2im(runned)
(r, g, b) = cv2.split(image)
image = cv2.merge([b, g, r])
fx = height
fy = width
image = cv2.resize(image, (fx, fy), interpolation=cv2.INTER_CUBIC)
return image
if __name__=='__main__':
#tmp/seg/raw.png输入图片路径
img = stillgan('/Users/xfdw/Desktop/project/back/tmp/model1/raw/1.png')
print(1)
#img增强后图片
cv2.imwrite('/Users/xfdw/Desktop/re/a.png', img)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化