代码拉取完成,页面将自动刷新
import argparse
import os
import evaluate
import numpy as np
from datasets import load_dataset
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
EarlyStoppingCallback,
Trainer,
TrainingArguments,
)
parser = argparse.ArgumentParser()
parser.add_argument("dataset_folder", type=str)
parser.add_argument("--save_folder", type=str, default=None)
parser.add_argument(
"--model_name", type=str, default="uer/albert-base-chinese-cluecorpussmall"
)
parser.add_argument("--do-train", action="store_true")
parser.add_argument("--do-eval", action="store_true")
parser.add_argument("--do-interactive", action="store_true")
args = parser.parse_args()
if args.save_folder is None:
args.save_folder = os.path.join(args.dataset_folder, "output")
all_labels = open(os.path.join(args.dataset_folder, "label.txt")).read().splitlines()
label2id = {label: i for i, label in enumerate(all_labels)}
id2label = {i: label for i, label in enumerate(all_labels)}
tokenizer = AutoTokenizer.from_pretrained(args.model_name, model_max_length=512)
accuracy = evaluate.load("./accuracy.py")
def preprocess_function(examples):
model_inputs = tokenizer(
examples["text"], truncation=True, max_length=512, padding=True
)
labels = [label2id[label] for label in examples["label"]]
model_inputs["label"] = labels
return model_inputs
train_dataset = load_dataset(
"csv",
data_files=os.path.join(args.dataset_folder, "train.txt"),
delimiter="\t",
column_names=["text", "label"],
cache_dir=os.path.join(args.dataset_folder, "cache"),
).map(preprocess_function, batched=True)["train"]
dev_dataset = load_dataset(
"csv",
data_files=os.path.join(args.dataset_folder, "dev.txt"),
delimiter="\t",
column_names=["text", "label"],
cache_dir=os.path.join(args.dataset_folder, "cache"),
).map(preprocess_function, batched=True)["train"]
test_dataset = load_dataset(
"csv",
data_files=os.path.join(args.dataset_folder, "test.txt"),
delimiter="\t",
column_names=["text", "label"],
cache_dir=os.path.join(args.dataset_folder, "cache"),
).map(preprocess_function, batched=True)["train"]
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return accuracy.compute(predictions=predictions, references=labels)
def do_train():
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name,
num_labels=len(all_labels),
id2label=id2label,
label2id=label2id,
)
# train with bf16
training_args = TrainingArguments(
output_dir=os.path.join(args.dataset_folder, "output"),
learning_rate=1e-5,
per_device_train_batch_size=48,
per_device_eval_batch_size=48,
num_train_epochs=30,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
warmup_steps=100,
save_total_limit=2,
bf16=True,
report_to="wandb",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=dev_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)
trainer.train()
# trainer.train(
# resume_from_checkpoint=os.path.join(
# args.dataset_folder, "output/checkpoint-7700"
# ),
# )
trainer.save_model(args.save_folder)
def do_eval():
model = AutoModelForSequenceClassification.from_pretrained(args.save_folder)
# predict
training_args = TrainingArguments(
output_dir=os.path.join(args.dataset_folder, "output"),
per_device_eval_batch_size=128,
)
trainer = Trainer(
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
predictions = trainer.predict(dev_dataset)
print("开发集测试结果:")
print(predictions.metrics)
predictions = trainer.predict(test_dataset)
print("测试集测试结果:")
print(predictions.metrics)
# print sample results
for i in range(len(test_dataset)):
if (
id2label[predictions.predictions[i].argmax()]
!= id2label[test_dataset[i]["label"]]
):
print(
id2label[predictions.predictions[i].argmax()],
id2label[test_dataset[i]["label"]],
len(
tokenizer.decode(
test_dataset[i]["input_ids"], skip_special_tokens=True
)
),
)
def do_interactive():
model = AutoModelForSequenceClassification.from_pretrained(args.save_folder)
while True:
text = input("Input text: ")
inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)
print(id2label[outputs.logits[0].argmax(-1).item()])
if __name__ == "__main__":
if args.do_train:
do_train()
if args.do_eval:
do_eval()
if args.do_interactive:
do_interactive()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。