This post explains my pytorch implementation of Direct Preference Optimization Algorithm.
Direct Preference Optimization
To uncover the implementation details in the minimal way, I implemented DPO from scratch with PyTorch in 👉 RLFromScratch. Let’s now understand it step by step.
Given a reward function, the RL Fine-tuning phase optimizes LLM via
\[\max _{\pi_\theta} \mathbb{E}_{x \sim \mathcal{D}, y \sim \pi_\theta(y \mid x)}\left[r_\phi(x, y)\right]-\beta \mathbb{D}_{\mathrm{KL}}\left[\pi_\theta(y \mid x) \| \pi_{\mathrm{ref}}(y \mid x)\right].\]Given the partition function $Z(x)=\sum_y \pi_{\mathrm{ref}}(y \mid x) \exp \left(\frac{1}{\beta} r(x, y)\right)$, the closed-form solution takes
\[\pi_r(y \mid x)=\frac{1}{Z(x)} \pi_{\mathrm{ref}}(y \mid x) \exp \left(\frac{1}{\beta} r(x, y)\right)\]This further gives
\[r(x, y)=\beta \log \frac{\pi_r(y \mid x)}{\pi_{\mathrm{ref}}(y \mid x)}+\beta \log Z(x) .\]Under the Bradley-Terry model, the preference model follows:
\[p^*\left(y_1 \succ y_2 \mid x\right)=\frac{1}{1+\exp \left(\beta \log \frac{\pi^*\left(y_2 \mid x\right)}{\pi_{\mathrm{ref}}\left(y_2 \mid x\right)}-\beta \log \frac{\pi^*\left(y_1 \mid x\right)}{\pi_{\mathrm{ref}}\left(y_1 \mid x\right)}\right)}\]So the negative likelihood naturally provides the following DPO loss:
\[\mathcal{L}_{\mathrm{DPO}}(\pi_\theta,\pi_{\mathrm{ref}})=-\mathbb{E}_{\left(x, y_w, y_l\right) \sim \mathcal{D}}\left[\log \sigma\left(\beta \log \frac{\pi_\theta\left(y_w \mid x\right)}{\pi_{\mathrm{ref}}\left(y_w \mid x\right)}-\beta \log \frac{\pi_\theta\left(y_l \mid x\right)}{\pi_{\mathrm{ref}}\left(y_l \mid x\right)}\right)\right].\]Our code implmented above loss for training.
- Format Input
def collate_fn(batch):
input_ids, labels_list = [], []
for prompt, chosen_resp, rejected_resp in batch:
# Token IDs for prompt & responses
p_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]
c_ids = tokenizer(chosen_resp, add_special_tokens=False)["input_ids"]
r_ids = tokenizer(rejected_resp, add_special_tokens=False)["input_ids"]
# Build input IDs
input_ids += [
torch.tensor(p_ids + c_ids, dtype = torch.long),
torch.tensor(p_ids + r_ids, dtype = torch.long)
]
# Build labels: mask prompt with -100, keep response tokens
labels_list += [
torch.tensor([-100]*len(p_ids) + c_ids, dtype = torch.long),
torch.tensor([-100]*len(p_ids) + r_ids, dtype = torch.long)
]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
attention_mask = (input_ids != tokenizer.pad_token_id)
labels_tensor = pad_sequence(labels_list, batch_first=True, padding_value=-100)
assert input_ids.shape == attention_mask.shape and attention_mask.shape == labels_tensor.shape
return input_ids.to(local_rank), attention_mask.to(local_rank), labels_tensor.to(local_rank)
This function appends chosen and rejected responses sequentially into a batch, pad_sequence ued to align the length of the batch. attention_mask is used as an input for attention score computation, and labels are used for loss computation.
- Forward current model
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits #.float() # [B*2, T, V]
Here logits denote the output logits of the LLM.
- Forward ref model (no grad)
with torch.no_grad():
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
ref_outputs = ref_model(input_ids, attention_mask=attention_mask)
ref_logits = ref_outputs.logits
Here ref_logits denote the output logits of the reference LLM.
- Shift logits and labels
logits = logits[...,:-1,:].contiguous() # [2B, T-1, V]
ref_logits = ref_logits[...,:-1,:].contiguous() # [2B, T-1, V]
labels = labels[...,1:].contiguous() # [2B, T-1]
This shift is required for cross-entropy loss computation.
- Compute per-token NLLs
V = logits.size(-1)
loss_t = F.cross_entropy(
logits.view(-1, V), labels.view(-1),
ignore_index=-100, reduction="none"
).view(logits.size(0), -1)
ref_loss= F.cross_entropy(
ref_logits.view(-1, V), labels.view(-1),
ignore_index=-100, reduction="none"
).view(ref_logits.size(0), -1)
Here loss_t
represents elementwise $-\log \pi_\theta(o_t \mid o_{<t},x)$ and ref_loss
elementwise $-\log \pi_{\mathrm{ref}}(o_t \mid o_{<t},x)$.
- Sum to get sequence NLL
nll_seq = loss_t.sum(dim=1)
ref_nll_seq = ref_loss.sum(dim=1)
Note by chain rule $\log\pi(y \mid x) = \log \prod_t \pi(o_t \mid o_{<t},x)=\sum\log \pi(o_t \mid o_{<t}, x)$, this convert elementwise loss_t
and ref_loss
to $-\log \pi_\theta\left(y_w \mid x\right)$ and $-\log \pi_{\mathrm{ref}}\left(y_w \mid x\right)$.
- Inner
diff_theta = nll_r - nll_c
diff_ref = ref_r - ref_c
inner = beta * (diff_theta - diff_ref)
This computes $\beta \log \frac{\pi_\theta\left(y_w \mid x\right)}{\pi_{\mathrm{ref}}\left(y_w \mid x\right)}-\beta \log \frac{\pi_\theta\left(y_l \mid x\right)}{\pi_{\mathrm{ref}}\left(y_l \mid x\right)}$
- DPO loss
dpo_loss = -F.logsigmoid(inner).mean()
This computes \(\mathcal{L}_{\mathrm{DPO}}(\pi_\theta,\pi_{\mathrm{ref}})\) . Note F.logsigmoid
is crucial for numerical stability.
Also see this on twitter:
I implemented GRPO and DPO from scratch in vanilla Pytorch to unravel every piece of training details. Hope it could be helpful for those who care about the implementation details of the algorithms. 👉 https://t.co/1Exq7GTkLY #AI #RL #LLM
— Ming Yin (@MingYin_0312) August 12, 2025