阿里控股审核一面

# 假设已存在:
# model, old_model
# optimizer
# eps, beta
# K: 每个 prompt 的 group size

old_model.load_state_dict(model.state_dict())
old_model.eval()

for step, batch in enumerate(dataloader):
    optimizer.zero_grad()

    total_loss = 0.0

    for prompt in batch:
        # 1. 用 old policy 采样(不参与梯度)
        with torch.no_grad():
            completions, old_log_probs = old_model.sample_log_probs(
                prompt, K
            )  # shape: [K, T]

        # 2. 用当前 policy 重新计算 log_probs
        _, new_log_probs = model.log_probs(
            prompt, completions
        )  # shape: [K, T]

        # 3. reward(sequence-level)
        rewards = reward_fn(completions)  # shape: [K]

        # 4. GRPO advantage(group 内归一化)
        advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
        advantages = advantages.detach()  # 明确不回传

        # 5. token-level log-prob 求和
        old_lp = old_log_probs.sum(dim=-1)  # [K]
        new_lp = new_log_probs.sum(dim=-1)  # [K]

        # 6. PPO / GRPO ratio
        ratio = torch.exp(new_lp - old_lp)

        # 7. clipped surrogate objective
        clipped_ratio = torch.clamp(ratio, 1 - eps, 1 + eps)
        policy_loss = -torch.mean(
            torch.min(ratio * advantages, clipped_ratio * advantages)
        )

        # 8. KL penalty(old || new)
        kl_loss = beta * torch.mean(old_lp - new_lp)

        # 9. 总 loss
        loss = policy_loss + kl_loss
        total_loss += loss

    # 10. batch 反传
    total_loss.backward()
    optimizer.step()

    # 11. 周期性同步 old_model
    if step % OLD_SYNC_INTERVAL == 0:
        old_model.load_state_dict(model.state_dict())
        old_model.eval()

  1. 忘记了torch.exp(log_logits)
  2. 忘记了
policy_loss = -torch.mean(
            torch.min(ratio * advantages, clipped_ratio * advantages)
        )

另外忘记了kl散度的计算方式

# 8. KL penalty(old || new)
  kl_loss = beta * torch.mean(old_lp - new_lp)

全部评论
怎么还在面呀 有签了的公司嘛
点赞 回复 分享
发布于 昨天 00:15 北京

相关推荐

评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务