加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
test1.py 2.45 KB
一键复制 编辑 原始数据 按行查看 历史
SunLucky 提交于 2023-08-08 16:33 . 初始化
from datetime import datetime
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import TensorDataset, DataLoader
import torch
# 加载预训练的 tokenizer 和模型
tokenizer = BertTokenizer.from_pretrained('model/chinese-macbert-base')
model = BertForSequenceClassification.from_pretrained('model/chinese-macbert-base', num_labels=10)
# 准备训练数据
texts = ['这是一个正样本', '这是一个负样本', '这是一个错误样本']
labels = [0, 1, 2]
# 编码文本数据
encoded_inputs = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
input_ids = encoded_inputs['input_ids']
attention_mask = encoded_inputs['attention_mask']
labels = torch.tensor(labels)
# 创建数据集和数据加载器
dataset = TensorDataset(input_ids, attention_mask, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# 定义优化器和损失函数
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
loss_fn = torch.nn.CrossEntropyLoss()
# 训练模型
model.train()
for epoch in range(5):
total_loss = 0
for step, batch in enumerate(dataloader):
batch_input_ids = batch[0]
batch_attention_mask = batch[1]
batch_labels = batch[2]
optimizer.zero_grad()
outputs = model(batch_input_ids, attention_mask=batch_attention_mask, labels=batch_labels)
loss = outputs.loss
total_loss += loss.item()
loss.backward()
optimizer.step()
avg_loss = total_loss / len(dataloader)
print(f'Epoch {epoch + 1}, average loss: {avg_loss:.4f}')
# 保存模型
# 获取当前时间戳,精确到秒
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
model_path = "train/trained_model" + timestamp + ".pt"
print(model_path);
torch.save(model.state_dict(), model_path)
# 加载训练好的模型
model = BertForSequenceClassification.from_pretrained('model/chinese-macbert-base', num_labels=10)
model.load_state_dict(torch.load(model_path))
# 准备验证数据
text = '这是一个错误样本'
encoded = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=128,
padding=True,
truncation=True,
return_tensors='pt'
)
input_ids = encoded['input_ids']
attention_mask = encoded['attention_mask']
# 模型前向传播
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
# 获取预测标签
predicted_label = torch.argmax(logits, dim=1).item()
print(f"预测标签: {predicted_label}")
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化