加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
gymtest.py 1.39 KB
一键复制 编辑 原始数据 按行查看 历史
张利峰 提交于 2022-05-14 13:25 . update
import gym
from IPython import display
import matplotlib.pyplot as plt
from dqn import Agent
def plot(score, mean):
display.clear_output(wait=True)
display.display(plt.gcf())
plt.figure(figsize=(20, 10))
plt.clf()
plt.title('Training...')
plt.xlabel('Episode')
plt.ylabel('Duration')
plt.plot(score)
plt.plot(mean)
plt.text(len(score) - 1, score[-1], str(score[-1]))
plt.text(len(mean) - 1, mean[-1], str(mean[-1]))
if __name__ == '__main__':
env = gym.make('CartPole-v1')
params = {
'gamma': 0.8,
'epsi_high': 0.9,
'epsi_low': 0.05,
'decay': 200,
'delay': 200,
'lr': 0.001,
'capacity': 10000,
'batch_size': 64,
'state_space_dim': env.observation_space.shape[0],
'action_space_dim': env.action_space.n
}
agent = Agent(**params)
score = []
mean = []
for episode in range(1000):
s0 = env.reset()
# print(s0.dtype)
total_reward = 1
while True:
env.render()
a0 = agent.act(s0)
s1, r1, done, _ = env.step(a0)
if done:
r1 = -1
agent.put(s0, a0, r1, s1)
if done:
break
total_reward += r1
s0 = s1
agent.learn()
score.append(total_reward)
mean.append(sum(score[-100:]) / 100)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化