加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train_batch.py 835 Bytes
一键复制 编辑 原始数据 按行查看 历史
zhijiezhong 提交于 2021-10-06 14:26 . first commit
import torch
import torch.nn as nn
def train_batch(crnn, train_iter, optimizer, criterion, device, converter):
acc_num = 0
images, labels = train_iter.next()
y = labels
images = images.to(device)
batch_size = images.size(0)
text, length = converter.encode(labels)
preds = crnn(images)
preds_size = torch.IntTensor([preds.size(0)] * batch_size)
loss = criterion(preds, text, preds_size, length)
y_hat = nn.functional.softmax(preds, 2).argmax(2).view(preds.size(0), -1)
y_hat = torch.transpose(y_hat, 1, 0)
y_hat = [converter.decode(i, torch.IntTensor([y_hat.size(1)])) for i in y_hat]
for txt, target in zip(y, y_hat):
if txt == target:
acc_num += 1
crnn.zero_grad()
loss.backward()
optimizer.step()
return loss.item(), batch_size, acc_num
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化