加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train_SAC.py 10.40 KB
一键复制 编辑 原始数据 按行查看 历史
十指紧扣 提交于 2023-09-23 21:58 . init
import os
import time
import dmc2gym
import torch
import numpy as np
import utils
from logger import Logger
from replay_buffer import ReplayBuffer
import algorithms
from arguments import parse_args
import panda_gym
import gymnasium
import shutil
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
torch.backends.cudnn.benchmark = True
def make_env(cfg, pixels=False, train=False):
# per dreamer: https://github.com/danijar/dreamer/blob/02f0210f5991c7710826ca7881f19c64a012290c/wrappers.py#L26
# env = gymnasium.make("PandaPickAndPlaceDense-v3")
env = gymnasium.make("PandaPushDense-v3",)
if not train:
# block_uid = env.unwrapped.sim._bodies_idx['object']
# env.unwrapped.sim.physics_client.changeDynamics(bodyUniqueId=block_uid, linkIndex=-1, mass=2.0)
env.unwrapped.sim.set_lateral_friction(env.unwrapped.robot.body_name, env.unwrapped.robot.fingers_indices[0], lateral_friction=0.5)
env.unwrapped.sim.set_lateral_friction(env.unwrapped.robot.body_name, env.unwrapped.robot.fingers_indices[1], lateral_friction=0.5)
env.unwrapped.sim.set_spinning_friction(env.unwrapped.robot.body_name, env.unwrapped.robot.fingers_indices[0], spinning_friction=0.001)
env.unwrapped.sim.set_spinning_friction(env.unwrapped.robot.body_name, env.unwrapped.robot.fingers_indices[1], spinning_friction=0.001)
env = utils.FrameStack(env, k=cfg.frame_stack)
# env.seed(cfg.seed)
assert env.action_space.low.min() >= -1
assert env.action_space.high.max() <= 1
return env
class Workspace(object):
def __init__(self, cfg):
self.work_dir = os.path.join(os.getcwd(), cfg.log_dir, cfg.exp_name, cfg.algorithm, str(cfg.seed))
if os.path.exists(self.work_dir):
shutil.rmtree(self.work_dir)
os.makedirs(self.work_dir)
print(f'workspace: {self.work_dir}')
self.save_dir = os.path.join(self.work_dir, "trained_models")
utils.write_info(cfg, os.path.join(self.work_dir, 'config.log'))
self.cfg = cfg
self.logger = Logger(self.work_dir,
save_tb=False,
log_frequency=cfg.log_freq,
action_repeat=cfg.action_repeat,
agent=cfg.algorithm)
utils.set_seed_everywhere(cfg.seed)
self.device = torch.device(cfg.device)
# self.train_env = make_env(cfg)
# self.env = self.train_env
# # set up correlated environments
# if self.cfg.correlated_with_colour:
# original_rgb = np.copy(self.env.physics.model.mat_rgba)[:, :3]
# self.colourA = np.copy(original_rgb)
# self.colourB = np.copy(original_rgb)
# self.colourA[1, :] = [0., 0., 1.0]
# self.colourB[1, :] = [0., 1.0, 0.]
# self.probabilities = [[self.cfg.correlation_probability, 1-self.cfg.correlation_probability], [1-self.cfg.correlation_probability, self.cfg.correlation_probability]]
# self.test_probabilities = [[self.cfg.test_correlation_probability, 1-self.cfg.test_correlation_probability], [1-self.cfg.test_correlation_probability, self.cfg.test_correlation_probability]]
# xml_pathA = os.path.join("world_models", f"{cfg.domain_name}_A.xml")
# envA = make_env(cfg)
# envA.physics.reload_from_xml_path(xml_pathA)
# xml_pathB = os.path.join("world_models", f"{cfg.domain_name}_B.xml")
# envB = make_env(cfg)
# envB.physics.reload_from_xml_path(xml_pathB)
# self.envs = [envA, envB]
# object = np.random.choice([0, 1])
# self.current_colour = eval(np.random.choice(["self.colourA", "self.colourB"], p=self.probabilities[object]))
# self.env = self.envs[object]
# self.env.reset(colour=self.current_colour)
# else:
# self.envs = False
# self.env.reset()
# xml_pathA = os.path.join("world_models", f"{cfg.domain_name}_{cfg.description}A.xml")
self.envA = make_env(cfg,train=True)
# self.envA.physics.reload_from_xml_path(xml_pathA)
# xml_pathB = os.path.join("world_models", f"{cfg.domain_name}_{cfg.description}B.xml")
self.envB = make_env(cfg,train=False)
# self.envB.physics.reload_from_xml_path(xml_pathB)
self.env = self.envA # default(training) env is envA
action_range = [
float(self.env.action_space.low.min()),
float(self.env.action_space.high.max())
]
observation_shape = (self.env.observation_space["observation"].shape[0] * self.cfg.frame_stack + \
self.env.observation_space["achieved_goal"].shape[0] * self.cfg.frame_stack + \
self.env.observation_space["desired_goal"].shape[0] * self.cfg.frame_stack,)
self.agent = algorithms.make_agent(observation_shape, self.env.action_space.shape, action_range, cfg)
self.replay_buffer = ReplayBuffer(observation_shape,
self.env.action_space.shape,
self.cfg.replay_buffer_capacity,
self.cfg.image_pad, self.device,
True if self.cfg.algorithm=='svea_cmid' else False)
self.step = 0
def evaluate(self):
eval_env = self.env
# eval_envs = self.envs
if self.cfg.correlated_with_colour:
eval_probs = self.probabilities
average_episode_reward = 0
# self.video_recorder.init(enabled=True)
for episode in range(self.cfg.num_eval_episodes):
# if eval_envs:
# object = np.random.choice(range(len(eval_envs)))
# eval_env = eval_envs[object]
# if self.cfg.correlated_with_colour:
# self.current_colour = eval(np.random.choice(["self.colourA", "self.colourB"], p=eval_probs[object]))
# obs = eval_env.reset(colour=self.current_colour)
# else:
obs = eval_env.reset()
done = False
episode_reward = 0
episode_step = 0
while episode_step < 100:
with utils.eval_mode(self.agent):
action = self.agent.act(obs, sample=False)
obs, reward, done, info = eval_env.step(action)
# self.video_recorder.record(eval_env)
episode_reward += reward
episode_step += 1
average_episode_reward += episode_reward
# self.video_recorder.save(f'{self.step}.mp4')
average_episode_reward /= self.cfg.num_eval_episodes
self.logger.log('eval/episode_reward', average_episode_reward, self.step)
self.logger.dump(self.step)
def run(self, initial_trans=True):
episode, episode_reward, episode_step, done = 0, 0, 1, True
total_num_steps = self.cfg.num_train_steps + self.cfg.num_test_steps
start_time = time.time()
while self.step <= (total_num_steps + 1):
if done:
# print("************", episode, episode_reward, episode_step, done)
if self.step > 0:
self.logger.log('train/duration', time.time() - start_time, self.step)
start_time = time.time()
self.logger.dump(self.step, save=(self.step > self.cfg.num_seed_steps))
# evaluate agent periodically
if self.step % self.cfg.eval_freq == 0:
self.logger.log('eval/episode', episode, self.step)
self.evaluate()
self.logger.log('train/episode_reward', episode_reward, self.step)
obs = self.env.reset()
prev_obs = obs.copy()
done = False
episode_reward = 0
episode_step = 0
episode += 1
self.logger.log('train/episode', episode, self.step)
# sample action for data collection
if self.step < self.cfg.num_seed_steps and initial_trans:
action = self.env.action_space.sample()
else:
with utils.eval_mode(self.agent):
action = self.agent.act(obs, sample=True)
# run training update
if self.step >= self.cfg.num_seed_steps:
for _ in range(self.cfg.num_train_iters):
self.agent.update(self.replay_buffer, self.logger, self.step)
elif not initial_trans:
for _ in range(self.cfg.num_train_iters):
self.agent.update(self.replay_buffer, self.logger, self.step)
if self.step > 0 and self.step % self.cfg.save_freq == 0:
saveables = {
"actor": self.agent.actor.state_dict(),
"critic": self.agent.critic.state_dict(),
"critic_target": self.agent.critic_target.state_dict()
}
save_at = os.path.join(self.save_dir, f"env_step{self.step * self.cfg.action_repeat}")
os.makedirs(save_at, exist_ok=True)
torch.save(saveables, os.path.join(save_at, "models.pt"))
next_obs, reward, done, info = self.env.step(action)
# allow infinite bootstrap
done = float(done)
done_no_max = 0 if episode_step + 1 == self.env._max_episode_steps else done
episode_reward += reward
self.replay_buffer.add(obs, action, reward, next_obs, done, done_no_max, episode, prev_obs)
prev_obs = obs
obs = next_obs
episode_step += 1
self.step += 1
if self.step == self.cfg.num_train_steps:
print("Switching to test env")
if self.cfg.correlated_with_colour:
self.probabilities = self.test_probabilities
done = True
def main(cfg):
from train_SAC import Workspace as W
global workspace
workspace = W(cfg)
# training period => envA
workspace.run()
print("Change envA => envB")
# switch to envB
workspace.env = workspace.envB
workspace.step = 0
workspace.run(initial_trans = False)
if __name__ == '__main__':
args = parse_args()
print(args.exp_name)
args.exp_name = args.domain_name + "_damping_slide_envA_10_envB_5e-10_" + str(args.num_train_steps) + "_seed" + str(args.seed)
main(args)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化