DDP from Scratch: a learner-friendly guide

Python basics
Today I Learned
Scratch to Scale
PyTorch
Distributed Training
Learn why dictionary comprehensions in python elegantly transform HuggingFace data for models, how kwargs unpacking makes model(batch) ‘just work’, and why gradient averaging vs LR(Learning‑rate) scaling are equivalent in distributed training. Plus: build a mini DDP from scratch to see it all in action.
Author

The Fire Hacker

Published

October 23, 2025

This note is part of the Scratch → Scale series by Zachary Mueller (course link). We’ll implement a toy DDP wrapper, explain why it works, and demystify two Python idioms you’ll see everywhere: dictionary comprehensions and kwargs (argument unpacking).

TL;DR

🔑 Core Python patterns explained:

  • Dictionary comprehensions: Transform raw data (lists, ints) into model-ready tensors in one elegant line — {k: torch.tensor(v).to(device) for k, v in item.items()} converts HuggingFace dataset samples to GPU tensors with proper shapes.
  • Kwargs unpacking (**): Unpack dictionaries into named function arguments — model(**batch) automatically maps dict keys to HuggingFace model’s forward() parameters like input_ids, attention_mask, labels.
  • Gradient averaging ⚖️ learning rate scaling: Dividing gradients by world_size or scaling LR by 1/world_size are mathematically equivalent — the choice is where in your algorithm the division happens: before the optimizer step (average gradients) or after (scale learning rate).

📋 DDP essentials:

  • Seed every process the same way before you create the model.
  • Average grads with dist.all_reduce(param.grad, op=SUM) then divide by world size.
  • Use **kwargs to pass batches to models: model(**batch) works seamlessly with HuggingFace transformers.

0) Visual mental model of distributed training

Rank 0 (GPU0)      Rank 1 (GPU1)      ...
┌──────────────┐   ┌──────────────┐
│ forward      │   │ forward      │  (same model weights)
│ loss.backward│   │ loss.backward│
└──────┬───────┘   └──────┬───────┘
       │   grads            │   grads
       └─────── all_reduce (SUM) ───────▶ (every rank gets sum of all grads)
                    │
              divide by world_size
                    │
                optimizer.step()

1) Seeding: making model replicas identical

Identical initialization across ranks is not optional. If rank 0 samples weights {W} and rank 1 samples different weights {W’}, averaging grads is meaningless. We seed each RNG per process, then construct the model.

def set_seed(seed: int = 43):
    import random, numpy as np, torch
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# In your entry point (each process runs this):
set_seed(43)            # must happen BEFORE model creation
model = build_model()   # identical on all ranks

Why no communication?

Each process runs the exact same Python code with the same seeds → same random draws → identical parameters. No dist.broadcast is required to make them equal, though you can use broadcast to enforce equality (see §3).

Pitfall: Seeding after constructing the model doesn’t retroactively change weights.


2) Two Python idioms you’ll see everywhere

2.1 Dictionary comprehension — Why we need this pattern

This line converts a HuggingFace dataset sample (lists/ints) into a batch dictionary of GPU tensors with an added batch dimension:

item = {k: torch.tensor(v).unsqueeze(0).to(device) for k, v in item.items()}

Why this transformation is essential:

HuggingFace datasets return items as Python dicts with lists and ints:

example = {"input_ids": [101, 2023, ...], "attention_mask": [1, 1, ...], "labels": 0}

But PyTorch models expect GPU tensors with batch dimensions:

batch = {"input_ids": tensor([[101, 2023, ...]], device='cuda:0'), ...}

Why tensors are required:

PyTorch models perform tensor operations (matrix multiplications, slicing, etc.) that require PyTorch tensor objects, not Python lists or integers. If you pass Python lists/ints directly, you’ll get errors like: - TypeError: expected Tensor as element 0 in argument 0, but got list - RuntimeError: Expected all tensors to be on the same device

The dictionary comprehension converts your data to the correct tensor format before passing it to the model. (See §2.2 for how these tensors flow through the model’s forward() method.)

The dictionary comprehension does three transformations in one line:

  1. Preserve structure: Keep the same dict keys (input_ids, attention_mask, etc.)
  2. Convert types: List/int → PyTorch tensor → GPU tensor
  3. Add batch dimension: Shape (seq_len,)(1, seq_len) for batching

Breakdown: * for k, v in item.items() → iterates over each key-value pair * torch.tensor(v) → converts list/int to tensor * .unsqueeze(0) → adds batch dimension: [a, b, c][[a, b, c]] * .to(device) → moves to GPU

Without this transformation, you’d pass Python lists/CPU arrays to the model, which would either error or require slow implicit conversion on each forward pass.

Alternative: Generator-based streaming with yield

For large datasets or memory-constrained scenarios, dictionary comprehensions can be memory-intensive (they build the entire dict in memory). A better approach uses generators with yield for lazy evaluation:

def stream_to_device(item, device):
    """Generator that yields tensors one at a time - memory efficient"""
    for k, v in item.items():
        yield k, torch.tensor(v).unsqueeze(0).to(device)

# Usage: build dict lazily
batch = dict(stream_to_device(example, device))

Why generators are better for large data: * Lazy evaluation: Tensors are created and moved to GPU one at a time, not all at once. * Lower memory footprint: Only one tensor exists in memory during transformation. * Scalable: Works with datasets that don’t fit in RAM.

When to use each: * Dict comprehension: Small to medium batches, simple one-liners, readable code. * Generator with yield: Large datasets, streaming data, memory-constrained environments, production pipelines.

2.2 Kwargs unpacking with ** — The HuggingFace connection

Given item = {"input_ids": X, "attention_mask": Y, "labels": Z}:

out = model(**item)
# exactly the same as:
out = model(input_ids=X, attention_mask=Y, labels=Z)

Why this matters for HuggingFace models:

The ** operator unpacks a dict into named arguments that match your model’s forward() signature. This is why HuggingFace workflows are so elegant:

  1. Dataset has standard keys: HuggingFace datasets/tokenizers output dicts with keys like "input_ids", "attention_mask", "labels".
  2. Model expects those keys: All HuggingFace models have a forward() method that accepts these exact parameter names.
  3. **kwargs bridges them: Instead of manually extracting each key, model(**batch) automatically maps dict keys to function parameters.

Without **kwargs (manual, verbose):

out = model(
    input_ids=batch["input_ids"],
    attention_mask=batch["attention_mask"],
    labels=batch["labels"]
)

With **kwargs (clean, scalable):

out = model(**batch)  # Automatically maps all keys!

This works because HuggingFace models define their forward() signature to match the standard dataset keys. It’s a deliberate design pattern that makes training code incredibly clean.

Tracing the forward() call chain:

When you call model(**batch), the unpacked tensors flow through the model’s forward pass. Here’s the call chain for AutoModelForSequenceClassification:

model(**batch)  # batch contains tensors: {"input_ids": tensor(...), ...}
    ↓
AutoModelForSequenceClassification.from_pretrained(...)
    ↓
SmolLM2ForSequenceClassification  # concrete architecture class
    ↓
GenericForSequenceClassification.forward(**kwargs)
    ↓
    # forward() signature receives unpacked kwargs:
    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        outputs = self.model(input_ids, attention_mask=attention_mask, **kwargs)
        #              ↑ **kwargs unpacking maps dict keys to these parameters
        pooled = outputs[0][:, 0, :]  # CLS token pooling
        logits = self.score(pooled)   # linear classifier head
        loss = self.loss_fn(logits, labels) if labels is not None else None
        return SequenceClassifierOutput(loss=loss, logits=logits, ...)

Key insight: The **batch unpacking automatically maps dictionary keys ("input_ids", "attention_mask", "labels") to the forward() method’s parameter names. This is why model(**batch) works seamlessly — the keys match the function signature exactly.

2.3 Gradient averaging vs learning-rate scaling ⚖️

This is a key insight: When training on N GPUs, you have two mathematically equivalent options for combining gradients:

Option A: Average gradients (most common)

# After backward on each rank
dist.all_reduce(param.grad, op=SUM)
param.grad /= world_size  # Average the gradients

# Optimizer update with normal LR
optimizer.step()  # uses original learning rate

Option B: Sum gradients, scale learning rate

# After backward on each rank  
dist.all_reduce(param.grad, op=SUM)  # Keep summed gradients

# Optimizer update with scaled LR
for param in model.parameters():
    param.data -= (lr / world_size) * param.grad

Why they’re equivalent:

\[ \text{param} - \text{lr} \times \frac{\text{grad}}{N} = \text{param} - \frac{\text{lr}}{N} \times \text{grad} \]

Real-world implications: * PyTorch DDP: Uses Option A (averages gradients), so you keep your learning rate unchanged. * Some frameworks (Horovod, older examples): Use Option B (sum gradients), expecting you to scale LR by 1/world_size. * The division can happen in two places: before the optimizer step (average gradients during sync) or after (scale learning rate during optimizer step) — same math, different location in the algorithm.

Practical tip: The instructor’s comment “it depends where in the algorithm you want the averaging” refers to this choice. Most modern code averages gradients (Option A) because it’s cleaner and doesn’t require you to remember to scale the learning rate.


3) A tiny DDP wrapper (teaching version)

This wrapper (a) verifies parameter equality at init (optionally enforces it) and (b) averages gradients after backward().

import torch
import torch.distributed as dist

class MiniDDP:
    def __init__(self, model: torch.nn.Module, enforce_broadcast: bool = False):
        self.model = model
        self.world_size = dist.get_world_size() if dist.is_initialized() else 1

        # --- verify / enforce identical params across ranks ---
        for p in self.model.parameters():
            # create a rank0 copy to compare/broadcast
            rank0_buf = p.detach().clone()
            dist.broadcast(rank0_buf, src=0)     # everyone receives rank0's tensor
            if enforce_broadcast:
                p.data.copy_(rank0_buf)          # enforce equality (optional)
            else:
                if not torch.equal(p.data, rank0_buf):
                    raise ValueError(
                        "Parameters differ at init. Seed all ranks BEFORE model construction, "
                        "or set enforce_broadcast=True."
                    )

    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def average_grads(self):
        if self.world_size == 1:
            return
        for p in self.model.parameters():
            if p.grad is None:
                continue
            dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
            p.grad.div_(self.world_size)

    # convenience passthroughs
    def train(self):
        self.model.train()
    def eval(self):
        self.model.eval()

Understanding enforce_broadcast:

The enforce_broadcast parameter controls how parameter synchronization is handled at initialization:

  1. enforce_broadcast=False (default): Verifies that all ranks already have identical parameters (e.g., via seeding). If parameters differ, it raises an error. This is the “trust but verify” approach — you’re responsible for ensuring equality (via seeding), and the wrapper checks that you did it correctly.

  2. enforce_broadcast=True: Forces all ranks to use rank 0’s parameters by overwriting each rank’s parameters with rank 0’s values. This is the “belt and suspenders” approach — even if seeding failed or parameters diverged, everyone gets rank 0’s exact state.

Why this mirrors PyTorch’s official DDP:

PyTorch’s DistributedDataParallel always performs parameter synchronization at initialization (like enforce_broadcast=True), but it does so internally, automatically, and efficiently: - It broadcasts parameters from rank 0 to all other ranks during construction - It handles buffers (like BatchNorm running stats) as well - It uses optimized communication patterns (coalesced broadcasts, bucketing)

This initial synchronization is a core part of DDP’s design to ensure all model replicas start with identical weights. As documented in the PyTorch DDP notes: “When a model is wrapped with DDP, the constructor synchronizes the model’s parameters across all processes. This is achieved by broadcasting the parameters from the process with rank 0 to all other processes.”

Key difference: In PyTorch’s DDP, this synchronization happens automatically in the constructor — there’s no user-facing parameter to control it. It’s an internal implementation detail that ensures correctness.

In MiniDDP, we make this synchronization explicit and optional so you can: - See exactly what’s happening (educational value) - Choose to verify vs. enforce (learning about seeding) - Understand the tradeoffs between verification and enforcement

This mirrors what PyTorch’s official DistributedDataParallel does conceptually, but without bucketing, overlap, or autograd hooks. Perfect for learning; use the real DDP for production.


4) Minimal distributed training loop

# torchrun --nproc_per_node=2 train.py

import os, torch, torch.distributed as dist
from torch.optim import Adam
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct"

def set_seed(seed=43):
    import random, numpy as np
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

def main():
    dist.init_process_group("nccl")
    rank  = dist.get_rank();  local_rank = int(os.environ.get("LOCAL_RANK", 0))
    device = torch.device(f"cuda:{local_rank}")

    set_seed(43)  # same on every process BEFORE creating the model

    tok = AutoTokenizer.from_pretrained(MODEL)
    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL, num_labels=2, torch_dtype=torch.bfloat16
    ).to(device)

    ddp = MiniDDP(model, enforce_broadcast=False)
    opt = Adam(ddp.model.parameters(), lr=1e-3)

    ds = load_dataset("glue", "mrpc")
    def encode(ex):
        return tok(ex["sentence1"], ex["sentence2"], padding=True, truncation=True)
    ds = ds.map(encode, batched=True).rename_columns({"label": "labels"})

    # toy per-rank sample (one example per rank to show divergence if not averaged)
    example = ds["train"][rank]
    batch = {k: torch.tensor(v).unsqueeze(0).to(device) for k, v in example.items()}

    ddp.train()
    out = ddp(**batch)         # kwargs unpacking
    out.loss.backward()
    ddp.average_grads()        # <— key! average across ranks
    opt.step(); opt.zero_grad(set_to_none=True)

    if rank == 0:
        print("step ok; loss:", out.loss.item())

    dist.destroy_process_group()

if __name__ == "__main__":
    main()

What just happened?

  • Both ranks ran the same code and created identical models (thanks to seeding).
  • Each rank used a different example (rank index) → losses differ initially.
  • average_grads() made every GPU apply the same averaged update, keeping replicas in lock‑step.

5) Why broadcast at init if we already seed?

Seeding should guarantee equality. The broadcast operation (when enforce_broadcast=True) is a belt‑and‑suspenders option:

  • Protect against forgotten seeds: If you forgot to seed on some ranks, broadcast ensures everyone still starts identical.
  • Handle divergent code paths: If different ranks take different initialization paths, broadcast syncs them.
  • Deal with non‑deterministic ops: Some operations (e.g., certain CUDA kernels) may not be fully deterministic even with seeds.
  • Enable joining late ranks: If a rank joins after initialization, broadcast can sync it to the current state from rank 0.

In practice: With proper seeding (see §1), enforce_broadcast=False (verify mode) is usually sufficient. Use enforce_broadcast=True only if you intend to force‑sync weights at init or are debugging initialization issues.

Note: PyTorch’s official DDP always performs this synchronization automatically (equivalent to enforce_broadcast=True), but hides it from you. MiniDDP makes it explicit so you can learn about the mechanism.


6) Common pitfalls & fixes

  • Different seeds / seeding too late → parameters differ. Fix: call set_seed() before build_model() on every rank.
  • Forgetting to divide after all_reduce(SUM) → LR effectively × world_size. Fix: divide grads (or use op=AVG on newer APIs like reduce_scatter_tensor).
  • Grad is None: layers not used in the forward didn’t receive gradients. Fix: check the graph; guard if p.grad is None: continue.
  • CPU tensors in batch: model expects CUDA tensors. Fix: dictionary comprehension that moves tensors to device.
  • Shape mismatches across ranks: ensure each rank’s micro‑batch has identical shapes (padding or a proper DistributedSampler).
  • NCCL init errors: set MASTER_ADDR/PORT, unique RANK, correct CUDA_VISIBLE_DEVICES.

7) From toy to real DDP

What we built is the core idea. Production torch.nn.parallel.DistributedDataParallel adds:

  • gradient bucketing and overlap with communication;
  • parameter and buffer broadcast on construction (with versioning);
  • autograd hooks for exact timing;
  • mixed precision, static graph optimizations, etc.

Upgrade path: once you grasp the flow above, swap MiniDDP for DistributedDataParallel(model, device_ids=[local_rank]) and use DistributedSampler in your DataLoader.


9) Cheatsheet

  • item = {k: f(v) for k, v in d.items()} → dictionary comprehension.
  • model(**d) → unpack d into named arguments to forward.
  • dist.all_reduce(t, SUM); t /= world_size → average a tensor across ranks.
  • Seed before model creation on every process.
  • If in doubt, force-sync params once with broadcast.

10) Appendix: tiny utilities

def recursively_apply(func, data):
    if isinstance(data, (tuple, list)):
        return type(data)(recursively_apply(func, x) for x in data)
    if isinstance(data, dict):
        return {k: recursively_apply(func, v) for k, v in data.items()}
    return func(data)

# Example: move a nested batch to device
batch = recursively_apply(lambda t: t.to(device) if isinstance(t, torch.Tensor) else t, batch)

11) Bonus: Where does forward() come from with AutoModel?

When we wrote:

model = AutoModelForSequenceClassification.from_pretrained(MODEL, num_labels=2)

that helper inspects the model config and dispatches to the architecture‑specific ...ForSequenceClassification class. For the SmolLM family, that class inherits a generic head that already implements forward().

Call chain at runtime (conceptual):

  • AutoModelForSequenceClassification → ArchitectureForSequenceClassification → GenericForSequenceClassification.forward(**kwargs) → ArchitectureModel.forward(…) → CLS pooling → classifier head (self.score) → loss (if labels)

Minimal shape of that forward():

class GenericForSequenceClassification(PreTrainedModel):
    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        outputs = self.model(input_ids, attention_mask=attention_mask, **kwargs)
        pooled = outputs[0][:, 0, :]
        logits = self.score(pooled)
        loss = self.loss_fn(logits, labels) if labels is not None else None
        return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)

This is why model(**batch) (see §2.2) “just works”: the dict keys map to the generic forward() signature, which calls the backbone’s forward() under the hood.


Happy scaling! If you’re following the course, tag this post as TIL/DDP‑from‑scratch and iterate from here. 🧪🚀

12) Quick Reference: Gradient sync patterns

Summary of the two equivalent approaches (see §2.3 for full explanation):

# Pattern A: Average gradients (PyTorch DDP default)
dist.all_reduce(param.grad, op=SUM)
param.grad /= world_size
param -= lr * param.grad  # Original LR

# Pattern B: Sum gradients, scale LR (Horovod-style)
dist.all_reduce(param.grad, op=SUM)
param -= (lr / world_size) * param.grad  # Scaled LR

Key takeaway: Both produce identical updates. Choose Pattern A for cleaner code that matches PyTorch DDP defaults.