代码拉取完成,页面将自动刷新
同步操作将从 luoyongcoder/unet_42 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
from model.unet_model import UNet
from utils.dataset import ISBI_Loader
from torch import optim
import torch.nn as nn
import torch
from tqdm import tqdm
def train_net(net, device, data_path, epochs=40, batch_size=1, lr=0.00001):
# 加载训练集
isbi_dataset = ISBI_Loader(data_path)
per_epoch_num = len(isbi_dataset) / batch_size
train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
batch_size=batch_size,
shuffle=True)
# 定义RMSprop算法
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
# 定义Loss算法
criterion = nn.BCEWithLogitsLoss()
# best_loss统计,初始化为正无穷
best_loss = float('inf')
# 训练epochs次
with tqdm(total=epochs*per_epoch_num) as pbar:
for epoch in range(epochs):
# 训练模式
net.train()
# 按照batch_size开始训练
for image, label in train_loader:
optimizer.zero_grad()
# 将数据拷贝到device中
image = image.to(device=device, dtype=torch.float32)
label = label.to(device=device, dtype=torch.float32)
# 使用网络参数,输出预测结果
pred = net(image)
# 计算loss
loss = criterion(pred, label)
# print('{}/{}:Loss/train'.format(epoch + 1, epochs), loss.item())
# 保存loss值最小的网络参数
if loss < best_loss:
best_loss = loss
torch.save(net.state_dict(), 'best_model.pth')
# 更新参数
loss.backward()
optimizer.step()
pbar.update(1)
if __name__ == "__main__":
# 选择设备,有cuda用cuda,没有就用cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载网络,图片单通道1,分类为1。
net = UNet(n_channels=1, n_classes=1) # todo edit input_channels n_classes
# 将网络拷贝到deivce中
net.to(device=device)
# 指定训练集地址,开始训练
data_path = "C:/Users/chenmingsong/Desktop/unetnnn/skin" # todo 修改为你本地的数据集位置
print("进度条出现卡着不动不是程序问题,是他正在计算,请耐心等待")
train_net(net, device, data_path, epochs=40, batch_size=1)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。