代码拉取完成,页面将自动刷新
from sb3_contrib import RecurrentPPO
import torch as th
import optuna
from stable_baselines3.common.evaluation import evaluate_policy
from define import *
from helper import *
TB_LOG_PATH = "../tb_log"
LEARN_TIMES = 100000
def optimize_ppo(trial):
na_num = trial.suggest_int('na_num', 4, 16)
net_arch = []
for i in range(na_num):
net_arch.append(trial.suggest_categorical("na"+str(i), [32,64,128,256,512,768,1024]))
vf_num = trial.suggest_int('vf_num', 4, 16)
vf_arch = []
for i in range(vf_num):
vf_arch.append(trial.suggest_categorical("vf"+str(i), [32,64,128,256,512,768,1024]))
pi_num = trial.suggest_int('pi_num', 4, 16)
pi_arch = []
for i in range(pi_num):
pi_arch.append(trial.suggest_categorical("pi"+str(i), [32,64,128,256,512,768,1024]))
net_arch.append(dict(vf=vf_arch, pi=pi_arch))
all_fn = [
th.nn.ReLU,
th.nn.RReLU,
th.nn.Hardtanh,
th.nn.ReLU6,
th.nn.Sigmoid,
th.nn.Hardsigmoid,
th.nn.Tanh,
th.nn.SiLU,
th.nn.Mish,
th.nn.Hardswish,
th.nn.ELU,
th.nn.CELU,
th.nn.SELU,
th.nn.GELU,
th.nn.Hardshrink,
th.nn.LeakyReLU,
th.nn.LogSigmoid,
th.nn.Softplus,
th.nn.Softshrink,
th.nn.PReLU,
th.nn.Softsign,
th.nn.Tanhshrink
]
act_fn_i = trial.suggest_int('act_fn_i', 0,len(all_fn))
return {
'n_steps':trial.suggest_categorical("n_steps", [2048, 3072, 4096, 8192]),
'gamma':trial.suggest_float('gamma', 0.8, 0.99),
'learning_rate':trial.suggest_float('learning_rate', 1e-5, 1e-4),
'clip_range':trial.suggest_float('clip_range', 0.1, 0.3),
'gae_lambda':trial.suggest_float('gae_lambda', 0.8, 0.99),
'policy_kwargs':dict(
activation_fn=all_fn[act_fn_i],
net_arch=net_arch
)
}
def optimize_agent(trial):
try:
# Create the vectorized environment
train_env = get_subproc_env(24)
# env = VecFrameStack(env,2,channels_order='last')
model_params = optimize_ppo(trial)
model = RecurrentPPO('MlpLstmPolicy', train_env,**model_params)
model.learn(total_timesteps=LEARN_TIMES)
#model.save(MODEL_PATH+'/trial_{}'.format(trial.number))
evaluate_env = get_subproc_env(12,True)
mean_reward, _ = evaluate_policy(model, evaluate_env, 3,deterministic=False)
print("mean_reward",mean_reward)
return mean_reward
except Exception as e:
print(e)
return -100000
if __name__ == '__main__':
study = optuna.create_study(direction='maximize')
study.optimize(optimize_agent, n_trials=100,gc_after_trial=True)
print(study.best_params)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。