代码拉取完成,页面将自动刷新
# from transformers import BertTokenizer
import os.path
from transformers import AutoTokenizer, AutoModelForMaskedLM
model_path = r'C:\Users\admin\.cache\huggingface\hub\models--bert-base-chinese\snapshots\8d2a91f91cc38c96bb8b4556ba70c392f8d5ee55'
tokenizer = AutoTokenizer.from_pretrained(model_path)
fine_tune_model = AutoModelForMaskedLM.from_pretrained(model_path)
from datasets import load_dataset
from datasets import load_from_disk
import os
#分词
def f(data):
return tokenizer(
data['sentence'],
padding='max_length',
truncation=True,
max_length=30,
)
if __name__ == '__main__':
# 从本地磁盘加载数据
data_dist_path = './data/glue_sst2'
preprocess_data_dist_path = os.path.join(data_dist_path, 'preprocess')
train_data_path = os.path.join(preprocess_data_dist_path, 'train')
test_data_path = os.path.join(preprocess_data_dist_path, 'validation')
if not os.path.exists(train_data_path):
# 加载数据集
datasets = load_from_disk(data_dist_path) # datasets = load_dataset(path='glue', name='sst2') # 从网络加载
datasets = datasets.map(f, batched=True, batch_size=1000, num_proc=4)
# 取数据子集,否则数据太多跑不动
dataset_train = datasets['train'].shuffle().select(range(1000))
dataset_test = datasets['validation'].shuffle().select(range(200))
dataset_train.save_to_disk(dataset_path=train_data_path)
dataset_test.save_to_disk(dataset_path=test_data_path)
del datasets
dataset_train = load_from_disk(train_data_path)
dataset_test = load_from_disk(test_data_path)
dataset_train[0].keys()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。