加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
Batch.py 2.12 KB
一键复制 编辑 原始数据 按行查看 历史
fengquanli 提交于 2021-01-06 05:08 . original
import torch
from torchtext import data
import numpy as np
from torch.autograd import Variable
def nopeak_mask(size, device):
np_mask = np.triu(np.ones((1, size, size)),
k=1).astype('uint8')
variable = Variable
np_mask = variable(torch.from_numpy(np_mask) == 0)
np_mask = np_mask.cuda(device)
return np_mask
def create_masks(src, trg, device):
src_mask = (src != -1).unsqueeze(-2)
if trg is not None:
trg_mask = (trg != -1).unsqueeze(-2)
trg_mask.cuda(device)
size = trg.size(1) # get seq_len for matrix
np_mask = nopeak_mask(size, device)
trg_mask = trg_mask & np_mask
else:
trg_mask = None
return src_mask, trg_mask
# patch on Torchtext's batching process that makes it more efficient
# from http://nlp.seas.harvard.edu/2018/04/03/attention.html#position-wise-feed-forward-networks
class MyIterator(data.Iterator):
def create_batches(self):
if self.train:
def pool(d, random_shuffler):
for p in data.batch(d, self.batch_size * 100):
p_batch = data.batch(
sorted(p, key=self.sort_key),
self.batch_size, self.batch_size_fn)
for b in random_shuffler(list(p_batch)):
yield b
self.batches = pool(self.data(), self.random_shuffler)
else:
self.batches = []
for b in data.batch(self.data(), self.batch_size,
self.batch_size_fn):
self.batches.append(sorted(b, key=self.sort_key))
global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar):
"Keep augmenting batch and calculate total number of tokens + padding."
global max_src_in_batch, max_tgt_in_batch
if count == 1:
max_src_in_batch = 0
max_tgt_in_batch = 0
max_src_in_batch = max(max_src_in_batch, len(new.src))
max_tgt_in_batch = max(max_tgt_in_batch, len(new.trg) + 2)
src_elements = count * max_src_in_batch
tgt_elements = count * max_tgt_in_batch
return max(src_elements, tgt_elements)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化