1FAIR at Meta 2Mila 3Université de Montréal 4Concordia University
DiLoCo is a powerful framework for training large language models (LLMs), enabling larger optimal batch sizes and increased accelerator utilization under networking constraints. However, DiLoCo's performance has been shown to degrade as the number of workers (K) increases. In this work, we posit that a related but often overlooked factor in DiLoCo's behavior is the choice of inner optimizer, which shapes the pseudogradient used by the outer optimizer. Given the recent success of Muon relative to AdamW for data parallel (DP) training, we examine how Muon's normalized optimizer steps can affect the pseudogradient's quality.
We find that, relative to AdamW, Muon yields more directionally correct pseudogradients as the number of workers (K) increases. In our experiments pre-training language models, we conduct extensive hyperparameter tuning across 150M, 416M, 914M, 1.76B, and 3.1B models for DiLoCo, MuLoCo, AdamW DP, and Muon DP. Consistently across all scales, we find that with K≥1 workers, MuLoCo (Muon inner optimizer DiLoCo) achieves superior performance to DiLoCo in absolute terms and for K>2 it outperforms DiLoCo relative to their data parallel baselines, while being compatible with quantization, streaming, and long synchronization intervals.
At K=1, we find that MuLoCo can even outperform the data-parallel gold standard while having larger critical batch sizes. Finally, we extrapolate optimal hyperparameters to 15B scale and train a model with each method (six in total) using K=1 and K=16 workers. We find that K=16 MuLoCo nearly matches single-worker performance at this scale, while MuLoCo K=1 matches the best performing baseline while using a much larger 16M token batch size.
MuLoCo K=1 outperforms DP Muon, DP AdamW, and DiLoCo K=1 at every scale from 150M to 3.1B parameters with extensive hyperparameter tuning.
MuLoCo K=1 matches DP Muon's optimal loss while using 8x larger batch sizes at 3.1B scale, enabling dramatically more parallelism.
For the same wall-clock training time, K=1 MuLoCo reaches up to ~10% lower loss than DP AdamW, thanks to its ability to leverage large batches.
At K>2, MuLoCo's performance relative to its DP baseline degrades more slowly than DiLoCo's, and this advantage is maintained at scale.
Both MuLoCo and DiLoCo achieve effectively lossless communication with 4-bit quantization. MuLoCo outperforms DiLoCo under all compression schemes.
MuLoCo K=1 trains at 16M token batch size while matching DP Muon and DiLoCo K=1 final loss and downstream accuracy at 15B parameters.
Performance vs batch size at 3.1B scale. MuLoCo K=1 has the largest critical batch size and best absolute performance.
Critical batch size scaling laws for K=1. MuLoCo's critical batch size grows faster with model scale than all baselines.
Wall-clock training at 15B. K=16 MuLoCo is fastest under bandwidth constraints (left); K=1 MuLoCo is fastest in high-bandwidth settings (right).
Final evaluation loss across model scales and worker counts. Bold indicates best loss per scale. All methods extensively tuned (2,200+ runs).
| Method | 150M | 416M | 914M | 1.76B | 3.1B | 15B | |
|---|---|---|---|---|---|---|---|
| DP | Muon | 3.124 | 2.641 | 2.402 | 2.246 | 2.128 | 1.864 |
| AdamW | 3.158 | 2.682 | 2.440 | 2.266 | 2.145 | 1.887 | |
| K=1 | MuLoCo | 3.120 | 2.638 | 2.400 | 2.238 | 2.122 | 1.884 |
| DiLoCo | 3.142 | 2.650 | 2.411 | 2.265 | 2.136 | 1.891 | |
| K=16 | MuLoCo | 3.222 | 2.713 | 2.448 | 2.291 | 2.165 | 1.917 |
| DiLoCo | 3.326 | 2.808 | 2.522 | 2.348 | 2.215 | 1.906 |
# PyTorch
pip install "muloco[pytorch] @ git+https://github.com/bentherien/muloco-1.git"
# JAX/Optax
pip install "muloco[jax] @ git+https://github.com/bentherien/muloco-1.git"
# From source
git clone https://github.com/bentherien/muloco-1.git
cd muloco-1
pip install -e ".[pytorch]"
from muloco.pytorch import MuLoCo1, Muon
# Classify params: Muon for 2D+ matrices, scalar optimizer for the rest
param_groups = [
{"params": matrix_params, "algorithm": "muon"},
{"params": other_params, "algorithm": "adamw"},
]
optimizer = MuLoCo1(
params=param_groups,
inner_lr=0.02, # Muon inner LR
outer_lr=0.7, # Outer Nesterov SGD LR
outer_momentum=0.6, # Outer momentum (lower than DiLoCo's 0.8)
sync_interval=30, # H=30 inner steps per outer step
)
# Standard training loop
for batch in dataloader:
loss = model(batch).loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
from muloco.jax import muloco, diloco, muloco_wrapper
# MuLoCo with Muon inner optimizer
opt = muloco(learning_rate=0.02, outer_lr=0.7, outer_momentum=0.6, sync_interval=30)
# Standard optax usage
opt_state = opt.init(params)
updates, opt_state = opt.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
If you use MuLoCo in your research, please cite our paper:
@article{therien2025muloco,
title={MuLoCo: Muon is a Practical Inner Optimizer for DiLoCo},
author={Therien, Benjamin and Huang, Xiaolong and Defazio, Aaron
and Rish, Irina and Belilovsky, Eugene},
journal={arXiv preprint arXiv:2505.23725},
year={2025}
}