代码拉取完成,页面将自动刷新
def train(config):
# %%
if seed:=config["train"]["seed"]:
import torch
torch.manual_seed(seed)
# %%
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained(config["model"]["pretrained"])
# %%
from dataset.wrappers import CMeIEData
from dataset.dataset import TPLinkerDataset
train_data = CMeIEData(config["train"]["dataset"]["datapath"])
val_data = CMeIEData(config["train"]["dataset"]["datapath"], "validation")
train_dataset = TPLinkerDataset(train_data, tokenizer)
val_dataset = TPLinkerDataset(val_data, tokenizer)
if config["train"]["weighted"]:
from util.statistics import TPLinkerDatasetBalancer
balancer = TPLinkerDatasetBalancer(val_dataset)
label_weights = balancer.get_weights4scaling()
else:
label_weights = None
# %%
from model.tplinker import ClnTPLinkerBert
tplinker = ClnTPLinkerBert(config["model"]["pretrained"], len(train_data.id2relation),
add_distance_embedding=config["model"]["add_distance"],
inner_encoder=config["model"]["inner_encoder"])
# %%
from torch.utils.data import DataLoader
from trainer.optim_schedule import TPLinkerOptimScheduler
from trainer.trainer import TPlinkerTrainer
train_dataloader = DataLoader(train_dataset, int(config["train"]["batch_size"]), shuffle=True)
val_dataloader = DataLoader(val_dataset, int(config["train"]["batch_size"]))
scheduler = TPLinkerOptimScheduler(tplinker, config["train"]["lr"]["dynamic"],
float(config["train"]["lr"]["bert_lr"]), float(config["train"]["lr"]["tplinker_lr"]))
trainer = TPlinkerTrainer(tplinker, scheduler, train_dataloader, label_weights, val_dataloader)
# %%
from util.saver import BestCheckpointSaver
saver = BestCheckpointSaver()
for e in range(int(config["train"]["epoch"])):
train_res = trainer.train()
val_res = trainer.validate()
val_f1 = val_res["validation_epoch_f1"]
saver.update(trainer, val_f1)
if __name__ == "__main__":
# %%
import argparse
praser = argparse.ArgumentParser()
praser.add_argument("-c", "--config", type=str)
args = praser.parse_args()
import yaml
with open(args.config, "r", encoding="utf-8") as fp:
config = yaml.load(fp, yaml.FullLoader)
train(config)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。