Create your Gitee Account
Explore and code with more than 12 million developers,Free private repositories !:)
Sign up
文件
This repository doesn't specify license. Please pay attention to the specific project description and its upstream code dependency when using it.
Clone or Download
dqn.py 2.40 KB
Copy Edit Raw Blame History
张利峰 authored 2022-03-22 11:26 . drl
import gym
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
class Net(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.linear1 = nn.Linear(input_size, hidden_size)
self.linear2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = F.relu(self.linear1(x))
x = self.linear2(x)
return x
class Agent(object):
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
self.eval_net = Net(self.state_space_dim, 256, self.action_space_dim)
self.target_net = Net(self.state_space_dim, 256, self.action_space_dim)
self.optimizer = optim.Adam(self.eval_net.parameters(), lr=self.lr)
self.buffer = []
self.steps = 0
self.learn_steps = 0
def act(self, s0):
self.steps += 1
epsi = self.epsi_low + (self.epsi_high - self.epsi_low) * (math.exp(-1.0 * self.steps / self.decay))
if random.random() < epsi:
a0 = random.randrange(self.action_space_dim)
else:
s0 = torch.unsqueeze(torch.FloatTensor(s0), 0)
a0 = torch.argmax(self.eval_net(s0)).item()
print(a0)
return a0
def put(self, *transition):
if len(self.buffer) == self.capacity:
self.buffer.pop(0)
self.buffer.append(transition)
def learn(self):
if self.learn_steps % self.delay == 0:
self.target_net.load_state_dict(self.eval_net.state_dict())
self.learn_steps += 1
if (len(self.buffer)) < self.batch_size:
return
samples = random.sample(self.buffer, self.batch_size)
s0, a0, r1, s1 = zip(*samples)
s0 = torch.tensor(np.array(s0), dtype=torch.float)
a0 = torch.tensor(np.array(a0), dtype=torch.long).view(self.batch_size, -1)
r1 = torch.tensor(np.array(r1), dtype=torch.float).view(self.batch_size, -1)
s1 = torch.tensor(np.array(s1), dtype=torch.float)
y_true = r1 + self.gamma * torch.max(self.target_net(s1).detach(), dim=1)[0].view(self.batch_size, -1)
y_pred = self.eval_net(s0).gather(1, a0)
loss_fn = nn.MSELoss()
loss = loss_fn(y_pred, y_true)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化