阿里控股审核一面
# 假设已存在:
# model, old_model
# optimizer
# eps, beta
# K: 每个 prompt 的 group size
old_model.load_state_dict(model.state_dict())
old_model.eval()
for step, batch in enumerate(dataloader):
optimizer.zero_grad()
total_loss = 0.0
for prompt in batch:
# 1. 用 old policy 采样(不参与梯度)
with torch.no_grad():
completions, old_log_probs = old_model.sample_log_probs(
prompt, K
) # shape: [K, T]
# 2. 用当前 policy 重新计算 log_probs
_, new_log_probs = model.log_probs(
prompt, completions
) # shape: [K, T]
# 3. reward(sequence-level)
rewards = reward_fn(completions) # shape: [K]
# 4. GRPO advantage(group 内归一化)
advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
advantages = advantages.detach() # 明确不回传
# 5. token-level log-prob 求和
old_lp = old_log_probs.sum(dim=-1) # [K]
new_lp = new_log_probs.sum(dim=-1) # [K]
# 6. PPO / GRPO ratio
ratio = torch.exp(new_lp - old_lp)
# 7. clipped surrogate objective
clipped_ratio = torch.clamp(ratio, 1 - eps, 1 + eps)
policy_loss = -torch.mean(
torch.min(ratio * advantages, clipped_ratio * advantages)
)
# 8. KL penalty(old || new)
kl_loss = beta * torch.mean(old_lp - new_lp)
# 9. 总 loss
loss = policy_loss + kl_loss
total_loss += loss
# 10. batch 反传
total_loss.backward()
optimizer.step()
# 11. 周期性同步 old_model
if step % OLD_SYNC_INTERVAL == 0:
old_model.load_state_dict(model.state_dict())
old_model.eval()
- 忘记了torch.exp(log_logits)
- 忘记了
policy_loss = -torch.mean(
torch.min(ratio * advantages, clipped_ratio * advantages)
)
另外忘记了kl散度的计算方式
# 8. KL penalty(old || new) kl_loss = beta * torch.mean(old_lp - new_lp)
