代码拉取完成,页面将自动刷新
同步操作将从 zhaozhengyun/BPNN_forecast 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
# -*- coding: utf-8 -*-
# 导入必要的库
import pickle
from config.dataset import dataset
from config.logger import *
from evaluate import evaluate
def train():
logger = logger1()
start_time = time.time()
# 加载数据集,模型,及模型保存路径
train_data, test_data, bp1, save_model_path,y_test, df2 = dataset(logger)
# 参数保存到log日志
batch_size = len(train_data)
logger.info(f'batch_size:, {batch_size}')
logger.info(f'epochs:, {args.epochs}')
logger.info(f'学习率:, {args.lr}')
# 神经网络训练
logger.info("==============开始训练================")
j = args.save_item
item_epoch = int(args.epochs / j)
for i in range(1, item_epoch + 1):
bp1.MSGD(train_data, j, batch_size, args.lr, args.min_loss, logger, i)
evaluate(i * j, bp1, logger, test_data, y_test, df2)
pickle.dump(bp1, open("./results/save_model/bp" + str(i * j) + ".pkl", 'wb'))
bp1 = bp1
# 保存模型
pickle.dump(bp1, open(save_model_path, 'wb'))
logger.info("==============训练结束================")
logger.info(f'保存模型的路径:, {save_model_path}')
end_time = time.time()
logger.info(f'Total Time Cost: {(end_time - start_time) / 60} mins!')
if __name__ == '__main__':
# 参数设置在config/parser文件中
train()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。