This post explains my pytorch implementation of Group Relative Policy Optimization Algorithm.
Group Relative Policy Optimization
To uncover the implementation details in the minimal way, I implemented GRPO from scratch with PyTorch in 👉 RLFromScratch. Let’s now understand it step by step.
The objective function of GRPO has the following
\[\mathcal{J}_{\mathrm{GRPO}}(\theta) = \mathbb{E}_{q,\,\{o_i\}_{i=1}^G \sim \pi_{\theta_{\mathrm{old}}}(\cdot \mid q)} \Bigg[ \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \Big\{ \min \Big[ \Gamma_{i,t}\,\hat{A}_{i,t},\; \mathrm{clip}\!\big(\Gamma_{i,t},\,1-\varepsilon,\,1+\varepsilon\big)\,\hat{A}_{i,t} \Big] - \beta\,\mathbb{D}_{\mathrm{KL}}\!\big[\pi_\theta \,\|\, \pi_{\mathrm{ref}}\big] \Big\} \Bigg] .\]where \(\Gamma_{i,t}:=\frac{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_{\theta_{\text {old }}}\left(o_{i, t} \mid q, o_{i,<t}\right)}\), \(\varepsilon, \beta\) are hyper-parameters, and \(\hat{A}_{i,t}\) is the advantage calculated based on relative rewards of the outputs inside each group only with \(\hat{A}_{i, t}=\widetilde{r}_i=\frac{r_i-\operatorname{mean}(\mathbf{r})}{\operatorname{std}(\mathbf{r})}\). The unbiased estimator of KL divergence is \(\mathbb{D}_{K L}\left[\pi_\theta \| \pi_{r e f}\right]=\frac{\pi_{r e f}\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}-\log \frac{\pi_{r e f}\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}-1\).
- Reward Design
I want to enforce the model to learn both the correct format and the correct answer. Format score is 1.0 and answer score is 2.0. The full reward is 3.0. The partial reward is 1.0.
def compute_format_score(batch_responses):
"""Reward function that checks if the completion has the correct format."""
pattern = r"^<reasoning>(?:(?!</reasoning>).)*</reasoning>\n<answer>(?:(?!</answer>).)*</answer>$"
matches = [bool(re.match(pattern, g_a)) for g_a in batch_responses]
format_scores = [1.0 if match else 0.0 for match in matches]
return format_scores
def compute_reward(batch_answers, answers):
"""Reward function that checks if the answer is correct."""
reward_scores = [2.0 if g_a == a else 0.0 for g_a, a in zip(batch_answers, answers)]
return reward_scores
- Initial Training Loop
for epoch in range(num_epochs):
train_sampler.set_epoch(epoch)
for step, (prompts, answers) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
prompt_enc = tokenizer(
prompts,
return_tensors = 'pt',
padding = True,
padding_side = 'left',
truncation = True
)
input_ids = prompt_enc["input_ids"].to(local_rank) # (B, prompt_len) and left_pad
attention_mask = prompt_enc["attention_mask"].to(local_rank)
The input_ids
and attention_mask
are used as inputs to generate responses (online exploration in RL).
- Generate K Samples Per Prompt
The number of Group is K.
policy_model.eval()
with torch.no_grad():
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
explore_generations = policy_model.module.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=True,
num_return_sequences=K,
top_p=0.9,
temperature=1.0,
eos_token_id=tokenizer.eos_token_id
) # (batch_size * K, prompt_len + max_new_tokens)
policy_model.train()
prompt_len = input_ids.shape[1]
batch_size = input_ids.shape[0]
batch_attention_mask = (explore_generations != tokenizer.pad_token_id).long()
explore_generations
and batch_attention_mask
will be used to compute logits.
- Compute logprobs
policy_model.eval()
with torch.no_grad():
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
out_old = policy_model(explore_generations, batch_attention_mask, use_cache=False)
logits_old = out_old.logits # [batch_size * K, seq_len, vocab_size]
policy_model.train()
# shift logits and labels
logits_old = logits_old[:, :-1, :].contiguous() # [batch_size * K, seq_len - 1, vocab_size]
labels_old = labels[:, 1:].contiguous() # [batch_size * K, seq_len - 1]
logprobs_old = -F.cross_entropy(
logits_old.view(-1, logits_old.shape[-1]), labels_old.view(-1), reduction = 'none', ignore_index=-100
).view(logits_old.shape[0], -1) # [batch_size * K, seq_len - 1]
assert batch_action_mask.shape[-1] == logprobs_old.shape[-1] + 1 and batch_action_mask.shape == batch_attention_mask.shape
logprobs_old = logprobs_old.view(batch_size, K, -1) # [batch_size, K, seq_len - 1]
logprobs_old
denotes elementwise \(\log \pi_{\theta_{old}}(o_t\mid o_{<t}, q)\). Similarly, logprobs_ref
computes \(\log \pi_{\theta_{ref}}(o_t\mid o_{<t}, q)\), and logprobs_new
denotes \(\log \pi_{\theta}(o_t\mid o_{<t}, q)\).
- Compute advantages
batch_responses_ids = explore_generations[:, prompt_len:] # (batch_size*K, response_length) right pad
batch_responses = tokenizer.batch_decode(batch_responses_ids, skip_special_tokens = True) # (batch_size*K, response_text_length)
batch_answers = [extract_xml_answer(batch_responses[i]) for i in range(len(batch_responses))] # (batch_size*K, generated_answer_length) str
answers_K = [a for a in answers for _ in range(K)]
assert len(batch_answers) == len(answers_K)
batch_format_scores = compute_format_score(batch_responses) # (batch_size*K, 1)
batch_reward_scores = compute_reward(batch_answers, answers_K) # (batch_size*K, 1)
batch_rewards = torch.tensor([bfs + brs for bfs, brs in zip(batch_format_scores, batch_reward_scores)], dtype=torch.float16)
batch_rewards = batch_rewards.view(batch_size, K) # (batch_size, K)
batch_advantages = (batch_rewards - batch_rewards.mean(dim = -1, keepdim = True)) / batch_rewards.std(dim = -1, keepdim = True).clamp_min(1e-6)
batch_advantages = batch_advantages.to(local_rank) # (batch_size, K)
assert batch_advantages.shape == (batch_size, K)
batch_advantages = batch_advantages.unsqueeze(2).expand_as(logprobs_ref) # [batch_size, K, seq_len - 1]
assert batch_advantages.shape == logprobs_ref.shape
The resulting batch_advantages
has shape [batch_size, K, seq_len - 1]
. The last dimension is simply a replication since \(\hat{A}_{i, t}=\widetilde{r}_i=\frac{r_i-\operatorname{mean}(\mathbf{r})}{\operatorname{std}(\mathbf{r})}\) is independent of sequence length t.
- Valid mask
valid_mask = batch_action_mask[:, :-1].contiguous().float().view(batch_size, K, -1) # [batch_size, K, seq_len - 1]
This is critical for the summation \(\frac{1}{\mid o_i\mid}\sum_{t=1}^{\mid o_i \mid}\) since every responses have vary lengths. Using .mean()
will also take the invalid tokens into account due to padding, which is not what we want.
- Compute probability ratios
ratio = torch.exp(logprobs_new - logprobs_old) # [batch_size, K, seq_len - 1]
ratio_clipped = torch.clamp(ratio, 1.0 - ppo_clip_range, 1.0 + ppo_clip_range) # [batch_size, K, seq_len - 1]
individual_ppo_reward = torch.min(ratio * batch_advantages, ratio_clipped * batch_advantages) # [batch_size, K, seq_len - 1]
ratio
computes pointwise \(\Gamma_{i,t}:=\frac{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_{\theta_{\text {old }}}\left(o_{i, t} \mid q, o_{i,<t}\right)}\), ratio_clipped
computes pointwise \(\mathrm{clip}\!\big(\Gamma_{i,t},\,1-\varepsilon,\,1+\varepsilon\big)\), and individual_ppo_reward
denotes \(\min[
\Gamma_{i,t}\,\hat{A}_{i,t},\;
\mathrm{clip}\!\big(\Gamma_{i,t},\,1-\varepsilon,\,1+\varepsilon\big)\,\hat{A}_{i,t}]\).
- Compute KL penalty
ratio_ref_log = logprobs_ref - logprobs_new # [batch_size, K, seq_len - 1]
ratio_ref = torch.exp(ratio_ref_log) # [batch_size, K, seq_len - 1]
individual_kl_penality = ratio_ref - ratio_ref_log - 1 # [batch_size, K, seq_len - 1]
individual_kl_penality
denotes \(\frac{\pi_{r e f}\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}-\log \frac{\pi_{r e f}\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}-1\).
- Compute the overall GRPO loss
sum_loss_ave_response = (individual_ppo_reward - kl_coef * individual_kl_penality).sum(dim = -1) # [batch_size, K]
count_ave_response = valid_mask.sum(dim = -1) # [batch_size, K]
reward_ave_response = sum_loss_ave_response / count_ave_response # [batch_size, K]
grpo_loss = -reward_ave_response.mean()
sum_loss_ave_response
preformed operation $\sum_{t=1}^{\mid o_i \mid}$, and reward_ave_response
preformed operation \(\frac{1}{\mid o_i \mid}\).
This concludes the key steps of GRPO.
Also see this on twitter:
I implemented GRPO and DPO from scratch in vanilla Pytorch to unravel every piece of training details. Hope it could be helpful for those who care about the implementation details of the algorithms. 👉 https://t.co/1Exq7GTkLY #AI #RL #LLM
— Ming Yin (@MingYin_0312) August 12, 2025