加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test_10.py 1.65 KB
一键复制 编辑 原始数据 按行查看 历史
bode135 提交于 2023-11-01 15:29 . 1
# from transformers import BertTokenizer
import os.path
from transformers import AutoTokenizer, AutoModelForMaskedLM
model_path = r'C:\Users\admin\.cache\huggingface\hub\models--bert-base-chinese\snapshots\8d2a91f91cc38c96bb8b4556ba70c392f8d5ee55'
tokenizer = AutoTokenizer.from_pretrained(model_path)
fine_tune_model = AutoModelForMaskedLM.from_pretrained(model_path)
from datasets import load_dataset
from datasets import load_from_disk
import os
#分词
def f(data):
return tokenizer(
data['sentence'],
padding='max_length',
truncation=True,
max_length=30,
)
if __name__ == '__main__':
# 从本地磁盘加载数据
data_dist_path = './data/glue_sst2'
preprocess_data_dist_path = os.path.join(data_dist_path, 'preprocess')
train_data_path = os.path.join(preprocess_data_dist_path, 'train')
test_data_path = os.path.join(preprocess_data_dist_path, 'validation')
if not os.path.exists(train_data_path):
# 加载数据集
datasets = load_from_disk(data_dist_path) # datasets = load_dataset(path='glue', name='sst2') # 从网络加载
datasets = datasets.map(f, batched=True, batch_size=1000, num_proc=4)
# 取数据子集,否则数据太多跑不动
dataset_train = datasets['train'].shuffle().select(range(1000))
dataset_test = datasets['validation'].shuffle().select(range(200))
dataset_train.save_to_disk(dataset_path=train_data_path)
dataset_test.save_to_disk(dataset_path=test_data_path)
del datasets
dataset_train = load_from_disk(train_data_path)
dataset_test = load_from_disk(test_data_path)
dataset_train[0].keys()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化