(3)Gymnasium--CartPole的测试基于DQN

chatgpt/2023/9/27 7:15:31

 1、使用Pytorch基于DQN的实现

1.1 主要参考

(1)推荐pytorch官方的教程

Reinforcement Learning (DQN) Tutorial — PyTorch Tutorials 2.0.1+cu117 documentation

(2)

Pytorch 深度强化学习 – CartPole问题|极客笔记

2.2 pytorch官方的教程原理

待续,这两天时期多,过两天整理一下。

2.3代码实现

import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import countimport torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as Fenv = gym.make("CartPole-v1")# set up matplotlib
# is_ipython = 'inline' in matplotlib.get_backend()
# if is_ipython:
#     from IPython import displayplt.ion()# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")Transition = namedtuple('Transition',('state', 'action', 'next_state', 'reward'))class ReplayMemory(object):def __init__(self, capacity):self.memory = deque([], maxlen=capacity)def push(self, *args):"""Save a transition"""self.memory.append(Transition(*args))def sample(self, batch_size):return random.sample(self.memory, batch_size)def __len__(self):return len(self.memory)class DQN(nn.Module):def __init__(self, n_observations, n_actions):super(DQN, self).__init__()self.layer1 = nn.Linear(n_observations, 128)self.layer2 = nn.Linear(128, 128)self.layer3 = nn.Linear(128, n_actions)# Called with either one element to determine next action, or a batch# during optimization. Returns tensor([[left0exp,right0exp]...]).def forward(self, x):x = F.relu(self.layer1(x))x = F.relu(self.layer2(x))return self.layer3(x)# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the ``AdamW`` optimizer
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
state, info = env.reset()
n_observations = len(state)policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)steps_done = 0def select_action(state):global steps_donesample = random.random()eps_threshold = EPS_END + (EPS_START - EPS_END) * \math.exp(-1. * steps_done / EPS_DECAY)steps_done += 1if sample > eps_threshold:with torch.no_grad():# t.max(1) will return the largest column value of each row.# second column on max result is index of where max element was# found, so we pick action with the larger expected reward.return policy_net(state).max(1)[1].view(1, 1)else:return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)episode_durations = []def plot_durations(show_result=False):plt.figure(1)durations_t = torch.tensor(episode_durations, dtype=torch.float)if show_result:plt.title('Result')else:plt.clf()plt.title('Training...')plt.xlabel('Episode')plt.ylabel('Duration')plt.plot(durations_t.numpy())# Take 100 episode averages and plot them tooif len(durations_t) >= 100:means = durations_t.unfold(0, 100, 1).mean(1).view(-1)means = torch.cat((torch.zeros(99), means))plt.plot(means.numpy())plt.pause(0.001)  # pause a bit so that plots are updated# if is_ipython:#     if not show_result:#         display.display(plt.gcf())#         display.clear_output(wait=True)#     else:#         display.display(plt.gcf())def optimize_model():if len(memory) < BATCH_SIZE:returntransitions = memory.sample(BATCH_SIZE)# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for# detailed explanation). This converts batch-array of Transitions# to Transition of batch-arrays.batch = Transition(*zip(*transitions))# Compute a mask of non-final states and concatenate the batch elements# (a final state would've been the one after which simulation ended)non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,batch.next_state)), device=device, dtype=torch.bool)non_final_next_states = torch.cat([s for s in batch.next_stateif s is not None])state_batch = torch.cat(batch.state)action_batch = torch.cat(batch.action)reward_batch = torch.cat(batch.reward)# Compute Q(s_t, a) - the model computes Q(s_t), then we select the# columns of actions taken. These are the actions which would've been taken# for each batch state according to policy_netstate_action_values = policy_net(state_batch).gather(1, action_batch)# Compute V(s_{t+1}) for all next states.# Expected values of actions for non_final_next_states are computed based# on the "older" target_net; selecting their best reward with max(1)[0].# This is merged based on the mask, such that we'll have either the expected# state value or 0 in case the state was final.next_state_values = torch.zeros(BATCH_SIZE, device=device)with torch.no_grad():next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]# Compute the expected Q valuesexpected_state_action_values = (next_state_values * GAMMA) + reward_batch# Compute Huber losscriterion = nn.SmoothL1Loss()loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))# Optimize the modeloptimizer.zero_grad()loss.backward()# In-place gradient clippingtorch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)optimizer.step()if torch.cuda.is_available():num_episodes = 600
else:# num_episodes = 50num_episodes = 600for i_episode in range(num_episodes):# Initialize the environment and get it's statestate, info = env.reset()state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)for t in count():action = select_action(state)observation, reward, terminated, truncated, _ = env.step(action.item())reward = torch.tensor([reward], device=device)done = terminated or truncatedif terminated:next_state = Noneelse:next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)# Store the transition in memorymemory.push(state, action, next_state, reward)# Move to the next statestate = next_state# Perform one step of the optimization (on the policy network)optimize_model()# Soft update of the target network's weights# θ′ ← τ θ + (1 −τ )θ′target_net_state_dict = target_net.state_dict()policy_net_state_dict = policy_net.state_dict()for key in policy_net_state_dict:target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)target_net.load_state_dict(target_net_state_dict)if done:episode_durations.append(t + 1)plot_durations()breakprint('Complete')
plot_durations(show_result=True)
plt.ioff()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.exyb.cn/news/show-5314036.html

如若内容造成侵权/违法违规/事实不符,请联系郑州代理记账网进行投诉反馈,一经查实,立即删除!

相关文章

2023-07-31力扣每日一题

链接&#xff1a; 143. 重排链表 题意&#xff1a; 将链表L0 → L1 → … → Ln - 1 → Ln变成L0 → Ln → L1 → Ln - 1 → L2 → Ln - 2 → … 解&#xff1a; 线性表法还是好写的 这边搞一下翻转法&#xff0c;快慢指针求翻转点&#xff08;翻转后面一半然后双指针合并…

用户需求转化为产品需求的5大注意事项

用户需求不等于产品需求&#xff0c;如果将用户需求和产品需求混为一谈&#xff0c;往往容易产生需求不完整、需求错误以及后期需求变更频繁等问题&#xff0c;那么用户需求转化为产品需求有什么注意事项&#xff1f; 1、了解需求转化落地逻辑流程 需要注意&#xff0c;用户需求…

【ChatGPT辅助学Rust | 基础系列 | 基础语法】变量,数据类型,运算符,控制流

文章目录 简介&#xff1a;一&#xff0c;变量1&#xff0c;变量的定义2&#xff0c;变量的可变性3&#xff0c;变量的隐藏 二、数据类型1&#xff0c;标量类型2&#xff0c;复合类型 三&#xff0c;运算符1&#xff0c;算术运算符2&#xff0c;比较运算符3&#xff0c;逻辑运算…

死锁产生的原因及解决方案

死锁 1. 死锁的成因2. 解决方案 1. 死锁的成因 互斥条件: 一个资源每次只能被一个进程使用。请求与保持条件&#xff1a;一个进程因请求资源而阻塞时&#xff0c;对已获得的资源保持不放。不可剥夺条件:进程已获得的资源&#xff0c;在末使用完之前&#xff0c;不能强行剥夺。…

多模态第2篇:MMGCN代码配置

一、Windows环境 1.创建并激活虚拟环境 #创建虚拟环境命名为mmgcn&#xff0c;指定python版本为3.8 conda create -n mmgcn python3.8 #激活虚拟环境 conda activate mmgcn2.安装pytorch #torch2.0.0 cu118 pip install torch2.0.0cu118 torchvision0.15.1cu118 torchaudio…

【Python数据分析】Python常用内置函数(二)

&#x1f389;欢迎来到Python专栏~Python常用内置函数&#xff08;二&#xff09; ☆* o(≧▽≦)o *☆嗨~我是小夏与酒&#x1f379; ✨博客主页&#xff1a;小夏与酒的博客 &#x1f388;该系列文章专栏&#xff1a;Python学习专栏 文章作者技术和水平有限&#xff0c;如果文…

【吐槽贴】项目居然因为采购管理失控被迫暂停了?

最近&#xff0c;手上的一个大型项目好不容易解决了进度延误、范围蔓延、质量不过关等难点&#xff0c;结果差点掉进了成本失控的坑里。没想到咱项目经理还要全权负责项目的采购管理&#xff0c;必须要分享出来&#xff0c;让大家避避雷。 先给大家介绍一下背景&#xff1a; 我…

Java算法一:线性回归【最小二乘法】预测数据

摘要 使用线性回归算法&#xff0c;基于已有的x轴、y轴数据来预测数据。以下两种算法预测结果基本一致&#xff0c;看个人喜好选择使用哪一种。 算法1 package com.test;import java.util.ArrayList; import java.util.List;public class PolynomialRegressionTest {public s…
推荐文章