Skip to content

Commit 22fe695

Browse files
authored
【Inference Optimize】Support automatic generation of marlin kernel (PaddlePaddle#3149)
* Support automatic generation of marlin kernel
1 parent b71cbb4 commit 22fe695

File tree

7 files changed

+125
-376
lines changed

7 files changed

+125
-376
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ nohup.out
156156
custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cutlass
157157
custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cute
158158

159+
#marlin_kernel
160+
custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_*.cu
161+
159162
# buff
160163
custom_ops/tmp*
161164

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/moe/marlin_moe_wna16/generate_kernels.py
2+
3+
import glob
4+
import itertools
5+
import os
6+
import subprocess
7+
8+
import jinja2
9+
10+
FILE_HEAD = """
11+
// auto generated by generate.py
12+
// clang-format off
13+
14+
#include "kernel.h"
15+
#include "marlin_template.h"
16+
17+
namespace MARLIN_NAMESPACE_NAME {
18+
""".strip()
19+
20+
TEMPLATE = (
21+
"template __global__ void Marlin<"
22+
"{{scalar_t}}, "
23+
"{{w_type_id}}, "
24+
"{{threads}}, "
25+
"{{thread_m_blocks}}, "
26+
"{{thread_n_blocks}}, "
27+
"{{thread_k_blocks}}, "
28+
"{{'true' if m_block_size_8 else 'false'}}, "
29+
"{{stages}}, "
30+
"{{group_blocks}}, "
31+
"{{'true' if is_zp_float else 'false'}}>"
32+
"( MARLIN_KERNEL_PARAMS );"
33+
)
34+
35+
# int8 with zero point case (MARLIN_NAMESPACE_NAME::kU8) is also supported,
36+
# we don't add it to reduce wheel size.
37+
SCALAR_TYPES = [
38+
"MARLIN_NAMESPACE_NAME::kU4",
39+
"MARLIN_NAMESPACE_NAME::kU4B8",
40+
# "MARLIN_NAMESPACE_NAME::kU8B128", "MARLIN_NAMESPACE_NAME::kFE4M3fn",
41+
# "MARLIN_NAMESPACE_NAME::kFE2M1f"
42+
]
43+
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
44+
45+
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
46+
# group_blocks:
47+
# = 0 : act order case
48+
# = -1 : channelwise quantization
49+
# > 0 : group_size=16*group_blocks
50+
GROUP_BLOCKS = [0, -1, 1, 2, 4, 8]
51+
DTYPES = ["fp16", "bf16"]
52+
53+
54+
def remove_old_kernels():
55+
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
56+
subprocess.call(["rm", "-f", filename])
57+
58+
59+
def generate_new_kernels():
60+
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
61+
all_template_str_list = []
62+
63+
for group_blocks, m_blocks, thread_configs in itertools.product(GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
64+
65+
# act order case only support gptq-int4 and gptq-int8
66+
if group_blocks == 0 and scalar_type not in [
67+
"MARLIN_NAMESPACE_NAME::kU4B8",
68+
"MARLIN_NAMESPACE_NAME::kU8B128",
69+
]:
70+
continue
71+
if thread_configs[2] == 256:
72+
# for small batch (m_blocks == 1), we only need (128, 128, 256)
73+
# for large batch (m_blocks > 1), we only need (64, 256, 256)
74+
if m_blocks <= 1 and thread_configs[0] != 128:
75+
continue
76+
if m_blocks > 1 and thread_configs[0] != 64:
77+
continue
78+
79+
# we only support channelwise quantization and group_size == 128
80+
# for fp8
81+
if scalar_type == "MARLIN_NAMESPACE_NAME::kFE4M3fn" and group_blocks not in [-1, 8]:
82+
continue
83+
# nvfp4 only supports group_size == 16
84+
if scalar_type == "MARLIN_NAMESPACE_NAME::kFE2M1f" and group_blocks not in [1, 2]:
85+
continue
86+
# other quantization methods don't support group_size = 16
87+
if scalar_type != "MARLIN_NAMESPACE_NAME::kFE2M1f" and group_blocks == 1:
88+
continue
89+
90+
k_blocks = thread_configs[0] // 16
91+
n_blocks = thread_configs[1] // 16
92+
threads = thread_configs[2]
93+
94+
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
95+
96+
template_str = jinja2.Template(TEMPLATE).render(
97+
scalar_t=c_dtype,
98+
w_type_id=scalar_type + ".id()",
99+
threads=threads,
100+
thread_m_blocks=max(m_blocks, 1),
101+
thread_n_blocks=n_blocks,
102+
thread_k_blocks=k_blocks,
103+
m_block_size_8=m_blocks == 0.5,
104+
stages="pipe_stages",
105+
group_blocks=group_blocks,
106+
is_zp_float=False,
107+
)
108+
109+
all_template_str_list.append(template_str)
110+
111+
file_content = FILE_HEAD + "\n\n"
112+
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
113+
filename = f"kernel_{dtype}_{scalar_type[23:].lower()}.cu"
114+
115+
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
116+
f.write(file_content)
117+
118+
119+
if __name__ == "__main__":
120+
remove_old_kernels()
121+
generate_new_kernels()

custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_bf16_ku4.cu

Lines changed: 0 additions & 89 deletions
This file was deleted.

custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_bf16_ku4b8.cu

Lines changed: 0 additions & 89 deletions
This file was deleted.

0 commit comments

Comments
 (0)