加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 2.43 KB
一键复制 编辑 原始数据 按行查看 历史
朱金阳 提交于 2022-10-29 14:23 . Add inner_encoder type.
def train(config):
# %%
if seed:=config["train"]["seed"]:
import torch
torch.manual_seed(seed)
# %%
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained(config["model"]["pretrained"])
# %%
from dataset.wrappers import CMeIEData
from dataset.dataset import TPLinkerDataset
train_data = CMeIEData(config["train"]["dataset"]["datapath"])
val_data = CMeIEData(config["train"]["dataset"]["datapath"], "validation")
train_dataset = TPLinkerDataset(train_data, tokenizer)
val_dataset = TPLinkerDataset(val_data, tokenizer)
if config["train"]["weighted"]:
from util.statistics import TPLinkerDatasetBalancer
balancer = TPLinkerDatasetBalancer(val_dataset)
label_weights = balancer.get_weights4scaling()
else:
label_weights = None
# %%
from model.tplinker import ClnTPLinkerBert
tplinker = ClnTPLinkerBert(config["model"]["pretrained"], len(train_data.id2relation),
add_distance_embedding=config["model"]["add_distance"],
inner_encoder=config["model"]["inner_encoder"])
# %%
from torch.utils.data import DataLoader
from trainer.optim_schedule import TPLinkerOptimScheduler
from trainer.trainer import TPlinkerTrainer
train_dataloader = DataLoader(train_dataset, int(config["train"]["batch_size"]), shuffle=True)
val_dataloader = DataLoader(val_dataset, int(config["train"]["batch_size"]))
scheduler = TPLinkerOptimScheduler(tplinker, config["train"]["lr"]["dynamic"],
float(config["train"]["lr"]["bert_lr"]), float(config["train"]["lr"]["tplinker_lr"]))
trainer = TPlinkerTrainer(tplinker, scheduler, train_dataloader, label_weights, val_dataloader)
# %%
from util.saver import BestCheckpointSaver
saver = BestCheckpointSaver()
for e in range(int(config["train"]["epoch"])):
train_res = trainer.train()
val_res = trainer.validate()
val_f1 = val_res["validation_epoch_f1"]
saver.update(trainer, val_f1)
if __name__ == "__main__":
# %%
import argparse
praser = argparse.ArgumentParser()
praser.add_argument("-c", "--config", type=str)
args = praser.parse_args()
import yaml
with open(args.config, "r", encoding="utf-8") as fp:
config = yaml.load(fp, yaml.FullLoader)
train(config)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化