加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train.py 1.33 KB
一键复制 编辑 原始数据 按行查看 历史
zhaozhengyun 提交于 2022-12-14 10:37 . Initial commit
# -*- coding: utf-8 -*-
# 导入必要的库
import pickle
from config.dataset import dataset
from config.logger import *
from evaluate import evaluate
def train():
logger = logger1()
start_time = time.time()
# 加载数据集,模型,及模型保存路径
train_data, test_data, bp1, save_model_path,y_test, df2 = dataset(logger)
# 参数保存到log日志
batch_size = len(train_data)
logger.info(f'batch_size:, {batch_size}')
logger.info(f'epochs:, {args.epochs}')
logger.info(f'学习率:, {args.lr}')
# 神经网络训练
logger.info("==============开始训练================")
j = args.save_item
item_epoch = int(args.epochs / j)
for i in range(1, item_epoch + 1):
bp1.MSGD(train_data, j, batch_size, args.lr, args.min_loss, logger, i)
evaluate(i * j, bp1, logger, test_data, y_test, df2)
pickle.dump(bp1, open("./results/save_model/bp" + str(i * j) + ".pkl", 'wb'))
bp1 = bp1
# 保存模型
pickle.dump(bp1, open(save_model_path, 'wb'))
logger.info("==============训练结束================")
logger.info(f'保存模型的路径:, {save_model_path}')
end_time = time.time()
logger.info(f'Total Time Cost: {(end_time - start_time) / 60} mins!')
if __name__ == '__main__':
# 参数设置在config/parser文件中
train()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化