加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
main.py 2.27 KB
一键复制 编辑 原始数据 按行查看 历史
bode135 提交于 2023-11-08 17:57 . 1
from tools import *
from bdtime import tt
from bdtime import show_json, show_ls
from pprint import pprint
debug = 1
# if debug: print([tokenizer.decode(i).strip("P").strip("E").strip("S") for i in tokenizer.get_batch_data(prefix=False)[1]][:10])
# if debug: print([tokenizer.decode(i) for i in tokenizer.get_batch_data(prefix=False)[1]][:10])
# label, input_ids, attention_mask = tokenizer.get_batch_data(prefix=False)
from tools import get_data
get_data.get_date_time_data(100)
# get_data.tokenize()
test_date_cn, test_date_en, lines = get_data.get_date_time_data(200)
batch_size, num_steps = 16, 13
train_iter, src_vocab, tgt_vocab = get_data.load_data_nmt(batch_size=batch_size, num_steps=num_steps, total_data=1000)
model_cls = ModelCLS()
output_path = os.path.join(output_dir, 'cls.model')
def main():
if debug: show_ls([tokenizer.decode(i) for i in tokenizer.get_batch_data(prefix=True)[1]][:5])
optimizer = torch.optim.AdamW(params=model_cls.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(500):
# break
# for batch in train_iter:
# X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
# print('[X, X_valid_len, Y, Y_valid_len]:', [xx.shape for xx in [X, X_valid_len, Y, Y_valid_len]])
# break
label, input_ids, attention_mask = tokenizer.get_batch_data(prefix=False)
label = torch.LongTensor(label).to(device)
input_ids = torch.LongTensor(input_ids).to(device)
attention_mask = torch.LongTensor(attention_mask).to(device)
# label.shape, input_ids.shape, attention_mask.shape
logits = model_cls(input_ids=input_ids, attention_mask=attention_mask)
loss = criterion(logits, label)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if epoch % 50 == 0:
logits = logits.argmax(1)
acc = (logits == label).sum().item() / len(label)
print(f'========== epoch: {epoch}, acc: {round(acc, 3)} ====== now:', tt.now())
for i in range(2):
print('--- tests:', tokenizer.decode(input_ids[i].tolist()), '--- type:', logits[i].item())
model_cls.to('cpu')
torch.save(model_cls, output_path)
if __name__ == '__main__':
main()
# pass
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化