代码拉取完成,页面将自动刷新
同步操作将从 Xingyan Chen/ChatGLM-Tuning 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
from transformers import TrainingArguments
from transformers import Trainer, HfArgumentParser
from modeling_chatglm import ChatGLMForConditionalGeneration
import torch
import torch.nn as nn
from peft import get_peft_model, LoraConfig, TaskType
from dataclasses import dataclass, field
import datasets
import os
@dataclass
class FinetuneArguments:
dataset_path: str = field(default="data/alpaca")
model_path: str = field(default="output")
lora_rank: int = field(default=8)
class CastOutputToFloat(nn.Sequential):
def forward(self, x): return super().forward(x).to(torch.float32)
class ModifiedTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
return model(
input_ids=inputs["input_ids"],
attention_mask=torch.ones_like(inputs["input_ids"]).bool(),
labels=inputs["input_ids"],
).loss
def data_collator(features: list) -> dict:
return {
"input_ids": torch.stack([
torch.LongTensor(f["input_ids"])
for f in features
])
}
def save_tunable_parameters(model, path):
saved_params = {
k: v.to("cpu")
for k, v in model.named_parameters()
if v.requires_grad
}
torch.save(saved_params, path)
def main():
finetune_args, training_args = HfArgumentParser(
(FinetuneArguments, TrainingArguments)).parse_args_into_dataclasses()
# init model
model = ChatGLMForConditionalGeneration.from_pretrained(
"THUDM/chatglm-6b", load_in_8bit=True, trust_remote_code=True, device_map='auto')
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
model.is_parallelizable = True
model.model_parallel = True
model.lm_head = CastOutputToFloat(model.lm_head)
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
# setup peft
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=finetune_args.lora_rank,
lora_alpha=32,
lora_dropout=0.1,
)
model = get_peft_model(model, peft_config)
# load dataset
dataset = datasets.load_from_disk(finetune_args.dataset_path)
# start train
trainer = ModifiedTrainer(
model=model,
train_dataset=dataset,
args=training_args,
data_collator=data_collator,
)
trainer.train()
# save model
save_tunable_parameters(model, os.path.join(training_args.output_dir, "chatglm-lora.pt"))
if __name__ == "__main__":
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。