代码拉取完成,页面将自动刷新
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import PPO
from training_env import TrainingEnv
from save_model import SaveModelCallback
import torch as th
TB_LOG_PATH = "../tb_log"
MODEL_PATH = "./model/ppo"
LEARN_TIMES = 1000000
TRAINING_BEGIN_TIME = ["2022-08-14","2022-08-15"
,"2022-08-18","2022-08-19","2022-08-20","2022-08-21"]
# The algorithms require a vectorized environment to run
env = Monitor(TrainingEnv(TRAINING_BEGIN_TIME), MODEL_PATH)
#model_path = "ppo_4_stock"
model_path = None
policy = dict(
activation_fn=th.nn.ReLU,
net_arch=[1024,512,256,128,dict(vf=[1024], pi=[1024])]
#net_arch=[128, 128]
#features_extractor_class=CustomCombinedExtractor
#features_extractor_kwargs=dict(features_dim=128),
)
if model_path is None :
model = PPO('MlpPolicy', env,learning_rate=0.0001, policy_kwargs=policy, verbose=1,tensorboard_log=TB_LOG_PATH)
else :
model = PPO.load(model_path,env)
#model = PPO.load("./ppo_1_stock",env)
#model = PPO('MultiInputPolicy', env, policy_kwargs=policy, verbose=1,batch_size=20480,tensorboard_log=TB_LOG_PATH)
callback = SaveModelCallback(check_freq=1024, log_dir=MODEL_PATH, verbose=1)
model.learn(total_timesteps=LEARN_TIMES,callback=callback)
model.save("ppo_4_stock")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。