Skip to content

Commit 3cb1367

Browse files
authored
FEAT: support MLX engine (#1765)
1 parent 8fff9e7 commit 3cb1367

File tree

12 files changed

+661
-13
lines changed

12 files changed

+661
-13
lines changed

.github/workflows/python.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ jobs:
8282
- { os: windows-latest, python-version: 3.10 }
8383
include:
8484
- { os: self-hosted, module: gpu, python-version: 3.9}
85+
- { os: macos-latest, module: metal, python-version: "3.10" }
8586

8687
steps:
8788
- name: Check out code
@@ -109,6 +110,9 @@ jobs:
109110
sudo rm -rf "/usr/local/share/boost"
110111
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
111112
fi
113+
if [ "$MODULE" == "metal" ]; then
114+
pip install mlx-lm
115+
fi
112116
pip install "llama-cpp-python==0.2.77"
113117
pip install transformers
114118
pip install attrdict
@@ -162,6 +166,10 @@ jobs:
162166
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
163167
-W ignore::PendingDeprecationWarning \
164168
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/audio/tests/test_chattts.py
169+
elif [ "$MODULE" == "metal" ]; then
170+
pytest --timeout=1500 \
171+
-W ignore::PendingDeprecationWarning \
172+
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/llm/mlx/tests/test_mlx.py
165173
else
166174
pytest --timeout=1500 \
167175
-W ignore::PendingDeprecationWarning \

doc/source/models/builtin/llm/qwen2-instruct.rst

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,71 @@ chosen quantization method from the options listed above::
206206
xinference launch --model-engine ${engine} --model-name qwen2-instruct --size-in-billions 72 --model-format awq --quantization ${quantization}
207207

208208

209-
Model Spec 13 (ggufv2, 0_5 Billion)
209+
Model Spec 13 (mlx, 0_5 Billion)
210+
++++++++++++++++++++++++++++++++++++++++
211+
212+
- **Model Format:** mlx
213+
- **Model Size (in billions):** 0_5
214+
- **Quantizations:** 4-bit
215+
- **Engines**: MLX
216+
- **Model ID:** Qwen/Qwen2-0.5B-Instruct-MLX
217+
- **Model Hubs**: `Hugging Face <https://huggingface.co/Qwen/Qwen2-0.5B-Instruct-MLX>`__, `ModelScope <https://modelscope.cn/models/qwen/Qwen2-0.5B-Instruct-MLX>`__
218+
219+
Execute the following command to launch the model, remember to replace ``${quantization}`` with your
220+
chosen quantization method from the options listed above::
221+
222+
xinference launch --model-engine ${engine} --model-name qwen2-instruct --size-in-billions 0_5 --model-format mlx --quantization ${quantization}
223+
224+
225+
Model Spec 14 (mlx, 1_5 Billion)
226+
++++++++++++++++++++++++++++++++++++++++
227+
228+
- **Model Format:** mlx
229+
- **Model Size (in billions):** 1_5
230+
- **Quantizations:** 4-bit
231+
- **Engines**: MLX
232+
- **Model ID:** Qwen/Qwen2-1.5B-Instruct-MLX
233+
- **Model Hubs**: `Hugging Face <https://huggingface.co/Qwen/Qwen2-1.5B-Instruct-MLX>`__, `ModelScope <https://modelscope.cn/models/qwen/Qwen2-1.5B-Instruct-MLX>`__
234+
235+
Execute the following command to launch the model, remember to replace ``${quantization}`` with your
236+
chosen quantization method from the options listed above::
237+
238+
xinference launch --model-engine ${engine} --model-name qwen2-instruct --size-in-billions 1_5 --model-format mlx --quantization ${quantization}
239+
240+
241+
Model Spec 15 (mlx, 7 Billion)
242+
++++++++++++++++++++++++++++++++++++++++
243+
244+
- **Model Format:** mlx
245+
- **Model Size (in billions):** 7
246+
- **Quantizations:** 4-bit
247+
- **Engines**: MLX
248+
- **Model ID:** Qwen/Qwen2-7B-Instruct-MLX
249+
- **Model Hubs**: `Hugging Face <https://huggingface.co/Qwen/Qwen2-7B-Instruct-MLX>`__, `ModelScope <https://modelscope.cn/models/qwen/Qwen2-7B-Instruct-MLX>`__
250+
251+
Execute the following command to launch the model, remember to replace ``${quantization}`` with your
252+
chosen quantization method from the options listed above::
253+
254+
xinference launch --model-engine ${engine} --model-name qwen2-instruct --size-in-billions 7 --model-format mlx --quantization ${quantization}
255+
256+
257+
Model Spec 16 (mlx, 72 Billion)
258+
++++++++++++++++++++++++++++++++++++++++
259+
260+
- **Model Format:** mlx
261+
- **Model Size (in billions):** 72
262+
- **Quantizations:** 4-bit
263+
- **Engines**: MLX
264+
- **Model ID:** mlx-community/Qwen2-72B-4bit
265+
- **Model Hubs**: `Hugging Face <https://huggingface.co/mlx-community/Qwen2-72B-4bit>`__
266+
267+
Execute the following command to launch the model, remember to replace ``${quantization}`` with your
268+
chosen quantization method from the options listed above::
269+
270+
xinference launch --model-engine ${engine} --model-name qwen2-instruct --size-in-billions 72 --model-format mlx --quantization ${quantization}
271+
272+
273+
Model Spec 17 (ggufv2, 0_5 Billion)
210274
++++++++++++++++++++++++++++++++++++++++
211275

212276
- **Model Format:** ggufv2
@@ -222,7 +286,7 @@ chosen quantization method from the options listed above::
222286
xinference launch --model-engine ${engine} --model-name qwen2-instruct --size-in-billions 0_5 --model-format ggufv2 --quantization ${quantization}
223287

224288

225-
Model Spec 14 (ggufv2, 1_5 Billion)
289+
Model Spec 18 (ggufv2, 1_5 Billion)
226290
++++++++++++++++++++++++++++++++++++++++
227291

228292
- **Model Format:** ggufv2
@@ -238,7 +302,7 @@ chosen quantization method from the options listed above::
238302
xinference launch --model-engine ${engine} --model-name qwen2-instruct --size-in-billions 1_5 --model-format ggufv2 --quantization ${quantization}
239303

240304

241-
Model Spec 15 (ggufv2, 7 Billion)
305+
Model Spec 19 (ggufv2, 7 Billion)
242306
++++++++++++++++++++++++++++++++++++++++
243307

244308
- **Model Format:** ggufv2
@@ -254,7 +318,7 @@ chosen quantization method from the options listed above::
254318
xinference launch --model-engine ${engine} --model-name qwen2-instruct --size-in-billions 7 --model-format ggufv2 --quantization ${quantization}
255319

256320

257-
Model Spec 16 (ggufv2, 72 Billion)
321+
Model Spec 20 (ggufv2, 72 Billion)
258322
++++++++++++++++++++++++++++++++++++++++
259323

260324
- **Model Format:** ggufv2

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ all =
103103
optimum
104104
outlines==0.0.34 # sglang errored for outlines > 0.0.34
105105
sglang[all] ; sys_platform=='linux'
106+
mlx-lm ; sys_platform=='darwin' and platform_machine=='arm64'
106107
attrdict # For deepseek VL
107108
timm>=0.9.16 # For deepseek VL
108109
torchvision # For deepseek VL
@@ -143,6 +144,8 @@ vllm =
143144
vllm>=0.2.6
144145
sglang =
145146
sglang[all]
147+
mlx =
148+
mlx-lm
146149
embedding =
147150
sentence-transformers>=2.7.0
148151
rerank =

xinference/model/llm/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
BUILTIN_MODELSCOPE_LLM_FAMILIES,
3535
LLAMA_CLASSES,
3636
LLM_ENGINES,
37+
MLX_CLASSES,
3738
SGLANG_CLASSES,
3839
SUPPORTED_ENGINES,
3940
TRANSFORMERS_CLASSES,
@@ -42,6 +43,7 @@
4243
GgmlLLMSpecV1,
4344
LLMFamilyV1,
4445
LLMSpecV1,
46+
MLXLLMSpecV1,
4547
PromptStyleV1,
4648
PytorchLLMSpecV1,
4749
get_cache_status,
@@ -112,6 +114,7 @@ def generate_engine_config_by_model_family(model_family):
112114
def _install():
113115
from .ggml.chatglm import ChatglmCppChatModel
114116
from .ggml.llamacpp import LlamaCppChatModel, LlamaCppModel
117+
from .mlx.core import MLXChatModel, MLXModel
115118
from .pytorch.baichuan import BaichuanPytorchChatModel
116119
from .pytorch.chatglm import ChatglmPytorchChatModel
117120
from .pytorch.cogvlm2 import CogVLM2Model
@@ -147,6 +150,7 @@ def _install():
147150
)
148151
SGLANG_CLASSES.extend([SGLANGModel, SGLANGChatModel])
149152
VLLM_CLASSES.extend([VLLMModel, VLLMChatModel])
153+
MLX_CLASSES.extend([MLXModel, MLXChatModel])
150154
TRANSFORMERS_CLASSES.extend(
151155
[
152156
BaichuanPytorchChatModel,
@@ -176,6 +180,7 @@ def _install():
176180
SUPPORTED_ENGINES["SGLang"] = SGLANG_CLASSES
177181
SUPPORTED_ENGINES["Transformers"] = TRANSFORMERS_CLASSES
178182
SUPPORTED_ENGINES["llama.cpp"] = LLAMA_CLASSES
183+
SUPPORTED_ENGINES["MLX"] = MLX_CLASSES
179184

180185
json_path = os.path.join(
181186
os.path.dirname(os.path.abspath(__file__)), "llm_family.json"

xinference/model/llm/llm_family.json

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2549,6 +2549,38 @@
25492549
],
25502550
"model_id": "Qwen/Qwen2-72B-Instruct-AWQ"
25512551
},
2552+
{
2553+
"model_format": "mlx",
2554+
"model_size_in_billions": "0_5",
2555+
"quantizations": [
2556+
"4-bit"
2557+
],
2558+
"model_id": "Qwen/Qwen2-0.5B-Instruct-MLX"
2559+
},
2560+
{
2561+
"model_format": "mlx",
2562+
"model_size_in_billions": "1_5",
2563+
"quantizations": [
2564+
"4-bit"
2565+
],
2566+
"model_id": "Qwen/Qwen2-1.5B-Instruct-MLX"
2567+
},
2568+
{
2569+
"model_format": "mlx",
2570+
"model_size_in_billions": 7,
2571+
"quantizations": [
2572+
"4-bit"
2573+
],
2574+
"model_id": "Qwen/Qwen2-7B-Instruct-MLX"
2575+
},
2576+
{
2577+
"model_format": "mlx",
2578+
"model_size_in_billions": 72,
2579+
"quantizations": [
2580+
"4-bit"
2581+
],
2582+
"model_id": "mlx-community/Qwen2-72B-Instruct-4bit"
2583+
},
25522584
{
25532585
"model_format": "ggufv2",
25542586
"model_size_in_billions": "0_5",

xinference/model/llm/llm_family.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,28 @@ def validate_model_size_with_radix(cls, v: object) -> object:
107107
return v
108108

109109

110+
class MLXLLMSpecV1(BaseModel):
111+
model_format: Literal["mlx"]
112+
# Must in order that `str` first, then `int`
113+
model_size_in_billions: Union[str, int]
114+
quantizations: List[str]
115+
model_id: Optional[str]
116+
model_hub: str = "huggingface"
117+
model_uri: Optional[str]
118+
model_revision: Optional[str]
119+
120+
@validator("model_size_in_billions", pre=False)
121+
def validate_model_size_with_radix(cls, v: object) -> object:
122+
if isinstance(v, str):
123+
if (
124+
"_" in v
125+
): # for example, "1_8" just returns "1_8", otherwise int("1_8") returns 18
126+
return v
127+
else:
128+
return int(v)
129+
return v
130+
131+
110132
class PromptStyleV1(BaseModel):
111133
style_name: str
112134
system_prompt: str = ""
@@ -226,7 +248,7 @@ def parse_raw(
226248

227249

228250
LLMSpecV1 = Annotated[
229-
Union[GgmlLLMSpecV1, PytorchLLMSpecV1],
251+
Union[GgmlLLMSpecV1, PytorchLLMSpecV1, MLXLLMSpecV1],
230252
Field(discriminator="model_format"),
231253
]
232254

@@ -249,6 +271,8 @@ def parse_raw(
249271

250272
VLLM_CLASSES: List[Type[LLM]] = []
251273

274+
MLX_CLASSES: List[Type[LLM]] = []
275+
252276
LLM_ENGINES: Dict[str, Dict[str, List[Dict[str, Any]]]] = {}
253277
SUPPORTED_ENGINES: Dict[str, List[Type[LLM]]] = {}
254278

@@ -549,7 +573,7 @@ def _get_meta_path(
549573
return os.path.join(cache_dir, "__valid_download")
550574
else:
551575
return os.path.join(cache_dir, f"__valid_download_{model_hub}")
552-
elif model_format in ["ggmlv3", "ggufv2", "gptq", "awq"]:
576+
elif model_format in ["ggmlv3", "ggufv2", "gptq", "awq", "mlx"]:
553577
assert quantization is not None
554578
if model_hub == "huggingface":
555579
return os.path.join(cache_dir, f"__valid_download_{quantization}")
@@ -588,7 +612,7 @@ def _skip_download(
588612
logger.warning(f"Cache {cache_dir} exists, but it was from {hub}")
589613
return True
590614
return False
591-
elif model_format in ["ggmlv3", "ggufv2", "gptq", "awq"]:
615+
elif model_format in ["ggmlv3", "ggufv2", "gptq", "awq", "mlx"]:
592616
assert quantization is not None
593617
return os.path.exists(
594618
_get_meta_path(cache_dir, model_format, model_hub, quantization)
@@ -683,7 +707,7 @@ def cache_from_csghub(
683707
):
684708
return cache_dir
685709

686-
if llm_spec.model_format in ["pytorch", "gptq", "awq"]:
710+
if llm_spec.model_format in ["pytorch", "gptq", "awq", "mlx"]:
687711
download_dir = retry_download(
688712
snapshot_download,
689713
llm_family.model_name,
@@ -751,7 +775,7 @@ def cache_from_modelscope(
751775
):
752776
return cache_dir
753777

754-
if llm_spec.model_format in ["pytorch", "gptq", "awq"]:
778+
if llm_spec.model_format in ["pytorch", "gptq", "awq", "mlx"]:
755779
download_dir = retry_download(
756780
snapshot_download,
757781
llm_family.model_name,
@@ -820,8 +844,8 @@ def cache_from_huggingface(
820844
if not IS_NEW_HUGGINGFACE_HUB:
821845
use_symlinks = {"local_dir_use_symlinks": True, "local_dir": cache_dir}
822846

823-
if llm_spec.model_format in ["pytorch", "gptq", "awq"]:
824-
assert isinstance(llm_spec, PytorchLLMSpecV1)
847+
if llm_spec.model_format in ["pytorch", "gptq", "awq", "mlx"]:
848+
assert isinstance(llm_spec, (PytorchLLMSpecV1, MLXLLMSpecV1))
825849
download_dir = retry_download(
826850
huggingface_hub.snapshot_download,
827851
llm_family.model_name,
@@ -910,7 +934,7 @@ def get_cache_status(
910934
]
911935
return any(revisions)
912936
# just check meta file for ggml and gptq model
913-
elif llm_spec.model_format in ["ggmlv3", "ggufv2", "gptq", "awq"]:
937+
elif llm_spec.model_format in ["ggmlv3", "ggufv2", "gptq", "awq", "mlx"]:
914938
ret = []
915939
for q in llm_spec.quantizations:
916940
assert q is not None

xinference/model/llm/llm_family_modelscope.json

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2921,6 +2921,33 @@
29212921
"model_id": "qwen/Qwen2-72B-Instruct-AWQ",
29222922
"model_hub": "modelscope"
29232923
},
2924+
{
2925+
"model_format": "mlx",
2926+
"model_size_in_billions": "0_5",
2927+
"quantizations": [
2928+
"4-bit"
2929+
],
2930+
"model_id": "qwen/Qwen2-0.5B-Instruct-MLX",
2931+
"model_hub": "modelscope"
2932+
},
2933+
{
2934+
"model_format": "mlx",
2935+
"model_size_in_billions": "1_5",
2936+
"quantizations": [
2937+
"4-bit"
2938+
],
2939+
"model_id": "qwen/Qwen2-1.5B-Instruct-MLX",
2940+
"model_hub": "modelscope"
2941+
},
2942+
{
2943+
"model_format": "mlx",
2944+
"model_size_in_billions": 7,
2945+
"quantizations": [
2946+
"4-bit"
2947+
],
2948+
"model_id": "qwen/Qwen2-7B-Instruct-MLX",
2949+
"model_hub": "modelscope"
2950+
},
29242951
{
29252952
"model_format": "ggufv2",
29262953
"model_size_in_billions": "0_5",

xinference/model/llm/mlx/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2022-2023 XProbe Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

0 commit comments

Comments
 (0)