1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
| class SAC: def __init__(self, state_dim, action_dim, hidden_dim=256, actor_lr=3e-4, critic_lr=3e-4, alpha_lr=3e-4, gamma=0.99, tau=0.005, alpha=0.2, device='cuda' if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else 'cpu', replay_buffer_capacity=10000): self.alpha_lr = alpha_lr self.gamma = gamma self.tau = tau self.device = device self.target_entropy = -action_dim self.replay_buffer = ReplayBuffer(state_dim, action_dim, replay_buffer_capacity) self.actor = Actor(state_dim, action_dim, hidden_dim).to(device) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr) self.critic = Critic(state_dim, action_dim, hidden_dim).to(device) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr) self.target_critic = Critic(state_dim, action_dim, hidden_dim).to(device) self.target_critic.load_state_dict(self.critic.state_dict()) self.log_std_min = -20 self.log_std_max = 2 self.log_alpha = torch.tensor(np.log(alpha), requires_grad=True, device=device) self.alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr)
@property def alpha(self): return self.log_alpha.exp()
def store_transition(self, state, action, reward, next_state, done): self.replay_buffer.add(state, action, reward, next_state, done)
def act(self, obs, evaluate=False): if isinstance(obs, np.ndarray): obs = torch.FloatTensor(obs).to(self.device).unsqueeze(0) pred = self.actor(obs) action_mean, action_log_std = torch.chunk(pred, 2, dim=-1) if evaluate: return torch.tanh(action_mean), None log_std = torch.clamp(action_log_std, self.log_std_min, self.log_std_max) std = torch.exp(log_std) dist = torch.distributions.Normal(action_mean, std) normal_sample = dist.rsample() action = torch.tanh(normal_sample) log_prob = dist.log_prob(normal_sample) correction = 2. * (np.log(2.) - normal_sample - F.softplus(-2. * normal_sample)) log_prob -= correction log_prob = log_prob.sum(dim=-1, keepdim=True)
return action, log_prob
def train(self, batch_size = 512): if self.replay_buffer.size < batch_size: return states, actions, rewards, next_states, dones = self.replay_buffer.sample(batch_size) with torch.no_grad(): next_actions, new_log_prob = self.act(next_states) target_Q1, target_Q2 = self.target_critic(next_states, next_actions) target_Q = torch.min(target_Q1, target_Q2) y = rewards + (1 - dones) * self.gamma * (target_Q - self.alpha.item() * new_log_prob) curr_Q1, curr_Q2 = self.critic(states, actions) critic_loss = F.mse_loss(curr_Q1, y) + F.mse_loss(curr_Q2, y)
self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() self.update_target()
new_actions, log_prob = self.act(states) q1, q2 = self.critic(states, new_actions) q_min = torch.min(q1, q2) actor_loss = (self.alpha.item() * log_prob - q_min).mean()
self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step()
alpha_loss = -(self.log_alpha * (log_prob + self.target_entropy).detach()).mean()
self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step()
def update_target(self): for param, target_param in zip(self.critic.parameters(), self.target_critic.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
def save(self, filename): """ 保存所有状态,确保既能用于测试,也能用于恢复训练 """ torch.save({ 'actor': self.actor.state_dict(), 'critic': self.critic.state_dict(), 'target_critic': self.target_critic.state_dict(), 'log_alpha': self.log_alpha.detach(),
'actor_optimizer': self.actor_optimizer.state_dict(), 'critic_optimizer': self.critic_optimizer.state_dict(), 'alpha_optimizer': self.alpha_optimizer.state_dict(), }, filename)
def load(self, filename, evaluate=False): """ 加载模型 :param filename: 模型路径 :param evaluate: True -> 仅加载 Actor 和 Critic (用于测试/验证) False -> 加载所有优化器和参数 (用于继续训练) """ checkpoint = torch.load(filename, map_location=self.device)
self.actor.load_state_dict(checkpoint['actor']) self.critic.load_state_dict(checkpoint['critic'])
if evaluate: self.actor.eval() self.critic.eval() print(f"Loaded model from {filename} (Evaluation Mode)") return
self.target_critic.load_state_dict(checkpoint['target_critic'])
self.log_alpha.data.copy_(checkpoint['log_alpha'])
self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer']) self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer']) self.alpha_optimizer.load_state_dict(checkpoint['alpha_optimizer'])
self.actor.train() self.critic.train() print(f"Loaded model from {filename} (Resume Training Mode)")
|