加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
learn_ppo.py 1.31 KB
一键复制 编辑 原始数据 按行查看 历史
邹吉华 提交于 2023-04-04 17:02 . 1.2.4
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import PPO
from training_env import TrainingEnv
from save_model import SaveModelCallback
import torch as th
TB_LOG_PATH = "../tb_log"
MODEL_PATH = "./model/ppo"
LEARN_TIMES = 1000000
TRAINING_BEGIN_TIME = ["2022-08-14","2022-08-15"
,"2022-08-18","2022-08-19","2022-08-20","2022-08-21"]
# The algorithms require a vectorized environment to run
env = Monitor(TrainingEnv(TRAINING_BEGIN_TIME), MODEL_PATH)
#model_path = "ppo_4_stock"
model_path = None
policy = dict(
activation_fn=th.nn.ReLU,
net_arch=[1024,512,256,128,dict(vf=[1024], pi=[1024])]
#net_arch=[128, 128]
#features_extractor_class=CustomCombinedExtractor
#features_extractor_kwargs=dict(features_dim=128),
)
if model_path is None :
model = PPO('MlpPolicy', env,learning_rate=0.0001, policy_kwargs=policy, verbose=1,tensorboard_log=TB_LOG_PATH)
else :
model = PPO.load(model_path,env)
#model = PPO.load("./ppo_1_stock",env)
#model = PPO('MultiInputPolicy', env, policy_kwargs=policy, verbose=1,batch_size=20480,tensorboard_log=TB_LOG_PATH)
callback = SaveModelCallback(check_freq=1024, log_dir=MODEL_PATH, verbose=1)
model.learn(total_timesteps=LEARN_TIMES,callback=callback)
model.save("ppo_4_stock")
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化