Skip to content

Commit 892eb42

Browse files
merveenoyankashifsergiopaniego
authored
TRL multimodal update blog (#3005)
--------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
1 parent e600cfd commit 892eb42

File tree

3 files changed

+238
-1
lines changed

3 files changed

+238
-1
lines changed

_blog.yml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6473,6 +6473,7 @@
64736473
- community
64746474
- open-source
64756475

6476+
64766477
- local: welcome-openai-gpt-oss
64776478
title: "Welcome GPT OSS, the new open-source model family from OpenAI!"
64786479
author: reach-vb
@@ -6484,4 +6485,15 @@
64846485
- gpt-oss
64856486
- open-source
64866487
- LLM
6487-
- community
6488+
- community
6489+
6490+
- local: trl-vlm-alignment
6491+
title: "Vision Language Model Alignment in TRL ⚡️"
6492+
author: qgallouedec
6493+
thumbnail: /blog/assets/trl_vlm/thumbnail.png
6494+
date: Aug 7, 2025
6495+
tags:
6496+
- trl
6497+
- vlm
6498+
- vision
6499+

assets/trl_vlm/thumbnail.png

316 KB
Loading

trl-vlm-alignment.md

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
---
2+
title: "Vision Language Model Alignment in TRL ⚡️"
3+
thumbnail: assets/trl_vlm/thumbnail.png
4+
authors:
5+
- user: qgallouedec
6+
- user: kashif
7+
- user: sergiopaniego
8+
- user: merve
9+
- user: ariG23498
10+
---
11+
12+
# Vision Language Model Alignment in TRL ⚡️
13+
14+
## Introduction
15+
16+
Vision Language Models (VLMs) are getting stronger, but *aligning* them to human preferences still matters. In TRL, we already showed how to post-train VLMs with [**Supervised Fine-Tuning (SFT)**](https://huggingface.co/docs/trl/main/en/training_vlm_sft) and [**Direct Preference Optimization (DPO)**](https://huggingface.co/learn/cookbook/fine_tuning_vlm_dpo_smolvlm_instruct). This time, we’re going further.
17+
18+
**tl;dr** We have added two new multimodal alignment methods to TRL: **Group Relative Policy Optimization (GRPO)**, its variant **Group Sequence Policy Optimization (GSPO)**, and **Mixed Preference Optimization (MPO)**. All of them let you go beyond pairwise DPO, extracting more signal from preference data and scaling better with modern VLMs. We release training scripts and demo notebooks to easily get started with them!
19+
20+
## Table of Contents
21+
22+
- [Multimodal Alignment for VLMs in TRL ⚡️](#multimodal-alignment-for-vlms-in-trl-️)
23+
- [Introduction](#introduction)
24+
- [Alignment for Vision Language Models](#alignment-for-vision-language-models)
25+
- [Mixed Preference Optimization (MPO)](#mixed-preference-optimization-mpo)
26+
- [Multimodal Group Relative Policy Optimization (GRPO)](#multimodal-group-relative-policy-optimization-grpo)
27+
- [Group Sequence Policy Optimization (GSPO)](#group-sequence-policy-optimization-gspo)
28+
- [Comparison](#comparison)
29+
- [vLLM Integration in TRL](#vllm-integration-in-trl)
30+
- [Useful Resources](#useful-resources)
31+
32+
## Alignment for Vision Language Models
33+
34+
Traditionally, you would take a base model, apply SFT to follow instructions, and then apply DPO to align it to preferential data. Previously, [we adapted this approach to Vision Language Models (VLMs)](https://huggingface.co/blog/dpo_vlm) and validated it on IDEFICS2, showing improvement in model responses.
35+
36+
DPO works by optimizing preferences between pairs of model responses using a contrastive loss: you have a chosen and a rejected answer and you optimize your preferences based on what you want and don’t want.
37+
38+
But in the last year, new multimodal alignment methods have gained popularity, GRPO and MPO, that can push VLM performance even further. At the end of the blog post you can find a table that showcases the differences between model responses.
39+
40+
### Mixed Preference Optimization (MPO)
41+
42+
Aligning multimodal models with SFT to do reasoning tasks fall short due to distribution shift. Meanwhile, models aligned with DPO fail to generate coherent rationales and might generate repetitive responses. To address this, there’s a new technique called [Mixed Preference Optimization](https://huggingface.co/papers/2411.10442) (MPO) specifically made for multimodal models. This method is essentially an extension of DPO with multiple losses: preference loss from DPO (sigmoid), quality loss from Binary Classifier Optimization (BCO), and generation loss from SFT. According to the [paper](https://huggingface.co/papers/2411.10442), simply switching to this combined loss results in 6.2 pts improvement in MathVista!
43+
44+
![MPO](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/trl-vlm/image_1.png)
45+
46+
Since this is only modifying the loss, we added combined loss support to TRL's `DPOTrainer` class. To use it, you can initialize the `DPOConfig` as follows:
47+
48+
```python
49+
mpo_config = DPOConfig(
50+
output_dir=tmp_dir,
51+
per_device_train_batch_size=2,
52+
learning_rate=9e-1,
53+
loss_type=["sigmoid", "bco_pair", "sft"], # Loss types to combine, as used in the MPO paper
54+
loss_weights=[0.8, 0.2, 1.0], # Corresponding weights, as used in the MPO paper
55+
report_to="none",
56+
bf16=False,
57+
fp16=False,
58+
use_cpu=True,
59+
max_steps=1,
60+
)
61+
```
62+
63+
Then initialize the `DPOTrainer`:
64+
65+
```python
66+
mpo_trainer = DPOTrainer(
67+
model=model_id,
68+
args=mpo_config,
69+
processing_class=tokenizer,
70+
train_dataset=dataset,
71+
)
72+
mpo_trainer.train()
73+
```
74+
75+
And that’s it! If you want to explore further, you can find a complete notebook example [here](https://huggingface.co/learn/cookbook/fine_tuning_vlm_mpo).
76+
77+
### Multimodal Group Relative Policy Optimization (GRPO)
78+
79+
Group Relative Policy Optimization (GRPO) is a cutting-edge alignment method initially introduced in [DeepSeek Math](https://huggingface.co/papers/2402.03300) paper and later integrated to DeepSeek R1, the groundbreaking LLM. It’s an addition to PPO where the policy updates are done over groups (batches of trajectories that represent how a dialogue rolls out). This feature makes it more robust to reward noise, as the noise averages out within groups. Since the model learns broader sense of a good response rather than singular high reward samples, this method also makes the model highly performant.
80+
81+
![GRPO](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/trl-vlm/image_2.png)
82+
83+
In TRL, we now introduce GRPO support for vision language models. We will not provide a full training script example, as you can find it in the notebook. Instead, we'll focus on highlighting the key component and concepts.
84+
85+
To make the training script work effectively, we need to validate that the format of the answer is correct and that the solution itself is close to the completed parts, so we write two reward functions. In order to really see improvements in the latter reward, you would need a rather maximalist setup, where you have relatively larger models, a lot of generations, and a high-quality, diverse dataset.
86+
87+
```python
88+
import re
89+
from math_verify import LatexExtractionConfig, parse, verify
90+
91+
def format_reward(completions, **kwargs):
92+
"""Reward function that checks if the completion has a specific format."""
93+
pattern = r"^<think>.*?</think>s*<answer>.*?</answer>$"
94+
matches = [re.match(pattern, content) for content in completions]
95+
rewards_list = [1.0 if match else 0.0 for match in matches]
96+
rewards = [1.0 if match else 0.0 for match in matches]
97+
print(completions)
98+
print(rewards)
99+
return rewards
100+
101+
def accuracy_reward(completions, **kwargs):
102+
"""Reward function that checks if the completion is the same as the ground truth."""
103+
solutions = kwargs['solution']
104+
completion_contents = [completion[0]["content"] for completion in completions]
105+
rewards = []
106+
for content, solution in zip(completion_contents, solutions):
107+
gold_parsed = parse(solution, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
108+
answer_parsed = parse(content, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
109+
if len(gold_parsed) != 0:
110+
try:
111+
rewards.append(float(verify(answer_parsed, gold_parsed)))
112+
except Exception:
113+
rewards.append(0.0)
114+
else:
115+
rewards.append(1.0)
116+
return rewards
117+
```
118+
119+
Then, you can initialize GRPOConfig and GRPOTrainer, pass in the reward functions we defined above and call train() to start training.
120+
121+
```python
122+
from trl import GRPOConfig
123+
124+
training_args = GRPOConfig(
125+
learning_rate=1e-5,
126+
remove_unused_columns=False,
127+
max_prompt_length=None,
128+
.. # setup other params of choice here
129+
)
130+
trainer = GRPOTrainer(
131+
model=model,
132+
reward_funcs=[format_reward, accuracy_reward],
133+
args=training_args,
134+
train_dataset=train_dataset,
135+
processing_class=processor
136+
)
137+
trainer.train()
138+
```
139+
140+
Explore the full notebook example [here](https://huggingface.co/learn/cookbook/fine_tuning_vlm_grpo_trl).
141+
142+
### Group Sequence Policy Optimization (GSPO)
143+
144+
[Group Sequence Policy Optimization](https://huggingface.co/papers/2507.18071) (GSPO) is a RL alignment algorithm recently released by Qwen that overcomes some limitations of GRPO. It achieves a more stable training computing importance sampling weights at the sequence level instead of per-token. Its benefits are more [relevant](https://github.com/volcengine/verl/pull/2775#issuecomment-3134375131) in MoE style models.
145+
146+
Latest TRL also introduces supports for GSPO and since it’s a variant of GRPO's loss, it comes with multimodal support. To create the trainer, the process is the same as with GRPO, but adding the following extra params (values are extracted from the paper).
147+
148+
```python
149+
from trl import GRPOConfig
150+
151+
training_args = GRPOConfig(
152+
...
153+
importance_sampling_level="sequence",
154+
epsilon=3e-4,
155+
epsilon_high=4e-4,
156+
beta=0.0,
157+
loss_type="grpo",
158+
steps_per_generation=1,
159+
steps_per_generation=4
160+
)
161+
```
162+
163+
### Comparison
164+
165+
Here's a table summarizing model outputs for Qwen2.5VL-3B fine-tuned with the techniques discussed above. Note that we've done minimal runs on dataset subsets, and the models were fine-tuned on different datasets, so the comparison is made for vibe-check.
166+
167+
<details><summary>Check out the answers here</summary>
168+
169+
**Image:** ![Image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/vlm_math_question.png)
170+
171+
**Question:** In the given figure, if the angle APO is 25.0 degrees, what is the measure of angle AOB? Choices: A: 140° B: 130° C: 120° D: 110°
172+
173+
**Answer in the Dataset:** To find the measure of angle AOB, we can use the property that the sum of the interior angles of a quadrilateral is equal to 360 degrees. Applying this property, we can calculate angle AOB as 360 - angle OAP - angle OBP - angle P. Since angle OAP and angle OBP are both 90 degrees, and angle P is given as 50 degrees, we can substitute these values into the equation to get angle AOB = 360 - 90 - 90 - 50 = 130 degrees. Therefore, the measure of angle AOB is 130 degrees. Therefore, the answer is option B. Answer:B
174+
175+
**Base Model Output:** To find the measure of angle AOB, we need to use the properties of angles subtended by the same arc in a circle. Specifically, the angle subtended by an arc at the center of the circle (angle AOB) is twice the angle subtended by the same arc at any point on the circumference (angle APO).Given that angle APO = 25°, we can calculate angle AOB as follows: Since angle AOB is twice angle APO, we have:[ Angle AOB = 2 * Angle APO ][ Angle AOB = 2 * 25° ][ Angle AOB = 50° ]However, this calculation seems incorrect based on the provided choices. Let's re-evaluate using the correct property.The correct property states that the angle subtended by an arc at the center of the circle is twice the angle subtended by the same arc at any point on the circumference. Therefore, the correct calculation should be: Angle AOB = 2 * Angle APO = 2 * 25° = 50°But since this does not match the given choices, let's consider another approach. If angle APO is 25°, and considering the properties of cyclic quadrilaterals or the fact that the sum of opposite angles in a cyclic quadrilateral is 180°, we might need to look at the supplementary angles formed by the arc and the chord.Let's assume the correct interpretation involves finding the supplementary angle to the given angle APO which would give us the measure of angle AOB. Angle AOB = 180° - Angle APO = 180° - 25° = 155°This still doesn't match the choices. Given the choices provided, it seems there might be a misunderstanding or a different interpretation needed. However, based on standard geometric principles, the correct answer should be: Angle AOB = 2 * Angle APO = 2 * 25° = 50° If this is not the case, please provide additional context or clarify the problem statement.
176+
177+
**MPO Output:** To find the measure of angle AOB, we need to use the properties of angles subtended by the same arc in a circle. Specifically, the angle subtended by an arc at the center of the circle (angle AOB) is twice the angle subtended by the same arc at any point on the circumference (angle APO).Given that angle APO = 25°, we can calculate angle AOB as follows: Since angle AOB is twice the angle APO, we have:[ Angle AOB = 2 * Angle APO ][ Angle AOB = 2 * 25° ][ Angle AOB = 50° ]However, this calculation seems incorrect based on the provided choices. Let's re-evaluate using the correct property.The correct property states that the angle subtended by an arc at the center of the circle is twice the angle subtended by the same arc at any point on the circumference. Therefore, the correct calculation should be: Angle AOB = 2 * Angle APO = 2 * 25° = 50°But since this does not match the given choices, let's consider another approach. If we look at the triangle OAP, angle AOP is 90° because it is a right angle (as it is inscribed in a semicircle). Now, angle APO is given as 25°, so angle AOA would be 65° (since it is supplementary to angle APO). Therefore, angle AOB would be half of this, which is 32.5°. However, this also doesn't match the choices.Let's consider the possibility that the problem might involve a different interpretation or a different geometric configuration. Given the choices, let's assume there might be a mistake in the problem statement or the choices themselves. If we were to choose from the given options, none of them seem to fit correctly based on standard geometric principles for this type of problem.
178+
179+
**GRPO Output**: To find the measure of angle AOB, we need to understand that angle AOB is an inscribed angle that intercepts the same arc as the central angle AOP. According to the Inscribed Angle Theorem, the measure of an inscribed angle is half the measure of the central angle that subtends the same arc.Given:- Angle APO = 25.0 degreesSince angle AOB is an inscribed angle intercepting the same arc as the central angle AOP, we can use the Inscribed Angle Theorem:The measure of angle AOB is half the measure of angle AOP.Therefore, the measure of angle AOB is 130 degrees.
180+
181+
182+
**GSPO Output:** To find the measure of angle AOB, we need to understand that angle AOB is an inscribed angle that intercepts the same arc as the central angle AOP. According to the Inscribed Angle Theorem, the measure of an inscribed angle is half the measure of the central angle that subtends the same arc.Given:- Angle APO = 25.0 degreesSince angle AOB is an inscribed angle that intercepts the same arc as the central angle AOP, we can use the Inscribed Angle Theorem to find the measure of angle AOB:The measure of angle AOB is half the measure of angle AOP.Therefore, the answer is B: 130°.
183+
184+
</details>
185+
186+
187+
## vLLM Integration in TRL
188+
189+
vLLM is integrated in TRL to support online alignment methods where you need to generate samples during training. Running the example scripts like the following enables vLLM:
190+
191+
```bash
192+
CUDA_VISIBLE_DEVICES=1,2 python3 examples/scripts/grpo_vlm.py --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct … --log_completions —use_vllm —vlm_mode colocate
193+
```
194+
195+
There’s mainly two modes: `colocate` and `server`. [`colocate`](https://huggingface.co/blog/vllm-colocate) runs vLLM in the same process as the training loop, sharing the same GPU between training and generation, creating a vLLM LLM instance inside the `GRPOTrainer`. Meanwhile `server` requires you to serve vLLM separately in a different process where you can hit the server. You can start this server with the command:
196+
197+
```bash
198+
trl vllm-serve --model Qwen/Qwen2.5-VL-3B-Instruct --tensor-parallel-size 1
199+
```
200+
201+
Then you can run the script as follows.
202+
203+
```bash
204+
CUDA_VISIBLE_DEVICES=1,2 python3 examples/scripts/grpo_vlm.py --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct … --log_completions —use_vllm —vlm_mode server
205+
```
206+
207+
One more tip: we have added support for using vLLM with transformers backend in TRL. You can enable it when running a script with colocate or when serving the model by passing the `--vllm_model_impl transformers` flag.
208+
209+
You can read more about vLLM integration in TRL [here](https://huggingface.co/docs/trl/en/vllm_integration).
210+
211+
## Useful Resources
212+
213+
Below, you can find a compilation of resources to explore the alignment of VLMs in detail. Enjoy!
214+
215+
- [**Vision Language Models (Better, Faster, Stronger)**](https://huggingface.co/blog/vlms-2025)
216+
- [**Enhancing the Reasoning Ability of Multimodal Large Language Models via Mixed Preference Optimization**](https://huggingface.co/papers/2411.10442) (**MPO paper**)
217+
- [**DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Model**](https://huggingface.co/papers/2402.03300) (**GRPO paper**)
218+
- [**Open-R1**](https://github.com/huggingface/open-r1) **repository** and [**Open-R1 reward functions**](https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py)
219+
- [**TRL documentation**](https://huggingface.co/docs/trl/en/index) and [**TRL repository**](https://github.com/huggingface/trl)
220+
- [**MPO VLM recipe**](https://huggingface.co/learn/cookbook/fine_tuning_vlm_mpo)
221+
- [**GRPO VLM recipe**](https://huggingface.co/learn/cookbook/fine_tuning_vlm_grpo_trl)
222+
- [**More multimodal alignment recipes**](https://huggingface.co/learn/cookbook/index)
223+
- [**TRL multimodal training scripts**](https://github.com/huggingface/trl/tree/main/examples/scripts)
224+
- [**vLLM Integration in trl docs**](https://huggingface.co/docs/trl/en/vllm_integration)
225+
- [**Transformers backend integration in vLLM**](https://blog.vllm.ai/2025/04/11/transformers-backend.html)

0 commit comments

Comments
 (0)