代码拉取完成,页面将自动刷新
from tools import *
from bdtime import tt
from bdtime import show_json, show_ls
from pprint import pprint
debug = 1
# if debug: print([tokenizer.decode(i).strip("P").strip("E").strip("S") for i in tokenizer.get_batch_data(prefix=False)[1]][:10])
# if debug: print([tokenizer.decode(i) for i in tokenizer.get_batch_data(prefix=False)[1]][:10])
# label, input_ids, attention_mask = tokenizer.get_batch_data(prefix=False)
from tools import get_data
get_data.get_date_time_data(100)
# get_data.tokenize()
test_date_cn, test_date_en, lines = get_data.get_date_time_data(200)
batch_size, num_steps = 16, 13
train_iter, src_vocab, tgt_vocab = get_data.load_data_nmt(batch_size=batch_size, num_steps=num_steps, total_data=1000)
model_cls = ModelCLS()
output_path = os.path.join(output_dir, 'cls.model')
def main():
if debug: show_ls([tokenizer.decode(i) for i in tokenizer.get_batch_data(prefix=True)[1]][:5])
optimizer = torch.optim.AdamW(params=model_cls.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(500):
# break
# for batch in train_iter:
# X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
# print('[X, X_valid_len, Y, Y_valid_len]:', [xx.shape for xx in [X, X_valid_len, Y, Y_valid_len]])
# break
label, input_ids, attention_mask = tokenizer.get_batch_data(prefix=False)
label = torch.LongTensor(label).to(device)
input_ids = torch.LongTensor(input_ids).to(device)
attention_mask = torch.LongTensor(attention_mask).to(device)
# label.shape, input_ids.shape, attention_mask.shape
logits = model_cls(input_ids=input_ids, attention_mask=attention_mask)
loss = criterion(logits, label)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if epoch % 50 == 0:
logits = logits.argmax(1)
acc = (logits == label).sum().item() / len(label)
print(f'========== epoch: {epoch}, acc: {round(acc, 3)} ====== now:', tt.now())
for i in range(2):
print('--- tests:', tokenizer.decode(input_ids[i].tolist()), '--- type:', logits[i].item())
model_cls.to('cpu')
torch.save(model_cls, output_path)
if __name__ == '__main__':
main()
# pass
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。