加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
utils.py 6.42 KB
一键复制 编辑 原始数据 按行查看 历史
十指紧扣 提交于 2023-09-23 21:58 . init
import math
import os
import random
from collections import deque
import numpy as np
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributions as pyd
from datetime import datetime
import subprocess
import json
import glob
class eval_mode(object):
def __init__(self, *models):
self.models = models
def __enter__(self):
self.prev_states = []
for model in self.models:
self.prev_states.append(model.training)
model.train(False)
def __exit__(self, *args):
for model, state in zip(self.models, self.prev_states):
model.train(state)
return False
def soft_update_params(net, target_net, tau):
for param, target_param in zip(net.parameters(), target_net.parameters()):
target_param.data.copy_(tau * param.data +
(1 - tau) * target_param.data)
def set_seed_everywhere(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def make_dir(*path_parts):
dir_path = os.path.join(*path_parts)
try:
os.mkdir(dir_path)
except OSError:
pass
return dir_path
def tie_weights(src, trg):
assert type(src) == type(trg)
trg.weight = src.weight
trg.bias = src.bias
def weight_init(m):
"""Custom weight init for Conv2D and Linear layers."""
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
# gain = nn.init.calculate_gain('relu')
# nn.init.orthogonal_(m.weight.data, gain)
# if hasattr(m.bias, 'data'):
# m.bias.data.fill_(0.0)
assert m.weight.size(2) == m.weight.size(3)
m.weight.data.fill_(0.0)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
mid = m.weight.size(2) // 2
gain = nn.init.calculate_gain('relu')
nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None):
if hidden_depth == 0:
mods = [nn.Linear(input_dim, output_dim)]
else:
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
for i in range(hidden_depth - 1):
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
mods.append(nn.Linear(hidden_dim, output_dim))
if output_mod is not None:
mods.append(output_mod)
trunk = nn.Sequential(*mods)
return trunk
def to_np(t):
if t is None:
return None
elif t.nelement() == 0:
return np.array([])
else:
return t.cpu().detach().numpy()
class FrameStack(gym.Wrapper):
def __init__(self, env, k):
gym.Wrapper.__init__(self, env)
self._k = k
self._frames = deque([], maxlen=k)
# shp = env.observation_space.shape
# # 适应panda
# if shp is None:
# shp = env.observation_space['observation'].shape
# self.observation_space = gym.spaces.Box(
# low=0,
# high=1,
# shape=((shp[0] * k,) + shp[1:]),
# dtype=env.observation_space.dtype)
self._max_episode_steps = env._max_episode_steps
def reset(self, colour=None):
obs = self.env.reset()
# obs = np.concatenate((obs[0]['observation'],obs[0]['achieved_goal'],obs[0]['desired_goal']),axis=0)
obs = np.concatenate((obs[0]['observation'],obs[0]['achieved_goal'],obs[0]['desired_goal']),axis=0)
for _ in range(self._k):
self._frames.append(obs)
return self._get_obs()
def step(self, action):
obs, reward, done,truncated,info = self.env.step(action)
obs = np.concatenate((obs['observation'],obs['achieved_goal'],obs['desired_goal']),axis=0)
# obs = np.concatenate((obs[0]['observation'],obs[0]['achieved_goal'],obs[0]['desired_goal']),axis=0)
self._frames.append(obs)
return self._get_obs(), reward, done, info
def _get_obs(self):
assert len(self._frames) == self._k
return np.concatenate(list(self._frames), axis=0)
def custom_reset(self, frames):
for _ in range(self._k):
self._frames.append(frames[0])
return self._get_obs()
class TanhTransform(pyd.transforms.Transform):
domain = pyd.constraints.real
codomain = pyd.constraints.interval(-1.0, 1.0)
bijective = True
sign = +1
def __init__(self, cache_size=1):
super().__init__(cache_size=cache_size)
@staticmethod
def atanh(x):
return 0.5 * (x.log1p() - (-x).log1p())
def __eq__(self, other):
return isinstance(other, TanhTransform)
def _call(self, x):
return x.tanh()
def _inverse(self, y):
# We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
# one should use `cache_size=1` instead
return self.atanh(y)
def log_abs_det_jacobian(self, x, y):
# We use a formula that is more numerically stable, see details in the following link
# https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
return 2. * (math.log(2.) - x - F.softplus(-2. * x))
class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
def __init__(self, loc, scale):
self.loc = loc
self.scale = scale
self.base_dist = pyd.Normal(loc, scale)
transforms = [TanhTransform()]
super().__init__(self.base_dist, transforms)
@property
def mean(self):
mu = self.loc
for tr in self.transforms:
mu = tr(mu)
return mu
def write_info(args, fp):
try:
data = {
'timestamp': str(datetime.now()),
'git': subprocess.check_output(["git", "describe", "--always"]).strip().decode(),
'args': vars(args)
}
except:
data = {
'timestamp': str(datetime.now()),
'args': vars(args)
}
with open(fp, 'w') as f:
json.dump(data, f, indent=4, separators=(',', ': '))
def listdir(dir_path, filetype='jpg', sort=True):
fpath = os.path.join(dir_path, f'*.{filetype}')
fpaths = glob.glob(fpath, recursive=True)
if sort:
return sorted(fpaths)
return
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化