This repository contains the evaluation code used in Outlier-Safe Pre-Training for Robust 4-Bit Quantization of Large Language Models. The codebase is heavily based on QuaRot and SpinQuant. The code has been properly adapted for various quantization scenarios.
Quantization plays a crucial role in deploying Large Language Models (LLMs) in resource-constrained environments. However, the presence of outlier features significantly hinders low-bit quantization. While many studies address this problem in a post-hoc manner to make use of already pre-trained models, the importance of handling outliers during pre-training is often underestimated.
Our work, Outlier-Safe Pre-Training (OSP), proposes a practical approach to training models that are robust to outliers from the start, without sacrificing performance or efficiency. Specifically, OSP focuses on the following goals:
-
πScaling to production-level training requirements
Prior methods for quantization-friendly pre-training are often limited to small-scale experiments (e.g., models under 1B parameters or 100B tokens). In contrast, we train a 1.4B-parameter model on 1 trillion tokens, demonstrating that OSP is effective at production scale. -
β‘Maintaining computational efficiency comparable to standard training
A method that prevents outliers but significantly reduces efficiency is unlikely to gain adoption. OSP introduces only a ~2% slowdown while reducing GPU memory usage, making it appealing for those seeking to train quantization-friendly foundation models from scratch. -
π§©Ensuring full compatibility with existing inference pipelines
We prioritize compatibility with widely adopted inference frameworks such as vLLM and SGLang. Rather than introducing architectural changes that break compatibility, OSP preserves computational invariance, allowing models to be directly integrated into existing pipelines without additional effort.
- 2025-06-25: Released Outlier-Safe Pre-Training for Robust 4-Bit Quantization of Large Language Models on arXiv, with GitHub and models.
- 2025-05-16: Our paper has been accepted to ACL 2025! π
The models were trained on 1 trillion tokens, following the pre-training recipe of SmolLM. Specifically, training was conducted using the smollm-corpus, a mixture of FineWeb-Edu, Cosmopedia, and Python-Edu.
- π€ OSP-1.4B-1T-Adam: Trained on the standard Adam optimizer, without any modifications.
- π€ OSP-1.4B-1T-Muon-SSNorm-EmbProj: Trained on the OSP framework. This is our final model.
Model | Optimizer | SSNorm | EmbProj | Ex. Kurt. | Had. | 4-4-4 | |
---|---|---|---|---|---|---|---|
Avg. | PPL | ||||||
π€ OSP-1.4B-100B-Adam | Adam | β | β | 1818.56 | β β |
26.8 26.9 |
8e4 3e4 |
π€ OSP-1.4B-100B-Muon-Only | Muonβ (w/o Adam) |
β | β | 361.35 | β β |
26.3 33.1 |
8e5 24.8 |
π€ OSP-1.4B-100B-Muon | Muon | β | β | 1575.12 | β β |
29.0 38.4 |
1e4 15.8 |
π€ OSP-1.4B-100B-Muon-SSNorm | Muon | β | β | 66.69 | β β |
36.4 38.3 |
44.2 34.1 |
π€ OSP-1.4B-100B-Muon-EmbProj | Muon | β | β | 703.23 | β β |
30.4 36.2 |
114.6 22.3 |
π€ OSP-1.4B-100B-Muon-SSNorm-EmbProj | Muon | β | β | 0.04 | β β |
37.5 38.9 |
19.6 13.5 |
To begin, install the libraries needed to evaluate the quantized model's performance. We recommend creating a conda environment using:
$ conda env create -f environment.yaml
We support evaluating WikiText-2 under various quantization settings:
- Round-to-nearest (RTN): This method reduces the bitwidth by rounding normalized values to their nearest baselines. To evaluate performance, run:
$ bash scripts/1_eval_ppl_rtn.sh [model_path] 4 4 4
- FFN Hadamard Rotation: Based on QuaRot, this method applies online Hadamard rotation within the feedforward (FFN) layers.
$ bash scripts/2_eval_ppl_rtn_had.sh [model_path] 4 4 4
- GPTQ: Uses GPTQ for weight quantization to achieve further improvements. FFN rotation is also applied.
$ bash scripts/3_eval_ppl_gptq_had.sh [model_path] 4 4 4
- QuaRot: A full replication of QuaRot, applying Hadamard rotation to all weight parameters while preserving computational equivalence.
$ bash scripts/4_eval_ppl_quarot.sh [model_path] 4 4 4
The three integers (e.g. 4 4 4
) represent the number of bits used for weight, activation, and key-value quantization, respectively. Based on this setup, you can replicate the following table:
Quantization | Adam | Muon (OSP) |
---|---|---|
RTN | 14475.51 | 45.92 |
+ FFN Hadβ‘ | 4794.00 | 19.27 |
+ GPTQ | 3723.46 | 14.29 |
+ QuaRot | 16.62 | 14.38 |
+ SpinQuant | 14.94 | 13.66 |
β‘Only applies Hadamard transform to FFN hidden states.
To further assess model performance on various tasks, including multiple-choice and open-ended questions, we integrate lighteval into our quantization framework. We evaluate on 10 benchmarks, as used in the paper:
- ARC
- CommonsenseQA
- GSM8K
- HellaSwag
- MMLU
- OpenBookQA
- PIQA
- SIQA
- TriviaQA
- WinoGrande
For the full implementation details and task definitions, refer to lighteval_tasks.py and tasks.txt. You can customize the tasks by following this tutorial. To run the evaluation pipeline, use the following command:
$ python lighteval_ptq.py \
[model_path] \
--w_bits 4 \
--a_bits 4 \
--kv_bits 4 \
--tasks tasks_tmp.txt
You can add the following arguments to enable additional configurations: FFN Hadamard Rotation (--rotate_down_proj
), GPTQ (--no_rtn
), and QuaRot (--no_rtn --rotate
). The --rotate
flag applies Hadamard rotation to all weight parameters, while --rotate_down_proj
enables online rotation within the FFN layers only.
For faster evaluation, we also provide a lighteval pipeline based on vLLM:
$ python lighteval_ptq_vllm.py \
[model_path] \
--w_bits 4 \
--a_bits 4 \
--kv_bits 4 \
--tasks tasks_tmp.txt
Note: Key-value quantization is not supported.
As described in the paper, we implement distributed Muon to accelerate Newton-Schulz orthogonalization using parallel computation. The full implementation can be found in optimization.py, and can be used as follows:
from optimization import muon
tx = muon(
learning_rate=5e-4,
beta=0.95,
steps=5, # Number of Newton-Schulz iterations
eps=1e-8,
weight_decay=1e-2,
)
This is the SPMD (Single Program, Multiple Data) version of the Muon optimizer. The device mesh should follow the shape (dp
, op
, fsdp
), representing data parallelism, optimizer parallelism, and fully sharded data parallelism, respectively. Note that op
is orthogonal to fsdp, and we recommend setting op to evenly divide the number of transformer layers. This will be discussed in more detail later.
For example, on a TPU v4-512 Pod slice, the device mesh can be defined as:
mesh = np.arange(jax.device_count()).reshape(2, 8, 16)
mesh = mesh_utils.create_device_mesh(mesh.shape, allow_split_physical_axes=True)
Mesh(mesh, axis_names=("dp", "op", "fsdp")).__enter__()
In this configuration:
- The model is replicated across 2 data-parallel groups (
dp
). - Each group shards model parameters and optimizer states across 16 devices (
fsdp
). - Orthogonalization is performed in parallel across 8 devices (
op
).
The implementation has two key components:
- Gradient grouping by shape for batched orthogonalization.
- Sharding over the
op
axis to parallelize Newton-Schulz steps.
# Gather gradients with the same shape so they can be orthogonalized together
# and then reassigned to their corresponding parameters.
new_updates, grad_groups, name_groups = {}, defaultdict(list), defaultdict(list)
for name, grad in flatten_dict(updates).items():
if isinstance(grad, jax.Array):
grad_groups[grad.shape].append(grad)
name_groups[grad.shape].append(name)
else:
new_updates[name] = grad
Once the gradients are grouped by shape, they are stacked and strictly sharded across devices:
# If gradients can be distributed across devices based on the optimizer
# parallelism rank, they will be stacked and orthogonalized in a single
# operation. Otherwise, they will be normalized individually.
for shape, grads in grad_groups.items():
print(f"[*] Muon Parallelsim: {len(grads)} x {shape}")
if len(shape) == 2 and (chunk_size := len(grads) // op_rank * op_rank) > 0:
chunks, grads = jnp.stack(grads[:chunk_size]), grads[chunk_size:]
chunks = jax.lax.with_sharding_constraint(chunks, P("op", "fsdp"))
grads = list(batch_ortho_fn(chunks)) + list(map(batch_ortho_fn, grads))
else:
grads = list(map(batch_ortho_fn, grads))
for name, grad in zip(name_groups[shape], grads):
if len(grad.shape) > 0:
grad = jax.lax.with_sharding_constraint(grad, P("fsdp"))
new_updates[name] = grad
If the stacked shape is not divisible by op
, the mismatched gradients are orthogonalized by iterating manually. This design enables Muon to leverage optimizer parallelism effectively, achieving up to 97.9% of the speed of standard Adam.
@article{park2025osp,
title={Outlier-Safe Pre-Training for Robust 4-Bit Quantization of Large Language Models},
author={Jungwoo Park and Taewhoo Lee and Chanwoong Yoon and Hyeon Hwang and Jaewoo Kang},
year={2025},
eprint={2506.19697},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2506.19697},
}