代码拉取完成,页面将自动刷新
from sb3_contrib import RecurrentPPO
from save_model import SaveModelCallback
from stable_baselines3.common.evaluation import evaluate_policy
import torch as th
from define import *
from helper import *
LEARN_TIMES = 10000000
def get_params():
vf_arch = [32, 32, 64, 128, 32, 128, 32, 128, 256]
pi_arch = [64, 1024, 128, 512, 64]
net_arch=[32, 256, 128, 32, 1024, 768, 32, 768, 32,dict(vf=vf_arch, pi=pi_arch)]
return {
'n_steps': 3072, 'gamma': 0.9014026506133596, 'learning_rate': 6.168219628138978e-05, 'clip_range': 0.14508773369622327, 'gae_lambda': 0.8364513164493685,
'policy_kwargs':dict(
activation_fn=th.nn.CELU,
net_arch=net_arch
)
}
if __name__ == '__main__':
# Create the vectorized environment
#env = DummyVecEnv([lambda: Monitor(TrainingEnv(TRAINING_BEGIN_TIME), MODEL_PATH)])
train_env = get_subproc_env(10)
# Stable Baselines provides you with make_vec_env() helper
# which does exactly the previous steps for you.
# You can choose between `DummyVecEnv` (usually faster) and `SubprocVecEnv`
# env = make_vec_env(env_id, n_envs=num_cpu, seed=0, vec_env_cls=SubprocVecEnv)
model_params = get_params()
model = RecurrentPPO('MlpLstmPolicy', train_env,verbose=1,tensorboard_log=TB_LOG_PATH,**model_params)
#model = RecurrentPPO.load("./model/rppo/rppo_finish_2",env=train_env)
model.learn(total_timesteps=LEARN_TIMES,callback=SaveModelCallback(check_freq=128, log_dir=MODEL_PATH))
evaluate_env = get_subproc_env(10,True)
mean_reward, std_reward = evaluate_policy(model, evaluate_env,3,deterministic=False)
print(f"{mean_reward} {std_reward}")
model.save(MODEL_PATH+"/rppo_finish")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。