DPO From Scratch

This post explains my pytorch implementation of Direct Preference Optimization Algorithm.

Direct Preference Optimization is an LLM alignment method that directly trains a language model to prefer human-preferred outputs over rejected ones by minimizing a contrastive loss between the two. Unlike the standard RLHF, it does not require an explicit reward model to fit/train.

To uncover the implementation details in the minimal way, I implemented DPO from scratch with PyTorch in 👉 RLFromScratch. Let’s now understand it step by step.


Quick Recap of DPO Algorithm

Given a reward function, the RL Fine-tuning phase optimizes LLM via

\[\max _{\pi_\theta} \mathbb{E}_{x \sim \mathcal{D}, y \sim \pi_\theta(y \mid x)}\left[r_\phi(x, y)\right]-\beta \mathbb{D}_{\mathrm{KL}}\left[\pi_\theta(y \mid x) \| \pi_{\mathrm{ref}}(y \mid x)\right].\]

Given the partition function $Z(x)=\sum_y \pi_{\mathrm{ref}}(y \mid x) \exp \left(\frac{1}{\beta} r(x, y)\right)$, the closed-form solution takes

\[\pi_r(y \mid x)=\frac{1}{Z(x)} \pi_{\mathrm{ref}}(y \mid x) \exp \left(\frac{1}{\beta} r(x, y)\right)\]

This further gives

\[r(x, y)=\beta \log \frac{\pi_r(y \mid x)}{\pi_{\mathrm{ref}}(y \mid x)}+\beta \log Z(x) .\]

Under the Bradley-Terry model, the preference model follows:

\[p^*\left(y_1 \succ y_2 \mid x\right)=\frac{1}{1+\exp \left(\beta \log \frac{\pi^*\left(y_2 \mid x\right)}{\pi_{\mathrm{ref}}\left(y_2 \mid x\right)}-\beta \log \frac{\pi^*\left(y_1 \mid x\right)}{\pi_{\mathrm{ref}}\left(y_1 \mid x\right)}\right)}\]

So the negative likelihood naturally provides the following DPO loss:

\[\mathcal{L}_{\mathrm{DPO}}(\pi_\theta,\pi_{\mathrm{ref}})=-\mathbb{E}_{\left(x, y_w, y_l\right) \sim \mathcal{D}}\left[\log \sigma\left(\beta \log \frac{\pi_\theta\left(y_w \mid x\right)}{\pi_{\mathrm{ref}}\left(y_w \mid x\right)}-\beta \log \frac{\pi_\theta\left(y_l \mid x\right)}{\pi_{\mathrm{ref}}\left(y_l \mid x\right)}\right)\right].\]

Our code implmented above loss for training.


Code Explanation

- Format Input

def collate_fn(batch):
    input_ids, labels_list = [], []
    for prompt, chosen_resp, rejected_resp in batch:
        # Token IDs for prompt & responses
        p_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]
        c_ids = tokenizer(chosen_resp, add_special_tokens=False)["input_ids"]
        r_ids = tokenizer(rejected_resp, add_special_tokens=False)["input_ids"]
        # Build input IDs
        input_ids += [
            torch.tensor(p_ids + c_ids, dtype = torch.long),
            torch.tensor(p_ids + r_ids, dtype = torch.long)
        ]
        # Build labels: mask prompt with -100, keep response tokens
        labels_list += [
            torch.tensor([-100]*len(p_ids) + c_ids, dtype = torch.long),
            torch.tensor([-100]*len(p_ids) + r_ids, dtype = torch.long)
        ]

    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_mask = (input_ids != tokenizer.pad_token_id)
    labels_tensor = pad_sequence(labels_list, batch_first=True, padding_value=-100)
    assert input_ids.shape == attention_mask.shape and attention_mask.shape == labels_tensor.shape

    return input_ids.to(local_rank), attention_mask.to(local_rank), labels_tensor.to(local_rank)

This function appends chosen and rejected responses sequentially into a batch, pad_sequence ued to align the length of the batch. attention_mask is used as an input for attention score computation, and labels are used for loss computation.

- Forward current model

with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
    outputs = model(input_ids, attention_mask=attention_mask)
logits  = outputs.logits #.float()  # [B*2, T, V]

Here logits denote the output logits of the LLM.

- Forward ref model (no grad)

 with torch.no_grad():
    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
        ref_outputs = ref_model(input_ids, attention_mask=attention_mask)
    ref_logits  = ref_outputs.logits 

Here ref_logits denote the output logits of the reference LLM.

- Shift logits and labels

logits = logits[...,:-1,:].contiguous() # [2B, T-1, V]
ref_logits = ref_logits[...,:-1,:].contiguous() # [2B, T-1, V]
labels = labels[...,1:].contiguous() # [2B, T-1]

This shift is required for cross-entropy loss computation.

- Compute per-token NLLs

V = logits.size(-1)
loss_t  = F.cross_entropy(
    logits.view(-1, V), labels.view(-1),
    ignore_index=-100, reduction="none"
).view(logits.size(0), -1)
ref_loss= F.cross_entropy(
    ref_logits.view(-1, V), labels.view(-1),
    ignore_index=-100, reduction="none"
).view(ref_logits.size(0), -1)

Here loss_t represents elementwise $-\log \pi_\theta(o_t \mid o_{<t},x)$ and ref_loss elementwise $-\log \pi_{\mathrm{ref}}(o_t \mid o_{<t},x)$.

- Sum to get sequence NLL

nll_seq     = loss_t.sum(dim=1)
ref_nll_seq = ref_loss.sum(dim=1)

Note by chain rule $\log\pi(y \mid x) = \log \prod_t \pi(o_t \mid o_{<t},x)=\sum\log \pi(o_t \mid o_{<t}, x)$, this convert elementwise loss_t and ref_loss to $-\log \pi_\theta\left(y_w \mid x\right)$ and $-\log \pi_{\mathrm{ref}}\left(y_w \mid x\right)$.

- Inner

diff_theta = nll_r - nll_c
diff_ref   = ref_r - ref_c
inner      = beta * (diff_theta - diff_ref)

This computes $\beta \log \frac{\pi_\theta\left(y_w \mid x\right)}{\pi_{\mathrm{ref}}\left(y_w \mid x\right)}-\beta \log \frac{\pi_\theta\left(y_l \mid x\right)}{\pi_{\mathrm{ref}}\left(y_l \mid x\right)}$

- DPO loss

dpo_loss   = -F.logsigmoid(inner).mean() 

This computes \(\mathcal{L}_{\mathrm{DPO}}(\pi_\theta,\pi_{\mathrm{ref}})\) . Note F.logsigmoid is crucial for numerical stability.

Miscellaneous

Also see this on twitter: