Fetch the repository succeeded.
import gym
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
class Net(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.linear1 = nn.Linear(input_size, hidden_size)
self.linear2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = F.relu(self.linear1(x))
x = self.linear2(x)
return x
class Agent(object):
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
self.eval_net = Net(self.state_space_dim, 256, self.action_space_dim)
self.target_net = Net(self.state_space_dim, 256, self.action_space_dim)
self.optimizer = optim.Adam(self.eval_net.parameters(), lr=self.lr)
self.buffer = []
self.steps = 0
self.learn_steps = 0
def act(self, s0):
self.steps += 1
epsi = self.epsi_low + (self.epsi_high - self.epsi_low) * (math.exp(-1.0 * self.steps / self.decay))
if random.random() < epsi:
a0 = random.randrange(self.action_space_dim)
else:
s0 = torch.unsqueeze(torch.FloatTensor(s0), 0)
a0 = torch.argmax(self.eval_net(s0)).item()
print(a0)
return a0
def put(self, *transition):
if len(self.buffer) == self.capacity:
self.buffer.pop(0)
self.buffer.append(transition)
def learn(self):
if self.learn_steps % self.delay == 0:
self.target_net.load_state_dict(self.eval_net.state_dict())
self.learn_steps += 1
if (len(self.buffer)) < self.batch_size:
return
samples = random.sample(self.buffer, self.batch_size)
s0, a0, r1, s1 = zip(*samples)
s0 = torch.tensor(np.array(s0), dtype=torch.float)
a0 = torch.tensor(np.array(a0), dtype=torch.long).view(self.batch_size, -1)
r1 = torch.tensor(np.array(r1), dtype=torch.float).view(self.batch_size, -1)
s1 = torch.tensor(np.array(s1), dtype=torch.float)
y_true = r1 + self.gamma * torch.max(self.target_net(s1).detach(), dim=1)[0].view(self.batch_size, -1)
y_pred = self.eval_net(s0).gather(1, a0)
loss_fn = nn.MSELoss()
loss = loss_fn(y_pred, y_true)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。