MuLoCo: Muon is a Practical Inner Optimizer for DiLoCo

Benjamin Therien1,2,3, Xiaolong Huang2,4, Aaron Defazio1, Irina Rish2,3, Eugene Belilovsky2,4

1FAIR at Meta    2Mila    3Université de Montréal    4Concordia University

Abstract

DiLoCo is a powerful framework for training large language models (LLMs), enabling larger optimal batch sizes and increased accelerator utilization under networking constraints. However, DiLoCo's performance has been shown to degrade as the number of workers (K) increases. In this work, we posit that a related but often overlooked factor in DiLoCo's behavior is the choice of inner optimizer, which shapes the pseudogradient used by the outer optimizer. Given the recent success of Muon relative to AdamW for data parallel (DP) training, we examine how Muon's normalized optimizer steps can affect the pseudogradient's quality.

We find that, relative to AdamW, Muon yields more directionally correct pseudogradients as the number of workers (K) increases. In our experiments pre-training language models, we conduct extensive hyperparameter tuning across 150M, 416M, 914M, 1.76B, and 3.1B models for DiLoCo, MuLoCo, AdamW DP, and Muon DP. Consistently across all scales, we find that with K≥1 workers, MuLoCo (Muon inner optimizer DiLoCo) achieves superior performance to DiLoCo in absolute terms and for K>2 it outperforms DiLoCo relative to their data parallel baselines, while being compatible with quantization, streaming, and long synchronization intervals.

At K=1, we find that MuLoCo can even outperform the data-parallel gold standard while having larger critical batch sizes. Finally, we extrapolate optimal hyperparameters to 15B scale and train a model with each method (six in total) using K=1 and K=16 workers. We find that K=16 MuLoCo nearly matches single-worker performance at this scale, while MuLoCo K=1 matches the best performing baseline while using a much larger 16M token batch size.

Key Findings

Beats Data-Parallel at K=1

MuLoCo K=1 outperforms DP Muon, DP AdamW, and DiLoCo K=1 at every scale from 150M to 3.1B parameters with extensive hyperparameter tuning.

Much Larger Critical Batch Sizes

MuLoCo K=1 matches DP Muon's optimal loss while using 8x larger batch sizes at 3.1B scale, enabling dramatically more parallelism.

Pareto-Optimal Training Time

For the same wall-clock training time, K=1 MuLoCo reaches up to ~10% lower loss than DP AdamW, thanks to its ability to leverage large batches.

Better Worker Scaling

At K>2, MuLoCo's performance relative to its DP baseline degrades more slowly than DiLoCo's, and this advantage is maintained at scale.

Lossless 4-bit Compression

Both MuLoCo and DiLoCo achieve effectively lossless communication with 4-bit quantization. MuLoCo outperforms DiLoCo under all compression schemes.

Validated at 15B Scale

MuLoCo K=1 trains at 16M token batch size while matching DP Muon and DiLoCo K=1 final loss and downstream accuracy at 15B parameters.

Results

MuLoCo K=1: Larger Critical Batch Sizes, Better Performance

Critical batch size comparison

Performance vs batch size at 3.1B scale. MuLoCo K=1 has the largest critical batch size and best absolute performance.

K=1 Batch Size Scaling Laws

K=1 batch size scaling laws

Critical batch size scaling laws for K=1. MuLoCo's critical batch size grows faster with model scale than all baselines.

15B Scale: Idealized Wall-Clock Training Time Under Bandwidth Constraints

15B wall clock comparison

Wall-clock training at 15B. K=16 MuLoCo is fastest under bandwidth constraints (left); K=1 MuLoCo is fastest in high-bandwidth settings (right).

Scaling Study Results

Final evaluation loss across model scales and worker counts. Bold indicates best loss per scale. All methods extensively tuned (2,200+ runs).

Method150M416M914M1.76B3.1B15B
DPMuon3.1242.6412.4022.2462.1281.864
AdamW3.1582.6822.4402.2662.1451.887
K=1MuLoCo3.1202.6382.4002.2382.1221.884
DiLoCo3.1422.6502.4112.2652.1361.891
K=16MuLoCo3.2222.7132.4482.2912.1651.917
DiLoCo3.3262.8082.5222.3482.2151.906

Get Started

Install

# PyTorch
pip install "muloco[pytorch] @ git+https://github.com/bentherien/muloco-1.git"

# JAX/Optax
pip install "muloco[jax] @ git+https://github.com/bentherien/muloco-1.git"

# From source
git clone https://github.com/bentherien/muloco-1.git
cd muloco-1
pip install -e ".[pytorch]"

PyTorch Quick Start

from muloco.pytorch import MuLoCo1, Muon

# Classify params: Muon for 2D+ matrices, scalar optimizer for the rest
param_groups = [
    {"params": matrix_params, "algorithm": "muon"},
    {"params": other_params,  "algorithm": "adamw"},
]

optimizer = MuLoCo1(
    params=param_groups,
    inner_lr=0.02,       # Muon inner LR
    outer_lr=0.7,        # Outer Nesterov SGD LR
    outer_momentum=0.6,  # Outer momentum (lower than DiLoCo's 0.8)
    sync_interval=30,    # H=30 inner steps per outer step
)

# Standard training loop
for batch in dataloader:
    loss = model(batch).loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

JAX/Optax Quick Start

from muloco.jax import muloco, diloco, muloco_wrapper

# MuLoCo with Muon inner optimizer
opt = muloco(learning_rate=0.02, outer_lr=0.7, outer_momentum=0.6, sync_interval=30)

# Standard optax usage
opt_state = opt.init(params)
updates, opt_state = opt.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)

Citation

If you use MuLoCo in your research, please cite our paper:

@article{therien2025muloco,
    title={MuLoCo: Muon is a Practical Inner Optimizer for DiLoCo},
    author={Therien, Benjamin and Huang, Xiaolong and Defazio, Aaron
            and Rish, Irina and Belilovsky, Eugene},
    journal={arXiv preprint arXiv:2505.23725},
    year={2025}
}