Skip to content

dmis-lab/Outlier-Safe-Pre-Training

Repository files navigation

Outlier-Safe Pre-Training

arXiv Models code

This repository contains the evaluation code used in Outlier-Safe Pre-Training for Robust 4-Bit Quantization of Large Language Models. The codebase is heavily based on QuaRot and SpinQuant. The code has been properly adapted for various quantization scenarios.

Introduction

Quantization plays a crucial role in deploying Large Language Models (LLMs) in resource-constrained environments. However, the presence of outlier features significantly hinders low-bit quantization. While many studies address this problem in a post-hoc manner to make use of already pre-trained models, the importance of handling outliers during pre-training is often underestimated.

Our work, Outlier-Safe Pre-Training (OSP), proposes a practical approach to training models that are robust to outliers from the start, without sacrificing performance or efficiency. Specifically, OSP focuses on the following goals:

  1. πŸ“ˆScaling to production-level training requirements
    Prior methods for quantization-friendly pre-training are often limited to small-scale experiments (e.g., models under 1B parameters or 100B tokens). In contrast, we train a 1.4B-parameter model on 1 trillion tokens, demonstrating that OSP is effective at production scale.

  2. ⚑Maintaining computational efficiency comparable to standard training
    A method that prevents outliers but significantly reduces efficiency is unlikely to gain adoption. OSP introduces only a ~2% slowdown while reducing GPU memory usage, making it appealing for those seeking to train quantization-friendly foundation models from scratch.

  3. 🧩Ensuring full compatibility with existing inference pipelines
    We prioritize compatibility with widely adopted inference frameworks such as vLLM and SGLang. Rather than introducing architectural changes that break compatibility, OSP preserves computational invariance, allowing models to be directly integrated into existing pipelines without additional effort.

drawing

News

  • 2025-06-25: Released Outlier-Safe Pre-Training for Robust 4-Bit Quantization of Large Language Models on arXiv, with GitHub and models.
  • 2025-05-16: Our paper has been accepted to ACL 2025! πŸŽ‰

Model Checkpoints

Final Models

The models were trained on 1 trillion tokens, following the pre-training recipe of SmolLM. Specifically, training was conducted using the smollm-corpus, a mixture of FineWeb-Edu, Cosmopedia, and Python-Edu.

Ablation Models

Model Optimizer SSNorm EmbProj Ex. Kurt. Had. 4-4-4
Avg. PPL
πŸ€— OSP-1.4B-100B-Adam Adam βœ— βœ— 1818.56 βœ—
βœ”
26.8
26.9
8e4
3e4
πŸ€— OSP-1.4B-100B-Muon-Only Muon†
(w/o Adam)
βœ— βœ— 361.35 βœ—
βœ”
26.3
33.1
8e5
24.8
πŸ€— OSP-1.4B-100B-Muon Muon βœ— βœ— 1575.12 βœ—
βœ”
29.0
38.4
1e4
15.8
πŸ€— OSP-1.4B-100B-Muon-SSNorm Muon βœ” βœ— 66.69 βœ—
βœ”
36.4
38.3
44.2
34.1
πŸ€— OSP-1.4B-100B-Muon-EmbProj Muon βœ— βœ” 703.23 βœ—
βœ”
30.4
36.2
114.6
22.3
πŸ€— OSP-1.4B-100B-Muon-SSNorm-EmbProj Muon βœ” βœ” 0.04 βœ—
βœ”
37.5
38.9
19.6
13.5
†Model configuration that disables decoupled embedding optimization by training with Muon optimizer without Adam optimization on embedding layers

Getting Started

Environment Setup

To begin, install the libraries needed to evaluate the quantized model's performance. We recommend creating a conda environment using:

$ conda env create -f environment.yaml

Evaluation on WikiText-2

We support evaluating WikiText-2 under various quantization settings:

  1. Round-to-nearest (RTN): This method reduces the bitwidth by rounding normalized values to their nearest baselines. To evaluate performance, run:
    $ bash scripts/1_eval_ppl_rtn.sh [model_path] 4 4 4
  2. FFN Hadamard Rotation: Based on QuaRot, this method applies online Hadamard rotation within the feedforward (FFN) layers.
    $ bash scripts/2_eval_ppl_rtn_had.sh [model_path] 4 4 4
  3. GPTQ: Uses GPTQ for weight quantization to achieve further improvements. FFN rotation is also applied.
    $ bash scripts/3_eval_ppl_gptq_had.sh [model_path] 4 4 4
  4. QuaRot: A full replication of QuaRot, applying Hadamard rotation to all weight parameters while preserving computational equivalence.
    $ bash scripts/4_eval_ppl_quarot.sh [model_path] 4 4 4

The three integers (e.g. 4 4 4) represent the number of bits used for weight, activation, and key-value quantization, respectively. Based on this setup, you can replicate the following table:

Quantization Adam Muon (OSP)
RTN 14475.51 45.92
+ FFN Had‑ 4794.00 19.27
+ GPTQ 3723.46 14.29
+ QuaRot 16.62 14.38
+ SpinQuant 14.94 13.66

‑Only applies Hadamard transform to FFN hidden states.

Evaluation on Lighteval Benchmarks

To further assess model performance on various tasks, including multiple-choice and open-ended questions, we integrate lighteval into our quantization framework. We evaluate on 10 benchmarks, as used in the paper:

  • ARC
  • CommonsenseQA
  • GSM8K
  • HellaSwag
  • MMLU
  • OpenBookQA
  • PIQA
  • SIQA
  • TriviaQA
  • WinoGrande

For the full implementation details and task definitions, refer to lighteval_tasks.py and tasks.txt. You can customize the tasks by following this tutorial. To run the evaluation pipeline, use the following command:

$ python lighteval_ptq.py \
    [model_path] \
    --w_bits 4 \
    --a_bits 4 \
    --kv_bits 4 \
    --tasks tasks_tmp.txt

You can add the following arguments to enable additional configurations: FFN Hadamard Rotation (--rotate_down_proj), GPTQ (--no_rtn), and QuaRot (--no_rtn --rotate). The --rotate flag applies Hadamard rotation to all weight parameters, while --rotate_down_proj enables online rotation within the FFN layers only.

Using vLLM

For faster evaluation, we also provide a lighteval pipeline based on vLLM:

$ python lighteval_ptq_vllm.py \
    [model_path] \
    --w_bits 4 \
    --a_bits 4 \
    --kv_bits 4 \
    --tasks tasks_tmp.txt

Note: Key-value quantization is not supported.

Distributed Muon on TPUs

As described in the paper, we implement distributed Muon to accelerate Newton-Schulz orthogonalization using parallel computation. The full implementation can be found in optimization.py, and can be used as follows:

from optimization import muon

tx = muon(
    learning_rate=5e-4,
    beta=0.95,
    steps=5,  # Number of Newton-Schulz iterations
    eps=1e-8,
    weight_decay=1e-2,
)

This is the SPMD (Single Program, Multiple Data) version of the Muon optimizer. The device mesh should follow the shape (dp, op, fsdp), representing data parallelism, optimizer parallelism, and fully sharded data parallelism, respectively. Note that op is orthogonal to fsdp, and we recommend setting op to evenly divide the number of transformer layers. This will be discussed in more detail later.

For example, on a TPU v4-512 Pod slice, the device mesh can be defined as:

mesh = np.arange(jax.device_count()).reshape(2, 8, 16)
mesh = mesh_utils.create_device_mesh(mesh.shape, allow_split_physical_axes=True)
Mesh(mesh, axis_names=("dp", "op", "fsdp")).__enter__()

In this configuration:

  • The model is replicated across 2 data-parallel groups (dp).
  • Each group shards model parameters and optimizer states across 16 devices (fsdp).
  • Orthogonalization is performed in parallel across 8 devices (op).

How it works?

The implementation has two key components:

  1. Gradient grouping by shape for batched orthogonalization.
  2. Sharding over the op axis to parallelize Newton-Schulz steps.
# Gather gradients with the same shape so they can be orthogonalized together
# and then reassigned to their corresponding parameters.
new_updates, grad_groups, name_groups = {}, defaultdict(list), defaultdict(list)
for name, grad in flatten_dict(updates).items():
    if isinstance(grad, jax.Array):
        grad_groups[grad.shape].append(grad)
        name_groups[grad.shape].append(name)
    else:
        new_updates[name] = grad

Once the gradients are grouped by shape, they are stacked and strictly sharded across devices:

# If gradients can be distributed across devices based on the optimizer
# parallelism rank, they will be stacked and orthogonalized in a single
# operation. Otherwise, they will be normalized individually.
for shape, grads in grad_groups.items():
    print(f"[*] Muon Parallelsim: {len(grads)} x {shape}")
    if len(shape) == 2 and (chunk_size := len(grads) // op_rank * op_rank) > 0:
        chunks, grads = jnp.stack(grads[:chunk_size]), grads[chunk_size:]
        chunks = jax.lax.with_sharding_constraint(chunks, P("op", "fsdp"))
        grads = list(batch_ortho_fn(chunks)) + list(map(batch_ortho_fn, grads))
    else:
        grads = list(map(batch_ortho_fn, grads))
    for name, grad in zip(name_groups[shape], grads):
        if len(grad.shape) > 0:
            grad = jax.lax.with_sharding_constraint(grad, P("fsdp"))
        new_updates[name] = grad

If the stacked shape is not divisible by op, the mismatched gradients are orthogonalized by iterating manually. This design enables Muon to leverage optimizer parallelism effectively, achieving up to 97.9% of the speed of standard Adam.

Citation

@article{park2025osp,
      title={Outlier-Safe Pre-Training for Robust 4-Bit Quantization of Large Language Models}, 
      author={Jungwoo Park and Taewhoo Lee and Chanwoong Yoon and Hyeon Hwang and Jaewoo Kang},
      year={2025},
      eprint={2506.19697},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2506.19697}, 
}

About

[ACL 2025] Outlier-Safe Pre-Training for Robust 4-Bit Quantization of Large Language Models

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published