1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
| from torch import nn import torch from torch.distributions import Categorical import numpy as np
class Actor(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim): super(Actor, self).__init__() self.net = nn.Sequential( nn.Linear(state_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, action_dim), nn.Softmax(dim=-1) )
def forward(self, x): return self.net(x)
class Critic(nn.Module): def __init__(self, state_dim, hidden_dim): super(Critic, self).__init__() self.net = nn.Sequential( nn.Linear(state_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1) )
def forward(self, x): return self.net(x)
class PPO(): def __init__(self, env, hidden_dim, actor_lr, critic_lr, K_opoch): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.tau = 0.001 self.gamma = 0.99 self.lamda = 0.95 self.epsilon = 0.2 self.c1_vf = 0.5 self.c2_entropy = 0.01 self.K_epochs = K_opoch self.state_dim = env.observation_space.shape[0] self.action_dim = env.action_space.n self.hidden_dim = hidden_dim self.actor = Actor(self.state_dim, self.action_dim, self.hidden_dim).to(self.device) self.critic = Critic(self.state_dim, self.hidden_dim).to(self.device) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
def select_action(self, state): state = torch.from_numpy(state).float().to(self.device) with torch.no_grad(): probs = self.actor(state) disk = Categorical(probs) action = disk.sample() log_prob = disk.log_prob(action) return action.item(), log_prob.item()
def GAE(self, states, rewards, next_states, dones): advantage = torch.zeros_like(rewards) values = self.critic(states).detach() next_values = self.critic(next_states).detach() last_advantage = 0 for T in reversed(range(states.shape[0])): td_target = rewards[T] + self.gamma * next_values[T] * (1 - dones[T]) td_delta = td_target - values[T] advantage[T] = td_delta + self.lamda * self.gamma * (1 - dones[T]) * last_advantage last_advantage = advantage[T] returns = advantage + values return advantage, returns
def train(self, states, actions, rewards, next_states, dones, old_log_probs): states = torch.FloatTensor(np.array(states)).to(self.device) actions = torch.LongTensor(np.array(actions)).view(-1, 1).to(self.device) rewards = torch.FloatTensor(np.array(rewards)).view(-1, 1).to(self.device) next_states = torch.FloatTensor(np.array(next_states)).to(self.device) dones = torch.FloatTensor(np.array(dones)).view(-1, 1).to(self.device) old_log_probs = torch.FloatTensor(np.array(old_log_probs)).view(-1, 1).to(self.device)
advantages, returns = self.GAE(states, rewards, next_states, dones) advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)
batch_size = 64 data_length = states.size(0)
for _ in range(self.K_epochs): indices = torch.randperm(data_length).to(self.device) for start_index in range(0, data_length, batch_size): sample_indices = indices[start_index: start_index + batch_size] mb_states = states[sample_indices] mb_actions = actions[sample_indices] mb_old_log_probs = old_log_probs[sample_indices] mb_advantages = advantages[sample_indices] mb_returns = returns[sample_indices] probs = self.actor(mb_states) dist = Categorical(probs) new_log_probs = dist.log_prob(mb_actions.squeeze(-1)).view(-1, 1) ratio = torch.exp(new_log_probs - mb_old_log_probs) surr1 = ratio * mb_advantages clipped_ratio = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) surr2 = clipped_ratio * mb_advantages actor_loss = -torch.min(surr1, surr2).mean()
current_values = self.critic(mb_states) critic_loss = self.c1_vf * nn.functional.mse_loss(current_values, mb_returns)
entropy_loss = -self.c2_entropy * dist.entropy().mean()
actor_loss += entropy_loss
self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step()
self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step()
|