加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
learn_rppo.py 1.71 KB
一键复制 编辑 原始数据 按行查看 历史
邹吉华 提交于 2023-04-04 17:20 . 1.3.1
from sb3_contrib import RecurrentPPO
from save_model import SaveModelCallback
from stable_baselines3.common.evaluation import evaluate_policy
import torch as th
from define import *
from helper import *
LEARN_TIMES = 10000000
def get_params():
vf_arch = [32, 32, 64, 128, 32, 128, 32, 128, 256]
pi_arch = [64, 1024, 128, 512, 64]
net_arch=[32, 256, 128, 32, 1024, 768, 32, 768, 32,dict(vf=vf_arch, pi=pi_arch)]
return {
'n_steps': 3072, 'gamma': 0.9014026506133596, 'learning_rate': 6.168219628138978e-05, 'clip_range': 0.14508773369622327, 'gae_lambda': 0.8364513164493685,
'policy_kwargs':dict(
activation_fn=th.nn.CELU,
net_arch=net_arch
)
}
if __name__ == '__main__':
# Create the vectorized environment
#env = DummyVecEnv([lambda: Monitor(TrainingEnv(TRAINING_BEGIN_TIME), MODEL_PATH)])
train_env = get_subproc_env(10)
# Stable Baselines provides you with make_vec_env() helper
# which does exactly the previous steps for you.
# You can choose between `DummyVecEnv` (usually faster) and `SubprocVecEnv`
# env = make_vec_env(env_id, n_envs=num_cpu, seed=0, vec_env_cls=SubprocVecEnv)
model_params = get_params()
model = RecurrentPPO('MlpLstmPolicy', train_env,verbose=1,tensorboard_log=TB_LOG_PATH,**model_params)
#model = RecurrentPPO.load("./model/rppo/rppo_finish_2",env=train_env)
model.learn(total_timesteps=LEARN_TIMES,callback=SaveModelCallback(check_freq=128, log_dir=MODEL_PATH))
evaluate_env = get_subproc_env(10,True)
mean_reward, std_reward = evaluate_policy(model, evaluate_env,3,deterministic=False)
print(f"{mean_reward} {std_reward}")
model.save(MODEL_PATH+"/rppo_finish")
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化