Skip to content

Commit 5d61e95

Browse files
Alcanderianthyecust
authored andcommitted
[feat] interface for platforms abstraction (sgl-project#4928)
1 parent 0259f4e commit 5d61e95

File tree

1 file changed

+371
-0
lines changed

1 file changed

+371
-0
lines changed
Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,371 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
# Adapted from
4+
# https://github.com/vllm-project/vllm/blob/v0.8.2/vllm/platforms/interface.py
5+
6+
import enum
7+
import platform
8+
import random
9+
from platform import uname
10+
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
11+
12+
if TYPE_CHECKING:
13+
from sglang.srt.server_args import ServerArgs
14+
from sglang.srt.configs.model_config import ModelConfig
15+
16+
import logging
17+
18+
import numpy as np
19+
import torch
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
def in_wsl() -> bool:
25+
# Reference: https://github.com/microsoft/WSL/issues/4071
26+
return "microsoft" in " ".join(uname()).lower()
27+
28+
29+
class PlatformEnum(enum.Enum):
30+
CUDA = enum.auto()
31+
ROCM = enum.auto()
32+
HPU = enum.auto()
33+
XPU = enum.auto()
34+
CPU = enum.auto()
35+
OOT = enum.auto()
36+
UNSPECIFIED = enum.auto()
37+
38+
39+
class CpuArchEnum(enum.Enum):
40+
X86 = enum.auto()
41+
ARM = enum.auto()
42+
POWERPC = enum.auto()
43+
OTHER = enum.auto()
44+
UNKNOWN = enum.auto()
45+
46+
47+
class DeviceCapability(NamedTuple):
48+
major: int
49+
minor: int
50+
51+
def as_version_str(self) -> str:
52+
return f"{self.major}.{self.minor}"
53+
54+
def to_int(self) -> int:
55+
"""
56+
Express device capability as an integer ``<major><minor>``.
57+
58+
It is assumed that the minor version is always a single digit.
59+
"""
60+
assert 0 <= self.minor < 10
61+
return self.major * 10 + self.minor
62+
63+
64+
class Platform:
65+
_enum: PlatformEnum
66+
67+
# Real device name of current platform.
68+
device_name: str
69+
70+
# For specifying torch device for cuda alike platform's capability.
71+
device_type: str
72+
73+
# The torch.distributed backend on current platform
74+
torch_distributed_backend: str
75+
76+
# The torch.compile backend for compiling simple and
77+
# standalone functions. The default value is "inductor" to keep
78+
# the same behavior as PyTorch.
79+
torch_compile_backend: str = "inductor"
80+
81+
supported_quantization: list[str] = []
82+
83+
supported_speculative_algorithm: list[str] = []
84+
85+
# Use first element as default dtype
86+
supported_dtype: list[str] = []
87+
88+
# Use first element as default backend
89+
supported_attntion_backend: list[str] = []
90+
91+
# Use first element as default backend
92+
supported_sampling_backend: list[str] = []
93+
94+
# Use first element as default backend
95+
supported_lora_backend: list[str] = []
96+
97+
def is_cuda(self) -> bool:
98+
return self._enum == PlatformEnum.CUDA
99+
100+
def is_rocm(self) -> bool:
101+
return self._enum == PlatformEnum.ROCM
102+
103+
def is_hpu(self) -> bool:
104+
return self._enum == PlatformEnum.HPU
105+
106+
def is_xpu(self) -> bool:
107+
return self._enum == PlatformEnum.XPU
108+
109+
def is_cpu(self) -> bool:
110+
return self._enum == PlatformEnum.CPU
111+
112+
def is_out_of_tree(self) -> bool:
113+
return self._enum == PlatformEnum.OOT
114+
115+
def is_cuda_alike(self) -> bool:
116+
"""Stateless version of :func:`torch.cuda.is_available`."""
117+
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
118+
119+
@classmethod
120+
def get_device_capability(
121+
cls,
122+
device_id: int = 0,
123+
) -> Optional[DeviceCapability]:
124+
"""Stateless version of :func:`torch.cuda.get_device_capability`."""
125+
return None
126+
127+
@classmethod
128+
def has_device_capability(
129+
cls,
130+
capability: Union[Tuple[int, int], int],
131+
device_id: int = 0,
132+
) -> bool:
133+
"""
134+
Test whether this platform is compatible with a device capability.
135+
136+
The ``capability`` argument can either be:
137+
138+
- A tuple ``(major, minor)``.
139+
- An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)
140+
"""
141+
current_capability = cls.get_device_capability(device_id=device_id)
142+
if current_capability is None:
143+
return False
144+
145+
if isinstance(capability, tuple):
146+
return current_capability >= capability
147+
148+
return current_capability.to_int() >= capability
149+
150+
@classmethod
151+
def get_device_module(cls) -> Any:
152+
"""Get `torch.device_module` like `torch.cuda` of current platform."""
153+
raise NotImplementedError
154+
155+
@classmethod
156+
def get_device_sku(cls, device_id: int = 0) -> str:
157+
"""Get the SKU name of a device."""
158+
raise NotImplementedError
159+
160+
@classmethod
161+
def get_device_uuid(cls, device_id: int = 0) -> str:
162+
"""Get the uuid of a device, e.g. the PCI bus ID."""
163+
raise NotImplementedError
164+
165+
@classmethod
166+
def get_device_core_count(cls, device_id: int = 0) -> str:
167+
"""Get the core count of a device, e.g. SMs of CUDA, CUs of ROCM."""
168+
raise NotImplementedError
169+
170+
@classmethod
171+
def get_device_count(cls) -> int:
172+
"""Get device count on current platform"""
173+
raise NotImplementedError
174+
175+
@classmethod
176+
def get_device_total_memory(cls, device_id: int = 0, distributed=False) -> float:
177+
"""
178+
Get total memory for device_type:device_id device in gigabytes.
179+
"""
180+
raise NotImplementedError
181+
182+
@classmethod
183+
def get_device_available_memory(
184+
cls, device_id: int = 0, distributed=False, empty_cache=True
185+
) -> float:
186+
"""
187+
Get available memory for device_type:device_id device in gigabytes.
188+
When distributed is True, the available memory is the minimum available memory of all GPUs.
189+
"""
190+
raise NotImplementedError
191+
192+
@classmethod
193+
def supports_overlap_scheduler(cls) -> bool:
194+
"""
195+
Check if the current platform supports overlap scheduler
196+
"""
197+
raise NotImplementedError
198+
199+
@classmethod
200+
def seed_everything(cls, seed: Optional[int] = None) -> None:
201+
"""
202+
Set the seed of each random module.
203+
`torch.manual_seed` will set seed on all devices.
204+
205+
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
206+
"""
207+
if seed is not None:
208+
random.seed(seed)
209+
np.random.seed(seed)
210+
torch.manual_seed(seed)
211+
212+
@classmethod
213+
def check_and_update_server_args(cls, server_args: ServerArgs) -> None:
214+
"""
215+
Check and update the server arguments for the current platform.
216+
217+
It can raise an exception if the configuration is not compatible with
218+
the current platform, or it can update the configuration to make it
219+
compatible with the current platform.
220+
221+
The config is passed by reference, so it can be modified in place.
222+
"""
223+
pass
224+
225+
@classmethod
226+
def check_and_update_model_dtype(cls, model_config: ModelConfig, dtype: str) -> str:
227+
"""
228+
Check and update the model's dtype for the current platform.
229+
"""
230+
if cls.supported_dtype and dtype not in cls.supported_dtype:
231+
logger.warning(
232+
f"dtype {dtype} is currently not supported in "
233+
f"{cls.device_name}. use {cls.supported_dtype[0]} instead"
234+
)
235+
return cls.supported_dtype[0]
236+
return dtype
237+
238+
@classmethod
239+
def check_and_update_attntion_backend(
240+
cls, model_config: ModelConfig, backend: str
241+
) -> str:
242+
"""
243+
Check and update the attntion backend for the current platform.
244+
"""
245+
raise NotImplementedError
246+
247+
@classmethod
248+
def check_and_update_sampling_backend(cls, backend: str) -> str:
249+
"""
250+
Check and update the sampling backend for the current platform.
251+
"""
252+
raise NotImplementedError
253+
254+
@classmethod
255+
def check_and_update_lora_backend(cls, backend: str) -> str:
256+
"""
257+
Check and update the lora backend for the current platform.
258+
"""
259+
raise NotImplementedError
260+
261+
@classmethod
262+
def verify_model_arch(cls, model_arch: str) -> None:
263+
"""
264+
Verify whether the current platform supports the specified model
265+
architecture.
266+
267+
- This will raise an Error or Warning based on the model support on
268+
the current platform.
269+
- By default all models are considered supported.
270+
"""
271+
pass
272+
273+
@classmethod
274+
def verify_quantization(cls, quant: str) -> None:
275+
"""
276+
Verify whether the quantization is supported by the current platform.
277+
"""
278+
if cls.supported_quantization and quant not in cls.supported_quantization:
279+
raise ValueError(
280+
f"{quant} quantization is currently not supported in "
281+
f"{cls.device_name}."
282+
)
283+
284+
@classmethod
285+
def verify_speculative_algorithm(cls, algo: str) -> None:
286+
"""
287+
Verify whether the speculative algorithm is supported by the current platform.
288+
"""
289+
if (
290+
cls.supported_speculative_algorithm
291+
and algo not in cls.supported_speculative_algorithm
292+
):
293+
raise ValueError(
294+
f"speculative algorithm {algo} is currently not supported in "
295+
f"{cls.device_name}."
296+
)
297+
298+
@classmethod
299+
def get_cpu_architecture(cls) -> CpuArchEnum:
300+
"""
301+
Determine the CPU architecture of the current system.
302+
Returns CpuArchEnum indicating the architecture type.
303+
"""
304+
machine = platform.machine().lower()
305+
306+
if machine in ("x86_64", "amd64", "i386", "i686"):
307+
return CpuArchEnum.X86
308+
elif machine.startswith("arm") or machine.startswith("aarch"):
309+
return CpuArchEnum.ARM
310+
elif machine.startswith("ppc"):
311+
return CpuArchEnum.POWERPC
312+
313+
return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN
314+
315+
@classmethod
316+
def is_pin_memory_available(cls) -> bool:
317+
"""Checks whether pin memory is available on the current platform."""
318+
if in_wsl():
319+
# Pinning memory in WSL is not supported.
320+
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
321+
logger.warning(
322+
"Using 'pin_memory=False' as WSL is detected. "
323+
"This may slow down the performance."
324+
)
325+
return False
326+
return True
327+
328+
@classmethod
329+
def get_device_communicator_cls(cls) -> str:
330+
"""
331+
Get device specific communicator class for distributed communication.
332+
"""
333+
raise NotImplementedError
334+
335+
@classmethod
336+
def supports_fp8(cls) -> bool:
337+
return False
338+
339+
@classmethod
340+
def fp8_dtype(cls) -> torch.dtype:
341+
"""
342+
Returns the preferred FP8 type on the current platform.
343+
"""
344+
return torch.float8_e4m3fn
345+
346+
@classmethod
347+
def fp8_min_max(cls) -> Tuple[float, float]:
348+
"""
349+
Returns the preferred FP8 max value on the current platform.
350+
"""
351+
fp8_max = torch.finfo(cls.fp8_dtype()).max
352+
return (-fp8_max, fp8_max)
353+
354+
@classmethod
355+
def is_triton_avaliable(cls) -> bool:
356+
raise NotImplementedError
357+
358+
@classmethod
359+
def init_environments(cls) -> None:
360+
"""
361+
Init environments on current platform.
362+
363+
- Init platform specific env vars.
364+
- Init platform specific patches.
365+
"""
366+
pass
367+
368+
369+
class UnspecifiedPlatform(Platform):
370+
_enum = PlatformEnum.UNSPECIFIED
371+
device_type = ""

0 commit comments

Comments
 (0)