加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 2.44 KB
一键复制 编辑 原始数据 按行查看 历史
Chenming 提交于 2022-03-04 23:11 . v2
from model.unet_model import UNet
from utils.dataset import ISBI_Loader
from torch import optim
import torch.nn as nn
import torch
from tqdm import tqdm
def train_net(net, device, data_path, epochs=40, batch_size=1, lr=0.00001):
# 加载训练集
isbi_dataset = ISBI_Loader(data_path)
per_epoch_num = len(isbi_dataset) / batch_size
train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
batch_size=batch_size,
shuffle=True)
# 定义RMSprop算法
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
# 定义Loss算法
criterion = nn.BCEWithLogitsLoss()
# best_loss统计,初始化为正无穷
best_loss = float('inf')
# 训练epochs次
with tqdm(total=epochs*per_epoch_num) as pbar:
for epoch in range(epochs):
# 训练模式
net.train()
# 按照batch_size开始训练
for image, label in train_loader:
optimizer.zero_grad()
# 将数据拷贝到device中
image = image.to(device=device, dtype=torch.float32)
label = label.to(device=device, dtype=torch.float32)
# 使用网络参数,输出预测结果
pred = net(image)
# 计算loss
loss = criterion(pred, label)
# print('{}/{}:Loss/train'.format(epoch + 1, epochs), loss.item())
# 保存loss值最小的网络参数
if loss < best_loss:
best_loss = loss
torch.save(net.state_dict(), 'best_model.pth')
# 更新参数
loss.backward()
optimizer.step()
pbar.update(1)
if __name__ == "__main__":
# 选择设备,有cuda用cuda,没有就用cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载网络,图片单通道1,分类为1。
net = UNet(n_channels=1, n_classes=1) # todo edit input_channels n_classes
# 将网络拷贝到deivce中
net.to(device=device)
# 指定训练集地址,开始训练
data_path = "C:/Users/chenmingsong/Desktop/unetnnn/skin" # todo 修改为你本地的数据集位置
print("进度条出现卡着不动不是程序问题,是他正在计算,请耐心等待")
train_net(net, device, data_path, epochs=40, batch_size=1)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化