Skip to content

Commit ffdd8f9

Browse files
llyycchheeqinxuye
andauthored
FEAT: [model] support glm-4.1v-thinking (#3756)
Co-authored-by: qinxuye <qinxuye@gmail.com>
1 parent c889075 commit ffdd8f9

File tree

4 files changed

+263
-1
lines changed

4 files changed

+263
-1
lines changed

xinference/core/chat_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def predict(history, bot, max_tokens, temperature, stream):
359359
if "content" not in delta:
360360
continue
361361
else:
362-
response_content += delta["content"]
362+
response_content += html.escape(delta["content"])
363363
bot[-1][1] = response_content
364364
yield history, bot
365365
history.append(

xinference/model/llm/llm_family.json

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18576,5 +18576,99 @@
1857618576
"#system_numpy#"
1857718577
]
1857818578
}
18579+
},
18580+
{
18581+
"version": 2,
18582+
"context_length": 65536,
18583+
"model_name": "glm-4.1v-thinking",
18584+
"model_lang": [
18585+
"en",
18586+
"zh"
18587+
],
18588+
"model_ability": [
18589+
"chat",
18590+
"vision",
18591+
"reasoning"
18592+
],
18593+
"model_description": "GLM-4.1V-9B-Thinking, designed to explore the upper limits of reasoning in vision-language models.",
18594+
"model_specs": [
18595+
{
18596+
"model_format": "pytorch",
18597+
"model_size_in_billions": 9,
18598+
"model_src": {
18599+
"huggingface": {
18600+
"quantizations": [
18601+
"none"
18602+
],
18603+
"model_id": "THUDM/GLM-4.1V-9B-Thinking",
18604+
"model_revision": "b627c82cd8fc9175ff2b82b33fb439eba260055f"
18605+
},
18606+
"modelscope": {
18607+
"quantizations": [
18608+
"none"
18609+
],
18610+
"model_id": "ZhipuAI/GLM-4.1V-9B-Thinking",
18611+
"model_revision": "master"
18612+
}
18613+
}
18614+
},
18615+
{
18616+
"model_format": "awq",
18617+
"model_size_in_billions": 9,
18618+
"model_src": {
18619+
"huggingface": {
18620+
"quantizations": [
18621+
"Int4"
18622+
],
18623+
"model_id": "dengcao/GLM-4.1V-9B-Thinking-AWQ"
18624+
},
18625+
"modelscope": {
18626+
"quantizations": [
18627+
"Int4"
18628+
],
18629+
"model_id": "dengcao/GLM-4.1V-9B-Thinking-AWQ",
18630+
"model_revision": "master"
18631+
}
18632+
}
18633+
},
18634+
{
18635+
"model_format": "gptq",
18636+
"model_size_in_billions": 9,
18637+
"model_src": {
18638+
"huggingface": {
18639+
"quantizations": [
18640+
"Int4-Int8Mix"
18641+
],
18642+
"model_id": "dengcao/GLM-4.1V-9B-Thinking-GPTQ-Int4-Int8Mix"
18643+
},
18644+
"modelscope": {
18645+
"quantizations": [
18646+
"Int4-Int8Mix"
18647+
],
18648+
"model_id": "dengcao/GLM-4.1V-9B-Thinking-GPTQ-Int4-Int8Mix",
18649+
"model_revision": "master"
18650+
}
18651+
}
18652+
}
18653+
],
18654+
"chat_template": "[gMASK]<sop> {%- for msg in messages %} {%- if msg.role == 'system' %} <|system|> {{ msg.content }} {%- elif msg.role == 'user' %} <|user|>{{ '\n' }} {%- if msg.content is string %} {{ msg.content }} {%- else %} {%- for item in msg.content %} {%- if item.type == 'video' or 'video' in item %} <|begin_of_video|><|video|><|end_of_video|> {%- elif item.type == 'image' or 'image' in item %} <|begin_of_image|><|image|><|end_of_image|> {%- elif item.type == 'text' %} {{ item.text }} {%- endif %} {%- endfor %} {%- endif %} {%- elif msg.role == 'assistant' %} {%- if msg.metadata %} <|assistant|>{{ msg.metadata }} {{ msg.content }} {%- else %} <|assistant|> {{ msg.content }} {%- endif %} {%- endif %} {%- endfor %} {% if add_generation_prompt %}<|assistant|> {% endif %}",
18655+
"stop_token_ids": [
18656+
151329,
18657+
151336,
18658+
151338
18659+
],
18660+
"stop": [
18661+
"<|endoftext|>",
18662+
"<|user|>",
18663+
"<|observation|>"
18664+
],
18665+
"reasoning_start_tag": "<think>",
18666+
"reasoning_end_tag": "</think>",
18667+
"virtualenv": {
18668+
"packages": [
18669+
"transformers>=4.53.2",
18670+
"#system_numpy#"
18671+
]
18672+
}
1857918673
}
1858018674
]
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Copyright 2022-2025 XProbe Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import logging
15+
from concurrent.futures import ThreadPoolExecutor
16+
from threading import Thread
17+
from typing import Any, Dict, Iterator, List, Tuple
18+
19+
import torch
20+
21+
from .....model.utils import select_device
22+
from ...llm_family import LLMFamilyV2, LLMSpecV1, register_transformer
23+
from ...utils import _decode_image
24+
from ..core import register_non_default_model
25+
from .core import PytorchMultiModalModel
26+
27+
logger = logging.getLogger(__name__)
28+
29+
30+
@register_transformer
31+
@register_non_default_model("glm-4.1v-thinking")
32+
class Glm4_1VModel(PytorchMultiModalModel):
33+
@classmethod
34+
def match_json(
35+
cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
36+
) -> bool:
37+
family = model_family.model_family or model_family.model_name
38+
if "glm-4.1v" in family.lower():
39+
return True
40+
return False
41+
42+
def decide_device(self):
43+
device = self._pytorch_model_config.get("device", "auto")
44+
self._device = select_device(device)
45+
46+
def load_processor(self):
47+
from transformers import AutoProcessor
48+
49+
self._processor = AutoProcessor.from_pretrained(self.model_path, use_fast=True)
50+
self._tokenizer = self._processor.tokenizer
51+
52+
def load_multimodal_model(self):
53+
from transformers import Glm4vForConditionalGeneration
54+
55+
kwargs = {"device_map": "auto"}
56+
kwargs = self.apply_bnb_quantization(kwargs)
57+
58+
model = Glm4vForConditionalGeneration.from_pretrained(
59+
self.model_path,
60+
torch_dtype=torch.bfloat16,
61+
**kwargs,
62+
)
63+
self._model = model.eval()
64+
self._device = self._model.device
65+
66+
@staticmethod
67+
def _get_processed_msgs(messages: List[Dict]) -> List[Dict]:
68+
res = []
69+
for message in messages:
70+
role = message["role"]
71+
content = message["content"]
72+
if isinstance(content, str):
73+
res.append({"role": role, "content": content})
74+
else:
75+
texts = []
76+
image_urls = []
77+
for c in content:
78+
c_type = c.get("type")
79+
if c_type == "text":
80+
texts.append(c["text"])
81+
else:
82+
assert (
83+
c_type == "image_url"
84+
), "Please follow the image input of the OpenAI API."
85+
image_urls.append(c["image_url"]["url"])
86+
if len(image_urls) > 1:
87+
raise RuntimeError("Only one image per message is supported")
88+
image_futures = []
89+
with ThreadPoolExecutor() as executor:
90+
for image_url in image_urls:
91+
fut = executor.submit(_decode_image, image_url)
92+
image_futures.append(fut)
93+
images = [fut.result() for fut in image_futures]
94+
assert len(images) <= 1
95+
text = " ".join(texts)
96+
if images:
97+
content = [
98+
{"type": "image", "image": images[0]},
99+
{"type": "text", "text": text},
100+
]
101+
res.append({"role": role, "content": content})
102+
else:
103+
res.append(
104+
{"role": role, "content": {"type": "text", "text": text}}
105+
)
106+
return res
107+
108+
def build_inputs_from_messages(
109+
self,
110+
messages: List[Dict],
111+
generate_config: Dict,
112+
):
113+
msgs = self._get_processed_msgs(messages)
114+
inputs = self._processor.apply_chat_template(
115+
msgs,
116+
add_generation_prompt=True,
117+
tokenize=True,
118+
return_tensors="pt",
119+
return_dict=True,
120+
) # chat mode
121+
inputs = inputs.to(self._model.device)
122+
return inputs
123+
124+
def get_stop_strs(self) -> List[str]:
125+
return ["<|endoftext|>"]
126+
127+
def get_builtin_stop_token_ids(self) -> Tuple:
128+
from transformers import AutoConfig
129+
130+
return tuple(AutoConfig.from_pretrained(self.model_path).eos_token_id)
131+
132+
def build_generate_kwargs(
133+
self,
134+
generate_config: Dict,
135+
) -> Dict[str, Any]:
136+
return dict(
137+
do_sample=True,
138+
top_p=generate_config.get("top_p", 1e-5),
139+
repetition_penalty=generate_config.get("repetition_penalty", 1.1),
140+
top_k=generate_config.get("top_k", 2),
141+
max_new_tokens=generate_config.get("max_tokens", 512),
142+
)
143+
144+
def build_streaming_iter(
145+
self,
146+
messages: List[Dict],
147+
generate_config: Dict,
148+
) -> Tuple[Iterator, int]:
149+
from transformers import TextIteratorStreamer
150+
151+
generate_kwargs = self.build_generate_kwargs(generate_config)
152+
inputs = self.build_inputs_from_messages(messages, generate_config)
153+
streamer = TextIteratorStreamer(
154+
tokenizer=self._tokenizer,
155+
timeout=60,
156+
skip_prompt=True,
157+
skip_special_tokens=False,
158+
)
159+
kwargs = {
160+
**inputs,
161+
**generate_kwargs,
162+
"streamer": streamer,
163+
}
164+
logger.debug("Generate with kwargs: %s", generate_kwargs)
165+
t = Thread(target=self._model.generate, kwargs=kwargs)
166+
t.start()
167+
return streamer, len(inputs.input_ids[0])

xinference/model/llm/vllm/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ class VLLMGenerateConfig(TypedDict, total=False):
264264

265265
if VLLM_INSTALLED and vllm.__version__ >= "0.9.2":
266266
VLLM_SUPPORTED_CHAT_MODELS.append("Ernie4.5")
267+
VLLM_SUPPORTED_VISION_MODEL_LIST.append("glm-4.1v-thinking")
267268

268269

269270
class VLLMModel(LLM):

0 commit comments

Comments
 (0)