代码拉取完成,页面将自动刷新
# -*- coding: utf-8 -*-
from __future__ import division
import numpy as np
import torch
Transition_dtype = np.dtype(
[('timestep', np.int32), ('state', np.uint8, (84, 84)), ('action', np.int32), ('reward', np.float32),
('nonterminal', np.bool_)])
blank_trans = (0, np.zeros((84, 84), dtype=np.uint8), 0, 0.0, False)
# Segment tree data structure where parent node values are sum/max of children node values
class SegmentTree():
def __init__(self, size):
self.index = 0
self.size = size
self.full = False # Used to track actual capacity
self.tree_start = 2 ** (size - 1).bit_length() - 1 # Put all used node leaves on last tree level
self.sum_tree = np.zeros((self.tree_start + self.size,), dtype=np.float32)
self.data = np.array([blank_trans] * size, dtype=Transition_dtype) # Build structured array
self.max = 1 # Initial max value to return (1 = 1^ω)
# Propagates single value up tree given a tree index for efficiency
def _propagate_index(self, index):
parent = (index - 1) // 2
left, right = 2 * parent + 1, 2 * parent + 2
self.sum_tree[parent] = self.sum_tree[left] + self.sum_tree[right]
if parent != 0:
self._propagate_index(parent)
# Updates single value given a tree index for efficiency
def _update_index(self, index, value):
self.sum_tree[index] = value # Set new value
self._propagate_index(index) # Propagate value
self.max = max(value, self.max)
# append single data and weight value
def append(self, data, value):
self.data[self.index] = data # Store data in underlying data structure
self._update_index(self.index + self.tree_start, value) # Update tree
self.index = (self.index + 1) % self.size # Update index
self.full = self.full or self.index == 0 # Save when capacity reached
self.max = max(value, self.max)
# Updates nodes values from current tree
def _update_nodes(self, indices):
children_indices = indices * 2 + np.expand_dims([1, 2], axis=1)
self.sum_tree[indices] = np.sum(self.sum_tree[children_indices], axis=0)
# Propagates changes up tree given tree indices
def _propagate(self, indices):
parents = (indices - 1) // 2
unique_parents = np.unique(parents)
self._update_nodes(unique_parents)
if unique_parents[0] != 0:
self._propagate(unique_parents)
# Updates values given tree indices
def update(self, indices, values):
self.sum_tree[indices] = values # Set new values
self._propagate(indices) # Propagate values
current_max_value = np.max(values)
self.max = max(current_max_value, self.max)
# Searches for the location of values in sum tree
def _retrieve(self, indices, values):
children_indices = (indices * 2 + np.expand_dims([1, 2], axis=1)) # Make matrix of children indices
# If indices correspond to leaf nodes, return them
if children_indices[0, 0] >= self.sum_tree.shape[0]:
return indices
# If children indices correspond to leaf nodes, bound rare outliers in case total slightly overshoots
elif children_indices[0, 0] >= self.tree_start:
children_indices = np.minimum(children_indices, self.sum_tree.shape[0] - 1)
left_children_values = self.sum_tree[children_indices[0]]
# Classify which values are in left or right branches
successor_choices = np.greater(values, left_children_values).astype(np.int32)
# Use classification to index into the indices matrix
successor_indices = children_indices[successor_choices, np.arange(indices.size)]
# Subtract the left branch values when searching in the right branch
successor_values = values - successor_choices * left_children_values
return self._retrieve(successor_indices, successor_values)
# Searches for values in sum tree and returns values, data indices and tree indices
def find(self, values):
tree_indices = self._retrieve(np.zeros(values.shape, dtype=np.int32), values)
data_index = tree_indices - self.tree_start
return (self.sum_tree[tree_indices], data_index, tree_indices) # Return weight values, data indices, tree indices
# Returns data given a data index
def get(self, data_index):
return self.data[data_index % self.size]
def total(self):
return self.sum_tree[0]
class ReplayMemory():
def __init__(self, args, capacity):
self.device = args.device
self.capacity = capacity
self.history = args.history_length
self.discount = args.discount
self.n = args.multi_step
self.priority_weight = args.priority_weight # Initial importance sampling weight β, annealed to 1 over course of training
self.priority_exponent = args.priority_exponent
self.t = 0 # Internal episode timestep counter
self.n_step_scaling = torch.tensor([self.discount ** i for i in range(self.n)], dtype=torch.float32,
device=self.device) # Discount-scaling vector for n-step returns
self.transitions = SegmentTree(
capacity) # Store transitions in a wrap-around cyclic buffer within a sum tree for querying priorities
# Adds state and action at time t, reward and terminal at time t + 1
def append(self, state, action, reward, terminal):
state = state[-1].mul(255).to(dtype=torch.uint8,
device=torch.device('cpu')) # Only store last frame and discretise to save memory
self.transitions.append((self.t, state, action, reward, not terminal),
self.transitions.max) # Store new transition with maximum priority
self.t = 0 if terminal else self.t + 1 # Start new episodes with t = 0
# Returns the transitions with blank states where appropriate
def _get_transitions(self, idxs):
transition_idxs = np.arange(-self.history + 1, self.n + 1) + np.expand_dims(idxs, axis=1)
transitions = self.transitions.get(transition_idxs)
transitions_firsts = transitions['timestep'] == 0
blank_mask = np.zeros_like(transitions_firsts, dtype=np.bool_)
for t in range(self.history - 2, -1, -1): # e.g. 2 1 0
blank_mask[:, t] = np.logical_or(blank_mask[:, t + 1],
transitions_firsts[:, t + 1]) # True if future frame has timestep 0
for t in range(self.history, self.history + self.n): # e.g. 4 5 6
blank_mask[:, t] = np.logical_or(blank_mask[:, t - 1],
transitions_firsts[:, t]) # True if current or past frame has timestep 0
transitions[blank_mask] = blank_trans
return transitions
# Returns a valid sample from each segment
def _get_samples_from_segments(self, batch_size, p_total):
segment_length = p_total / batch_size # Batch size number of segments, based on sum over all probabilities
segment_starts = np.arange(batch_size) * segment_length
valid = False
while not valid:
samples = np.random.uniform(0.0, segment_length,
[batch_size]) + segment_starts # Uniformly sample from within all segments
probs, idxs, tree_idxs = self.transitions.find(
samples) # Retrieve samples from tree with un-normalised probability
if np.all((self.transitions.index - idxs) % self.capacity > self.n) and np.all(
(idxs - self.transitions.index) % self.capacity >= self.history) and np.all(probs != 0):
valid = True # Note that conditions are valid but extra conservative around buffer index 0
# Retrieve all required transition data (from t - h to t + n)
transitions = self._get_transitions(idxs)
# Create un-discretised states and nth next states
all_states = transitions['state']
states = torch.tensor(all_states[:, :self.history], device=self.device, dtype=torch.float32).div_(255)
next_states = torch.tensor(all_states[:, self.n:self.n + self.history], device=self.device,
dtype=torch.float32).div_(255)
# Discrete actions to be used as index
actions = torch.tensor(np.copy(transitions['action'][:, self.history - 1]), dtype=torch.int64,
device=self.device)
# Calculate truncated n-step discounted returns R^n = Σ_k=0->n-1 (γ^k)R_t+k+1 (note that invalid nth next states have reward 0)
rewards = torch.tensor(np.copy(transitions['reward'][:, self.history - 1:-1]), dtype=torch.float32,
device=self.device)
R = torch.matmul(rewards, self.n_step_scaling)
# Mask for non-terminal nth next states
nonterminals = torch.tensor(np.expand_dims(transitions['nonterminal'][:, self.history + self.n - 1], axis=1),
dtype=torch.float32, device=self.device)
return probs, idxs, tree_idxs, states, actions, R, next_states, nonterminals
def sample(self, batch_size):
p_total = self.transitions.total() # Retrieve sum of all priorities (used to create a normalised probability distribution)
probs, idxs, tree_idxs, states, actions, returns, next_states, nonterminals = self._get_samples_from_segments(
batch_size, p_total) # Get batch of valid samples
probs = probs / p_total # Calculate normalised probabilities
capacity = self.capacity if self.transitions.full else self.transitions.index
weights = (capacity * probs) ** -self.priority_weight # Compute importance-sampling weights w
weights = torch.tensor(weights / weights.max(), dtype=torch.float32,
device=self.device) # Normalise by max importance-sampling weight from batch
return tree_idxs, states, actions, returns, next_states, nonterminals, weights
def update_priorities(self, idxs, priorities):
priorities = np.power(priorities, self.priority_exponent)
self.transitions.update(idxs, priorities)
# Set up internal state for iterator
def __iter__(self):
self.current_idx = 0
return self
# Return valid states for validation
def __next__(self):
if self.current_idx == self.capacity:
raise StopIteration
transitions = self.transitions.data[np.arange(self.current_idx - self.history + 1, self.current_idx + 1)]
transitions_firsts = transitions['timestep'] == 0
blank_mask = np.zeros_like(transitions_firsts, dtype=np.bool_)
for t in reversed(range(self.history - 1)):
blank_mask[t] = np.logical_or(blank_mask[t + 1],
transitions_firsts[t + 1]) # If future frame has timestep 0
transitions[blank_mask] = blank_trans
state = torch.tensor(transitions['state'], dtype=torch.float32, device=self.device).div_(
255) # Agent will turn into batch
self.current_idx += 1
return state
next = __next__ # Alias __next__ for Python 2 compatibility
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。