代码拉取完成,页面将自动刷新
from StillGAN.models import create_model
from StillGAN.models.networks import ResnetGenerator
from StillGAN.data import create_dataset
from StillGAN.options.test_options import TestOptions
from StillGAN.util.visualizer import save_images
from StillGAN.util import html
from StillGAN.util.util import *
from StillGAN.models.networks import ResUNet
from torch import nn
import functools
from torch.autograd import Variable
from torchvision import transforms
import cv2
# 某工具
class Identity(nn.Module):
def forward(self, x):
return x
def get_norm_layer(norm_type='instance'):
"""Return a normalization layer
Parameters:
norm_type (str) -- the name of the normalization layer: batch | instance | none
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
"""
if norm_type == 'batch':
norm_layer = functools.partial(
nn.BatchNorm2d, affine=True, track_running_stats=True)
elif norm_type == 'instance':
norm_layer = functools.partial(
nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == 'none':
def norm_layer(x):
return Identity()
else:
raise NotImplementedError(
'normalization layer [%s] is not found' % norm_type)
return norm_layer
# 增强预处理
def pre_still(img):
transform_list = []
res = img
size = 518
osize = [size, size]
transform_list.append(transforms.Resize(osize, Image.BICUBIC))
# transform_list.append(transforms.RandomCrop(size))
transform_list += [transforms.ToTensor()]
# transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
transform_list += [transforms.Normalize((0.5), (0.5))]
trans = transforms.Compose(transform_list)
res = trans(res)
res = Variable(torch.unsqueeze(res, dim=0).float(), requires_grad=False)
return res
from seg_system import ApplicationConfig
# 马博增强
def stillgan_model(img, model1):
pic = pre_still(img)
model = ResnetGenerator(1, 1, 64, norm_layer=get_norm_layer(),
use_dropout=False, n_blocks=9)
net = torch.load(model1)
model.load_state_dict(net)
model = model.to(ApplicationConfig.SystemConfig.DEVICE)
pic = pic.to(ApplicationConfig.SystemConfig.DEVICE)
runned = model(pic)
image = tensor2im(runned)
(r, g, b) = cv2.split(image)
image = cv2.merge([b, g, r])
return image
if __name__ == '__main__':
# tmp/seg/raw.png输入图片路径
img = stillgan_model(Image.open('./tmp/multi_batch_tmp/tacom/12345/normal/0/batch_1_1.png'),
'StillGAN/checkpoints/isee_csigan/75_net_G_A.pth')
# img增强后图片
cv2.imwrite('tmp/seg/enh.png', img)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。