GRPO From Scratch

This post explains my pytorch implementation of Group Relative Policy Optimization Algorithm.

Group Relative Policy Optimization is an RL algorithm that improves the model’s reasoning capability. It is a variant of Proximal Policy Optimization that is originally designed for robotic locomotion control and Atari game playing. GRPO turned out to be effective for rule-based reward, as demonstrated in DeepSeek-R1 .

To uncover the implementation details in the minimal way, I implemented GRPO from scratch with PyTorch in 👉 RLFromScratch. Let’s now understand it step by step.


Quick Recap of GRPO Algorithm

The objective function of GRPO has the following

\[\mathcal{J}_{\mathrm{GRPO}}(\theta) = \mathbb{E}_{q,\,\{o_i\}_{i=1}^G \sim \pi_{\theta_{\mathrm{old}}}(\cdot \mid q)} \Bigg[ \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \Big\{ \min \Big[ \Gamma_{i,t}\,\hat{A}_{i,t},\; \mathrm{clip}\!\big(\Gamma_{i,t},\,1-\varepsilon,\,1+\varepsilon\big)\,\hat{A}_{i,t} \Big] - \beta\,\mathbb{D}_{\mathrm{KL}}\!\big[\pi_\theta \,\|\, \pi_{\mathrm{ref}}\big] \Big\} \Bigg] .\]

where \(\Gamma_{i,t}:=\frac{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_{\theta_{\text {old }}}\left(o_{i, t} \mid q, o_{i,<t}\right)}\), \(\varepsilon, \beta\) are hyper-parameters, and \(\hat{A}_{i,t}\) is the advantage calculated based on relative rewards of the outputs inside each group only with \(\hat{A}_{i, t}=\widetilde{r}_i=\frac{r_i-\operatorname{mean}(\mathbf{r})}{\operatorname{std}(\mathbf{r})}\). The unbiased estimator of KL divergence is \(\mathbb{D}_{K L}\left[\pi_\theta \| \pi_{r e f}\right]=\frac{\pi_{r e f}\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}-\log \frac{\pi_{r e f}\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}-1\).


Code Explanation

- Reward Design

I want to enforce the model to learn both the correct format and the correct answer. Format score is 1.0 and answer score is 2.0. The full reward is 3.0. The partial reward is 1.0.

def compute_format_score(batch_responses):
    """Reward function that checks if the completion has the correct format."""
    pattern = r"^<reasoning>(?:(?!</reasoning>).)*</reasoning>\n<answer>(?:(?!</answer>).)*</answer>$"
    matches = [bool(re.match(pattern, g_a)) for g_a in batch_responses]
    format_scores = [1.0 if match else 0.0 for match in matches]
    return format_scores

def compute_reward(batch_answers, answers):
    """Reward function that checks if the answer is correct."""
    reward_scores = [2.0 if g_a == a else 0.0 for g_a, a in zip(batch_answers, answers)]
    return reward_scores

- Initial Training Loop

for epoch in range(num_epochs):
    train_sampler.set_epoch(epoch)
    for step, (prompts, answers) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
        prompt_enc = tokenizer(
            prompts,
            return_tensors = 'pt',
            padding = True,
            padding_side = 'left',
            truncation = True
        )
        input_ids = prompt_enc["input_ids"].to(local_rank) # (B, prompt_len) and left_pad
        attention_mask = prompt_enc["attention_mask"].to(local_rank)

The input_ids and attention_mask are used as inputs to generate responses (online exploration in RL).

- Generate K Samples Per Prompt

The number of Group is K.

policy_model.eval()
with torch.no_grad():
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        explore_generations = policy_model.module.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                num_return_sequences=K,
                top_p=0.9,
                temperature=1.0,
                eos_token_id=tokenizer.eos_token_id
            ) # (batch_size * K, prompt_len + max_new_tokens)
policy_model.train()
prompt_len = input_ids.shape[1]
batch_size = input_ids.shape[0]
batch_attention_mask = (explore_generations != tokenizer.pad_token_id).long() 

explore_generations and batch_attention_mask will be used to compute logits.

- Compute logprobs

policy_model.eval()
with torch.no_grad():
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        out_old = policy_model(explore_generations,  batch_attention_mask, use_cache=False)
    logits_old = out_old.logits # [batch_size * K, seq_len, vocab_size]
policy_model.train()
# shift logits and labels
logits_old = logits_old[:, :-1, :].contiguous() # [batch_size * K, seq_len - 1, vocab_size]
labels_old = labels[:, 1:].contiguous() # [batch_size * K, seq_len - 1]
logprobs_old = -F.cross_entropy(
    logits_old.view(-1, logits_old.shape[-1]), labels_old.view(-1), reduction = 'none', ignore_index=-100
).view(logits_old.shape[0], -1) # [batch_size * K, seq_len - 1]
assert batch_action_mask.shape[-1] == logprobs_old.shape[-1] + 1 and batch_action_mask.shape == batch_attention_mask.shape
logprobs_old = logprobs_old.view(batch_size, K, -1) # [batch_size, K, seq_len - 1]

logprobs_old denotes elementwise \(\log \pi_{\theta_{old}}(o_t\mid o_{<t}, q)\). Similarly, logprobs_ref computes \(\log \pi_{\theta_{ref}}(o_t\mid o_{<t}, q)\), and logprobs_new denotes \(\log \pi_{\theta}(o_t\mid o_{<t}, q)\).

- Compute advantages

batch_responses_ids = explore_generations[:, prompt_len:] # (batch_size*K, response_length) right pad
batch_responses = tokenizer.batch_decode(batch_responses_ids, skip_special_tokens = True) # (batch_size*K, response_text_length)
batch_answers = [extract_xml_answer(batch_responses[i]) for i in range(len(batch_responses))] # (batch_size*K, generated_answer_length) str
answers_K = [a for a in answers for _ in range(K)]
assert len(batch_answers) == len(answers_K)
batch_format_scores = compute_format_score(batch_responses) # (batch_size*K, 1)
batch_reward_scores = compute_reward(batch_answers, answers_K) # (batch_size*K, 1)
batch_rewards = torch.tensor([bfs + brs for bfs, brs in zip(batch_format_scores, batch_reward_scores)], dtype=torch.float16)
batch_rewards = batch_rewards.view(batch_size, K) # (batch_size, K)
batch_advantages = (batch_rewards - batch_rewards.mean(dim = -1, keepdim = True)) / batch_rewards.std(dim = -1, keepdim = True).clamp_min(1e-6)
batch_advantages = batch_advantages.to(local_rank) # (batch_size, K)
assert batch_advantages.shape == (batch_size, K)
batch_advantages = batch_advantages.unsqueeze(2).expand_as(logprobs_ref) # [batch_size, K, seq_len - 1]
assert batch_advantages.shape == logprobs_ref.shape

The resulting batch_advantages has shape [batch_size, K, seq_len - 1]. The last dimension is simply a replication since \(\hat{A}_{i, t}=\widetilde{r}_i=\frac{r_i-\operatorname{mean}(\mathbf{r})}{\operatorname{std}(\mathbf{r})}\) is independent of sequence length t.

- Valid mask

valid_mask = batch_action_mask[:, :-1].contiguous().float().view(batch_size, K, -1) # [batch_size, K, seq_len - 1] 

This is critical for the summation \(\frac{1}{\mid o_i\mid}\sum_{t=1}^{\mid o_i \mid}\) since every responses have vary lengths. Using .mean() will also take the invalid tokens into account due to padding, which is not what we want.

- Compute probability ratios

ratio = torch.exp(logprobs_new - logprobs_old) # [batch_size, K, seq_len - 1]
ratio_clipped = torch.clamp(ratio, 1.0 - ppo_clip_range, 1.0 + ppo_clip_range) # [batch_size, K, seq_len - 1]
individual_ppo_reward = torch.min(ratio * batch_advantages, ratio_clipped * batch_advantages) # [batch_size, K, seq_len - 1]

ratio computes pointwise \(\Gamma_{i,t}:=\frac{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_{\theta_{\text {old }}}\left(o_{i, t} \mid q, o_{i,<t}\right)}\), ratio_clipped computes pointwise \(\mathrm{clip}\!\big(\Gamma_{i,t},\,1-\varepsilon,\,1+\varepsilon\big)\), and individual_ppo_reward denotes \(\min[ \Gamma_{i,t}\,\hat{A}_{i,t},\; \mathrm{clip}\!\big(\Gamma_{i,t},\,1-\varepsilon,\,1+\varepsilon\big)\,\hat{A}_{i,t}]\).

- Compute KL penalty

ratio_ref_log = logprobs_ref - logprobs_new # [batch_size, K, seq_len - 1]
ratio_ref = torch.exp(ratio_ref_log) # [batch_size, K, seq_len - 1]
individual_kl_penality = ratio_ref - ratio_ref_log - 1 # [batch_size, K, seq_len - 1]

individual_kl_penality denotes \(\frac{\pi_{r e f}\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}-\log \frac{\pi_{r e f}\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}-1\).

- Compute the overall GRPO loss

sum_loss_ave_response = (individual_ppo_reward - kl_coef * individual_kl_penality).sum(dim = -1) # [batch_size, K]
count_ave_response = valid_mask.sum(dim = -1) # [batch_size, K]
reward_ave_response = sum_loss_ave_response / count_ave_response # [batch_size, K]
grpo_loss = -reward_ave_response.mean()

sum_loss_ave_response preformed operation $\sum_{t=1}^{\mid o_i \mid}$, and reward_ave_response preformed operation \(\frac{1}{\mid o_i \mid}\).

This concludes the key steps of GRPO.

Miscellaneous

Also see this on twitter: