1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
| import gymnasium as gym, torch, torch.nn as nn, numpy as np, matplotlib.pyplot as plt from collections import deque
env = gym.make("CartPole-v1", render_mode="human") state_dim = env.observation_space.shape[0] if len(env.observation_space.shape) == 1 else env.observation_space.n action_dim = env.action_space.n device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") class Actor(nn.Module): def __init__(self, hidden_dim = 128): super(Actor, self).__init__() self.net = nn.Sequential( nn.Linear(state_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, action_dim), nn.Softmax(dim=-1) )
def forward(self, x): return self.net(x)
class Critic(nn.Module): def __init__(self, hidden_dim = 128): super(Critic, self).__init__() self.net = nn.Sequential( nn.Linear(state_dim, hidden_dim),nn.ReLU(), nn.Linear(hidden_dim, hidden_dim),nn.ReLU(), nn.Linear(hidden_dim, 1) )
def forward(self, x): return self.net(x)
class ActorCritic(): def __init__(self, gamma): self.actor = Actor().to(device) self.critic = Critic().to(device) self.optimizer_a = torch.optim.Adam(self.actor.parameters(), lr=actor_lr) self.optimizer_c = torch.optim.Adam(self.critic.parameters(), lr=critic_lr) self.gamma = gamma
def act(self, states): states = torch.from_numpy(states).float().to(device) with torch.no_grad(): probs = self.actor(states) disk = torch.distributions.Categorical(probs) return disk.sample().item()
def train(self, states, actions, rewards, next_states, dones): td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones) td_delta = td_target - self.critic(states) log_probs = torch.log(self.actor(states).gather(1, actions) + 1e-9) actor_loss = torch.mean(-log_probs * td_delta.detach()) critic_loss = nn.functional.mse_loss(self.critic(states), td_target.detach()) self.optimizer_c.zero_grad() self.optimizer_a.zero_grad() critic_loss.backward() actor_loss.backward() self.optimizer_c.step() self.optimizer_a.step()
torch.manual_seed(0) actor_lr = 1e-4 critic_lr = 1e-3 gamma = 0.99 scores = [] episodes = 1000 model = ActorCritic(gamma) from tqdm import tqdm pbar = tqdm(range(episodes), desc="Training") for episode in pbar: score = 0 state, _ = env.reset() done = False states, actions, rewards, dones, next_states = [], [], [], [], [] while not done: action = model.act(state) next_state, reward, done, truncated,_ = env.step(action) done = done or truncated score += reward states.append(state) actions.append(action) rewards.append(reward) next_states.append(next_state) dones.append(done) state = next_state states = torch.FloatTensor(np.array(states)).to(device) actions = torch.LongTensor(np.array(actions)).view(-1, 1).to(device) rewards = torch.FloatTensor(np.array(rewards)).view(-1, 1).to(device) next_states = torch.FloatTensor(np.array(next_states)).to(device) dones = torch.FloatTensor(np.array(dones)).view(-1, 1).to(device) model.train(states, actions, rewards, next_states, dones) scores.append(score) pbar.set_postfix(ep=episode, score=score, avg100=np.mean(scores[-100:])) torch.save(model.actor.state_dict(),'../../model/cartpole-a.pt') torch.save(model.critic.state_dict(),'../../model/cartpole-c.pt') print(np.mean(scores[-100:])) plt.plot(scores) plt.show()
|