代码拉取完成,页面将自动刷新
import torch
from torchvision import transforms
from PIL import ImageOps
import numpy as np
import os
import glob
import cv2
def segment(nerve_img=None):
"""
API for corneal nerve segmentation
:param nerve_img: image object or directory path
:param save_path: should be a directory path when the nerve_img is a path
:return:
"""
mask = predict(nerve_img=nerve_img)
return mask
def ReScaleSize(image, re_size=512):
w, h = image.size
max_len = max(w, h)
new_w, new_h = max_len, max_len
delta_w = new_w - w
delta_h = new_h - h
padding = (delta_w // 2, delta_h // 2, delta_w -
(delta_w // 2), delta_h - (delta_h // 2))
image = ImageOps.expand(image, padding, fill=0)
# origin_w, origin_h = w, h
image = image.resize((re_size, re_size))
return image # , origin_w, origin_h
def load_net():
# from model.finerseg import FinerCSNet
net = torch.load("./trained_model/FinerCSNet.pkl",
map_location=torch.device('cpu'))
if isinstance(net, torch.nn.DataParallel):
net = net.module
return net
def save_prediction(pred, save_path, filename=''):
# for MSELoss()
mask = pred.permute(0, 2, 3, 1).contiguous()
mask = mask.squeeze_(0).squeeze_(-1)
mask = mask.data.cpu().numpy() * 255
# thresholding
mask[mask < 127] = 0
mask[mask >= 127] = 255
cv2.imwrite(os.path.join(save_path, filename + '.png'))
def return_prediction(pred):
# for MSELoss()
mask = pred.permute(0, 2, 3, 1).contiguous()
mask = mask.squeeze_(0).squeeze_(-1)
mask = mask.data.cpu().numpy() * 255
mask[mask < 127] = 0
mask[mask >= 127] = 255
mask=mask.astype(np.uint8)
return mask
def load_nerve(path):
if not os.listdir(path):
raise ValueError("The directory is empty.")
test_images = []
for file in glob.glob(os.path.join(path, '*')):
test_images.append(file)
return test_images
def predict(nerve_img=None):
# load the trained model
net = load_net()
# define the image transformation
transform = transforms.Compose([
transforms.ToTensor()
])
# load images
if nerve_img is None:
raise ValueError(
"Segmentation object should be directory path or an image")
else: # the input is an image object
image = nerve_img
with torch.no_grad():
net.eval()
image = image.resize((384, 384))
image = image.crop((0, 0, 384, 384))
image = transform(image) # .cuda()
image = image.unsqueeze(0)
coarse, fine = net(image)
# save_prediction(coarse, "coarse_map", index)
mask = return_prediction(fine)
return mask
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。