Training¶
Training is driven by gpt-simple train (or gpt_simple.train(...) from
Python) and configured entirely through the training, optimizer, and
data config sections. Multi-GPU is handled with
Accelerate.
Running¶
Single GPU:
gpt-simple train --config config.yaml
Multi-GPU (launches torchrun automatically):
gpt-simple train --config config.yaml --nproc_per_node 4
Override config values at launch, and start fresh if needed:
gpt-simple train --config config.yaml \
--training.max_steps 5000 \
--optimizer.learning_rate 1e-4 \
--force # discard existing checkpoints in output_dir
From Python:
import gpt_simple
result = gpt_simple.train(config="config.yaml")
print(result.final_loss, result.total_tokens, result.checkpoint_path)
Effective batch size¶
The number of tokens per optimizer step is:
per_device_batch_size × gradient_accumulation_steps × world_size × max_length
Per-rank batch size is independent of world_size: adding GPUs scales the
global batch rather than shrinking each GPU's work. The learning-rate
schedule advances once per optimizer step regardless of accumulation or
GPU count.
Mixed precision¶
training.mixed_precision selects the compute precision. Left as null
(the default) it auto-detects per device: bf16 on Ampere and newer,
fp16 on older CUDA GPUs, and no mixed precision on CPU. Prefer bf16
where available — it has the dynamic range of fp32 and needs no loss
scaling. See Hardware tuning.
torch.compile¶
training.compile: true (default) wraps the model with torch.compile
for a meaningful throughput gain. Compilation happens once, on the first
step (so step 1 is slow). The attention call is treated as an opaque
op and is not decomposed by the compiler.
compile + DDP + gradient checkpointing¶
When all three of multi-GPU (DDP), compile: true, and
gradient_checkpointing: true are on, the trainer disables Dynamo's DDP
graph-splitter (torch._dynamo.config.optimize_ddp = False) before
compiling. The graph-splitter does not support the higher-order ops that
gradient checkpointing introduces, and recent PyTorch hard-errors on the
combination (pytorch/pytorch#104674).
torch.compile itself stays fully enabled — only the graph splitting is
turned off, so the module compiles as a single graph. The only cost is
slightly less communication/compute overlap (the whole graph is one
gradient bucket), which is minor on a single NVLink node and more
noticeable across multiple nodes. This setting is a no-op without DDP.
A compile-compatible bucketed reducer (
torch._dynamo.config.optimize_ddp = "python_reducer") can recover most of that overlap on newer PyTorch. It is less battle-tested than the safe default; validate on a short multi-GPU run before adopting it for a long job.
Gradient checkpointing¶
training.gradient_checkpointing: true (default) recomputes each block's
activations during the backward pass instead of storing them, trading
compute for memory. Turn it off if you have memory headroom and want
maximum throughput.
Logging and evaluation¶
logging_steps— interval for loss, learning rate, gradient norm, and throughput.eval_steps— interval for validation over theval/data (max_eval_batchescaps the work).save_steps— checkpoint interval; see Checkpointing & resume.
Weights & Biases¶
Set training.wandb_project to enable W&B (install the extra:
pip install ".[wandb]", then wandb login). The run id is persisted in
the checkpoint, so a stop/resume chain reports as one continuous run.
Leave wandb_project unset to disable logging entirely.
Validation before training¶
gpt-simple validate --config config.yaml checks a config (and,
optionally, runtime/memory feasibility) without starting a run — useful
as a submission gate. The trainer also runs validation automatically at
startup.
Authoritative source: src/gpt_simple/train.py.