diff --git a/python/sglang/srt/layers/attention/blackwell_prefill_attention_backend.py b/python/sglang/srt/layers/attention/blackwell_prefill_attention_backend.py new file mode 100644 index 00000000000..0772a5056d5 --- /dev/null +++ b/python/sglang/srt/layers/attention/blackwell_prefill_attention_backend.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F + +# from sglang.srt.configs.model_config import AttentionArch +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend + +# from sglang.srt.managers.schedule_batch import global_server_args_dict +# from sglang.srt.mem_cache.memory_pool import SWAKVPool +from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode + + +@dataclass +class ForwardMetaData: + cu_seqlens_q: Optional[torch.Tensor] = None + cu_seqlens_k: Optional[torch.Tensor] = None + page_table: Optional[torch.Tensor] = None + + +class BlackwellPrefillAttentionBackend(AttentionBackend): + def __init__(self, model_runner: ModelRunner): + from sglang.srt.layers.attention.cute_ops.prefill_attention import ( + flash_attn_varlen_func, + ) + + super().__init__() + self.flash_attn_func = flash_attn_varlen_func + self.page_size = model_runner.page_size + self.device = model_runner.device + self.forward_metadata: Optional[ForwardMetaData] = None + + def init_forward_metadata(self, forward_batch: ForwardBatch): + assert ( + forward_batch.forward_mode.is_extend() + ), "Only support extend (i.e., prefill) batches." + + max_seqlen_k = forward_batch.seq_lens_cpu.max().item() + cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(forward_batch.seq_lens, dim=0, dtype=torch.int32), pad=(1, 0) + ) + page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, :max_seqlen_k + ] + + if any(forward_batch.extend_prefix_lens_cpu): + extend_seq_lens = forward_batch.extend_seq_lens + cu_seqlens_q = F.pad( + torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), pad=(1, 0) + ) + else: + cu_seqlens_q = cu_seqlens_k + + self.forward_metadata = ForwardMetaData( + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, page_table=page_table + ) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + ): + raise RuntimeError("Prefill attention should not be captured in a CUDA graph.") + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + seq_lens_cpu: Optional[torch.Tensor], + out_cache_loc: Optional[torch.Tensor] = None, + ): + raise RuntimeError("Prefill attention should not be replayed in a CUDA graph.") + + def forward_extend( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + sinks: Optional[torch.Tensor] = None, + ): + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) + + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + k_cache = k_cache.view(-1, self.page_size, layer.tp_k_head_num, layer.head_dim) + v_cache = v_cache.view(-1, self.page_size, layer.tp_v_head_num, layer.head_dim) + + metadata = self.forward_metadata + out = self.flash_attn_func( + q=q.reshape(-1, layer.tp_q_head_num, layer.head_dim), + k=k_cache, + v=v_cache, + cu_seqlens_q=metadata.cu_seqlens_q, + page_table=metadata.page_table, + softcap=layer.logit_cap, + softmax_scale=layer.scaling, + window_size=(layer.sliding_window_size, 0), + causal=True, + learnable_sink=sinks.to(torch.bfloat16) if sinks is not None else None, + )[0] + + return out.view(-1, layer.tp_q_head_num * layer.head_dim) + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + ): + raise NotImplementedError( + "BlackwellPrefillAttentionBackend does not support forward_decode" + ) + + forward = forward_extend + + def support_triton(self): + return False diff --git a/python/sglang/srt/layers/attention/cute_ops/blackwell_helpers.py b/python/sglang/srt/layers/attention/cute_ops/blackwell_helpers.py new file mode 100644 index 00000000000..508671ccd87 --- /dev/null +++ b/python/sglang/srt/layers/attention/cute_ops/blackwell_helpers.py @@ -0,0 +1,761 @@ +# Copyright (c) 2025, Tri Dao. +from typing import Optional, Tuple +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import tcgen05 +from cutlass.cutlass_dsl import T +from cutlass._mlir.dialects import llvm + +import sglang.srt.layers.attention.cute_ops.mma_sm100_desc as sm100_desc + + +@cute.jit +def gemm( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + zero_init: bool | cutlass.Boolean = False, +) -> None: + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + tiled_mma.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) + cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + + +def i64_to_i32x2(i: int) -> Tuple[int, int]: + """Convert a 64-bit integer to a tuple of two 32-bit integers.""" + return i & 0xFFFF_FFFF, (i >> 32) & 0xFFFF_FFFF + + +@cute.jit +def gemm_ptx( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + sA_swizzle: Optional[cute.Swizzle], + sB_swizzle: cute.Swizzle, + zero_init: bool | cutlass.Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if cutlass.const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + assert ( + sA_swizzle is not None + ), "sA_swizzle must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else None + sB_layout = sB.layout + idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) + if cutlass.const_expr(not is_ts): + smem_desc_base_a: int = cutlass.const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + ( + sm100_desc.Major.K + if cutlass.const_expr( + op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K + ) + else sm100_desc.Major.MN + ), + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = cutlass.const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b: int = cutlass.const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + ( + sm100_desc.Major.K + if cutlass.const_expr( + op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K + ) + else sm100_desc.Major.MN + ), + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) + + if cutlass.const_expr(not is_ts): + smem_desc_start_a_lo = cutlass.Int32( + smem_desc_base_a_lo + ) | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = cutlass.Int32( + smem_desc_base_b_lo + ) | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + if cutlass.const_expr(not is_ts): + smem_desc_a_lo = smem_desc_start_a_lo + ( + (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 + ) + smem_desc_b_lo = smem_desc_start_b_lo + ( + (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 + ) + # with cute.arch.elect_one(): + # cute.printf("smem_desc_a_lo = {}, smem_desc_b_lo = {}", smem_desc_a_lo, smem_desc_b_lo) + # cute.printf("smem_desc_a_lo_correct = {}, smem_desc_b_lo_correct = {}", smem_desc_a_lo_correct, smem_desc_b_lo_correct) + with cute.arch.elect_one(): + if cutlass.const_expr(not is_ts): + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + smem_desc_a_lo.ir_value(), + smem_desc_b_lo.ir_value(), + cutlass.Int32(not zero_init or k != 0).ir_value(), + ], + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + ".reg .b32 idesc;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\n\t" + f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t" + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + tCrA[None, None, k].iterator.toint().ir_value(), + smem_desc_b_lo.ir_value(), + cutlass.Int32(not zero_init or k != 0).ir_value(), + ], + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b64 smem_desc_b;\n\t" + f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t" + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_loop( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + sA_swizzle: Optional[cute.Swizzle], + sB_swizzle: cute.Swizzle, + zero_init: bool | cutlass.Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if cutlass.const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + assert ( + sA_swizzle is not None + ), "sA_swizzle must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else tCrA.layout + sB_layout = sB.layout + idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) + if cutlass.const_expr(not is_ts): + smem_desc_base_a: int = cutlass.const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + ( + sm100_desc.Major.K + if cutlass.const_expr( + op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K + ) + else sm100_desc.Major.MN + ), + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = cutlass.const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b: int = cutlass.const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + ( + sm100_desc.Major.K + if cutlass.const_expr( + op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K + ) + else sm100_desc.Major.MN + ), + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) + + if cutlass.const_expr(not is_ts): + offset_a = [ + (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])) + ] + else: + offset_a = [ + cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])) + ] + offset_a_diff = [ + offset_a[k] - offset_a[k - 1] + for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ] + offset_b = [ + (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 + for k in cutlass.range_constexpr(cute.size(tCrB.shape[2])) + ] + offset_b_diff = [ + offset_b[k] - offset_b[k - 1] + for k in cutlass.range_constexpr(1, cute.size(tCrB.shape[2])) + ] + + if cutlass.const_expr(not is_ts): + smem_desc_start_a_lo = cutlass.Int32( + smem_desc_base_a_lo + | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + ) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = cutlass.Int32( + smem_desc_base_b_lo + | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + ) + pred_str = ( + "p" if isinstance(zero_init, cutlass.Boolean) else "0" if zero_init else "1" + ) + if cutlass.const_expr(not is_ts): + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + cutlass.Int32( + cute.arch.make_warp_uniform(smem_desc_start_a_lo) + ).ir_value(), + cutlass.Int32( + cute.arch.make_warp_uniform(smem_desc_start_b_lo) + ).ir_value(), + cutlass.Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + "mov.b32 smem_desc_a_lo, $1;\n\t" + "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + cutlass.Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + cutlass.Int32(smem_desc_start_b_lo).ir_value(), + cutlass.Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + "mov.b32 tmem_a, $1;\n\t" + "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_partial( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc_tmem_addr: cutlass.Constexpr[int], + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + sA_swizzle: Optional[cute.Swizzle], + sB_swizzle: cute.Swizzle, + mbar_ptr: Optional[cutlass.Pointer] = None, + mbar_phase: Optional[cutlass.Int32] = None, + zero_init: bool | cutlass.Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if cutlass.const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + assert ( + sA_swizzle is not None + ), "sA_swizzle must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else tCrA.layout + sB_layout = sB.layout + idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) + if cutlass.const_expr(not is_ts): + smem_desc_base_a: int = cutlass.const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + ( + sm100_desc.Major.K + if cutlass.const_expr( + op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K + ) + else sm100_desc.Major.MN + ), + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = cutlass.const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b: int = cutlass.const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + ( + sm100_desc.Major.K + if cutlass.const_expr( + op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K + ) + else sm100_desc.Major.MN + ), + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) + + tCrA_layout = ( + tCrA.layout + if cutlass.const_expr(not is_ts) + else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout) + ) + offset_a = [ + cute.crd2idx((0, 0, k), tCrA_layout) for k in range(cute.size(tCrA.shape[2])) + ] + offset_a_diff = [ + offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2])) + ] + offset_b = [ + cute.crd2idx((0, 0, k), tCrB.layout) for k in range(cute.size(tCrB.shape[2])) + ] + offset_b_diff = [ + offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2])) + ] + + if cutlass.const_expr(not is_ts): + smem_desc_start_a_lo = cutlass.Int32( + smem_desc_base_a_lo + | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + ) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = cutlass.Int32( + smem_desc_base_b_lo + | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + ) + pred_str = ( + "p" if isinstance(zero_init, cutlass.Boolean) else "0" if zero_init else "1" + ) + if cutlass.const_expr(not is_ts): + assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM" + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + cutlass.Int32(smem_desc_start_a_lo).ir_value(), + cutlass.Int32(smem_desc_start_b_lo).ir_value(), + cutlass.Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + "mov.b32 smem_desc_a_lo, $0;\n\t" + "mov.b32 smem_desc_b_lo, $1;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + input_args = [ + cutlass.Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + cutlass.Int32(smem_desc_start_b_lo).ir_value(), + cutlass.Int32(not zero_init).ir_value(), + ] + if cutlass.const_expr(mbar_ptr is not None): + assert ( + mbar_phase is not None + ), "mbar_phase must be provided when mbar_ptr is not None" + input_args.append(mbar_ptr.toint().ir_value()) + input_args.append(cutlass.Int32(mbar_phase).ir_value()) + mbar_wait_str = ( + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [$3], $4, 10000000; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + ) + else: + mbar_wait_str = "" + llvm.inline_asm( + None, + # [ + # # acc.iterator.toint().ir_value(), + # cutlass.Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + # cutlass.Int32(smem_desc_start_b_lo).ir_value(), + # cutlass.Int32(not zero_init).ir_value(), + # ], + input_args, + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_a, $0;\n\t" + f"mov.b32 smem_desc_b_lo, $1;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range( + 1, + ( + cute.size(tCrA.shape[2]) + if cutlass.const_expr(mbar_ptr is None) + else cute.size(tCrA.shape[2]) // 4 * 3 + ), + ) + ) + + mbar_wait_str + + ( + "".join( + ( + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range( + cute.size(tCrA.shape[2]) // 4 * 3, cute.size(tCrA.shape[2]) + ) + ) + if cutlass.const_expr(mbar_ptr is not None) + else "" + ) + + "}\n", + # "r,r,r", + "r,r,r" if cutlass.const_expr(mbar_ptr is None) else "r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_partial1( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc_tmem_addr: cutlass.Constexpr[int], + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA_base_addr_for_desc: cutlass.Int32, + sA_addr_offset_for_desc: cutlass.Constexpr[int], + sA_stage: cutlass.Int32, + sB_base_addr_for_desc: cutlass.Int32, + sB_addr_offset_for_desc: cutlass.Constexpr[int], + sB_stage: cutlass.Int32, + sA_layout: Optional[cute.Layout], + sB_layout: Optional[cute.Layout], + sA_swizzle: Optional[cute.Swizzle], + sB_swizzle: cute.Swizzle, + zero_init: bool | cutlass.Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if cutlass.const_expr(not is_ts): + assert ( + sA_layout is not None + ), "sA_layout must be provided when a_src is not TMEM" + assert ( + sA_swizzle is not None + ), "sA_swizzle must be provided when a_src is not TMEM" + idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) + if cutlass.const_expr(not is_ts): + smem_desc_base_a: int = cutlass.const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + ( + sm100_desc.Major.K + if cutlass.const_expr( + op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K + ) + else sm100_desc.Major.MN + ), + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = cutlass.const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b: int = cutlass.const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + ( + sm100_desc.Major.K + if cutlass.const_expr( + op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K + ) + else sm100_desc.Major.MN + ), + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) + mask = [cutlass.Int32(0)] * 4 + + if cutlass.const_expr(not is_ts): + offset_a = [ + (cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4 + for k in range(cute.size(tCrA.shape[2])) + ] + else: + offset_a = [ + cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in range(cute.size(tCrA.shape[2])) + ] + offset_a_diff = [ + offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2])) + ] + offset_b = [ + (cute.crd2idx((0, 0, k), sB_layout) * op.b_dtype.width // 8) >> 4 + for k in range(cute.size(tCrB.shape[2])) + ] + offset_b_diff = [ + offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2])) + ] + + if cutlass.const_expr(not is_ts): + # smem_desc_start_a_lo = cutlass.Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)) + smem_desc_start_a_lo = cutlass.const_expr(smem_desc_base_a_lo) + else: + smem_desc_start_a_lo = None + # smem_desc_start_b_lo = cutlass.Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)) + smem_desc_start_b_lo = cutlass.const_expr(smem_desc_base_b_lo) + pred_str = ( + "p" if isinstance(zero_init, cutlass.Boolean) else "0" if zero_init else "1" + ) + if cutlass.const_expr(not is_ts): + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + # cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + cutlass.Int32(sA_base_addr_for_desc).ir_value(), + cutlass.Int32(sA_stage).ir_value(), + # cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + cutlass.Int32(sB_base_addr_for_desc).ir_value(), + cutlass.Int32(sB_stage).ir_value(), + cutlass.Int32(not zero_init).ir_value(), + mask[0].ir_value(), + mask[1].ir_value(), + mask[2].ir_value(), + mask[3].ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + # "mov.b32 smem_desc_a_lo, $0;\n\t" + # f"add.u32 smem_desc_a_lo, $0, {hex(smem_desc_start_a_lo)};\n\t" + f"mad.lo.u32 smem_desc_a_lo, $1, {hex(sA_addr_offset_for_desc)}, $0;\n\t" + # "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mad.lo.u32 smem_desc_b_lo, $3, {hex(sB_addr_offset_for_desc)}, $2;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $4, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + cutlass.Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + cutlass.Int32(smem_desc_start_b_lo).ir_value(), + cutlass.Int32(not zero_init).ir_value(), + mask[0].ir_value(), + mask[1].ir_value(), + mask[2].ir_value(), + mask[3].ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_a, $1;\n\t" + f"mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t" + + "".join( + ( + f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) diff --git a/python/sglang/srt/layers/attention/cute_ops/block_info.py b/python/sglang/srt/layers/attention/cute_ops/block_info.py new file mode 100644 index 00000000000..7cb783e0f1c --- /dev/null +++ b/python/sglang/srt/layers/attention/cute_ops/block_info.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +from typing import Tuple, Optional +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute + +from sglang.srt.layers.attention.cute_ops.seqlen_info import SeqlenInfo + + +@dataclass(frozen=True) +class BlockInfo: + m_block_size: cutlass.Constexpr[int] + n_block_size: cutlass.Constexpr[int] + is_causal: cutlass.Constexpr[bool] + is_local: cutlass.Constexpr[bool] = False + window_size_left: Optional[cutlass.Int32] = None + window_size_right: Optional[cutlass.Int32] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + + @cute.jit + def get_n_block_min_max( + self, seqlen_info: SeqlenInfo, m_block: cutlass.Int32 + ) -> Tuple[cutlass.Int32, cutlass.Int32]: + n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.n_block_size) + if cutlass.const_expr( + self.is_causal or (self.is_local and self.window_size_right is not None) + ): + m_idx_max = (m_block + 1) * self.m_block_size + if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) + n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q + n_idx_right = ( + n_idx + if cutlass.const_expr(self.is_causal) + else n_idx + self.window_size_right + ) + n_block_max = min( + n_block_max, cute.ceil_div(n_idx_right, self.n_block_size) + ) + n_block_min = 0 + if cutlass.const_expr(self.is_local and self.window_size_left is not None): + m_idx_min = m_block * self.m_block_size + if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa + n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q + n_idx_left = n_idx - self.window_size_left + n_block_min = cutlass.max(n_idx_left // self.n_block_size, 0) + return n_block_min, n_block_max + + @cute.jit + def get_n_block_min_causal_local_mask( + self, + seqlen_info: SeqlenInfo, + m_block: cutlass.Int32, + n_block_min: cutlass.Int32, + ) -> cutlass.Int32: + """If we have separate iterations with causal or local masking at the start, where do we stop""" + m_idx_min = m_block * self.m_block_size + if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa + n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q + n_idx_right = ( + n_idx + if cutlass.const_expr(not self.is_local or self.window_size_right is None) + else n_idx + self.window_size_right + ) + return cutlass.max(n_block_min, n_idx_right // self.n_block_size) + + @cute.jit + def get_n_block_min_before_local_mask( + self, + seqlen_info: SeqlenInfo, + m_block: cutlass.Int32, + n_block_min: cutlass.Int32, + ) -> cutlass.Int32: + """If we have separate iterations with local masking at the end, where do we stop the non-masked iterations""" + if cutlass.const_expr(not self.is_local or self.window_size_left is None): + return n_block_min + else: + m_idx_max = (m_block + 1) * self.m_block_size + if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) + n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q + n_idx_left = n_idx - self.window_size_left + return cutlass.max( + n_block_min, cute.ceil_div(n_idx_left, self.n_block_size) + ) diff --git a/python/sglang/srt/layers/attention/cute_ops/fast_math.py b/python/sglang/srt/layers/attention/cute_ops/fast_math.py new file mode 100644 index 00000000000..0dfd14d26c8 --- /dev/null +++ b/python/sglang/srt/layers/attention/cute_ops/fast_math.py @@ -0,0 +1,103 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Uint32 +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm + + +@cute.jit +def clz(x: Int32) -> Int32: + # for i in cutlass.range_constexpr(32): + # if (1 << (31 - i)) & x: + # return Int32(i) + # return Int32(32) + # Early exit is not supported yet + res = Int32(32) + done = False + for i in cutlass.range(32): + if ((1 << (31 - i)) & x) and not done: + res = Int32(i) + done = True + return res + + +def find_log2(x: Int32) -> Int32: + a: Int32 = Int32(31 - clz(x)) + return a + ((x & (x - 1)) != 0) # Round up, add 1 if not a power of 2. + + +@dsl_user_op +def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32: + return Uint32( + llvm.inline_asm( + T.i32(), + [Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)], + "mul.hi.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +class FastDivmod: + def __init__( + self, + divisor: Int32, + multipler: Uint32, + shift_right: Uint32, + *, + loc=None, + ip=None + ): + self.divisor = divisor + self.multiplier = multipler + self.shift_right = shift_right + self._loc = loc + + # called by host + @staticmethod + def create(divisor: Int32, *, loc=None, ip=None) -> "FastDivmod": + """Construct the FastDivmod object, in host code. + This precomputes some values based on the divisor and is computationally expensive. + """ + p = Uint32(31 + find_log2(divisor)) + divisor_u32 = Uint32(divisor) + multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32) + shift_right = Uint32(p - 32) + return FastDivmod(divisor, multiplier, shift_right, loc=loc, ip=ip) + + @cute.jit + def div(self, dividend: Int32) -> Int32: + return ( + Int32(umulhi(dividend, self.multiplier) >> self.shift_right) + if self.divisor != 1 + else dividend + ) + + def divmod(self, dividend: Int32) -> Tuple[Int32, Int32]: + quotient = self.div(dividend) + remainder = dividend - quotient * self.divisor + return quotient, remainder + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.divisor, self.multiplier, self.shift_right]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.divisor, self.multiplier, self.shift_right], self._values_pos + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return FastDivmod(*(tuple(obj_list)), loc=self._loc) diff --git a/python/sglang/srt/layers/attention/cute_ops/flash_fwd_sm100.py b/python/sglang/srt/layers/attention/cute_ops/flash_fwd_sm100.py new file mode 100644 index 00000000000..14aef96516c --- /dev/null +++ b/python/sglang/srt/layers/attention/cute_ops/flash_fwd_sm100.py @@ -0,0 +1,2557 @@ +# Supported features: +# - BF16 & FP16 dtype +# - noncausal & causal attention +# - MHA, GQA, MQA +# - hdim 64, 96, 128, (192, 128). +# - varlen +# - sliding window +# Unsupported features that will be added later: +# - split-kv (optimizing for inference) +# - more hdim (192, 256) +# Based on the cutlass example and cute-dsl example: +# https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha +# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py + +import enum +import math +from typing import Callable, Optional +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.pipeline +from cutlass import Float32, Int32, const_expr +from cutlass.cute.nvgpu import cpasync +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils_basic + +from sglang.srt.layers.attention.cute_ops import blackwell_helpers as sm100_utils +from sglang.srt.layers.attention.cute_ops import utils +from sglang.srt.layers.attention.cute_ops.mask import AttentionMask +from sglang.srt.layers.attention.cute_ops.block_info import BlockInfo +from sglang.srt.layers.attention.cute_ops.pack_gqa import PackGQA +from sglang.srt.layers.attention.cute_ops.softmax import SoftmaxSm100 +from sglang.srt.layers.attention.cute_ops.seqlen_info import SeqlenInfo +from sglang.srt.layers.attention.cute_ops.tile_scheduler import ( + TileSchedulerArguments, + SingleTileScheduler, + StaticPersistentTileScheduler, + SingleTileLPTScheduler, + SingleTileVarlenScheduler, + ParamsBase, +) + + +# class NamedBarrierFwd(enum.IntEnum): +# Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() +# WarpSchedulerWG1 = enum.auto() +# WarpSchedulerWG2 = enum.auto() +# WarpSchedulerWG3 = enum.auto() +# PFull = enum.auto() +# PEmpty = enum.auto() + + +class FlashAttentionForwardSm100: + + arch = 100 + + def __init__( + self, + # dtype: Type[cutlass.Numeric], + head_dim: int, + head_dim_v: Optional[int] = None, + qhead_per_kvhead: cutlass.Constexpr[int] = 1, + is_causal: bool = False, + is_local: bool = False, + pack_gqa: bool = False, + m_block_size: int = 128, + n_block_size: int = 128, + is_persistent: bool = True, + ): + # self.dtype = dtype + # padding head_dim to a multiple of 16 as k_block_size + hdim_multiple_of = 16 + self.head_dim_padded = int( + math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of + ) + head_dim_v = head_dim_v if head_dim_v is not None else head_dim + self.same_hdim_kv = head_dim == head_dim_v + self.head_dim_v_padded = int( + math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of + ) + self.same_hdim_kv_padded = self.head_dim_padded == self.head_dim_v_padded + self.check_hdim_oob = head_dim != self.head_dim_padded + self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded + self.m_block_size = m_block_size + self.n_block_size = n_block_size + self.q_stage = 2 + assert self.q_stage in [1, 2] + + # 2 Q tile per CTA + self.cta_tiler = ( + self.q_stage * m_block_size, + n_block_size, + self.head_dim_padded, + ) + self.mma_tiler_qk = (m_block_size, n_block_size, self.head_dim_padded) + self.mma_tiler_pv = (m_block_size, self.head_dim_v_padded, n_block_size) + self.qk_acc_dtype = Float32 + self.pv_acc_dtype = Float32 + self.cluster_shape_mn = (1, 1) + self.is_persistent = is_persistent + self.is_causal = is_causal + self.is_local = is_local + self.qhead_per_kvhead = qhead_per_kvhead + self.pack_gqa = pack_gqa + if pack_gqa: + assert ( + m_block_size % self.qhead_per_kvhead == 0 + ), "For PackGQA, m_block_size must be divisible by qhead_per_kvhead" + # Does S1 need to wait for S0 to finish + # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) + self.s0_s1_barrier = False + self.overlap_sO_sQ = ( + self.head_dim_padded == 192 and self.head_dim_v_padded >= 64 + ) + if self.overlap_sO_sQ: + assert ( + self.head_dim_padded >= self.head_dim_v_padded + ) # We assume sQ is larger than sO + self.is_persistent = False + + self.softmax0_warp_ids = (0, 1, 2, 3) + self.softmax1_warp_ids = (4, 5, 6, 7) + self.correction_warp_ids = (8, 9, 10, 11) + self.mma_warp_id = 12 + self.load_warp_id = 13 + self.epilogue_warp_ids = (14,) + self.empty_warp_ids = (15,) + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + self.threads_per_cta = cute.arch.WARP_SIZE * len( + ( + *self.softmax0_warp_ids, + *self.softmax1_warp_ids, + *self.correction_warp_ids, + self.mma_warp_id, + self.load_warp_id, + *self.epilogue_warp_ids, + *self.empty_warp_ids, + ) + ) + + self.tmem_alloc_sync_bar_id = 1 + + self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128 + self.tmem_o_offset = [ + self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded + for i in range(self.q_stage) + ] # e.g., 256, 384 + self.tmem_total = self.tmem_o_offset[-1] + self.head_dim_v_padded + assert self.tmem_total <= SM100_TMEM_CAPACITY_COLUMNS + self.tmem_s_to_p_offset = self.n_block_size // 2 + self.tmem_p_offset = [ + self.tmem_s_offset[i] + self.tmem_s_to_p_offset for i in range(2) + ] # 0, 128 + + # vec buffer for row_max & row_sum + self.tmem_vec_offset = self.tmem_s_offset + + if self.head_dim_padded < 96: + self.num_regs_softmax = 200 + self.num_regs_correction = 64 + self.num_regs_other = 48 + else: + self.num_regs_softmax = 192 if self.is_causal or self.is_local else 184 + # self.num_regs_softmax = 176 + # self.num_regs_correction = 96 + # self.num_regs_correction = 80 + # self.num_regs_correction = 64 if self.is_causal or self.is_local else 80 + self.num_regs_correction = 64 + # self.num_regs_other = 32 + # self.num_regs_other = 64 + # self.num_regs_other = 80 + # self.num_regs_other = 48 + # self.num_regs_other = 96 if self.is_causal or self.is_local else 80 + self.num_regs_other = 64 if self.is_causal or self.is_local else 80 + self.num_regs_empty = 24 + + self.buffer_align_bytes = 1024 + + def _setup_attributes(self): + """Set up configurations and parameters for the FMHA kernel operation. + + This method initializes and configures various attributes required for the + execution of the fused multi-head attention kernel, mainly about the pipeline stages: + + - Sets up staging parameters for Q, K, V inputs and accumulator data + - Configures pipeline stages for softmax, correction, and epilogue operations + """ + + self.kv_stage = 4 if self.q_dtype.width == 8 else 3 + self.acc_stage = 1 + self.epi_stage = 2 + # For hdim 192,128, we don't have enough smem to store all 3 stages of KV: + # 128 x 192 x 2 bytes x 3 stages = 144KB, and we need 96KB for Q. + # Instead we store smem as [smem_large, smem_small, smem_large], where smem_large is + # 128 x 192 and smem_small is 128 x 128. We set the stride between the stages to be + # 128 * 160, so that indexing the 0th and 2nd stages will get the right address, + # but for the 1st stage we need to add or subtract (depending on phase) 128 x 64. + self.uneven_kv_smem = ( + self.head_dim_padded == 192 + and self.head_dim_v_padded == 128 + and self.kv_stage == 3 + ) + self.uneven_kv_smem_offset = ( + self.m_block_size * (self.head_dim_padded - self.head_dim_v_padded) // 2 + if self.uneven_kv_smem + else 0 + ) + assert self.uneven_kv_smem_offset % 1024 == 0 + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table + mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table + mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + mLSE: Optional[cute.Tensor], + softmax_scale: Float32, + stream: cuda.CUstream, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq) + softcap: Float32 | float | None = None, + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + learnable_sink: Optional[cute.Tensor] = None, + ): + """Execute the Fused Multi-Head Attention operation on the provided tensors. + + This method prepares the input tensors for processing, validates their shapes and types, + configures the computation parameters, and launches the CUDA kernel. + + The method handles: + 1. Tensor layout transformations for specific memory access patterns + 2. Validation of tensor shapes and data types + 3. Initialization of hardware-specific parameters and memory layouts + 4. Configuration of TMA (Tensor Memory Access) operations + 5. Grid and work scheduling computation + 6. Kernel launch with appropriate parameters + """ + + # setup static attributes before smem/grid/tma computation + self.q_dtype = mQ.element_type + self.k_dtype = mK.element_type + self.v_dtype = mV.element_type + self.o_dtype = mO.element_type + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mQ, mK, mV, mO = [ + cute.make_tensor( + t.iterator, cute.make_layout(t.shape, stride=new_stride(t)) + ) + for t in (mQ, mK, mV, mO) + ] + QO_layout_transpose = ( + [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + ) + mQ, mO = [ + cute.make_tensor( + t.iterator, cute.select(t.layout, mode=QO_layout_transpose) + ) + for t in (mQ, mO) + ] + # (s_k, d, h_k, b_k) or (total_k, d, h_k) if there's cu_seqlens_k or (page_size, d, h_k, num_pages) if there's page_table + KV_layout_transpose = ( + [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] + ) + mK, mV = [ + cute.make_tensor( + t.iterator, cute.select(t.layout, mode=KV_layout_transpose) + ) + for t in (mK, mV) + ] + LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + mLSE = ( + cute.make_tensor( + mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose) + ) + if const_expr(mLSE is not None) + else None + ) + # (s, d, h, b) -> (d, s, h, b) + V_layout_transpose = ( + [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2] + ) + mV = cute.make_tensor( + mV.iterator, cute.select(mV.layout, mode=V_layout_transpose) + ) + + self.q_major_mode = cutlass.utils.LayoutEnum.from_tensor(mQ).mma_major_mode() + self.k_major_mode = cutlass.utils.LayoutEnum.from_tensor(mK).mma_major_mode() + self.v_major_mode = cutlass.utils.LayoutEnum.from_tensor(mV).mma_major_mode() + self.o_layout = cutlass.utils.LayoutEnum.from_tensor(mO) + + if const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of mQ is not supported") + if const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of mK is not supported") + if const_expr(self.v_major_mode != tcgen05.OperandMajorMode.MN): + raise RuntimeError("The layout of mV is not supported") + + # check type consistency + if const_expr(self.q_dtype != self.k_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}") + if const_expr(self.q_dtype != self.v_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") + self._setup_attributes() + self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None + # This can be tuned + self.e2e_freq = 16 + if const_expr( + self.head_dim_padded > 64 + and not self.is_causal + and not self.is_local + and self.pack_gqa + ): + self.e2e_freq = ( + 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else 10 + ) + + cta_group = tcgen05.CtaGroup.ONE + # the intermediate tensor p is from tmem & mK-major + p_source = tcgen05.OperandSource.TMEM + p_major_mode = tcgen05.OperandMajorMode.K + tiled_mma_qk = sm100_utils_basic.make_trivial_tiled_mma( + self.q_dtype, + self.q_major_mode, + self.k_major_mode, + self.qk_acc_dtype, + cta_group, + self.mma_tiler_qk[:2], + ) + tiled_mma_pv = sm100_utils_basic.make_trivial_tiled_mma( + self.v_dtype, + p_major_mode, + self.v_major_mode, + self.pv_acc_dtype, + cta_group, + self.mma_tiler_pv[:2], + p_source, + ) + + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (tiled_mma_qk.thr_id.shape,), + ) + + self.epi_tile = self.mma_tiler_pv[:2] + + sQ_layout = sm100_utils_basic.make_smem_layout_a( + tiled_mma_qk, + self.mma_tiler_qk, + self.q_dtype, + self.q_stage, + ) + sK_layout = sm100_utils_basic.make_smem_layout_b( + tiled_mma_qk, + self.mma_tiler_qk, + self.k_dtype, + self.kv_stage, + ) + tP_layout = sm100_utils_basic.make_smem_layout_a( + tiled_mma_pv, + self.mma_tiler_pv, + self.q_dtype, + self.acc_stage, + ) + sV_layout = sm100_utils_basic.make_smem_layout_b( + tiled_mma_pv, + self.mma_tiler_pv, + self.v_dtype, + self.kv_stage, + ) + sO_layout = sm100_utils_basic.make_smem_layout_epi( + self.o_dtype, + self.o_layout, + self.epi_tile, + self.epi_stage, + ) + if const_expr(not self.same_hdim_kv_padded): + # sK and sV are using the same physical smem so we need to adjust the stride so that they line up + stride_sK = const_expr( + max(sK_layout.outer.stride[-1], 0) + ) # take max to turn tuple to Int32 + stride_sV = const_expr(max(sV_layout.outer.stride[-1], 0)) + stage_stride = const_expr( + max(stride_sK, stride_sV) + if not self.uneven_kv_smem + else (stride_sK + stride_sV) // 2 + ) + sK_layout = cute.make_composed_layout( + sK_layout.inner, + 0, + cute.make_layout( + (*sK_layout.outer.shape[:-1], self.kv_stage), + stride=(*sK_layout.outer.stride[:-1], stage_stride), + ), + ) + sV_layout = cute.make_composed_layout( + sV_layout.inner, + 0, + cute.make_layout( + (*sV_layout.outer.shape[:-1], self.kv_stage), + stride=(*sV_layout.outer.stride[:-1], stage_stride), + ), + ) + + if const_expr(self.pack_gqa): + shape_Q_packed = ( + (self.qhead_per_kvhead, mQ.shape[0]), + mQ.shape[1], + mK.shape[2], + *mQ.shape[3:], + ) + stride_Q_packed = ( + (mQ.stride[2], mQ.stride[0]), + mQ.stride[1], + mQ.stride[2] * self.qhead_per_kvhead, + *mQ.stride[3:], + ) + mQ = cute.make_tensor( + mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed) + ) + shape_O_packed = ( + (self.qhead_per_kvhead, mO.shape[0]), + mK.shape[1], + mK.shape[2], + *mO.shape[3:], + ) + stride_O_packed = ( + (mO.stride[2], mO.stride[0]), + mO.stride[1], + mO.stride[2] * self.qhead_per_kvhead, + *mO.stride[3:], + ) + mO = cute.make_tensor( + mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed) + ) + if const_expr(mLSE is not None): + shape_LSE_packed = ( + (self.qhead_per_kvhead, mLSE.shape[0]), + mK.shape[2], + *mLSE.shape[2:], + ) + stride_LSE_packed = ( + (mLSE.stride[1], mLSE.stride[0]), + mLSE.stride[1] * self.qhead_per_kvhead, + *mLSE.stride[2:], + ) + mLSE = cute.make_tensor( + mLSE.iterator, + cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed), + ) + + # TMA load for Q + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + + tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + mQ, + cute.select(sQ_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, + tiled_mma_qk, + self.cluster_layout_vmnk.shape, + ) + + # TMA load for K + tma_atom_K, tma_tensor_K = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + mK, + cute.select(sK_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, + tiled_mma_qk, + self.cluster_layout_vmnk.shape, + ) + # TMA load for V + tma_atom_V, tma_tensor_V = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + mV, + cute.select(sV_layout, mode=[0, 1, 2]), + self.mma_tiler_pv, + tiled_mma_pv, + self.cluster_layout_vmnk.shape, + ) + + o_cta_v_layout = cute.composition( + cute.make_identity_layout(mO.shape), self.epi_tile + ) + + # print(sO_layout.outer) + if const_expr(not self.use_tma_O): + self.epilogue_warp_ids = (14, 15) + self.empty_warp_ids = () + self.num_epilogue_threads = cute.arch.WARP_SIZE * len(self.epilogue_warp_ids) + if const_expr(self.use_tma_O): + tma_atom_O, mO = cpasync.make_tiled_tma_atom( + tma_store_op, + mO, + cute.select(sO_layout, mode=[0, 1]), + o_cta_v_layout, + ) + gmem_tiled_copy_O = None + else: + tma_atom_O = None + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.o_dtype.width + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.o_dtype, + num_bits_per_copy=universal_copy_bits, + ) + tO_shape_dim_1 = sO_layout.outer.shape[1][0] // async_copy_elems + tO_layout = cute.make_ordered_layout( + (self.num_epilogue_threads // tO_shape_dim_1, tO_shape_dim_1), + order=(1, 0), + ) + # So that we don't have to check if we overshoot kBlockM when we store O + assert self.m_block_size % tO_layout.shape[0] == 0 + vO_layout = cute.make_layout((1, async_copy_elems)) + gmem_tiled_copy_O = cute.make_tiled_copy_tv( + atom_universal_copy, tO_layout, vO_layout + ) + + self.tma_copy_q_bytes = cute.size_in_bytes( + self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2]) + ) + self.tma_copy_k_bytes = cute.size_in_bytes( + self.k_dtype, cute.select(sK_layout, mode=[0, 1, 2]) + ) + self.tma_copy_v_bytes = cute.size_in_bytes( + self.v_dtype, cute.select(sV_layout, mode=[0, 1, 2]) + ) + + if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): + TileScheduler = SingleTileVarlenScheduler + else: + if const_expr(self.is_causal or self.is_local): + TileScheduler = SingleTileLPTScheduler + else: + TileScheduler = ( + SingleTileScheduler + if const_expr(not self.is_persistent) + else StaticPersistentTileScheduler + ) + tile_sched_args = TileSchedulerArguments( + cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tiler[0]), + cute.size(mQ.shape[2]), + ( + cute.size(mQ.shape[3]) + if const_expr(mCuSeqlensQ is None) + else cute.size(mCuSeqlensQ.shape[0] - 1) + ), + ( + cute.size(mK.shape[0]) + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1] + ), + mQ.shape[1], + mV.shape[ + 0 + ], # Note that this is different from Sm90 since we transpose mV in Sm100 + total_q=( + cute.size(mQ.shape[0]) + if const_expr(mCuSeqlensQ is not None) + else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]) + ), + tile_shape_mn=self.cta_tiler[:2], + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, + qhead_per_kvhead_packgqa=( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ), + element_size=self.k_dtype.width // 8, + is_persistent=self.is_persistent, + lpt=self.is_causal or self.is_local, + ) + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + self.tile_scheduler_cls = TileScheduler + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + + self.mbar_load_q_full_offset = 0 + self.mbar_load_q_empty_offset = self.mbar_load_q_full_offset + self.q_stage + self.mbar_load_kv_full_offset = self.mbar_load_q_empty_offset + self.q_stage + self.mbar_load_kv_empty_offset = self.mbar_load_kv_full_offset + self.kv_stage + self.mbar_P_full_O_rescaled_offset = ( + self.mbar_load_kv_empty_offset + self.kv_stage + ) + self.mbar_S_full_offset = self.mbar_P_full_O_rescaled_offset + 2 + self.mbar_O_full_offset = self.mbar_S_full_offset + 2 + self.mbar_softmax_corr_full_offset = self.mbar_O_full_offset + 2 + self.mbar_softmax_corr_empty_offset = self.mbar_softmax_corr_full_offset + 2 + self.mbar_corr_epi_full_offset = ( + self.mbar_softmax_corr_empty_offset + self.epi_stage + ) + self.mbar_corr_epi_empty_offset = ( + self.mbar_corr_epi_full_offset + self.epi_stage + ) + self.mbar_s0_s1_sequence_offset = self.mbar_corr_epi_empty_offset + 2 + self.mbar_tmem_dealloc_offset = self.mbar_s0_s1_sequence_offset + 8 + self.mbar_P_full_2_offset = self.mbar_tmem_dealloc_offset + 1 + self.mbar_total = self.mbar_P_full_2_offset + 2 + + sO_size = cute.cosize(sO_layout) if const_expr(not self.overlap_sO_sQ) else 0 + + @cute.struct + class SharedStorage: + # m_barriers for pipelines + mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mbar_total] + # Tmem holding buffer + tmem_holding_buf: Int32 + # Smem tensors + # store row max and row sum + sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2] + sO: cute.struct.Align[ + cute.struct.MemRange[self.o_dtype, sO_size], + self.buffer_align_bytes, + ] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout)], + self.buffer_align_bytes, + ] + sK: cute.struct.Align[ + # cute.cosize(sK_layout) is correct even in the case of self.uneven_kv_smem + cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. + # Right after this, we multiply by log2(e) before applying exp2. + # To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val + # (assigning it to softcap_val) and pre-multiply softcap_val * log2(e) + # (assigning it to softmax_scale_log2). + LOG2_E = math.log2(math.e) + if const_expr(softcap is None): + softmax_scale_log2 = softmax_scale * LOG2_E + softcap_val = None + else: + softmax_scale_log2 = softcap * LOG2_E + softcap_val = Float32(softmax_scale / softcap) + if const_expr(window_size_left is not None): + window_size_left = Int32(window_size_left) + if const_expr(window_size_right is not None): + window_size_right = Int32(window_size_right) + # Launch the kernel synchronously + self.kernel( + tma_tensor_Q, + tma_tensor_K, + tma_tensor_V, + mO, + mLSE, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, + mPageTable, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_O, + softmax_scale_log2, + softcap_val, + window_size_left, + window_size_right, + learnable_sink, + sQ_layout, + sK_layout, + tP_layout, + sV_layout, + sO_layout, + gmem_tiled_copy_O, + tiled_mma_qk, + tiled_mma_pv, + tile_sched_params, + ).launch( + grid=grid_dim, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + + # GPU device kernel + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, # (s_q, d, h, b) or (total_q, d, h) if there is cu_seqlens_q + mK: cute.Tensor, # (s_k, d, h_k, b_k) or (total_k, d, h_k) if there is cu_seqlens_k or (page_size, d, h_k, num_pages) if there is page_table + mV: cute.Tensor, # (d, s_k, h_k, b_k) or (d, total_k, h_k) if there is cu_seqlens_k or (d, page_size, h_k, num_pages) if there is page_table + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], + mPageTable: Optional[cute.Tensor], + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_O: Optional[cute.CopyAtom], + softmax_scale_log2: Float32, + softcap_val: Optional[Float32], + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + learnable_sink: Optional[cute.Tensor], + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + tP_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sO_layout: cute.ComposedLayout, + gmem_tiled_copy_O: Optional[cute.TiledCopy], + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tile_sched_params: ParamsBase, + ): + """The device kernel implementation of the Fused Multi-Head Attention. + + This kernel coordinates multiple specialized warps to perform different phases of the FMHA computation: + 1. Load warp: Loads Q, K, V data from global memory to shared memory using TMA + 2. MMA warp: Performs matrix multiplications (Q*K^T and P*V) + 3. Softmax warps: Compute softmax normalization on attention scores + 4. Correction warps: Apply adjustments to intermediate results + 5. Epilogue warp: Handles final output transformation and storage + + The kernel implements a complex pipeline with overlapping computation and memory operations, + using tensor memory access (TMA) for efficient data loading, warp specialization for different + computation phases, and optional attention masking. + """ + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + # Prefetch tma descriptor + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + if const_expr(tma_atom_O is not None): + cpasync.prefetch_descriptor(tma_atom_O) + + # Alloc + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + mbar_ptr = storage.mbar_ptr.data_ptr() + if warp_idx == 1: + # Init "full" barrier with number of producers, "empty" barrier with number of consumers + for i in cutlass.range_constexpr(self.q_stage): + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_load_q_full_offset + i, + len([self.load_warp_id]), + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_load_q_empty_offset + i, + len([self.mma_warp_id]), + ) + if warp_idx == 2: + for i in cutlass.range_constexpr(2): + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_softmax_corr_empty_offset + i, + cute.arch.WARP_SIZE * 4, + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_softmax_corr_full_offset + i, + cute.arch.WARP_SIZE * 4, + ) + if warp_idx == 3: + if const_expr(self.s0_s1_barrier): + for i in cutlass.range_constexpr(8): + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_s0_s1_sequence_offset + i, + cute.arch.WARP_SIZE, + ) + if warp_idx == 4: + for i in cutlass.range_constexpr(self.q_stage): + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_corr_epi_full_offset + i, + cute.arch.WARP_SIZE * len(self.correction_warp_ids), + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_corr_epi_empty_offset + i, + cute.arch.WARP_SIZE * len(self.epilogue_warp_ids), + ) + if warp_idx == 5: + for i in cutlass.range_constexpr(2): + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_P_full_O_rescaled_offset + i, + cute.arch.WARP_SIZE + * (len(self.softmax0_warp_ids) + len(self.correction_warp_ids)), + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_S_full_offset + i, len([self.mma_warp_id]) + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_O_full_offset + i, len([self.mma_warp_id]) + ) + if warp_idx == 6: + for i in cutlass.range_constexpr(2): + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_P_full_2_offset + i, + cute.arch.WARP_SIZE * len(self.softmax0_warp_ids), + ) + if warp_idx == 7: + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_tmem_dealloc_offset, + cute.arch.WARP_SIZE + * len( + ( + *self.softmax0_warp_ids, + *self.softmax1_warp_ids, + *self.correction_warp_ids, + ) + ), + ) + # Relying on pipeline_kv constructor to call mbarrier_init_fence and sync + pipeline_kv = self.make_and_init_load_kv_pipeline( + mbar_ptr + self.mbar_load_kv_full_offset + ) + + # Generate smem tensor Q/K/V/O + # (MMA, MMA_Q, MMA_D, PIPE) + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + # sQ_pi = storage.sQ.get_tensor(sQ_layout) + # (MMA, MMA_K, MMA_D, PIPE) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + # sK_pi = storage.sK.get_tensor(sK_layout) + # (MMA, MMA_K, MMA_D, PIPE) + # Strip swizzle info to reuse smem + sV = cute.make_tensor( + cute.recast_ptr(sK.iterator, sV_layout.inner), sV_layout.outer + ) + if const_expr(not self.overlap_sO_sQ): + sO = storage.sO.get_tensor(sO_layout.outer, swizzle=sO_layout.inner) + else: + sO = cute.make_tensor( + cute.recast_ptr(sQ.iterator, sO_layout.inner), sO_layout.outer + ) + + sScale = storage.sScale.get_tensor( + cute.make_layout(self.q_stage * self.m_block_size * 2) + ) + + thr_mma_qk = tiled_mma_qk.get_slice(0) # default 1SM + thr_mma_pv = tiled_mma_pv.get_slice(0) # default 1SM + + qk_acc_shape = thr_mma_qk.partition_shape_C( + (self.mma_tiler_qk[0], self.mma_tiler_qk[1]) + ) + tStS_fake = thr_mma_qk.make_fragment_C(qk_acc_shape) + # This is a fake tensor, by right need to retrieve tmem_ptr. But we know that we always + # request 512 columns of tmem, so we know that it starts at 0. + tmem_ptr = cute.make_ptr( + Float32, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16 + ) + tStS = cute.make_tensor(tmem_ptr, tStS_fake.layout) + + pv_acc_shape = thr_mma_pv.partition_shape_C( + (self.mma_tiler_pv[0], self.mma_tiler_pv[1]) + ) + tOtO = thr_mma_pv.make_fragment_C(pv_acc_shape) + + tStSs = tuple( + cute.make_tensor(tStS.iterator + self.tmem_s_offset[stage], tStS.layout) + for stage in range(2) + ) + tOtOs = tuple( + cute.make_tensor(tOtO.iterator + self.tmem_o_offset[stage], tOtO.layout) + for stage in range(self.q_stage) + ) + + tP = cute.make_tensor(tStS.iterator, tP_layout.outer) + tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] + + tOrPs = [ + cute.make_tensor( + tOrP.iterator + + self.qk_acc_dtype.width + // self.q_dtype.width + * self.tmem_p_offset[stage], + tOrP.layout, + ) + for stage in range(2) + ] + + block_info = BlockInfo( + # This is cta_tiler, not mma_tiler_qk, since we move by block by (2 * mma_tiler[0], mma_tiler[1]) + self.cta_tiler[0], + self.cta_tiler[1], + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + qhead_per_kvhead_packgqa=( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ), + ) + SeqlenInfoCls = partial( + SeqlenInfo, + seqlen_q_static=( + mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1] + ), + seqlen_k_static=( + mK.shape[0] + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1] + ), + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + ) + AttentionMaskCls = partial( + AttentionMask, + self.m_block_size, + self.n_block_size, + window_size_left=window_size_left, + window_size_right=window_size_right, + qhead_per_kvhead_packgqa=( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ), + ) + TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) + + # /////////////////////////////////////////////////////////////////////////////// + # EMPTY + # /////////////////////////////////////////////////////////////////////////////// + if const_expr(len(self.empty_warp_ids) > 0): + if warp_idx == self.empty_warp_ids[0]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + # /////////////////////////////////////////////////////////////////////////////// + # LOAD + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + self.load( + thr_mma_qk, + thr_mma_pv, + mQ, + mK, + mV, + sQ, + sK, + sV, + mPageTable, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + pipeline_kv, + mbar_ptr, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # MMA + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.mma_warp_id: + # if warp_idx == self.mma_warp_id or warp_idx == self.empty_warp_ids: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + # Alloc tmem buffer + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + if warp_idx == self.mma_warp_id: + cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf) + cute.arch.sync_warp() + + self.mma( + tiled_mma_qk, + tiled_mma_pv, + sQ, + sK, + sV, + sQ_layout.inner, + sK_layout.inner, + sV_layout.inner, + tStSs, + tOtOs, + tOrPs, + pipeline_kv, + mbar_ptr, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) + + # if warp_idx == self.mma_warp_id: + # dealloc tmem buffer + cute.arch.relinquish_tmem_alloc_permit() + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_tmem_dealloc_offset, 0) + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + # Retrieving tmem ptr and make acc + tmem_ptr = cute.arch.retrieve_tmem_ptr( + Float32, + alignment=16, + ptr_to_buffer_holding_addr=storage.tmem_holding_buf, + ) + cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + if ( + warp_idx >= self.epilogue_warp_ids[0] + and warp_idx <= self.epilogue_warp_ids[-1] + ): + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + self.epilogue_s2g( + mO, + sO, + gmem_tiled_copy_O, + tma_atom_O, + mbar_ptr, + SeqlenInfoCls, + TileSchedulerCls, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Softmax + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx < self.correction_warp_ids[0]: + # increase register after decreasing + cute.arch.warpgroup_reg_alloc(self.num_regs_softmax) + softmax_loop = partial( + self.softmax_loop, + softmax_scale_log2=softmax_scale_log2, + thr_mma_qk=thr_mma_qk, + sScale=sScale, + mLSE=mLSE, + learnable_sink=learnable_sink, + mbar_ptr=mbar_ptr, + block_info=block_info, + SeqlenInfoCls=SeqlenInfoCls, + AttentionMaskCls=AttentionMaskCls, + TileSchedulerCls=TileSchedulerCls, + ) + + if const_expr(not self.s0_s1_barrier): + stage = Int32(0 if warp_idx < self.softmax1_warp_ids[0] else 1) + softmax_loop( + stage=stage, + tStSi=cute.make_tensor( + tStS.iterator + + ( + self.tmem_s_offset[0] + if stage == 0 + else self.tmem_s_offset[1] + ), + tStS.layout, + ), + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + else: + # If there's s0_s1_barrier, it's faster to have 2 WGs having different code + if warp_idx < self.softmax1_warp_ids[0]: + tStSi = cute.make_tensor( + tStS.iterator + self.tmem_s_offset[0], tStS.layout + ) + softmax_loop(stage=0, tStSi=tStSi) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + if ( + warp_idx < self.correction_warp_ids[0] + and warp_idx >= self.softmax1_warp_ids[0] + ): + tStSi = cute.make_tensor( + tStS.iterator + self.tmem_s_offset[1], tStS.layout + ) + softmax_loop(stage=1, tStSi=tStSi) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + + # /////////////////////////////////////////////////////////////////////////////// + # Correction + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_correction) + self.correction_loop( + thr_mma_qk, + thr_mma_pv, + tStS, + tOtOs, + sScale, + mO, + mLSE, + sO, + learnable_sink, + tma_atom_O, + mbar_ptr, + softmax_scale_log2, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + + return + + @cute.jit + def load( + self, + thr_mma_qk: cute.core.ThrMma, + thr_mma_pv: cute.core.ThrMma, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + mPageTable: Optional[cute.Tensor], + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + pipeline_kv: cutlass.pipeline.PipelineAsync, + mbar_ptr: cute.Pointer, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + + q_producer_phase = Int32(1) + kv_producer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.kv_stage + ) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(not seqlen.has_cu_seqlens_q): + mQ_cur = mQ[None, None, head_idx, batch_idx] + else: + offset = ( + seqlen.offset_q + if const_expr(not self.pack_gqa) + else (0, seqlen.offset_q) + ) + mQ_cur = cute.domain_offset((offset, 0), mQ[None, None, head_idx]) + gQ = cute.local_tile( + mQ_cur, cute.select(self.mma_tiler_qk, mode=[0, 2]), (None, 0) + ) + + head_idx_kv = ( + head_idx // self.qhead_per_kvhead + if const_expr(not self.pack_gqa) + else head_idx + ) + if const_expr(mPageTable is None): + if const_expr(not seqlen.has_cu_seqlens_k): + mK_cur, mV_cur = [ + t[None, None, head_idx_kv, batch_idx] for t in (mK, mV) + ] + else: + mK_cur = cute.domain_offset( + (seqlen.offset_k, 0), mK[None, None, head_idx_kv] + ) + mV_cur = cute.domain_offset( + (0, seqlen.offset_k), mV[None, None, head_idx_kv] + ) + gK = cute.local_tile( + mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0) + ) + gV = cute.local_tile( + mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None) + ) + else: + # Need to keep batch coord None since we'll index into it with page idx + mK_cur, mV_cur = [t[None, None, head_idx_kv, None] for t in (mK, mV)] + gK = cute.local_tile( + mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0, None) + ) + gV = cute.local_tile( + mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None, None) + ) + tSgQ = thr_mma_qk.partition_A(gQ) + tSgK = thr_mma_qk.partition_B(gK) + tOgV = thr_mma_pv.partition_B(gV) + tQsQ, tQgQ = cpasync.tma_partition( + tma_atom_Q, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sQ, 0, 3), + cute.group_modes(tSgQ, 0, 3), + ) + tKsK, tKgK = cpasync.tma_partition( + tma_atom_K, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK, 0, 3), + ) + tVsV, tVgV = cpasync.tma_partition( + tma_atom_V, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tOgV, 0, 3), + ) + + load_Q = partial( + self.load_Q, + tma_atom_Q, + tQgQ, + tQsQ, + mbar_ptr + self.mbar_load_q_full_offset, + mbar_ptr + self.mbar_load_q_empty_offset, + phase=q_producer_phase, + ) + # We have to use mbarrier directly in the load for KV instead of replying on + # pipeline_kv, because we could have different number of TMA bytes for K and V + load_K = partial( + self.load_KV, + tma_atom_K, + tKgK, + tKsK, + mbar_ptr + self.mbar_load_kv_full_offset, + mbar_ptr + self.mbar_load_kv_empty_offset, + K_or_V="K", + ) + load_V = partial( + self.load_KV, + tma_atom_V, + tVgV, + tVsV, + mbar_ptr + self.mbar_load_kv_full_offset, + mbar_ptr + self.mbar_load_kv_empty_offset, + K_or_V="V", + ) + + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 + page_idx = ( + mPageTable[batch_idx, n_block_max - 1] + if const_expr(mPageTable is not None) + else None + ) + load_K( + block=n_block_max - 1, + producer_state=kv_producer_state, + page_idx=page_idx, + ) # K0 + kv_producer_state.advance() + if const_expr(self.q_stage == 2): + load_Q(block=self.q_stage * m_block + 1, stage=1) # Q1 + q_producer_phase ^= 1 + load_V( + block=n_block_max - 1, + producer_state=kv_producer_state, + page_idx=page_idx, + ) # V0 + kv_producer_state.advance() + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block = n_block_max - 2 - i + page_idx = ( + mPageTable[batch_idx, n_block] + if const_expr(mPageTable is not None) + else None + ) + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx) + load_K( + block=n_block, producer_state=kv_producer_state, page_idx=page_idx + ) # Ki + kv_producer_state.advance() + load_V( + block=n_block, producer_state=kv_producer_state, page_idx=page_idx + ) # Vi + kv_producer_state.advance() + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + @cute.jit + def mma( + self, + tiled_mma_qk: cute.core.ThrMma, + tiled_mma_pv: cute.core.ThrMma, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sQ_swizzle: cute.Swizzle, + sK_swizzle: cute.Swizzle, + sV_swizzle: cute.Swizzle, + tStSs: tuple[cute.Tensor, cute.Tensor], + tOtOs: tuple[cute.Tensor], + tOrPs: tuple[cute.Tensor, cute.Tensor], + pipeline_kv: cutlass.pipeline.PipelineAsync, + mbar_ptr: cute.Pointer, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + thr_mma_qk = tiled_mma_qk.get_slice(0) # default 1SM + thr_mma_pv = tiled_mma_pv.get_slice(0) # default 1SM + tSrQ = thr_mma_qk.make_fragment_A(sQ) + tSrK = thr_mma_qk.make_fragment_B(sK) + tOrV = thr_mma_pv.make_fragment_B(sV) + if const_expr(self.q_stage == 2): + tSrQs = (tSrQ[None, None, None, 0], tSrQ[None, None, None, 1]) + else: + tSrQs = (tSrQ[None, None, None, 0], tSrQ[None, None, None, 0]) + + qk_mma_op, pv_mma_op = tiled_mma_qk.op, tiled_mma_pv.op + + gemm_Si = [ + partial( + sm100_utils.gemm_ptx_partial, + qk_mma_op, + self.tmem_s_offset[stage], + tSrQs[stage], + sA=sQ[None, None, None, stage], + sA_swizzle=sQ_swizzle, + sB_swizzle=sK_swizzle, + zero_init=True, + ) + for stage in range(2) + ] + gemm_Pi = [ + partial( + sm100_utils.gemm_ptx_partial, + pv_mma_op, + self.tmem_o_offset[stage if self.q_stage == 2 else 0], + tOrPs[stage], + sA=None, + sA_swizzle=None, + sB_swizzle=sV_swizzle, + ) + for stage in range(2) + ] + + mma_q_consumer_phase = Int32(0) + mma_kv_consumer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.kv_stage + ) + P_full_O_rescaled_phase = Int32(0) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + + for stage in cutlass.range_constexpr(self.q_stage): + # GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1) + # 1. wait for Q0 / Q1 + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_load_q_full_offset + stage, + mma_q_consumer_phase, + ) + # 2. wait for K0 + if const_expr(stage == 0): + pipeline_kv.consumer_wait(mma_kv_consumer_state) + tSrKi = tSrK[None, None, None, mma_kv_consumer_state.index] + # We don't need to acquire empty S0 / S1. + # For the first iteration, we don't need to wait as we're guaranteed S0 / S1 + # are empty. For subsequent iterations, the wait happened at the end + # of the while loop. + # 3. gemm + # sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrKi, zero_init=True) + sK_cur = sK[None, None, None, mma_kv_consumer_state.index] + if const_expr(self.uneven_kv_smem): + sK_cur = self.offset_kv_smem( + sK_cur, mma_kv_consumer_state.index, mma_kv_consumer_state.phase + ) + gemm_Si[stage](tCrB=tSrKi, sB=sK_cur) + # 4. release S0 / S1 + with cute.arch.elect_one(): + tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage) + mma_q_consumer_phase ^= 1 + # 5. release K0 + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + # End of GEMM (Q1 * K0 -> S1) + # Note: Q0 & Q1 are still needed in the seqlen_kv loop + # so we need to release them after the seqlen_kv loop + + # O hasn't been accumulated yet, its first MMA calculation doesn't need to accumulate + O_should_accumulate = False + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop + # 1. wait for V0 + pipeline_kv.consumer_wait(mma_kv_consumer_state) + mma_kv_release_state = mma_kv_consumer_state.clone() + Vi_index, Vi_phase = ( + mma_kv_consumer_state.index, + mma_kv_consumer_state.phase, + ) + tOrVi = tOrV[None, None, None, Vi_index] + for stage in cutlass.range_constexpr(2): + # 2. acquire corrected O0/O1_partial and P0 / P1 + # For the first iteration in this work tile, waiting for O0/O1_partial + # means that the correction warps has finished reading tO during + # the last iteration of the previous work tile has finished. + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, + P_full_O_rescaled_phase, + ) + # 3. gemm + # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) + # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) + sV_cur = sV[None, None, None, Vi_index] + if const_expr(self.uneven_kv_smem): + sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase) + gemm_Pi[stage]( + tCrB=tOrVi, + sB=sV_cur, + zero_init=not O_should_accumulate, + mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, + mbar_phase=P_full_O_rescaled_phase, + ) + # 4. release accumulated O0_partial / O1_partial + # Don't need to signal O_full to the correction warps anymore since the + # correction warps wait for the softmax warps anyway. By the time the softmax + # warps finished, S_i for the next iteration must have been done, so O_i-1 + # must have been done as well. + # with cute.arch.elect_one(): + # tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage) + # 5. release V(i-1) + if const_expr(stage == 1): + pipeline_kv.consumer_release(mma_kv_release_state) + mma_kv_release_state.advance() + # End of GEMM_PV00 (P0 * V0 -> O0_partial) + + # GEMM_QK0i (Q0 * Ki -> S0) + # 1. wait for Ki + if const_expr(stage == 0): + mma_kv_consumer_state.advance() + pipeline_kv.consumer_wait(mma_kv_consumer_state) + Ki_index, Ki_phase = ( + mma_kv_consumer_state.index, + mma_kv_consumer_state.phase, + ) + # 2. gemm + # Don't need to wait for the softmax warp to have finished reading the previous + # Si, since this gemm is scheduled after the PV gemm, which guaranteed that Si + # has been read and Pi has been written. + # sm100_utils.gemm(tiled_mma_qk, tStS0, tSrQs[0], tSrK[None, None, None, Ki_index], zero_init=True) + sK_cur = sK[None, None, None, Ki_index] + if const_expr(self.uneven_kv_smem): + sK_cur = self.offset_kv_smem(sK_cur, Ki_index, Ki_phase) + gemm_Si[stage](tCrB=tSrK[None, None, None, Ki_index], sB=sK_cur) + # 3. release S0 + with cute.arch.elect_one(): + tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage) + # End of GEMM_QK0i (Q0 * Ki -> S0) + # 4. release Ki + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + P_full_O_rescaled_phase ^= 1 + O_should_accumulate = True + # End of seqlen_kv loop + + # release Q0 & Q1 + with cute.arch.elect_one(): + for stage in cutlass.range_constexpr(self.q_stage): + tcgen05.commit(mbar_ptr + self.mbar_load_q_empty_offset + stage) + + # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop + # 1. wait for V0 + pipeline_kv.consumer_wait(mma_kv_consumer_state) + Vi_index, Vi_phase = ( + mma_kv_consumer_state.index, + mma_kv_consumer_state.phase, + ) + tOrVi = tOrV[None, None, None, Vi_index] + for stage in cutlass.range_constexpr(2): + # 2. acquire corrected Oi_partial and Pi + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, + P_full_O_rescaled_phase, + ) + # 3. gemm + # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) + # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) + sV_cur = sV[None, None, None, Vi_index] + if const_expr(self.uneven_kv_smem): + sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase) + gemm_Pi[stage]( + tCrB=tOrVi, + sB=sV_cur, + zero_init=not O_should_accumulate, + mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, + mbar_phase=P_full_O_rescaled_phase, + ) + # 4. release accumulated O0_partial + # We do need O_full here since for the last tile, by the time the softmax warp + # has signaled to the correction warp, the softmax warp has just finished compute + # the row sum of the current tile. It does not guarantee that the 1st tile + # of the next work tile has been computed yet. + with cute.arch.elect_one(): + tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage) + # End of GEMM_PV00 (P0 * V0 -> O0_partial) + P_full_O_rescaled_phase ^= 1 + # 5. release Vi_end + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1) + + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + # for both softmax0 and softmax1 warp group + @cute.jit + def softmax_loop( + self, + stage: int | Int32, + softmax_scale_log2: Float32, + thr_mma_qk: cute.core.ThrMma, + tStSi: cute.Tensor, + sScale: cute.Tensor, + mLSE: Optional[cute.Tensor], + learnable_sink: Optional[cute.Tensor], + mbar_ptr: cute.Pointer, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + AttentionMaskCls: Callable, + TileSchedulerCls: Callable, + ): + """Compute softmax on attention scores from QK matrix multiplication. + + This method handles the softmax computation for either the first or second half of the + attention matrix, depending on the 'stage' parameter. It calculates row-wise maximum + and sum values needed for stable softmax computation, applies optional masking, and + transforms raw attention scores into probability distributions. + + The implementation uses specialized memory access patterns and efficient math operations + for computing exp(x) using exp2 functions. It also coordinates pipeline + synchronization between MMA, correction, and sequence processing stages. + """ + tidx = cute.arch.thread_idx()[0] % ( + cute.arch.WARP_SIZE + # * (len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids) + * (len(self.softmax0_warp_ids)) + ) + + cS_base = cute.make_identity_tensor( + (self.mma_tiler_qk[0], self.mma_tiler_qk[1]) + ) + tScS = thr_mma_qk.partition_C(cS_base) + + tStS_scale_layout = cute.composition( + tStSi.layout, cute.make_layout((self.m_block_size, 1)) + ) + tStScale = cute.make_tensor(tStSi.iterator, tStS_scale_layout) + tScS_vec_layout = cute.composition( + tScS.layout, cute.make_layout((self.m_block_size, 1)) + ) + tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + + tilePlikeFP32 = self.mma_tiler_qk[1] // 32 * self.v_dtype.width + tStP_layout = cute.composition( + tStSi.layout, cute.make_layout((self.m_block_size, tilePlikeFP32)) + ) + tStP = cute.make_tensor(tStSi.iterator + self.tmem_s_to_p_offset, tStP_layout) + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + Float32, + ) + thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStSi).get_slice(tidx) + tStS_t2r = thr_tmem_load.partition_S(tStSi) + + tmem_store_scale_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(1)), + Float32, + ) + thr_tmem_store_scale = tcgen05.make_tmem_copy( + tmem_store_scale_atom, tStScale + ).get_slice(tidx) + + tStScale_r2t = thr_tmem_store_scale.partition_D(tStScale) + tSrScale_r2t_shape = thr_tmem_store_scale.partition_S(tScS_vec).shape + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), + Float32, + ) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP) + thr_tmem_store = tiled_tmem_store.get_slice(tidx) + tStP_r2t = thr_tmem_store.partition_D(tStP) + + mma_si_consumer_phase = Int32(0) + si_corr_producer_phase = Int32(1) + s0_s1_sequence_phase = Int32(1 if stage == 0 else 0) + + # self.warp_scheduler_barrier_init() + + warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 + mbar_s0_s1_sequence_offset = self.mbar_s0_s1_sequence_offset + warp_idx_in_wg + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) + mask_fn = partial( + mask.apply_mask_sm100, + m_block=m_block * 2 + stage, + thr_mma=thr_mma_qk, + thr_tmem_load=thr_tmem_load, + mask_causal=self.is_causal, + mask_local=self.is_local, + ) + softmax = SoftmaxSm100( + softmax_scale_log2, + rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, + ) + softmax.reset() + + softmax_step = partial( + self.softmax_step, + softmax=softmax, + mbar_ptr=mbar_ptr, + mbar_s0_s1_sequence_offset=mbar_s0_s1_sequence_offset, + thr_mma_qk=thr_mma_qk, + thr_tmem_load=thr_tmem_load, + thr_tmem_store=thr_tmem_store, + thr_tmem_store_scale=thr_tmem_store_scale, + tStS_t2r=tStS_t2r, + tStScale_r2t=tStScale_r2t, + tStP_r2t=tStP_r2t, + sScale=sScale, + stage=stage, + ) + + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, + si_corr_producer_phase, + ) + si_corr_producer_phase ^= 1 + + # 1 masking iter + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( + softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block_max - 1, + is_first=True, + mask_fn=partial(mask_fn, mask_seqlen=True), + ) + ) + n_block_max -= 1 + # Next couple of iterations with causal masking + if const_expr(self.is_causal or self.is_local): + n_block_min_causal_local_mask = ( + block_info.get_n_block_min_causal_local_mask( + seqlen, m_block, n_block_min + ) + ) + for n_tile in cutlass.range( + n_block_max - n_block_min_causal_local_mask, unroll=1 + ): + n_block = n_block_max - 1 - n_tile + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) + # The remaining iterations have no masking + n_block_min_before_local_mask = ( + block_info.get_n_block_min_before_local_mask( + seqlen, m_block, n_block_min + ) + ) + for n_tile in cutlass.range( + n_block_max - n_block_min_before_local_mask, unroll=1 + ): + n_block = n_block_max - n_tile - 1 + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( + softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block, + ) + ) + # Separate iterations with local masking on the left + if const_expr(self.is_local and block_info.window_size_left is not None): + n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) + for n_tile in cutlass.range(0, n_block_max - n_block_min, unroll=1): + n_block = n_block_max - 1 - n_tile + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + # Now that we no longer already have the 1st iteration, need mask_seqlen=True here + + # tSrScale_r2t = cute.make_fragment(tSrScale_r2t_shape, Float32) + # tSrScale_r2t[0] = softmax.row_sum[0] + # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) + # cute.arch.fence_view_async_tmem_store() + sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] + if const_expr(mLSE is not None or learnable_sink is not None): + sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] = ( + softmax.row_max[0] + ) + # if tidx == 0: + # cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0]) + cute.arch.mbarrier_arrive( + mbar_ptr + self.mbar_softmax_corr_full_offset + stage + ) + # if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0]) + + # # Write LSE to gmem + # if const_expr(mLSE is not None): + # acc_O_mn_row_is_zero_or_nan = softmax.row_sum[0] == 0.0 or softmax.row_sum[0] != softmax.row_sum[0] + # scale = ( + # cute.arch.rcp_approx(softmax.row_sum[0] if not acc_O_mn_row_is_zero_or_nan else 1.0) + # ) + # LN2 = math.log(2.0) + # lse = ( + # (softmax.row_max[0] * softmax.scale_log2 + utils.log2f(softmax.row_sum[0])) * LN2 + # if not acc_O_mn_row_is_zero_or_nan else -Float32.inf + # ) + # if const_expr(not seqlen.has_cu_seqlens_q): + # mLSE_cur = mLSE[None, head_idx, batch_idx] + # else: + # mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[None, head_idx]) + # gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block * 2 + stage,)) + # if tidx < seqlen.seqlen_q - (m_block * 2 + stage) * self.m_block_size: + # gLSE[tidx] = lse + + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + @cute.jit + def softmax_step( + self, + mma_si_consumer_phase: Int32, + si_corr_producer_phase: Int32, + s0_s1_sequence_phase: Int32, + n_block: Int32, + softmax: SoftmaxSm100, + mbar_ptr: cute.Pointer, + mbar_s0_s1_sequence_offset: Int32, + thr_mma_qk: cute.core.ThrMma, + thr_tmem_load: cute.CopyAtom, + thr_tmem_store: cute.CopyAtom, + thr_tmem_store_scale: cute.CopyAtom, + tStS_t2r: cute.Tensor, + tStScale_r2t: cute.Tensor, + tStP_r2t: cute.Tensor, + sScale: cute.Tensor, + stage: int | Int32, + mask_fn: Optional[Callable] = None, + is_first: bool = False, + ) -> tuple[cute.Int32, cute.Int32, cute.Int32]: + """Perform a single step of the softmax computation on a block of attention scores. + + This method processes one block of the attention matrix, computing numerically stable + softmax by first finding the row maximum, subtracting it from all elements, applying + exponential function, and then normalizing by the sum of exponentials. It also handles + optional masking of attention scores. + + The method involves several key operations: + 1. Loading attention scores from tensor memory + 2. Applying optional masking based on position + 3. Computing row-wise maximum values for numerical stability + 4. Transforming scores using exp2(x*scale - max*scale) + 5. Computing row sums for normalization + 6. Coordinating pipeline synchronization between different processing stages + """ + tilePlikeFP32 = self.mma_tiler_qk[1] // Float32.width * self.v_dtype.width + tScS = thr_mma_qk.partition_C( + cute.make_identity_tensor((self.mma_tiler_qk[0], self.mma_tiler_qk[1])) + ) + tScS_vec_layout = cute.composition( + tScS.layout, cute.make_layout((self.m_block_size, 1)) + ) + tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + + tScP_layout = cute.composition( + tScS.layout, cute.make_layout((self.m_block_size, tilePlikeFP32)) + ) + tScP = cute.make_tensor(tScS.iterator, tScP_layout) + tScS_t2r_shape = thr_tmem_load.partition_D(tScS).shape + + # Wait for Si + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_S_full_offset + stage, mma_si_consumer_phase + ) + tSrS_t2r = cute.make_fragment(tScS_t2r_shape, self.qk_acc_dtype) + cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) + if const_expr(mask_fn is not None): + mask_fn(tSrS_t2r, n_block=n_block) + row_max, acc_scale = softmax.update_row_max(tSrS_t2r.load(), is_first) + + if const_expr(not is_first): + # tSrScale_r2t = cute.make_fragment(thr_tmem_store_scale.partition_S(tScS_vec).shape, Float32) + # tSrScale_r2t[0] = acc_scale + # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) + # cute.arch.fence_view_async_tmem_store() + thread_idx = thr_tmem_load.thr_idx + sScale[thread_idx + stage * self.m_block_size] = acc_scale + # if thread_idx == 0: cute.printf("softmax acc_scale stage %d: %f, row_max = %f\n", stage, acc_scale, row_max) + # Notify correction wg that row_max is ready + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) + + # if thread_idx == 0 and stage == 0: cute.print_tensor(tSrS_t2r) + # print(tSrS_t2r) + softmax.scale_subtract_rowmax(tSrS_t2r, row_max) + # Sequence barrier wait + if const_expr(self.s0_s1_barrier): + cute.arch.mbarrier_wait( + mbar_ptr + mbar_s0_s1_sequence_offset + stage * 4, s0_s1_sequence_phase + ) + tSrP_r2t_f32 = cute.make_fragment( + thr_tmem_store.partition_S(tScP).shape, Float32 + ) + tSrP_r2t = cute.make_tensor( + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), + tSrS_t2r.layout, + ) + # softmax.scale_apply_exp2_convert(tSrS_t2r, row_max, tSrP_r2t) + softmax.apply_exp2_convert( + tSrS_t2r, + tSrP_r2t, + e2e=mask_fn is None and self.head_dim_padded <= 128, + e2e_freq=self.e2e_freq, + ) + # Sequence barrier arrive + if const_expr(self.s0_s1_barrier): + cute.arch.mbarrier_arrive( + mbar_ptr + mbar_s0_s1_sequence_offset + (1 - stage) * 4 + ) + # print(tSrP_r2t_f32, tStP_r2t) + # cute.copy(thr_tmem_store, tSrP_r2t_f32, tStP_r2t) + for i in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2]) // 4 * 3): + cute.copy( + thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i] + ) + cute.arch.fence_view_async_tmem_store() + # Notify mma warp that P is ready + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) + for i in cutlass.range_constexpr( + cute.size(tStP_r2t.shape[2]) // 4 * 3, cute.size(tStP_r2t.shape[2]) + ): + cute.copy( + thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i] + ) + cute.arch.fence_view_async_tmem_store() + # Notify mma warp that the 2nd half of P is ready + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_2_offset + stage) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, + si_corr_producer_phase, + ) + softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first) + # acc_scale = cute.arch.exp2(acc_scale_) + return ( + mma_si_consumer_phase ^ 1, + si_corr_producer_phase ^ 1, + s0_s1_sequence_phase ^ 1, + ) + + @cute.jit + def correction_loop( + self, + thr_mma_qk: cute.core.ThrMma, + thr_mma_pv: cute.core.ThrMma, + tStS: cute.Tensor, + tOtOs: tuple[cute.Tensor], + sScale: cute.Tensor, + mO: cute.Tensor, + mLSE: cute.Tensor, + sO: cute.Tensor, + learnable_sink: Optional[cute.Tensor], + tma_atom_O: cute.CopyAtom, + mbar_ptr: cute.Pointer, + softmax_scale_log2: Float32, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + tScS = thr_mma_qk.partition_C( + cute.make_identity_tensor((self.mma_tiler_qk[0], self.mma_tiler_qk[1])) + ) + tStS_scale_layout = cute.composition( + tStS.layout, cute.make_layout((self.m_block_size, 1)) + ) + tStScales = tuple( + cute.make_tensor( + tStS.iterator + self.tmem_vec_offset[stage], tStS_scale_layout + ) + for stage in range(2) + ) + tScS_vec_layout = cute.composition( + tScS.layout, cute.make_layout((self.m_block_size, 1)) + ) + tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + tmem_load_v_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(1)), + self.qk_acc_dtype, + ) + tiled_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_v_atom, tStScales[0]) + tidx = cute.arch.thread_idx()[0] % ( + cute.arch.WARP_SIZE * len(self.correction_warp_ids) + ) + thr_tmem_load_vec = tiled_tmem_load_vec.get_slice(tidx) + + tStScales_t2r = [ + thr_tmem_load_vec.partition_S(tStScales[stage]) for stage in range(2) + ] + tSrScale_t2r_shape = thr_tmem_load_vec.partition_D(tScS_vec).shape + + # First iter: no correction is required + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + 0) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + 1) + + softmax_corr_consumer_phase = Int32(0) + o_corr_consumer_phase = Int32(0) + corr_epi_producer_phase = Int32(1) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + + # Ignore first signal from softmax as no correction is required + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + 0, + softmax_corr_consumer_phase, + ) + cute.arch.mbarrier_arrive( + mbar_ptr + self.mbar_softmax_corr_empty_offset + 0 + ) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + 1, + softmax_corr_consumer_phase, + ) + softmax_corr_consumer_phase ^= 1 + + tSrScale_t2r = cute.make_fragment(tSrScale_t2r_shape, Float32) + for i in cutlass.range(n_block_max - n_block_min - 1, unroll=1): + for stage in cutlass.range_constexpr(2): + # wait for S0 / S1 + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + stage, + softmax_corr_consumer_phase, + ) + # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) + # cute.arch.fence_view_async_tmem_load() + # scale = tSrScale_t2r[0] + scale = sScale[tidx + stage * self.m_block_size] + should_rescale = cute.arch.vote_ballot_sync(scale < 1.0) != 0 + # should_rescale = True + # if tidx == 0: cute.printf("Correction scale i = %d, for stage %d: %f, should_rescale = %d\n", i, stage, scale, should_rescale) + # Don't need O_full anymore, since by the time softmax has signaled the correction + # warps, S_i must have been done, so O_i-1 must have been done as well. + # cute.arch.mbarrier_wait(mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase) + if should_rescale: + self.correction_rescale( + thr_mma_pv, + tOtOs[stage if self.q_stage == 2 else 0], + tidx, + scale, + ) + cute.arch.mbarrier_arrive( + mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage + ) + cute.arch.mbarrier_arrive( + mbar_ptr + self.mbar_softmax_corr_empty_offset + (1 - stage) + ) + softmax_corr_consumer_phase ^= 1 + # o_corr_consumer_phase ^= 1 + # End of seqlen_corr_loop_steps + + cute.arch.mbarrier_arrive( + mbar_ptr + self.mbar_softmax_corr_empty_offset + 1 + ) + + # Even in the case of self.overlap_sO_sQ, we can write to stage 0 of sO without + # additional sync because the MMA in the top half must have been done. + # Similarly we can write to stage 1 of sO without additional sync. + stats = [None] * self.q_stage + learnable_sink_val = [None] * self.q_stage + if const_expr(learnable_sink is not None): + if const_expr(not self.pack_gqa): + sink_val = Float32(learnable_sink[head_idx]) + learnable_sink_val = [sink_val] * self.q_stage + else: # Each thread might have a different sink value due to different q_head + for stage in cutlass.range_constexpr(self.q_stage): + q_head_idx = ( + (self.q_stage * m_block + stage) * self.m_block_size + tidx + ) % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead + learnable_sink_val[stage] = Float32(learnable_sink[q_head_idx]) + for stage in cutlass.range_constexpr(self.q_stage): + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + stage, + softmax_corr_consumer_phase, + ) + # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) + # cute.arch.fence_view_async_tmem_load() + # scale = tSrScale_t2r[0] + row_sum = sScale[tidx + stage * self.m_block_size] + if const_expr(mLSE is not None or learnable_sink is not None): + row_max = sScale[ + tidx + stage * self.m_block_size + self.m_block_size * 2 + ] + else: + row_max = None + cute.arch.mbarrier_arrive( + mbar_ptr + self.mbar_softmax_corr_empty_offset + stage + ) + if const_expr(learnable_sink is not None): + LOG2_E = math.log2(math.e) + row_sum += utils.exp2f( + learnable_sink_val[stage] * LOG2_E + - row_max * softmax_scale_log2 + ) + acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum + stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan) + scale = cute.arch.rcp_approx( + row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0 + ) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase + ) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_corr_epi_empty_offset + stage, + corr_epi_producer_phase, + ) + self.correction_epilogue( + thr_mma_pv, + tOtOs[stage], + tidx, + scale, + sO[None, None, stage], + ) + cute.arch.mbarrier_arrive( + mbar_ptr + self.mbar_corr_epi_full_offset + stage + ) + # Signal for the next work tile that O buffers in tmem are already read, so + # mma warp can write to them + cute.arch.mbarrier_arrive( + mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage + ) + # if tidx == 0: cute.printf("Correction final scale for stage %d: %f\n", stage, scale) + if const_expr(mLSE is not None): + if const_expr(not seqlen.has_cu_seqlens_q): + mLSE_cur = mLSE[None, head_idx, batch_idx] + else: + offset = ( + seqlen.offset_q + if const_expr(not self.pack_gqa) + else (0, seqlen.offset_q) + ) + mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) + for stage in cutlass.range_constexpr(self.q_stage): + gLSE = cute.local_tile( + mLSE_cur, + (self.m_block_size,), + (self.q_stage * m_block + stage,), + ) + row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage] + # if tidx == 0 and stage <= 1: + # cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan) + LN2 = math.log(2.0) + lse = ( + (row_max * softmax_scale_log2 + utils.log2f(row_sum)) * LN2 + if not acc_O_mn_row_is_zero_or_nan + else -Float32.inf + ) + seqlen_q = ( + seqlen.seqlen_q + if const_expr(not self.pack_gqa) + else seqlen.seqlen_q * self.qhead_per_kvhead + ) + if ( + tidx + < seqlen_q + - (self.q_stage * m_block + stage) * self.m_block_size + ): + # This actually just works with PackGQA too + gLSE[tidx] = lse + + o_corr_consumer_phase ^= 1 + softmax_corr_consumer_phase ^= 1 + corr_epi_producer_phase ^= 1 + + # gO_qdhb = cute.local_tile(mO, cute.select(self.mma_tiler_pv, mode=[0, 1]), (None, 0, None, None)) + # gO = gO_qdhb[None, None, None, head_idx, batch_idx] + # tOsO, tOgO = cpasync.tma_partition( + # tma_atom_O, + # 0, + # cute.make_layout(1), + # cute.group_modes(sO, 0, 2), + # cute.group_modes(gO, 0, 2), + # ) + # warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 + # stage = warp_idx_in_wg + # if stage < self.q_stage: + # # wait from corr, issue tma store on smem + # # 1. wait for O0 / O1 final + # cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, corr_epi_producer_phase) + # # 2. copy O0 / O1 to gmem + # cute.copy(tma_atom_O, tOsO[None, stage], tOgO[None, self.q_stage * m_block + stage]) + # cute.arch.cp_async_bulk_commit_group() + # # Ensure O0 / O1 buffer is ready to be released + # cute.arch.cp_async_bulk_wait_group(0, read=True) + # cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) + + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + @cute.jit + def correction_rescale( + self, + thr_mma: cute.core.ThrMma, + tOtO: cute.Tensor, + thread_idx: Int32, + scale: Float32, + ): + """Rescale intermediate attention results based on softmax normalization factor. + + This method performs a crucial correction step in the attention computation pipeline. + When processing attention in blocks, the softmax normalization factors may change + as new blocks are processed. This method rescales previously computed partial + output values to account for updated normalization factors. + + The implementation uses efficient tensor memory operations to: + 1. Load existing partial attention output from tensor memory + 2. Apply the scaling factor to all elements + 3. Store the rescaled results back to tensor memory + """ + cO = cute.make_identity_tensor((self.mma_tiler_pv[0], self.mma_tiler_pv[1])) + tOcO = thr_mma.partition_C(cO) + + corr_tile_size = 16 # tuneable parameter + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + + tOtO_i_layout = cute.composition( + tOtO.layout, cute.make_layout((self.m_block_size, corr_tile_size)) + ) + tOcO_i_layout = cute.composition( + tOcO.layout, cute.make_layout((self.m_block_size, corr_tile_size)) + ) + tOtO_i = cute.make_tensor(tOtO.iterator, tOtO_i_layout) + tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) + + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tOtO_i) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tOtO_i) + thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + thr_tmem_store = tiled_tmem_store.get_slice(thread_idx) + + tOtO_t2r = thr_tmem_load.partition_S(tOtO_i) + tOrO_t2r_shape = thr_tmem_load.partition_D(tOcO_i).shape + tOtO_r2t = thr_tmem_store.partition_D(tOtO_i) + + frg_count = self.head_dim_v_padded // corr_tile_size + tOrO_frg = cute.make_fragment((tOrO_t2r_shape, frg_count), self.pv_acc_dtype) + for i in cutlass.range_constexpr(frg_count): + tOrO_frg_i = tOrO_frg[None, i] + tTMrO_i_layout = cute.composition( + tOrO_frg_i.layout, cute.make_layout(tOrO_frg.shape[0]) + ) + tTMrO_i = cute.make_tensor(tOrO_frg_i.iterator, tTMrO_i_layout) + tOtO_t2r_i = cute.make_tensor( + tOtO_t2r.iterator + i * corr_tile_size, tOtO_t2r.layout + ) + cute.copy(tiled_tmem_load, tOtO_t2r_i, tTMrO_i) + for j in cutlass.range_constexpr(0, cute.size(tTMrO_i), 2): + tTMrO_i[j], tTMrO_i[j + 1] = cute.arch.mul_packed_f32x2( + (tTMrO_i[j], tTMrO_i[j + 1]), + (scale, scale), + ) + tOtO_r2t_i = cute.make_tensor( + tOtO_r2t.iterator + i * corr_tile_size, tOtO_r2t.layout + ) + cute.copy(tiled_tmem_store, tTMrO_i, tOtO_r2t_i) + cute.arch.fence_view_async_tmem_store() + + @cute.jit + def correction_epilogue( + self, + thr_mma: cute.core.ThrMma, + tOtO: cute.Tensor, + thread_idx: Int32, + scale: Float32, + sO: cute.Tensor, + ): + """Apply final scaling and transformation to attention output before writing to global memory. + + This correction_epilogue function handles the final processing step for attention output values. + It applies a scaling factor to the accumulated attention results and prepares the + data for efficient transfer back to global memory. + + The method performs: + 1. Loading of accumulated attention results from tensor memory + 2. Application of the final output scaling factor + 3. Type conversion if necessary (typically from higher precision accumulator to output precision) + 4. Reorganization of data for optimal memory access patterns + 5. Preparation for efficient TMA store operations + + :param thr_mma: Thread MMA operation for the computation + :type thr_mma: cute.core.ThrMma + :param tOtO: Tensor containing accumulated attention output + :type tOtO: cute.Tensor + :param scale: Final scaling factor to apply to the output + :type scale: Float32 + :param sO: Shared memory tensor for the final output + :type sO: cute.Tensor + """ + + cO = cute.make_identity_tensor((self.mma_tiler_pv[0], self.mma_tiler_pv[1])) + corr_tile_size = 32 * 8 // self.o_dtype.width + tOsO = thr_mma.partition_C(sO) + tOcO = thr_mma.partition_C(cO) + + tOtO_i = cute.logical_divide( + tOtO, cute.make_layout((self.m_block_size, corr_tile_size)) + ) + tOcO_i = cute.logical_divide( + tOcO, cute.make_layout((self.m_block_size, corr_tile_size)) + ) + tOsO_i = cute.logical_divide( + tOsO, cute.make_layout((self.m_block_size, corr_tile_size)) + ) + + epi_subtile = (self.epi_tile[0], corr_tile_size) + tmem_copy_atom = sm100_utils_basic.get_tmem_load_op( + self.mma_tiler_pv, + self.o_layout, + self.o_dtype, + self.pv_acc_dtype, + epi_subtile, + use_2cta_instrs=False, + ) + + tiled_tmem_load = tcgen05.make_tmem_copy( + tmem_copy_atom, tOtO_i[(None, None), 0] + ) + + thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + smem_copy_atom = sm100_utils_basic.get_smem_store_op( + self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load + ) + tiled_smem_store = cute.make_tiled_copy( + smem_copy_atom, + layout_tv=tiled_tmem_load.layout_dst_tv_tiled, + tiler_mn=tiled_tmem_load.tiler_mn, + ) + + tOtO_t2r = thr_tmem_load.partition_S(tOtO_i[(None, None), None]) + tOsO_s2r = thr_tmem_load.partition_D(tOsO_i[(None, None), None]) + tOcO_t2r = thr_tmem_load.partition_D(tOcO_i[(None, None), None]) + + for i in cutlass.range_constexpr(self.head_dim_v_padded // corr_tile_size): + tOtO_t2r_i = tOtO_t2r[None, 0, 0, i] + tOsO_r2s_i = tOsO_s2r[None, 0, 0, i] + tOrO_frg = cute.make_fragment( + tOcO_t2r[None, 0, 0, i].shape, self.pv_acc_dtype + ) + cute.copy(tiled_tmem_load, tOtO_t2r_i, tOrO_frg) + for j in cutlass.range_constexpr(0, cute.size(tOrO_frg), 2): + tOrO_frg[j], tOrO_frg[j + 1] = cute.arch.mul_packed_f32x2( + (tOrO_frg[j], tOrO_frg[j + 1]), + (scale, scale), + ) + tSMrO = cute.make_fragment(tOrO_frg.shape, self.o_dtype) + o_vec = tOrO_frg.load() + tSMrO.store(o_vec.to(self.o_dtype)) + cute.copy(tiled_smem_store, tSMrO, tOsO_r2s_i) + + # fence view async shared + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + + @cute.jit + def epilogue_s2g( + self, + mO: cute.Tensor, + sO: cute.Tensor, + gmem_tiled_copy_O: cute.TiledCopy, + tma_atom_O: Optional[cute.CopyAtom], + mbar_ptr: cute.Pointer, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + epi_consumer_phase = Int32(0) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(not seqlen.has_cu_seqlens_q): + mO_cur = mO[None, None, head_idx, batch_idx] + else: + offset = ( + seqlen.offset_q + if const_expr(not self.pack_gqa) + else (0, seqlen.offset_q) + ) + mO_cur = cute.domain_offset((offset, 0), mO[None, None, head_idx]) + gO = cute.local_tile( + mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0) + ) + if const_expr(self.use_tma_O): + tOsO, tOgO = cpasync.tma_partition( + tma_atom_O, + 0, + cute.make_layout(1), + cute.group_modes(sO, 0, 2), + cute.group_modes(gO, 0, 2), + ) + for stage in cutlass.range_constexpr(self.q_stage): + # wait from corr, issue tma store on smem + # 1. wait for O0 / O1 final + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_corr_epi_full_offset + stage, + epi_consumer_phase, + ) + # 2. copy O0 / O1 to gmem + cute.copy( + tma_atom_O, + tOsO[None, stage], + tOgO[None, self.q_stage * m_block + stage], + ) + cute.arch.cp_async_bulk_commit_group() + for stage in cutlass.range_constexpr(self.q_stage): + # Ensure O0 / O1 buffer is ready to be released + cute.arch.cp_async_bulk_wait_group(1 - stage, read=True) + cute.arch.mbarrier_arrive( + mbar_ptr + self.mbar_corr_epi_empty_offset + stage + ) + else: + tidx = cute.arch.thread_idx()[0] % ( + cute.arch.WARP_SIZE * len(self.epilogue_warp_ids) + ) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOsO = gmem_thr_copy_O.partition_S(sO) + cO = cute.make_identity_tensor( + (self.m_block_size, self.head_dim_v_padded) + ) + tOgO = gmem_thr_copy_O.partition_D(gO) + tOcO = gmem_thr_copy_O.partition_S(cO) + t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) + # TODO: the packgqa case isn't correct rn (sometimes IMA), disabling it + assert not self.pack_gqa + pack_gqa = PackGQA( + self.m_block_size, + self.head_dim_v_padded, + self.check_hdim_v_oob, + self.qhead_per_kvhead, + ) + for stage in cutlass.range_constexpr(self.q_stage): + # wait from corr, issue tma store on smem + # 1. wait for O0 / O1 final + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_corr_epi_full_offset + stage, + epi_consumer_phase, + ) + # 2. copy O0 / O1 to gmem + # load acc O from smem to rmem for wider vectorization + tOrO = cute.make_fragment_like( + tOsO[None, None, None, 0], self.o_dtype + ) + cute.autovec_copy(tOsO[None, None, None, stage], tOrO) + # copy acc O from rmem to gmem + if const_expr(not self.pack_gqa): + for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + if ( + t0OcO[0, rest_m, 0][0] + < seqlen.seqlen_q + - (self.q_stage * m_block + stage) * self.m_block_size + - tOcO[0][0] + ): + cute.copy( + gmem_tiled_copy_O, + tOrO[None, rest_m, None], + tOgO[ + None, + rest_m, + None, + self.q_stage * m_block + stage, + ], + pred=( + tOpO[None, rest_m, None] + if self.check_hdim_v_oob + else None + ), + ) + else: + pack_gqa.store_O( + mO_cur, + tOrO, + gmem_tiled_copy_O, + tidx, + self.q_stage * m_block + stage, + seqlen.seqlen_q, + ) + cute.arch.mbarrier_arrive( + mbar_ptr + self.mbar_corr_epi_empty_offset + stage + ) + + # Advance to next tile + epi_consumer_phase ^= 1 + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + def load_Q( + self, + tma_atom: cute.CopyAtom, + tQgQ: cute.Tensor, + tQsQ: cute.Tensor, + mbar_full_ptr: cute.Pointer, + mbar_empty_ptr: cute.Pointer, + block: Int32, + stage: int, + phase: Int32, + ): + cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + mbar_full_ptr + stage, self.tma_copy_q_bytes + ) + cute.copy( + tma_atom, + tQgQ[None, block], + tQsQ[None, stage], + tma_bar_ptr=mbar_full_ptr + stage, + ) + + @cute.jit + def load_KV( + self, + tma_atom: cute.CopyAtom, + tXgX: cute.Tensor, + tXsX: cute.Tensor, + mbar_full_ptr: cute.Pointer, + mbar_empty_ptr: cute.Pointer, + block: Int32, + producer_state: cutlass.pipeline.PipelineState, + K_or_V: str, + page_idx: Optional[Int32] = None, + ): + assert K_or_V in ("K", "V") + tma_copy_bytes = ( + self.tma_copy_k_bytes + if const_expr(K_or_V == "K") + else self.tma_copy_v_bytes + ) + stage, phase = producer_state.index, producer_state.phase + cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase) + if const_expr(K_or_V == "K" and self.uneven_kv_smem): + # Before this round, the smem location was occupied by V, which is smaller than + # K. So we need to wait for the stage after that (stage 1) to be empty as well. + if stage == 0: + cute.arch.mbarrier_wait(mbar_empty_ptr + 1, phase) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + mbar_full_ptr + stage, tma_copy_bytes + ) + tXsX_cur = tXsX[None, stage] + if const_expr(self.uneven_kv_smem): + # Since this is the producer_state, the phase starts at 1, so we have to invert it + tXsX_cur = self.offset_kv_smem(tXsX_cur, stage, phase ^ 1) + # Currently we assume that page_size == n_block_size so we index into tXgX with block = 0 + tXgX_cur = ( + tXgX[None, block] + if const_expr(page_idx is None) + else tXgX[None, 0, page_idx] + ) + cute.copy(tma_atom, tXgX_cur, tXsX_cur, tma_bar_ptr=mbar_full_ptr + stage) + + @cute.jit + def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): + if const_expr(self.uneven_kv_smem): + # smem layout is [smem_large, smem_small, smem_large], and the current stride is + # (smem_large + smem_small) // 2. So for stage == 1, move right by offset if + # phase == 0, or left by offset if phase == 1. + offset = 0 if stage != 1 else self.uneven_kv_smem_offset * (1 - 2 * phase) + return cute.make_tensor(sX.iterator + offset, sX.layout) + else: + return sX + + def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): + load_kv_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.load_warp_id]) + ) + load_kv_consumer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return cutlass.pipeline.PipelineTmaUmma.create( + barrier_storage=load_kv_mbar_ptr, + num_stages=self.kv_stage, + producer_group=load_kv_producer_group, + consumer_group=load_kv_consumer_group, + tx_count=self.tma_copy_k_bytes, + ) + + # @cute.jit + # def warp_scheduler_barrier_init(self): + # warp_group_idx = utils.canonical_warp_group_idx(sync=False) + # if warp_group_idx == 0: + # cute.arch.barrier_arrive( + # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), number_of_threads=2 * 128, + # ) + + # def warp_scheduler_barrier_sync(self): + # cute.arch.barrier( + # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + utils.canonical_warp_group_idx(sync=False), + # number_of_threads=2 * 128 + # ) + + # def warp_scheduler_barrier_arrive(self): + # cur_wg = utils.canonical_warp_group_idx(sync=False) + # next_wg = 1 - cur_wg + # cute.arch.barrier_arrive( + # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128, + # ) diff --git a/python/sglang/srt/layers/attention/cute_ops/mask.py b/python/sglang/srt/layers/attention/cute_ops/mask.py new file mode 100644 index 00000000000..1e858c4bbda --- /dev/null +++ b/python/sglang/srt/layers/attention/cute_ops/mask.py @@ -0,0 +1,277 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Optional +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute + +from sglang.srt.layers.attention.cute_ops import utils + + +@dataclass(frozen=True) +class AttentionMask: + m_block_size: cutlass.Constexpr[int] + n_block_size: cutlass.Constexpr[int] + seqlen_q: cutlass.Int32 + seqlen_k: cutlass.Int32 + window_size_left: Optional[cutlass.Int32] = None + window_size_right: Optional[cutlass.Int32] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = ( + 1 # only pass in if we're doing PackGQA + ) + + @cute.jit + def apply_mask( + self, + acc_S: cute.Tensor, + m_block: cutlass.Int32, + n_block: cutlass.Int32, + thr_mma: cute.TiledMma, + mask_seqlen: cutlass.Constexpr[bool], + mask_causal: cutlass.Constexpr[bool], + mask_local: cutlass.Constexpr[bool] = False, + ) -> None: + assert not ( + mask_causal and mask_local + ), "mask_causal and mask_local cannot be both True" + acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) + cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size)) + tScS_mn = utils.make_acc_tensor_mn_view(thr_mma.partition_C(cS)) + # We use t0ScS as these indices are known at compile time. We then must subtract the + # column limit by the thread column offset. + t0ScS_mn = utils.make_acc_tensor_mn_view(thr_mma.get_slice(0).partition_C(cS)) + thr_col_offset = tScS_mn[0][1] + seqlenk_col_limit = self.seqlen_k - n_block * self.n_block_size - thr_col_offset + if cutlass.const_expr(not mask_causal and not mask_local): + if cutlass.const_expr(mask_seqlen): + # traverse column index. + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): + # if t0ScS_mn[0, c][1] >= seqlenk_col_limit: + # acc_S_mn[None, c].fill(-cutlass.Float32.inf) + oob = t0ScS_mn[0, c][1] >= seqlenk_col_limit + for r in cutlass.range( + cute.size(tScS_mn.shape[0]), unroll_full=True + ): + acc_S_mn[r, c] = -cutlass.Float32.inf if oob else acc_S_mn[r, c] + else: # Causal or local + # If PackGQA, we split the work of compute divmod among threads in the same row + threads_per_row = thr_mma.tv_layout_C.shape[0][0] + if cutlass.const_expr(self.qhead_per_kvhead_packgqa != 1): + assert ( + cute.arch.WARP_SIZE % threads_per_row == 0 + ), "threads_per_row must divide WARP_SIZE" + assert cute.size(acc_S_mn.shape[0]) <= threads_per_row + tidx = thr_mma.thr_idx + mma_m_idx = ( + m_block * self.m_block_size + tScS_mn[tidx % threads_per_row, 0][0] + ) // self.qhead_per_kvhead_packgqa + causal_row_offset = ( + 1 + + self.seqlen_k + - n_block * self.n_block_size + - self.seqlen_q + - thr_col_offset + ) + if cutlass.const_expr(mask_causal): + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + # get the column index limit based on current row. Only consider the row index, so the column index sets to 0. + if cutlass.const_expr(self.qhead_per_kvhead_packgqa == 1): + row_idx = tScS_mn[r, 0][0] + m_block * self.m_block_size + else: + row_idx = utils.shuffle_sync( + mma_m_idx, r % threads_per_row, width=threads_per_row + ) + col_limit_right = row_idx + causal_row_offset + if cutlass.const_expr(mask_seqlen): + col_limit_right = cutlass.min( + col_limit_right, seqlenk_col_limit + ) + # traverse column index. + for c in cutlass.range( + cute.size(tScS_mn.shape[1]), unroll_full=True + ): + # only consider the column index, so the row index sets to 0. + # if t0ScS_mn[0, c][1] >= col_limit_right: + # acc_S_mn[r, c] = -cutlass.Float32.inf + acc_S_mn[r, c] = ( + -cutlass.Float32.inf + if t0ScS_mn[0, c][1] >= col_limit_right + else acc_S_mn[r, c] + ) + else: # Local + local_row_offset_right = ( + causal_row_offset + self.window_size_right + if cutlass.const_expr(self.window_size_right is not None) + else None + ) + local_row_offset_left = ( + causal_row_offset - 1 - self.window_size_left + if cutlass.const_expr(self.window_size_left is not None) + else None + ) + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + if cutlass.const_expr(self.qhead_per_kvhead_packgqa == 1): + row_idx = tScS_mn[r, 0][0] + m_block * self.m_block_size + else: + row_idx = utils.shuffle_sync( + mma_m_idx, r % threads_per_row, width=threads_per_row + ) + if cutlass.const_expr(self.window_size_right is not None): + col_limit_right = row_idx + local_row_offset_right + if cutlass.const_expr(mask_seqlen): + col_limit_right = cutlass.min( + col_limit_right, seqlenk_col_limit + ) + else: + col_limit_right = self.n_block_size + col_limit_left = ( + row_idx + local_row_offset_left + if cutlass.const_expr(self.window_size_left is not None) + else 0 + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block = {}, r = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", n_block, r, row_idx, causal_row_offset, col_limit_right, col_limit_left) + # traverse column index. + for c in cutlass.range( + cute.size(tScS_mn.shape[1]), unroll_full=True + ): + col_idx = t0ScS_mn[0, c][1] + # only consider the column index, so the row index sets to 0. + if col_idx >= col_limit_right or col_idx < col_limit_left: + acc_S_mn[r, c] = -cutlass.Float32.inf + + @cute.jit + def apply_mask_sm100( + self, + acc_S: cute.Tensor, + m_block: cutlass.Int32, + n_block: cutlass.Int32, + thr_mma: cute.TiledMma, + thr_tmem_load: cute.TiledCopy, + mask_seqlen: cutlass.Constexpr, + mask_causal: cutlass.Constexpr, + mask_local: cutlass.Constexpr, + ) -> None: + assert not ( + mask_causal and mask_local + ), "mask_causal and mask_local cannot be both True" + cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size)) + tScS = thr_mma.partition_C(cS) + tScS_t2r = thr_tmem_load.partition_D(tScS) + seqlenk_col_limit = self.seqlen_k - n_block * self.n_block_size + if cutlass.const_expr(not mask_causal and not mask_local): + if cutlass.const_expr(mask_seqlen): + ncol = cutlass.const_expr(cute.size(tScS_t2r.shape)) + if cutlass.const_expr(not ncol % 16 == 0): + for i in cutlass.range(ncol, unroll_full=True): + # if tScS_t2r[i][1] >= seqlenk_col_limit: + # acc_S[i] = -cutlass.Float32.inf + # For some reason the 2 lines above generate really bad SASS + acc_S[i] = ( + -cutlass.Float32.inf + if tScS_t2r[i][1] >= seqlenk_col_limit + else acc_S[i] + ) + else: + # Bit manipulation, compiles down to the R2P instruction + # We know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using + # Ideally we'd move by 32 instead of 16, but mask >> i isn't correct for i == 31 + # (see below). + for s in cutlass.range(ncol // 16, unroll_full=True): + col_limit_right_s = seqlenk_col_limit - s * 16 + # Don't need to clamp to 32 since the shr.u32 instruction does that already + col_limit_right_cur = cutlass.Uint32(max(col_limit_right_s, 0)) + # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 + mask = cutlass.Uint32((1 << col_limit_right_cur) - 1) + # if tidx == 0: cute.printf("mask = 0x%x, col_limit_right_s = %d, col_limit_right_cur = %d", mask, col_limit_right_s, col_limit_right_cur) + # This needs to be range_constexpr, otherwise the compiler can't generate + # the R2P instruction + for i in cutlass.range_constexpr(16): + # mask >> i does not produce correct result for 0b11..11 >> 31 + # However, if we use utils.shr_u32, the compiler doesn't generate + # the R2P instruction, so it's slower. + # Instead we just move by 16 instead of 32. + mask_i_bit = cutlass.Boolean(mask & (1 << i)) + # mask_i_bit = cutlass.Boolean(utils.shr_u32(mask, i) & 1) + # if tidx == 0: cute.printf("mask_i_bit = %d, after shift = 0x%x, i = %d, s = %d", mask_i_bit, utils.shr_u32(mask, i), i, s) + acc_S[s * 16 + i] = ( + acc_S[s * 16 + i] + if mask_i_bit + else -cutlass.Float32.inf + ) + # This is the equivalent of: + # acc_S[s * 16 + i] = acc_S[s * 16 + i] if col_limit_right_s <= i else -cutlass.Float32.inf + # if tidx == 0: cute.print_tensor(acc_S) + else: # Causal or local + causal_row_offset = ( + 1 + self.seqlen_k - n_block * self.n_block_size - self.seqlen_q + ) + row_idx = tScS_t2r[0][0] + m_block * self.m_block_size + if cutlass.const_expr(self.qhead_per_kvhead_packgqa != 1): + row_idx = row_idx // self.qhead_per_kvhead_packgqa + if cutlass.const_expr(mask_causal): + col_limit_right = row_idx + causal_row_offset + if cutlass.const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) + # if cute.arch.thread_idx()[0] % 32 == 0: + # cute.printf("tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\n", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset) + ncol = cutlass.const_expr(cute.size(tScS_t2r.shape)) + if cutlass.const_expr(not ncol % 16 == 0): + for i in cutlass.range(ncol, unroll_full=True): + acc_S[i] = ( + -cutlass.Float32.inf + if tScS_t2r[i][1] >= col_limit_right + else acc_S[i] + ) + else: + # Bit manipulation, compiles down to the R2P instruction + # We know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using + for s in cutlass.range(ncol // 16, unroll_full=True): + col_limit_right_s = col_limit_right - s * 16 + col_limit_right_cur = cutlass.Uint32(max(col_limit_right_s, 0)) + # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 + mask = cutlass.Uint32((1 << col_limit_right_cur) - 1) + # This needs to be range_constexpr, otherwise the compiler can't generate + # the R2P instruction + for i in cutlass.range_constexpr(16): + # mask_i_bit = cutlass.Boolean(utils.shr_u32(mask, i) & 1) + mask_i_bit = cutlass.Boolean(mask & (1 << i)) + acc_S[s * 16 + i] = ( + acc_S[s * 16 + i] + if mask_i_bit + else -cutlass.Float32.inf + ) + # This is the equivalent of: + # acc_S[s * 16 + i] = acc_S[s * 16 + i] if col_limit_right_s <= i else -cutlass.Float32.inf + else: + local_row_offset_right = ( + causal_row_offset + self.window_size_right + if cutlass.const_expr(self.window_size_right is not None) + else None + ) + local_row_offset_left = ( + causal_row_offset - 1 - self.window_size_left + if cutlass.const_expr(self.window_size_left is not None) + else None + ) + if cutlass.const_expr(self.window_size_right is not None): + col_limit_right = row_idx + local_row_offset_right + if cutlass.const_expr(mask_seqlen): + col_limit_right = cutlass.min( + col_limit_right, seqlenk_col_limit + ) + else: + col_limit_right = self.n_block_size + col_limit_left = ( + row_idx + local_row_offset_left + if cutlass.const_expr(self.window_size_left is not None) + else 0 + ) + # if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left) + for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): + col_idx = tScS_t2r[i][1] + acc_S[i] = ( + -cutlass.Float32.inf + if col_idx >= col_limit_right or col_idx < col_limit_left + else acc_S[i] + ) diff --git a/python/sglang/srt/layers/attention/cute_ops/mma_sm100_desc.py b/python/sglang/srt/layers/attention/cute_ops/mma_sm100_desc.py new file mode 100644 index 00000000000..4243e76fb51 --- /dev/null +++ b/python/sglang/srt/layers/attention/cute_ops/mma_sm100_desc.py @@ -0,0 +1,312 @@ +# Copyright (c) 2025, Tri Dao. +# Ported Cutlass code from C++ to Python: +# https://github.com/NVIDIA/cutlass/blob/main/include/cute/arch/mma_sm100_desc.hpp +# https://github.com/NVIDIA/cutlass/blob/main/include/cute/atom/mma_traits_sm100.hpp + +from enum import IntEnum + +import cutlass +import cutlass.cute as cute + +# --------------------------------------------------------------------------- +# Enumerations that match the HW encodings (values MUST stay identical) +# --------------------------------------------------------------------------- + + +class Major(IntEnum): # matrix “layout” in the ISA docs + K = 0 + MN = 1 + + +class ScaleIn(IntEnum): # negate flags + One = 0 + Neg = 1 + + +class Saturate(IntEnum): + False_ = 0 + True_ = 1 + + +class CFormat(IntEnum): # 2-bit field (bits 4-5) + F16 = 0 + F32 = 1 + S32 = 2 + + +class F16F32Format(IntEnum): # 3-bit field (A/B element type) + F16 = 0 + BF16 = 1 + TF32 = 2 + + +class S8Format(IntEnum): + UINT8 = 0 + INT8 = 1 + + +class MXF8F6F4Format(IntEnum): + E4M3 = 0 + E5M2 = 1 + E2M3 = 3 + E3M2 = 4 + E2M1 = 5 + + +class MaxShift(IntEnum): + NoShift = 0 + MaxShift8 = 1 + MaxShift16 = 2 + MaxShift32 = 3 + + +# --------------------------------------------------------------------------- +# CUTLASS-type → encoding helpers +# --------------------------------------------------------------------------- + + +def to_UMMA_format(cutlass_type) -> int: + """ + Map a CUTLASS scalar class to the 3-bit encoding for Matrix A/B. + """ + if cutlass_type is cutlass.Int8: + return S8Format.INT8 + # Unsigned 8-bit (if available in your CUTLASS build) + if cutlass_type is cutlass.Uint8: + return S8Format.UINT8 + # FP-16 / BF-16 + if cutlass_type is cutlass.Float16: + return F16F32Format.F16 + if cutlass_type is cutlass.BFloat16: + return F16F32Format.BF16 + # TensorFloat-32 (8-bit exponent, 10-bit mantissa packed in 19 bits) + if cutlass_type is cutlass.TFloat32: + return F16F32Format.TF32 + # Float-8 / Float-6 / Float-4 – add whenever CUTLASS exposes them + if cutlass_type is cutlass.FloatE4M3FN: + return MXF8F6F4Format.E4M3 + if cutlass_type is cutlass.FloatE5M2: + return MXF8F6F4Format.E5M2 + raise TypeError(f"Unsupported CUTLASS scalar type for A/B: {cutlass_type!r}") + + +def to_C_format(cutlass_type) -> int: + """ + Map a CUTLASS scalar class to the 2-bit accumulator encoding. + """ + if cutlass_type is cutlass.Float16: + return CFormat.F16 + if cutlass_type is cutlass.Float32: + return CFormat.F32 + if cutlass_type is cutlass.Int32: + return CFormat.S32 + raise TypeError( + f"Unsupported CUTLASS scalar type for accumulator: {cutlass_type!r}" + ) + + +# --------------------------------------------------------------------------- +# The constructor – accepts only CUTLASS scalar classes +# --------------------------------------------------------------------------- + + +def make_instr_desc( + a_type, # CUTLASS scalar class, e.g. cutlass.Int8 + b_type, + c_type, + M: int, # 64, 128 or 256 + N: int, # 8 … 256 (multiple of 8) + a_major: Major, + b_major: Major, + a_neg: ScaleIn = ScaleIn.One, + b_neg: ScaleIn = ScaleIn.One, + c_sat: Saturate = Saturate.False_, + is_sparse: bool = False, + max_shift: MaxShift = MaxShift.NoShift, +) -> int: + """ + Build the 32-bit instruction descriptor for Blackwell MMA. + All matrix/accumulator **types must be CUTLASS scalar classes** – + passing integers is forbidden. + """ + # --- encode element formats ------------------------------------------------- + a_fmt = int(to_UMMA_format(a_type)) + b_fmt = int(to_UMMA_format(b_type)) + c_fmt = int(to_C_format(c_type)) + + # --- range checks on M/N ----------------------------------------------------- + if M not in (64, 128, 256): + raise ValueError("M must be 64, 128 or 256") + if N < 8 or N > 256 or (N & 7): + raise ValueError("N must be a multiple of 8 in the range 8…256") + + m_dim = M >> 4 # 5-bit field + n_dim = N >> 3 # 6-bit field + + # --- pack the bit-fields ----------------------------------------------------- + desc = 0 + desc |= (0 & 0x3) << 0 # sparse_id2 (always 0 here) + desc |= (int(is_sparse) & 0x1) << 2 # sparse_flag + desc |= (int(c_sat) & 0x1) << 3 # saturate + desc |= (c_fmt & 0x3) << 4 # c_format + desc |= (a_fmt & 0x7) << 7 # a_format + desc |= (b_fmt & 0x7) << 10 # b_format + desc |= (int(a_neg) & 0x1) << 13 # a_negate + desc |= (int(b_neg) & 0x1) << 14 # b_negate + desc |= (int(a_major) & 0x1) << 15 # a_major + desc |= (int(b_major) & 0x1) << 16 # b_major + desc |= (n_dim & 0x3F) << 17 # n_dim (6 bits) + desc |= (m_dim & 0x1F) << 24 # m_dim (5 bits) + desc |= (int(max_shift) & 0x3) << 30 # max_shift (2 bits) + + return desc & 0xFFFF_FFFF # ensure 32-bit result + + +def mma_op_to_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp): + return make_instr_desc( + op.a_dtype, + op.b_dtype, + op.acc_dtype, + op.shape_mnk[0], + op.shape_mnk[1], + ( + Major.K + if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K + else Major.MN + ), + ( + Major.K + if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K + else Major.MN + ), + ) + + +class LayoutType(IntEnum): # occupies the top-3 bits [61:64) + SWIZZLE_NONE = 0 # (a.k.a. “INTERLEAVE” in older docs) + SWIZZLE_128B_BASE32B = 1 + SWIZZLE_128B = 2 + SWIZZLE_64B = 4 + SWIZZLE_32B = 6 + # values 3,5,7 are reserved / illegal for UMMA + + +# --------------------------------------------------------------------------- +# Helpers – figure out the SWIZZLE_* family from the tensor layout +# --------------------------------------------------------------------------- + + +def _layout_type(swizzle: cute.Swizzle) -> LayoutType: + # No idea what the right way to get B, M, S is – so we're just parsing it from the __str__ + # Swizzle string has the form "S" + swz_str = str(swizzle) + inside = swz_str[swz_str.index("<") + 1 : swz_str.index(">")] # '3,4,3' + B, M, S = [int(x) for x in inside.split(",")] # [3, 4, 3] + + if M == 4: # Swizzle<*,4,3> + if S != 3: + raise ValueError("Unexpected swizzle shift – want S==3 for M==4") + return { + 0: LayoutType.SWIZZLE_NONE, + 1: LayoutType.SWIZZLE_32B, + 2: LayoutType.SWIZZLE_64B, + 3: LayoutType.SWIZZLE_128B, + }[ + B + ] # KeyError ⇒ invalid B→ raise + if M == 5: # Swizzle<2,5,2> (the only legal triple for M==5) + if (B, S) != (2, 2): + raise ValueError("Only Swizzle<2,5,2> supported for 128B_BASE32B") + return LayoutType.SWIZZLE_128B_BASE32B + + # Any other (M,B,S) triple is not a UMMA-legal shared-memory layout + raise ValueError("Unsupported swizzle triple for UMMA smem descriptor") + + +def make_smem_desc_base( + layout: cute.Layout, swizzle: cute.Swizzle, major: Major +) -> int: + """ + Convert a 2-D *shared-memory* Cute layout into the Blackwell 64-bit + smem-descriptor, without the smem start address. + layout must correspond to layout of an uint128 tensor. + """ + # ------------------------------------------------------------------ meta + layout_type = _layout_type(swizzle) # resolve SWIZZLE_* family + + VERSION = 1 # bits 46–47 + LBO_MODE = 0 # bit 52 + BASE_OFFSET = 0 # bits 49–51 (CUTLASS always 0) + + # ---------------------------------------------------------- strides (units: uint128_t = 16 B) + swizzle_atom_mn_size = { + LayoutType.SWIZZLE_NONE: 1, + LayoutType.SWIZZLE_32B: 2, + LayoutType.SWIZZLE_64B: 4, + LayoutType.SWIZZLE_128B: 8, + LayoutType.SWIZZLE_128B_BASE32B: 8, + }[layout_type] + + if major is Major.MN: + swizzle_atom_k_size = 4 if layout_type is LayoutType.SWIZZLE_128B_BASE32B else 8 + canonical_layout = cute.logical_divide( + layout, (swizzle_atom_mn_size, swizzle_atom_k_size) + ) + if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))): + raise ValueError( + "Not a canonical UMMA_MN Layout: Expected profile failure." + ) + stride_00 = canonical_layout.stride[0][0] + if layout_type is not LayoutType.SWIZZLE_NONE and stride_00 != 1: + raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.") + stride_10 = canonical_layout.stride[1][0] + if stride_10 != swizzle_atom_mn_size: + raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.") + stride_01, stride_11 = ( + canonical_layout.stride[0][1], + canonical_layout.stride[1][1], + ) + if layout_type is LayoutType.SWIZZLE_NONE: + stride_byte_offset, leading_byte_offset = stride_01, stride_11 + else: + stride_byte_offset, leading_byte_offset = stride_11, stride_01 + else: + if layout_type == LayoutType.SWIZZLE_128B_BASE32B: + raise ValueError("SWIZZLE_128B_BASE32B is invalid for Major-K") + if not cute.size(layout.shape[0]) % 8 == 0: + raise ValueError( + "Not a canonical UMMA_K Layout: Expected MN-size multiple of 8." + ) + canonical_layout = cute.logical_divide(layout, (8, 2)) + if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))): + raise ValueError("Not a canonical UMMA_K Layout: Expected profile failure.") + stride_00 = canonical_layout.stride[0][0] + if stride_00 != swizzle_atom_mn_size: + raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.") + stride_10 = canonical_layout.stride[1][0] + if layout_type is not LayoutType.SWIZZLE_NONE and stride_10 != 1: + raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.") + stride_01 = canonical_layout.stride[0][1] + stride_byte_offset, leading_byte_offset = stride_01, stride_10 + + # ------------------------------------------------------------------ pack + desc = 0 + # leading_byte_offset_ [16:30) + desc |= (leading_byte_offset & 0x3FFF) << 16 + # stride_byte_offset_ [32:46) + desc |= (stride_byte_offset & 0x3FFF) << 32 + # version_ [46:48) + desc |= (VERSION & 0x3) << 46 + # base_offset_ [49:52) + desc |= (BASE_OFFSET & 0x7) << 49 + # lbo_mode_ [52:53) + desc |= (LBO_MODE & 0x1) << 52 + # layout_type_ [61:64) + desc |= (int(layout_type) & 0x7) << 61 + + return desc & 0xFFFF_FFFF_FFFF_FFFF # force 64-bit width + + +def make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32: + # 14 bits, remove 4 LSB (bits 0-13 in desc) + return (start_addr.toint() & 0x3FFFF) >> 4 diff --git a/python/sglang/srt/layers/attention/cute_ops/pack_gqa.py b/python/sglang/srt/layers/attention/cute_ops/pack_gqa.py new file mode 100644 index 00000000000..3f67b8a3cfb --- /dev/null +++ b/python/sglang/srt/layers/attention/cute_ops/pack_gqa.py @@ -0,0 +1,191 @@ +# Copyright (c) 2025, Tri Dao. + +import cutlass +import cutlass.cute as cute + +from sglang.srt.layers.attention.cute_ops import utils + + +class PackGQA: + def __init__( + self, + m_block_size: cutlass.Constexpr[int], + head_dim_padded: cutlass.Constexpr[int], + check_hdim_oob: cutlass.Constexpr[bool], + qhead_per_kvhead: cutlass.Constexpr[bool], + ): + self.m_block_size = m_block_size + self.head_dim_padded = head_dim_padded + self.check_hdim_oob = check_hdim_oob + self.qhead_per_kvhead = qhead_per_kvhead + + @cute.jit + def compute_ptr( + self, + tensor: cute.Tensor, + cRows: cute.Tensor, + tidx: cutlass.Int32, + block: cutlass.Int32, + threads_per_row: cutlass.Constexpr[int], + num_threads: cutlass.Constexpr[int], + ): + num_ptr_per_thread = cute.ceil_div(cute.size(cRows), threads_per_row) + tPrPtr = cute.make_fragment(num_ptr_per_thread, cutlass.Int64) + for i in cutlass.range_constexpr(num_ptr_per_thread): + row = i * num_threads + cRows[tidx % threads_per_row][0] + idx = block * self.m_block_size + row + m_idx = idx // self.qhead_per_kvhead + h_idx = idx - m_idx * self.qhead_per_kvhead + tPrPtr[i] = utils.elem_pointer(tensor, ((h_idx, m_idx),)).toint() + return tPrPtr + + @cute.jit + def load_Q( + self, + mQ: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim) + sQ: cute.Tensor, # (m_block_size, head_dim_padded) + gmem_tiled_copy: cute.TiledCopy, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + gmem_thr_copy = gmem_tiled_copy.get_slice(tidx) + cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tQsQ = gmem_thr_copy.partition_D(sQ) + tQcQ = gmem_thr_copy.partition_S(cQ) + t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ) + tQpQ = utils.predicate_k(tQcQ, limit=mQ.shape[1]) + tQcQ_row = tQcQ[0, None, 0] + threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0] + assert ( + cute.arch.WARP_SIZE % threads_per_row == 0 + ), "threads_per_row must divide WARP_SIZE" + num_threads = gmem_tiled_copy.size + tPrQPtr = self.compute_ptr( + mQ[None, 0], tQcQ_row, tidx, block, threads_per_row, num_threads + ) + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): + q_ptr_i64 = utils.shuffle_sync( + tPrQPtr[m // threads_per_row], + m % threads_per_row, + width=threads_per_row, + ) + q_gmem_ptr = cute.make_ptr( + mQ.element_type, q_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + if ( + t0QcQ[0, m, 0][0] + < seqlen * self.qhead_per_kvhead + - block * self.m_block_size + - tQcQ_row[0][0] + ): + mQ_cur = cute.make_tensor(q_gmem_ptr, (self.head_dim_padded,)) + elems_per_load = cute.size(tQsQ.shape[0][0]) + mQ_cur_copy = cute.tiled_divide(mQ_cur, (elems_per_load,)) + for k in cutlass.range_constexpr(cute.size(tQsQ.shape[2])): + ki = tQcQ[0, 0, k][1] // elems_per_load + cute.copy( + gmem_thr_copy, + mQ_cur_copy[None, ki], + tQsQ[None, m, k], + pred=( + tQpQ[None, m, k] + if cutlass.const_expr(self.check_hdim_oob) + else None + ), + ) + # We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + + @cute.jit + def store_LSE( + self, + mLSE: cute.Tensor, # (qhead_per_kvhead, seqlen_q) + tLSErLSE: cute.Tensor, # (m_block_size, head_dim_padded) + tiled_mma: cute.TiledMma, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + thr_mma = tiled_mma.get_slice(tidx) + caccO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + taccOcO = thr_mma.partition_C(caccO) + taccOcO_row = utils.make_acc_tensor_mn_view(taccOcO)[None, 0] + assert cute.size(tLSErLSE) == cute.size(taccOcO_row) + threads_per_row = tiled_mma.tv_layout_C.shape[0][0] + assert ( + cute.arch.WARP_SIZE % threads_per_row == 0 + ), "threads_per_row must divide WARP_SIZE" + assert cute.size(tLSErLSE) <= threads_per_row + num_threads = tiled_mma.size + tPrLSEPtr = self.compute_ptr( + mLSE, taccOcO_row, tidx, block, threads_per_row, num_threads + ) + for m in cutlass.range_constexpr(cute.size(tLSErLSE)): + lse_ptr_i64 = utils.shuffle_sync( + tPrLSEPtr[m // threads_per_row], + m % threads_per_row, + width=threads_per_row, + ) + lse_gmem_ptr = cute.make_ptr( + mLSE.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4 + ) + row = block * self.m_block_size + taccOcO_row[m][0] + # Only the thread corresponding to column 0 writes out the lse to gmem + if taccOcO[0][1] == 0 and row < seqlen * self.qhead_per_kvhead: + mLSE_copy = cute.make_tensor(lse_gmem_ptr, (1,)) + mLSE_copy[0] = tLSErLSE[m] + + @cute.jit + def store_O( + self, + mO: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim) + tOrO: cute.Tensor, # (m_block_size, head_dim_padded) split across threads according to gmem_tiled_copy + gmem_tiled_copy: cute.TiledCopy, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + gmem_thr_copy = gmem_tiled_copy.get_slice(tidx) + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tOcO = gmem_thr_copy.partition_S(cO) + t0OcO = gmem_thr_copy.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) + tOcO_row = tOcO[0, None, 0] + threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0] + assert ( + cute.arch.WARP_SIZE % threads_per_row == 0 + ), "threads_per_row must divide WARP_SIZE" + num_threads = gmem_tiled_copy.size + tPrOPtr = self.compute_ptr( + mO[None, 0], tOcO_row, tidx, block, threads_per_row, num_threads + ) + for m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + o_ptr_i64 = utils.shuffle_sync( + tPrOPtr[m // threads_per_row], + m % threads_per_row, + width=threads_per_row, + ) + o_gmem_ptr = cute.make_ptr( + mO.element_type, o_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + if ( + t0OcO[0, m, 0][0] + < seqlen * self.qhead_per_kvhead + - block * self.m_block_size + - tOcO_row[0][0] + ): + mO_cur = cute.make_tensor(o_gmem_ptr, (self.head_dim_padded,)) + elems_per_load = cute.size(tOrO.shape[0][0]) + mO_cur_copy = cute.tiled_divide(mO_cur, (elems_per_load,)) + for k in cutlass.range_constexpr(cute.size(tOrO.shape[2])): + ki = tOcO[0, 0, k][1] // elems_per_load + cute.copy( + gmem_thr_copy, + tOrO[None, m, k], + mO_cur_copy[None, ki], + pred=( + tOpO[None, m, k] + if cutlass.const_expr(self.check_hdim_oob) + else None + ), + ) diff --git a/python/sglang/srt/layers/attention/cute_ops/prefill_attention.py b/python/sglang/srt/layers/attention/cute_ops/prefill_attention.py new file mode 100644 index 00000000000..4a630a2cc11 --- /dev/null +++ b/python/sglang/srt/layers/attention/cute_ops/prefill_attention.py @@ -0,0 +1,267 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# Modified by SGLang team. +# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0. + +import math +from typing import Optional + +import torch + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack + +from sglang.srt.layers.attention.cute_ops.flash_fwd_sm100 import FlashAttentionForwardSm100 + + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +torch2cute_dtype_map = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, +} + + +def _flash_attn_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + softcap: Optional[float] = None, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, + learnable_sink: Optional[torch.Tensor] = None, + # m_block_size: int = 128, + # n_block_size: int = 64, + # num_threads: int = 128, + m_block_size: int = 128, + n_block_size: int = 128, + num_threads: int = 384, + pack_gqa: Optional[bool] = None, + _compute_capability: Optional[int] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + q, k, v = [maybe_contiguous(t) for t in (q, k, v)] + num_head, head_dim = q.shape[-2:] + if cu_seqlens_q is None: + batch_size, seqlen_q = q.shape[:2] + total_q = batch_size * seqlen_q + else: + batch_size = cu_seqlens_q.shape[0] - 1 + seqlen_q = None + total_q = q.shape[0] + if page_table is not None: + assert cu_seqlens_k is None, "page_table is not supported with cu_seqlens_k" + assert page_table.dtype == torch.int32, "page_table must be int32" + assert page_table.stride(-1) == 1, "page_table must be contiguous in the last dimension" + max_num_pages_per_seq = page_table.shape[1] + assert page_table.shape == (batch_size, max_num_pages_per_seq) + num_pages, page_size = k.shape[:2] + seqlen_k = num_pages * page_size + else: + num_pages, page_size = None, None + seqlen_k = k.shape[-3] + num_head_kv = k.shape[-2] + head_dim_v = v.shape[-1] + if cu_seqlens_k is None: + if page_table is None: + assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) + assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) + else: + assert k.shape == (num_pages, page_size, num_head_kv, head_dim) + assert v.shape == (num_pages, page_size, num_head_kv, head_dim_v) + else: + assert k.shape == (seqlen_k, num_head_kv, head_dim) + assert v.shape == (seqlen_k, num_head_kv, head_dim_v) + assert cu_seqlens_k.shape == (batch_size + 1,), "cu_seqlens_k must have shape (batch_size + 1,)" + if cu_seqlens_q is not None: + assert cu_seqlens_q.shape == (batch_size + 1,), "cu_seqlens_q must have shape (batch_size + 1,)" + assert seqused_q is None or seqused_q.shape == (batch_size,), "seqused_q must have shape (batch_size,)" + assert seqused_k is None or seqused_k.shape == (batch_size,), "seqused_k must have shape (batch_size,)" + assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" + assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype" + for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]: + if t is not None: + assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32" + assert t.stride(0) == 1, "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous" + if learnable_sink is not None: + assert learnable_sink.shape == (num_head,) + assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" + assert all(t is None or t.is_cuda for t in (q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, page_table, learnable_sink)), "inputs must be on CUDA device" + assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" + assert head_dim <= 256, "head_dim must be less than or equal to 256" + alignment = 16 // q.element_size() + assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}" + assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" + if softmax_scale is None: + softmax_scale = 1.0 / math.sqrt(head_dim) + if softcap == 0.0: + softcap = None + qhead_per_kvhead = num_head // num_head_kv + if pack_gqa is None: + pack_gqa = qhead_per_kvhead > 1 + + out_torch_dtype = q.dtype + device = q.device + q_batch_seqlen_shape = (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,) + out = torch.empty(*q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device) + lse_shape = (batch_size, num_head, seqlen_q) if cu_seqlens_q is None else (num_head, total_q) + requires_grad = q.requires_grad or k.requires_grad or v.requires_grad + lse = torch.empty(lse_shape, dtype=torch.float32, device=device) if requires_grad else None + + dtype = torch2cute_dtype_map[q.dtype] + q_tensor, k_tensor, v_tensor, o_tensor = [ + from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) + for t in (q, k, v, out) + ] + lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) if lse is not None else None + cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, additive_sink_tensor = [ + from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) if t is not None else None + for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink) + ] + page_table_tensor = from_dlpack(page_table.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=1) if page_table is not None else None + if causal: + window_size_right = 0 + local = window_size_left is not None or window_size_right is not None + if window_size_left is not None or window_size_right is not None: + if window_size_left is None and window_size_right == 0: + causal, local = True, False + else: + causal, local = False, True + compute_capability = torch.cuda.get_device_capability()[0] if _compute_capability is None else _compute_capability + assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + if compute_capability == 9: # TODO: tune block size according to hdim + if not causal and not local: + n_block_size = 192 + if compute_capability == 10: + # TODO: fix the varlen case + if pack_gqa and (128 % qhead_per_kvhead != 0) or (cu_seqlens_q is not None or seqused_q is not None): + pack_gqa = False + + compile_key = ( + dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap is not None, + lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, + page_table is not None, + window_size_left is not None, window_size_right is not None, + learnable_sink is not None, + m_block_size, n_block_size, num_threads, pack_gqa, + compute_capability, + ) + if compile_key not in _flash_attn_fwd.compile_cache: + if compute_capability == 9: + assert page_table is None, "paged KV not supported on SM 9.0" + assert learnable_sink is None, "Sm90 doesn't support additive sink" + # fa_fwd = FlashAttentionForwardSm80( + fa_fwd = FlashAttentionForwardSm90( + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + is_causal=causal, + is_local=local, + pack_gqa=pack_gqa, + m_block_size=m_block_size, + n_block_size=n_block_size, + # num_stages=1, + num_stages=2, + num_threads=num_threads, + Q_in_regs=False, + ) + elif compute_capability == 10: + assert page_size in [None, 128], "Only page_size=128 is supported for paged KV on SM 10.0" + fa_fwd = FlashAttentionForwardSm100( + head_dim, + head_dim_v, + qhead_per_kvhead=qhead_per_kvhead, + is_causal=causal, + is_local=local, + pack_gqa=pack_gqa, + is_persistent=not causal and not local and cu_seqlens_q is None and seqused_q is None, + ) + else: + raise ValueError(f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x") + # TODO: check @can_implement + _flash_attn_fwd.compile_cache[compile_key] = cute.compile( + fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, + cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, + page_table_tensor, + softcap, window_size_left, window_size_right, additive_sink_tensor, + ) + _flash_attn_fwd.compile_cache[compile_key]( + q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, + cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, + page_table_tensor, + softcap, window_size_left, window_size_right, additive_sink_tensor, + ) + return out, lse + + +_flash_attn_fwd.compile_cache = {} + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = True, + window_size: tuple[Optional[int], Optional[int]] = (None, None), + learnable_sink: Optional[torch.Tensor] = None, + softcap: float = 0.0, + pack_gqa: Optional[bool] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Entry point to Flash Attention. + + Args: + q: + k: + v: + cu_seqlens_q: `(batch_size+1,)` cumsum of seqlens_q, starting with 0. + cu_seqlens_k: `(batch_size+1,)` cumsum of seqlens_k, starting with 0. + seqused_q: not sure what this means. + seqused_k: `(batch_size,)`. how many tokens are in the kv cache for each sequence. + page_table: `(batch_size, max_num_pages)`, list of pages for each sequence. + softmax_scale: `q[i] @ k[j].T * softmax_scale`. If `None`, uses `rsqrt(head_dim)`. + causal: just set to `True`. + window_size: `(left, right)`. Leaving `None, None` means full attention. + learnable_sink: this is actually not learnable, just a Tensor containing the sink constant. + softcap: if not `0`, each logit will become `softcap * tanh(q[i] @ k[j].T * softmax_scale / softcap)`. + pack_gqa: + """ + assert causal, "Only support causal." + + out, lse = _flash_attn_fwd( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + page_table=page_table, + softmax_scale=softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + learnable_sink=learnable_sink, + softcap=softcap, + pack_gqa=pack_gqa, + ) + return out, lse diff --git a/python/sglang/srt/layers/attention/cute_ops/prefill_attention_test.py b/python/sglang/srt/layers/attention/cute_ops/prefill_attention_test.py new file mode 100644 index 00000000000..78f8a0a2a87 --- /dev/null +++ b/python/sglang/srt/layers/attention/cute_ops/prefill_attention_test.py @@ -0,0 +1,322 @@ +"""Tests for prefill_attention.""" + +from typing import Optional +import numpy as np +import torch +import torch.nn.functional as F + +from sglang.srt.layers.attention.cute_ops.prefill_attention import ( + flash_attn_varlen_func, +) + + +def _green(x: str) -> str: + return f"\033[1;32m{x}\033[0m" + + +def _red(x: str) -> str: + return f"\033[1;31m{x}\033[0m" + + +def _yellow(x: str) -> str: + return f"\033[1;33m{x}\033[0m" + + +torch.set_printoptions(precision=3, sci_mode=False, linewidth=120) + +np.set_printoptions( + suppress=True, precision=3, linewidth=120, formatter={"float": "{:>8.3f}".format} +) + + +def _ref_impl( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = True, +) -> torch.Tensor: + head_dim = q.shape[-1] + softmax_scale = head_dim**-0.5 if softmax_scale is None else softmax_scale + + num_seqs = cu_seqlens_q.shape[0] - 1 if cu_seqlens_q is not None else 1 + + out = torch.empty_like(q) + + for i in range(num_seqs): + if cu_seqlens_q is not None: + qo_start = cu_seqlens_q[i].item() + qo_final = cu_seqlens_q[i + 1].item() + else: + qo_start = 0 + qo_final = q.shape[0] + + if cu_seqlens_k is not None: + kv_start = cu_seqlens_k[i].item() + kv_final = cu_seqlens_k[i + 1].item() + else: + kv_start = 0 + kv_final = k.shape[0] + + curr_q = q[qo_start:qo_final, :, :] + curr_k = k[kv_start:kv_final, :, :] + curr_v = v[kv_start:kv_final, :, :] + + qo_len = qo_final - qo_start + kv_len = kv_final - kv_start + + logits = ( + torch.einsum("qhd,khd->qhk", curr_q, curr_k).to(torch.float32) + * softmax_scale + ) + + if causal: + mask = ( + torch.arange( + kv_len - qo_len, kv_len, dtype=torch.int32, device=logits.device + )[:, None] + >= torch.arange(kv_len, dtype=torch.int32, device=logits.device)[ + None, : + ] + ) + + # TODO(hieu): continue here. why is the result wrong whenever qo_len, kv_len are not all the same number?? + # print(_green("--> DEBUG:"), f"{i=} {kv_start=:<4d} {kv_final=} {qo_len=} {kv_len=} {curr_q.shape=} {curr_k.shape=} {mask[:, None, :].shape=} {logits.shape=} {cu_seqlens_k=}") + + logits = torch.where( + mask[:, None, :], + logits, + torch.tensor(float("-inf"), dtype=torch.float32, device=logits.device), + ) + + scores = F.softmax(logits, dim=-1).to(curr_v.dtype) + out[qo_start:qo_final, :, :] = torch.einsum("qhv,vhd->qhd", scores, curr_v) + + return out + + +def test_ragged( + qo_lens: tuple[int, ...], + kv_lens: tuple[int, ...], + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + softmax_scale: Optional[float] = None, + causal: bool = True, + init_range: float = 0.5, + dtype: torch.dtype = torch.bfloat16, + seed: int = 31415, +): + torch.manual_seed(seed) + np.random.seed(seed) + + qo_len = sum(qo_lens) + kv_len = sum(kv_lens) + seqlens_q = torch.tensor(list(qo_lens), dtype=torch.int32, device="cuda") + seqlens_k = torch.tensor(list(qo_lens), dtype=torch.int32, device="cuda") + cu_seqlens_q = F.pad( + torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), + pad=(1, 0), + mode="constant", + value=0, + ) + cu_seqlens_k = F.pad( + torch.cumsum(seqlens_k, dim=0, dtype=torch.int32), + pad=(1, 0), + mode="constant", + value=0, + ) + + q = torch.empty( + size=(qo_len, num_qo_heads, head_dim), dtype=dtype, device="cuda" + ).uniform_(-init_range, init_range) + k = torch.empty( + size=(kv_len, num_kv_heads, head_dim), dtype=dtype, device="cuda" + ).uniform_(-init_range, init_range) + v = torch.empty( + size=(kv_len, num_kv_heads, head_dim), dtype=dtype, device="cuda" + ).uniform_(-init_range, init_range) + + out = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + causal=causal, + )[0] + + ref = _ref_impl( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + softmax_scale=softmax_scale, + ) + diff = (out - ref).abs_().max().item() + + print( + _green(f"--> {q.shape=} {k.shape=} {v.shape=}"), + f"{ref.shape=}", + f"{out.shape=}", + ) + print(_green("max_diff: "), f"{diff:<.5f}", flush=True) + + +def test_paged( + qo_lens: tuple[int, ...], + kv_lens: tuple[int, ...], + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + page_size: int = 128, + num_pages: int = 32 * 1024, + softmax_scale: Optional[float] = None, + init_range: float = 0.5, + dtype: torch.dtype = torch.bfloat16, + seed: int = 31415, +): + torch.manual_seed(seed) + np.random.seed(seed) + + assert len(qo_lens) == len(kv_lens) + assert all(qo_len <= kv_len for qo_len, kv_len in zip(qo_lens, kv_lens)) + + num_seqs = len(qo_lens) + qo_len = sum(qo_lens) + max_num_pages = (max(kv_lens) + page_size - 1) // page_size + + seqlens_q = torch.tensor(list(qo_lens), dtype=torch.int32, device="cuda") + seqlens_k = torch.tensor(list(kv_lens), dtype=torch.int32, device="cuda") + cu_seqlens_q = F.pad( + torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), + pad=(1, 0), + mode="constant", + value=0, + ) + cu_seqlens_k = F.pad( + torch.cumsum(seqlens_k, dim=0, dtype=torch.int32), + pad=(1, 0), + mode="constant", + value=0, + ) + + q = torch.empty( + size=(qo_len, num_qo_heads, head_dim), dtype=dtype, device="cuda" + ).uniform_(-init_range, init_range) + k_cache = torch.empty( + size=(num_pages + 1, page_size, num_kv_heads, head_dim), + dtype=dtype, + device="cuda", + ).uniform_(-init_range, init_range) + v_cache = torch.empty( + size=(num_pages + 1, page_size, num_kv_heads, head_dim), + dtype=dtype, + device="cuda", + ).uniform_(-init_range, init_range) + + page_table = np.random.randint( + size=[num_seqs, max_num_pages], low=1, high=num_pages + 1, dtype=np.int32 + ) + page_table = torch.tensor(page_table, dtype=torch.int32, device="cuda") + + page_table = torch.where( + torch.arange(max_num_pages, dtype=torch.int32, device="cuda")[None, :] + < ( + torch.tensor(list(kv_lens), dtype=torch.int32, device="cuda")[:, None] + + page_size + - 1 + ) + // page_size, + page_table, + 0, + ) + + out = flash_attn_varlen_func( + q=q, + k=k_cache, + v=v_cache, + cu_seqlens_q=cu_seqlens_q, + seqused_k=torch.tensor(list(seqlens_k), dtype=torch.int32, device="cuda"), + page_table=page_table, + )[0] + + def _extract_kv(cache: torch.Tensor, should_print: bool = False): + out = [] + for i, kv_len in enumerate(kv_lens): + out.append( + cache[page_table[i], :, :, :].reshape(-1, num_kv_heads, head_dim)[ + :kv_len + ] + ) + if should_print: + print(_red(f"--> DEBUG: {i=} "), page_table[i], flush=True) + return torch.concat(out, axis=0) + + k = _extract_kv(k_cache) + v = _extract_kv(v_cache) + ref = _ref_impl( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + softmax_scale=softmax_scale, + ) + + out = torch.split(out, list(qo_lens), dim=0) + ref = torch.split(ref, list(qo_lens), dim=0) + + okay = True + for i in range(len(qo_lens)): + max_diff = (out[i] - ref[i]).abs_().max().item() + print(_yellow(f"max_diff_{i}: "), f"{max_diff:<.5f}", flush=True) + if max_diff > 0.02: + okay = False + + assert okay + + +if __name__ == "__main__": + test_ragged( + qo_lens=(8,), + kv_lens=(8,), + num_qo_heads=1, + num_kv_heads=1, + head_dim=128, + softmax_scale=None, + ) + + test_ragged( + qo_lens=(11, 12, 32), + kv_lens=(256, 128, 64), + num_qo_heads=4, + num_kv_heads=4, + head_dim=128, + softmax_scale=None, + ) + + test_paged( + qo_lens=(8, 17), + kv_lens=(11, 19), + num_qo_heads=2, + num_kv_heads=2, + num_pages=32, + page_size=128, + head_dim=128, + softmax_scale=None, + ) + + test_paged( + qo_lens=(21, 12, 11, 19), + kv_lens=(33, 71, 18, 31), + num_qo_heads=4, + num_kv_heads=4, + num_pages=32, + page_size=128, + head_dim=128, + softmax_scale=None, + ) diff --git a/python/sglang/srt/layers/attention/cute_ops/seqlen_info.py b/python/sglang/srt/layers/attention/cute_ops/seqlen_info.py new file mode 100644 index 00000000000..3a9ca4dbd80 --- /dev/null +++ b/python/sglang/srt/layers/attention/cute_ops/seqlen_info.py @@ -0,0 +1,41 @@ +from typing import Optional + +import cutlass +import cutlass.cute as cute + + +class SeqlenInfo: + def __init__( + self, + batch_idx: cutlass.Int32, + seqlen_q_static: cutlass.Int32, + seqlen_k_static: cutlass.Int32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + ): + self.offset_q = ( + 0 if cutlass.const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] + ) + self.offset_k = ( + 0 if cutlass.const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx] + ) + if cutlass.const_expr(mSeqUsedQ is not None): + self.seqlen_q = mSeqUsedQ[batch_idx] + else: + self.seqlen_q = ( + seqlen_q_static + if cutlass.const_expr(mCuSeqlensQ is None) + else mCuSeqlensQ[batch_idx + 1] - self.offset_q + ) + if cutlass.const_expr(mSeqUsedK is not None): + self.seqlen_k = mSeqUsedK[batch_idx] + else: + self.seqlen_k = ( + seqlen_k_static + if cutlass.const_expr(mCuSeqlensK is None) + else mCuSeqlensK[batch_idx + 1] - self.offset_k + ) + self.has_cu_seqlens_q: int = mCuSeqlensQ is not None + self.has_cu_seqlens_k: int = mCuSeqlensK is not None diff --git a/python/sglang/srt/layers/attention/cute_ops/softmax.py b/python/sglang/srt/layers/attention/cute_ops/softmax.py new file mode 100644 index 00000000000..34744351bcb --- /dev/null +++ b/python/sglang/srt/layers/attention/cute_ops/softmax.py @@ -0,0 +1,285 @@ +# Copyright (c) 2025, Tri Dao. + +import math +import operator +from typing import Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Float32 + +from sglang.srt.layers.attention.cute_ops import utils + + +class Softmax: + def __init__( + self, + scale_log2: Float32, + num_rows: cutlass.Constexpr[int], + arch: cutlass.Constexpr[int] = 80, + ): + self.scale_log2 = scale_log2 + self.row_max = cute.make_fragment(num_rows, Float32) + self.row_sum = cute.make_fragment_like(self.row_max) + self.arch = arch + + def reset(self) -> None: + self.row_max.fill(-Float32.inf) + self.row_sum.fill(0.0) + + def _compute_row_max( + self, acc_S_row: cute.TensorSSA, init_val: float | Float32 | None = None + ) -> Float32: + return utils.fmax_reduce(acc_S_row, init_val, arch=self.arch) + + def _compute_row_sum( + self, acc_S_row_exp: cute.TensorSSA, init_val: float | Float32 | None = None + ) -> Float32: + return utils.fadd_reduce(acc_S_row_exp, init_val, arch=self.arch) + + @cute.jit + def online_softmax( + self, + acc_S: cute.Tensor, + is_first: cutlass.Constexpr[bool] = False, + check_inf: cutlass.Constexpr[bool] = True, + ) -> cute.Tensor: + """Apply online softmax and return the row_scale to rescale O. + + :param acc_S: acc_S tensor + :type acc_S: cute.Tensor + :param is_first: is first n_block + :type is_first: cutlass.Constexpr + """ + # Change acc_S to M,N layout view. + acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) + row_scale = cute.make_fragment_like(self.row_max, Float32) + # Each iteration processes one row of acc_S + for r in cutlass.range(cute.size(self.row_max), unroll_full=True): + acc_S_row = acc_S_mn[r, None].load() # (n_block_size) + row_max_cur = self._compute_row_max( + acc_S_row, + init_val=self.row_max[r] if cutlass.const_expr(not is_first) else None, + ) + row_max_cur = utils.warp_reduce(row_max_cur, cute.arch.fmax, width=4) + if cutlass.const_expr(check_inf): + row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur + if cutlass.const_expr(is_first): + row_max_cur_scaled = row_max_cur * self.scale_log2 + acc_S_row_exp = utils.exp2f( + acc_S_row * self.scale_log2 - row_max_cur_scaled + ) + acc_S_row_sum = self._compute_row_sum(acc_S_row_exp) + row_scale[r] = 1.0 + else: + row_max_prev = self.row_max[r] + row_max_cur_scaled = row_max_cur * self.scale_log2 + acc_S_row_exp = utils.exp2f( + acc_S_row * self.scale_log2 - row_max_cur_scaled + ) + # row_scale[r] = utils.exp2f(row_max_prev * self.scale_log2 - row_max_cur_scaled) + row_scale[r] = utils.exp2f( + (row_max_prev - row_max_cur) * self.scale_log2 + ) + acc_S_row_sum = self._compute_row_sum( + acc_S_row_exp, init_val=self.row_sum[r] * row_scale[r] + ) + self.row_max[r] = row_max_cur + self.row_sum[r] = acc_S_row_sum + acc_S_mn[r, None].store(acc_S_row_exp) + return row_scale + + @cute.jit + def finalize(self, final_scale: Float32 = 1.0) -> cute.Tensor: + """Finalize the online softmax by computing the scale and logsumexp.""" + # quad reduction for row_sum as we didn't do it during each iteration of online softmax + self.row_sum.store( + utils.warp_reduce(self.row_sum.load(), operator.add, width=4) + ) + row_scale = cute.make_fragment_like(self.row_max, Float32) + for r in cutlass.range(cute.size(self.row_sum), unroll_full=True): + # if row_sum is zero or nan, set acc_O_mn_row to 1.0 + acc_O_mn_row_is_zero_or_nan = ( + self.row_sum[r] == 0.0 or self.row_sum[r] != self.row_sum[r] + ) + row_scale[r] = ( + cute.arch.rcp_approx( + self.row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0 + ) + ) * final_scale + row_sum_cur = self.row_sum[r] + LN2 = math.log(2.0) + self.row_sum[r] = ( + (self.row_max[r] * self.scale_log2 + utils.log2f(row_sum_cur)) * LN2 + if not acc_O_mn_row_is_zero_or_nan + else -Float32.inf + ) + return row_scale + + @cute.jit + def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None: + """Scale each row of acc_O by the given scale tensor. + :param acc_O: input tensor + :type acc_O: cute.Tensor + :param row_scale: row_scale tensor + :type row_scale: cute.Tensor + """ + acc_O_mn = utils.make_acc_tensor_mn_view(acc_O) + assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0]) + for r in cutlass.range(cute.size(row_scale), unroll_full=True): + acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r]) + + +class SoftmaxSm100(Softmax): + def __init__( + self, scale_log2: Float32, rescale_threshold: cutlass.Constexpr[float] = 0.0 + ): + super().__init__(scale_log2, num_rows=1, arch=100) + self.rescale_threshold = rescale_threshold + + @cute.jit + def update_row_max( + self, acc_S_row: cute.TensorSSA, is_first: int + ) -> Tuple[Float32, Float32]: + if cutlass.const_expr(is_first): + row_max_new = self._compute_row_max(acc_S_row) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale = 0.0 + else: + row_max_old = self.row_max[0] + row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 + acc_scale = utils.exp2f(acc_scale_) + if cutlass.const_expr(self.rescale_threshold > 0.0): + if acc_scale_ >= -self.rescale_threshold: + row_max_new = row_max_old + row_max_safe = row_max_old + acc_scale = 1.0 + self.row_max[0] = row_max_new + return row_max_safe, acc_scale + + def update_row_sum( + self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32, is_first: int = False + ) -> None: + init_val = ( + self.row_sum[0] * row_scale if cutlass.const_expr(not is_first) else None + ) + # self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=self.row_sum[0] * row_scale) + self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=init_val) + # tmp = self._compute_row_sum(acc_S_row_exp) + # self.row_sum[0] = self.row_sum[0] * row_scale + tmp + + @cute.jit + def scale_subtract_rowmax( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + ): + assert ( + cute.size(acc_S_row.shape) % 2 == 0 + ), "acc_S_row must have an even number of elements" + row_max_scaled = row_max * self.scale_log2 + for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True): + acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( + (acc_S_row[i], acc_S_row[i + 1]), + (self.scale_log2, self.scale_log2), + (-row_max_scaled, -row_max_scaled), + ) + + @cute.jit + def apply_exp2_convert( + self, + acc_S_row: cute.Tensor, + acc_S_row_converted: cute.Tensor, + e2e: cutlass.Constexpr[bool] = False, + e2e_freq: cutlass.Constexpr[int] = 16, + e2e_res: cutlass.Constexpr[int] = 4, + e2e_frg_limit: cutlass.Constexpr[int] = 1, + ): + assert ( + cute.size(acc_S_row.shape) % 2 == 0 + ), "acc_S_row must have an even number of elements" + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): + # acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j]) + # acc_S_row_frg[k + 1, j] = utils.exp2f(acc_S_row_frg[k + 1, j]) + if cutlass.const_expr(not e2e): + acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) + acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j]) + else: + if cutlass.const_expr( + k % e2e_freq < e2e_freq - e2e_res + or j >= frg_cnt - e2e_frg_limit + ): + acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) + acc_S_row_frg[k + 1, j] = cute.arch.exp2( + acc_S_row_frg[k + 1, j] + ) + else: + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.e2e_asm2( + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] + ) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) + + @cute.jit + def scale_apply_exp2_convert( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + acc_S_row_converted: cute.Tensor, + ): + assert ( + cute.size(acc_S_row.shape) % 2 == 0 + ), "acc_S_row must have an even number of elements" + minus_row_max_scaled = -row_max * self.scale_log2 + for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): + acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( + (acc_S_row[i], acc_S_row[i + 1]), + (self.scale_log2, self.scale_log2), + (minus_row_max_scaled, minus_row_max_scaled), + ) + + # for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): + # acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( + # (acc_S_row[i], acc_S_row[i + 1]), + # (self.scale_log2, self.scale_log2), + # (minus_row_max_scaled, minus_row_max_scaled), + # ) + # acc_S_row[i] = cute.arch.exp2(acc_S_row[i]) + # acc_S_row[i + 1] = cute.arch.exp2(acc_S_row[i + 1]) + + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): + # acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = ( + # cute.arch.fma_packed_f32x2( + # (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]), + # (self.scale_log2, self.scale_log2), + # (minus_row_max_scaled, minus_row_max_scaled), + # ) + # ) + # acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j]) + # acc_S_row_frg[k + 1, j] = utils.exp2f(acc_S_row_frg[k + 1, j]) + acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) + acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j]) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) diff --git a/python/sglang/srt/layers/attention/cute_ops/tile_scheduler.py b/python/sglang/srt/layers/attention/cute_ops/tile_scheduler.py new file mode 100644 index 00000000000..c95a55b6a24 --- /dev/null +++ b/python/sglang/srt/layers/attention/cute_ops/tile_scheduler.py @@ -0,0 +1,696 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Optional, Tuple +from dataclasses import dataclass, fields + +import cutlass +import cutlass.cute as cute +from cutlass import Int32 + +from sglang.srt.layers.attention.cute_ops import utils +from sglang.srt.layers.attention.cute_ops.fast_math import FastDivmod, clz + + +@dataclass +class ParamsBase: + def __extract_mlir_values__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [ + f for f in all_fields if not isinstance(f, cutlass.Constexpr) + ] + values, self._values_pos = [], [] + for obj in non_constexpr_fields: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = { + n: f for n, f in all_fields.items() if isinstance(f, cutlass.Constexpr) + } + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, cutlass.Constexpr) + } + for (name, field), n_items in zip( + non_constexpr_fields.items(), self._values_pos + ): + non_constexpr_fields[name] = cutlass.new_from_mlir_values( + field, values[:n_items] + ) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +@dataclass +class TileSchedulerArguments(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch: Int32 + seqlen_k: Int32 + headdim: Int32 + headdim_v: Int32 + total_q: Int32 + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + element_size: cutlass.Constexpr[int] = 2 + is_persistent: cutlass.Constexpr[bool] = False + lpt: cutlass.Constexpr[bool] = False + + +class SingleTileScheduler: + @dataclass + class Params(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch: Int32 + + @staticmethod + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileScheduler.Params": + return SingleTileScheduler.Params( + args.num_block, args.num_head, args.num_batch + ) + + def __init__(self, blk_coord: cute.Coord, *, loc=None, ip=None): + self._blk_coord = blk_coord + self._is_first_block = True + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> Params: + return SingleTileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create(params: Params, *, loc=None, ip=None) -> "SingleTileScheduler": + blk_coord = cute.arch.block_idx() + return SingleTileScheduler(blk_coord, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + return params.num_block, params.num_head, params.num_batch + + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + return cutlass.utils.WorkTileInfo(self._blk_coord, self._is_first_block) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + self._is_first_block = False + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self._blk_coord]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self._blk_coord], self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return SingleTileScheduler(*(tuple(obj_list)), loc=self._loc) + + +class StaticPersistentTileScheduler: + @dataclass + class Params(ParamsBase): + num_block_divmod: FastDivmod + num_head_divmod: FastDivmod + total_blocks: Int32 + + @staticmethod + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "StaticPersistentTileScheduler.Params": + total_blocks = args.num_block * args.num_head * args.num_batch + return StaticPersistentTileScheduler.Params( + FastDivmod.create(args.num_block), + FastDivmod.create(args.num_head), + total_blocks, + ) + + def __init__( + self, + num_block_divmod: FastDivmod, + num_head_divmod: FastDivmod, + total_blocks: Int32, + tile_idx: Int32, + *, + loc=None, + ip=None, + ): + self.num_block_divmod = num_block_divmod + self.num_head_divmod = num_head_divmod + self.total_blocks = total_blocks + self._tile_idx = tile_idx + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> Params: + return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create(params: Params, *, loc=None, ip=None) -> "StaticPersistentTileScheduler": + tile_idx = cute.arch.block_idx()[0] + return StaticPersistentTileScheduler( + params.num_block_divmod, + params.num_head_divmod, + params.total_blocks, + tile_idx, + loc=loc, + ip=ip, + ) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + hardware_info = cutlass.utils.HardwareInfo() + sm_count = hardware_info.get_device_multiprocessor_count() + return (cutlass.min(sm_count, params.total_blocks), Int32(1), Int32(1)) + + # @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + hn_idx, block_idx = self.num_block_divmod.divmod(self._tile_idx) + batch_idx, head_idx = self.num_head_divmod.divmod(hn_idx) + is_valid = self._tile_idx < self.total_blocks + # if cute.arch.thread_idx()[0] == 0: + # cute.printf("TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid) + return cutlass.utils.WorkTileInfo( + (Int32(block_idx), Int32(head_idx), Int32(batch_idx)), is_valid + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + self._tile_idx += cute.arch.grid_dim()[0] + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [ + self.num_block_divmod, + self.num_head_divmod, + self.total_blocks, + self._tile_idx, + ]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [ + self.num_block_divmod, + self.num_head_divmod, + self.total_blocks, + self._tile_idx, + ], + self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc) + + +class SingleTileLPTScheduler: + @dataclass + class Params(ParamsBase): + total_blocks: Int32 + num_block_divmod: FastDivmod + num_head_divmod: FastDivmod + l2_minor_divmod: FastDivmod + l2_major_divmod: FastDivmod + l2_minor_residual_divmod: FastDivmod + num_hb_quotient: Int32 + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileLPTScheduler.Params": + # cute.printf(args.num_block, args.num_head, args.num_batch, args.seqlen_k, args.headdim, args.headdim_v, args.total_q, args.tile_shape_mn, args.qhead_per_kvhead_packgqa, args.element_size) + size_one_kv_head = ( + args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size + ) + size_one_head = size_one_kv_head + size_l2 = 50 * 1024 * 1024 # 40 MB for K & V + # Swizzle is the size of each "section". Round swizzle to a power of 2 + # Need to be careful about the case where only one head will fit + # swizzle is how many heads can fit in L2 + # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head) + # Seems faster if swizzle if a power of 2 + log2_floor = lambda n: 31 - clz(n) + swizzle = ( + 1 + if size_l2 < size_one_head + else (1 << log2_floor(size_l2 // size_one_head)) + ) + # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + num_hb_quotient = (args.num_head * args.num_batch) // swizzle + num_hb_remainder = (args.num_head * args.num_batch) % swizzle + return SingleTileLPTScheduler.Params( + total_blocks=args.num_block * args.num_head * args.num_batch, + num_block_divmod=FastDivmod.create(args.num_block), + num_head_divmod=FastDivmod.create(args.num_head), + l2_minor_divmod=FastDivmod.create(swizzle), + l2_major_divmod=FastDivmod.create(swizzle * args.num_block), + l2_minor_residual_divmod=FastDivmod.create( + max(num_hb_remainder, 1) + ), # don't divide by 0 + num_hb_quotient=Int32(num_hb_quotient), + ) + + def __init__( + self, + total_blocks: Int32, + num_block_divmod: FastDivmod, + num_head_divmod: FastDivmod, + l2_minor_divmod: FastDivmod, + l2_major_divmod: FastDivmod, + l2_minor_residual_divmod: FastDivmod, + num_hb_quotient: Int32, + tile_idx: Int32, + *, + loc=None, + ip=None, + ): + self.total_blocks = total_blocks + self.num_block_divmod = num_block_divmod + self.num_head_divmod = num_head_divmod + self.l2_minor_divmod = l2_minor_divmod + self.l2_major_divmod = l2_major_divmod + self.l2_minor_residual_divmod = l2_minor_residual_divmod + self.num_hb_quotient = num_hb_quotient + self._tile_idx = tile_idx + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> Params: + return SingleTileLPTScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + @cute.jit + def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTScheduler": + tile_idx = cute.arch.block_idx()[0] + return SingleTileLPTScheduler( + params.total_blocks, + params.num_block_divmod, + params.num_head_divmod, + params.l2_minor_divmod, + params.l2_major_divmod, + params.l2_minor_residual_divmod, + params.num_hb_quotient, + tile_idx, + loc=loc, + ip=ip, + ) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + return (params.total_blocks, Int32(1), Int32(1)) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + # Implement LPT scheduling coordinate calculation + bidhb, l2_mod = self.l2_major_divmod.divmod(self._tile_idx) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + block, bidhb_residual = 0, 0 + if bidhb < self.num_hb_quotient: + block, bidhb_residual = self.l2_minor_divmod.divmod(l2_mod) + else: + block, bidhb_residual = self.l2_minor_residual_divmod.divmod(l2_mod) + bidhb_actual = bidhb * self.l2_minor_divmod.divisor + bidhb_residual + batch_idx, head_idx = self.num_head_divmod.divmod(bidhb_actual) + # Longest-processing-time-first + block = self.num_block_divmod.divisor - 1 - block + is_valid = self._tile_idx < self.total_blocks + return cutlass.utils.WorkTileInfo( + (Int32(block), Int32(head_idx), Int32(batch_idx)), is_valid + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + # Single tile scheduler - set to invalid tile_idx to indicate no more work + self._tile_idx = self.total_blocks + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [ + self.total_blocks, + self.num_block_divmod, + self.num_head_divmod, + self.l2_minor_divmod, + self.l2_major_divmod, + self.l2_minor_residual_divmod, + self.num_hb_quotient, + self._tile_idx, + ]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [ + self.total_blocks, + self.num_block_divmod, + self.num_head_divmod, + self.l2_minor_divmod, + self.l2_major_divmod, + self.l2_minor_residual_divmod, + self.num_hb_quotient, + self._tile_idx, + ], + self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return SingleTileLPTScheduler(*(tuple(obj_list)), loc=self._loc) + + +class SingleTileVarlenScheduler: + @dataclass + class Params(ParamsBase): + num_head: Int32 + num_batch: Int32 + total_q: Int32 + max_kvblock_in_l2: Int32 + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + lpt: cutlass.Constexpr[bool] = False + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileVarlenScheduler.Params": + size_l2 = 50 * 1024 * 1024 # 50 MB for K & V + max_kvblock_in_l2 = size_l2 // ( + (args.headdim + args.headdim_v) + * args.element_size + * args.tile_shape_mn[1] + ) + return SingleTileVarlenScheduler.Params( + num_head=args.num_head, + num_batch=args.num_batch, + total_q=args.total_q, + max_kvblock_in_l2=max_kvblock_in_l2, + tile_shape_mn=args.tile_shape_mn, + mCuSeqlensQ=args.mCuSeqlensQ, + mSeqUsedQ=args.mSeqUsedQ, + qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa, + lpt=args.lpt, + ) + + def __init__( + self, + num_head: Int32, + num_batch: Int32, + max_kvblock_in_l2: Int32, + tile_idx: Int32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + tile_shape_mn: cutlass.Constexpr[[int, int]] = (128, 128), + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1, + lpt: cutlass.Constexpr[bool] = False, + *, + loc=None, + ip=None, + ): + self.num_head = num_head + self.num_batch = num_batch + self.max_kvblock_in_l2 = max_kvblock_in_l2 + self.mCuSeqlensQ = mCuSeqlensQ + self.mSeqUsedQ = mSeqUsedQ + assert ( + self.mCuSeqlensQ is not None or self.mSeqUsedQ is not None + ), "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" + self.tile_shape_mn = tile_shape_mn + self.qhead_per_kvhead_packgqa = qhead_per_kvhead_packgqa + self.lpt = lpt + self._tile_idx = tile_idx + self._is_first_block = True + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> Params: + return SingleTileVarlenScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create(params: Params, *, loc=None, ip=None) -> "SingleTileVarlenScheduler": + tile_idx = cute.arch.block_idx()[0] + return SingleTileVarlenScheduler( + params.num_head, + params.num_batch, + params.max_kvblock_in_l2, + tile_idx, + mCuSeqlensQ=params.mCuSeqlensQ, + mSeqUsedQ=params.mSeqUsedQ, + tile_shape_mn=params.tile_shape_mn, + qhead_per_kvhead_packgqa=params.qhead_per_kvhead_packgqa, + lpt=params.lpt, + loc=loc, + ip=ip, + ) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + total_blocks_max = ( + params.total_q + params.num_batch * (params.tile_shape_mn[0] - 1) + ) // params.tile_shape_mn[0] + return (total_blocks_max * params.num_head, Int32(1), Int32(1)) + + @cute.jit + def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: + batch_idx = lane + bidb_start + if cutlass.const_expr(self.mSeqUsedQ is not None): + seqlen = Int32(0) + if batch_idx < self.num_batch: + seqlen = self.mSeqUsedQ[batch_idx] + else: + assert self.mCuSeqlensQ is not None + cur_cu_seqlen = Int32(0) + if batch_idx <= self.num_batch: + cur_cu_seqlen = self.mCuSeqlensQ[batch_idx] + next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) + seqlen = next_cu_seqlen - cur_cu_seqlen + if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + seqlen *= self.qhead_per_kvhead_packgqa + return ( + cute.ceil_div(seqlen, self.tile_shape_mn[0]) + if batch_idx < self.num_batch and lane < cute.arch.WARP_SIZE - 1 + else Int32(0) + ) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + lane_idx = cute.arch.lane_idx() + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0) + num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) + # Total number of blocks for the next 31 batches + m_blocks_in_group = cute.arch.shuffle_sync( + num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1 + ) + # Same for all lanes + group_end_tile = m_blocks_in_group * self.num_head + # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group) + block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0) + next_tile_idx = self._tile_idx + while group_end_tile <= next_tile_idx: + batch_idx += cute.arch.WARP_SIZE - 1 + if batch_idx >= self.num_batch: + batch_idx = Int32(self.num_batch) + group_end_tile = next_tile_idx + 1 + else: + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx) + num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) + m_blocks_in_group = cute.arch.shuffle_sync( + num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1 + ) + group_end_tile += m_blocks_in_group * self.num_head + is_valid = False + if batch_idx >= self.num_batch: + block, head_idx, batch_idx = Int32(0), Int32(0), Int32(self.num_batch) + else: + group_start_tile = group_end_tile - m_blocks_in_group * self.num_head + # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, batch_idx = %d", self._tile_idx, group_end_tile, num_m_blocks, batch_idx) + # The next problem to process is the first one that does not have ending tile position + # that is greater than or equal to tile index. + batch_idx_in_group = cute.arch.popc( + cute.arch.vote_ballot_sync( + group_start_tile + num_m_blocks_cumulative * self.num_head + <= next_tile_idx + ) + ) + batch_idx += batch_idx_in_group + num_m_blocks_prev_lane = ( + 0 + if batch_idx_in_group == 0 + else cute.arch.shuffle_sync( + num_m_blocks_cumulative, batch_idx_in_group - 1 + ) + ) + num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) + mh_block = ( + next_tile_idx + - group_start_tile + - num_m_blocks_prev_lane * self.num_head + ) + if cutlass.const_expr(self.lpt): + # This is a version of the SingleTileLPTScheduler, complicated by the fact that + # the seqlen can vary per batch. + # TODO: is there any case where num_m_blocks is 0? + # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here + num_n_blocks = ( + num_m_blocks * self.tile_shape_mn[0] // self.tile_shape_mn[1] + ) + # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head) + # Seems faster to have this be a power of 2 + nheads_in_l2 = ( + 16 + if num_n_blocks * 16 <= self.max_kvblock_in_l2 + else ( + 8 + if num_n_blocks * 8 <= self.max_kvblock_in_l2 + else ( + 4 + if num_n_blocks * 4 <= self.max_kvblock_in_l2 + else ( + 2 if num_n_blocks * 2 <= self.max_kvblock_in_l2 else 1 + ) + ) + ) + ) + nheads_in_l2 = min(nheads_in_l2, self.num_head) + mh_in_l2 = nheads_in_l2 * num_m_blocks + section_idx = mh_block // mh_in_l2 + l2_mod = mh_block - section_idx * mh_in_l2 + # Deal with tail section + nheads_in_this_section = ( + nheads_in_l2 + if nheads_in_l2 * (section_idx + 1) <= self.num_head + else self.num_head - section_idx * nheads_in_l2 + ) + block = l2_mod // nheads_in_this_section + head_idx_residual = l2_mod - block * nheads_in_this_section + head_idx = section_idx * nheads_in_l2 + head_idx_residual + block = num_m_blocks - 1 - block + else: + head_idx = mh_block // num_m_blocks + block = mh_block - head_idx * num_m_blocks + is_valid = self._is_first_block and batch_idx < self.num_batch + # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid) + return cutlass.utils.WorkTileInfo( + (Int32(block), Int32(head_idx), Int32(batch_idx)), is_valid + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + # Single tile scheduler - set to invalid tile_idx to indicate no more work + self._is_first_block = False + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [ + self.num_head, + self.num_batch, + self.max_kvblock_in_l2, + self._tile_idx, + self.mCuSeqlensQ, + self.mSeqUsedQ, + ]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [ + self.num_head, + self.num_batch, + self.max_kvblock_in_l2, + self._tile_idx, + self.mCuSeqlensQ, + self.mSeqUsedQ, + ], + self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return SingleTileVarlenScheduler( + *(tuple(obj_list)), + tile_shape_mn=self.tile_shape_mn, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead_packgqa, + lpt=self.lpt, + loc=self._loc, + ) diff --git a/python/sglang/srt/layers/attention/cute_ops/utils.py b/python/sglang/srt/layers/attention/cute_ops/utils.py new file mode 100644 index 00000000000..3ebf06a6b9a --- /dev/null +++ b/python/sglang/srt/layers/attention/cute_ops/utils.py @@ -0,0 +1,597 @@ +# Copyright (c) 2025, Tri Dao. + +import math +from typing import Type, Callable, Optional, Tuple + +import cutlass +import cutlass.cute as cute + +from cutlass import Float32, Int32 +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import nvvm, llvm, arith, vector +from cutlass.cute.runtime import from_dlpack + + +def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor: + return ( + from_dlpack(x, assumed_align=alignment) + .mark_layout_dynamic(leading_dim=leading_dim) + .mark_compact_shape_dynamic( + mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility + ) + ) + + +def make_tiled_copy_A( + copy_atom: cute.CopyAtom, + tiled_mma: cute.TiledMma, + swapAB: cutlass.Constexpr[bool] = False, +) -> cute.TiledCopy: + if cutlass.const_expr(swapAB): + return cute.make_tiled_copy_B(copy_atom, tiled_mma) + else: + return cute.make_tiled_copy_A(copy_atom, tiled_mma) + + +def make_tiled_copy_B( + copy_atom: cute.CopyAtom, + tiled_mma: cute.TiledMma, + swapAB: cutlass.Constexpr[bool] = False, +) -> cute.TiledCopy: + if cutlass.const_expr(swapAB): + return cute.make_tiled_copy_A(copy_atom, tiled_mma) + else: + return cute.make_tiled_copy_B(copy_atom, tiled_mma) + + +def mma_make_fragment_A( + smem: cute.Tensor, + thr_mma: cute.core.ThrMma, + swapAB: cutlass.Constexpr[bool] = False, +) -> cute.Tensor: + if cutlass.const_expr(swapAB): + return mma_make_fragment_B(smem, thr_mma) + else: + return thr_mma.make_fragment_A(thr_mma.partition_A(smem)) + + +def mma_make_fragment_B( + smem: cute.Tensor, + thr_mma: cute.core.ThrMma, + swapAB: cutlass.Constexpr[bool] = False, +) -> cute.Tensor: + if cutlass.const_expr(swapAB): + return mma_make_fragment_A(smem, thr_mma) + else: + return thr_mma.make_fragment_B(thr_mma.partition_B(smem)) + + +def get_smem_store_atom( + arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric] +) -> cute.CopyAtom: + if cutlass.const_expr(arch < 90): + return cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + element_type, + num_bits_per_copy=2 * element_type.width, + ) + else: + return cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=False, num_matrices=4), + element_type, + ) + + +@cute.jit +def warp_reduce( + val: cute.TensorSSA | cute.Numeric, + op: Callable, + width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, +) -> cute.TensorSSA | cute.Numeric: + if cutlass.const_expr(isinstance(val, cute.TensorSSA)): + res = cute.make_fragment(val.shape, val.dtype) + res.store(val) + for i in cutlass.range_constexpr(cute.size(val.shape)): + res[i] = warp_reduce(res[i], op, width) + return res.load() + else: + for i in cutlass.range_constexpr(int(math.log2(width))): + val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i)) + return val + + +def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout: + """ + For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...). + For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...). + """ + acc_layout_col_major = cute.make_layout(acc_layout.shape) + acc_layout_mn = cute.make_layout( + ( + (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M + ( + acc_layout_col_major.shape[0][0], + *acc_layout_col_major.shape[0][2:], + acc_layout_col_major.shape[2], + ), # MMA_N + *acc_layout_col_major.shape[3:], + ), + stride=( + ( + acc_layout_col_major.stride[0][1], + acc_layout_col_major.stride[1], + ), # MMA_M + ( + acc_layout_col_major.stride[0][0], + *acc_layout_col_major.stride[0][2:], + acc_layout_col_major.stride[2], + ), # MMA_N + *acc_layout_col_major.stride[3:], + ), + ) + return cute.composition(acc_layout, acc_layout_mn) + + +def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout)) + + +@cute.jit +def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout: + # For back to back gemm, convert layout of acc0 to gemm 1 accept layout. + # For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + # For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) + # TODO: Sm90 FP8 + if cutlass.const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90 + l = cute.logical_divide( + acc_layout, ((None, None, 2), None, None) + ) # ((2, 2, (2, N / 16)), MMA_M, MMA_N) + rA_mma_view = cute.make_layout( + ( + (l.shape[0][0], l.shape[0][1], l.shape[0][2][0]), + l.shape[1], + (l.shape[0][2][1], l.shape[2]), + ), + stride=( + (l.stride[0][0], l.stride[0][1], l.stride[0][2][0]), + l.stride[1], + (l.stride[0][2][1], l.stride[2]), + ), + ) + else: # Sm80 + # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2)) + l = cute.logical_divide(acc_layout, (None, None, 2)) + rA_mma_view = cute.make_layout( + ( + (l.shape[0], l.shape[2][0]), + l.shape[1], + l.shape[2][1], + ), + stride=( + (l.stride[0], l.stride[2][0]), + l.stride[1], + l.stride[2][1], + ), + ) + return rA_mma_view + + +def transpose_view(a: cute.Tensor) -> cute.Tensor: + """Transpose the first two dimensions of a tensor on smem.""" + shape = (a.shape[1], a.shape[0], *a.shape[2:]) + order = (1, 0, *range(2, cute.rank(a))) + return cute.composition(a, cute.make_ordered_layout(shape, order=order)) + + +@dsl_user_op +def exp2f_asm(a: float | Float32, *, loc=None, ip=None) -> Float32: + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(a).ir_value(loc=loc, ip=ip)], + "ex2.approx.ftz.f32 $0, $1;", + "=f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@cute.jit +def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32: + """exp2f calculation for both vector and scalar. + :param x: input value + :type x: cute.TensorSSA or Float32 + :return: exp2 value + :rtype: cute.TensorSSA or Float32 + """ + if cutlass.const_expr(isinstance(x, cute.TensorSSA)): + res = cute.make_fragment(x.shape, Float32) + res.store(x) + for i in cutlass.range_constexpr(cute.size(x.shape)): + res[i] = cute.arch.exp2(res[i]) + return res.load() + else: + return cute.arch.exp2(x) + + +@dsl_user_op +def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32: + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(a).ir_value(loc=loc, ip=ip)], + "lg2.approx.ftz.f32 $0, $1;", + "=f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def fmax( + a: float | Float32, + b: float | Float32, + c: float | Float32 | None = None, + *, + loc=None, + ip=None, +) -> Float32: + return Float32( + nvvm.fmax( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, + loc=loc, + ip=ip, + ) + ) + + +@cute.jit +def fmax_reduce( + x: cute.TensorSSA, + init_val: float | Float32 | None = None, + arch: cutlass.Constexpr[int] = 80, +) -> Float32: + if cutlass.const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + # if cutlass.const_expr(init_val is None): + # init_val = -cutlass.Float32.if + # return x.reduce(cute.ReductionOp.MAX, init_val, 0) + res = cute.make_fragment(x.shape, Float32) + res.store(x) + # local_max = [res[0], res[1]] + # for i in cutlass.range_constexpr(2, cute.size(x.shape), 2): + # local_max[0] = fmax(local_max[0], res[i + 0]) + # local_max[1] = fmax(local_max[1], res[i + 1]) + # local_max[0] = fmax(local_max[0], local_max[1]) + # return local_max[0] if cutlass.const_expr(init_val is None) else fmax(local_max[0], init_val) + local_max = [res[0], res[1], res[2], res[3]] + for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): + local_max[0] = fmax(local_max[0], res[i + 0]) + local_max[1] = fmax(local_max[1], res[i + 1]) + local_max[2] = fmax(local_max[2], res[i + 2]) + local_max[3] = fmax(local_max[3], res[i + 3]) + local_max[0] = fmax(local_max[0], local_max[1]) + local_max[2] = fmax(local_max[2], local_max[3]) + local_max[0] = fmax(local_max[0], local_max[2]) + return ( + local_max[0] + if cutlass.const_expr(init_val is None) + else fmax(local_max[0], init_val) + ) + else: + # [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max + # We instead force the 3-input max. + res = cute.make_fragment(x.shape, Float32) + res.store(x) + local_max = [ + ( + fmax(init_val, res[0], res[1]) + if cutlass.const_expr(init_val is not None) + else fmax(res[0], res[1]) + ), + fmax(res[2], res[3]), + fmax(res[4], res[5]), + fmax(res[6], res[7]), + ] + for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): + local_max[0] = fmax(local_max[0], res[i], res[i + 1]) + local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3]) + local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5]) + local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7]) + local_max[0] = fmax(local_max[0], local_max[1]) + return fmax(local_max[0], local_max[2], local_max[3]) + + +@cute.jit +def fadd_reduce( + x: cute.TensorSSA, + init_val: float | Float32 | None = None, + arch: cutlass.Constexpr[int] = 80, +) -> Float32: + if cutlass.const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + if cutlass.const_expr(init_val is None): + init_val = Float32.zero + return x.reduce(cute.ReductionOp.ADD, init_val, 0) + # res = cute.make_fragment(x.shape, Float32) + # res.store(x) + # local_sum = [res[0], res[1], res[2], res[3]] + # for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): + # local_sum[0] += res[i + 0] + # local_sum[1] += res[i + 1] + # local_sum[2] += res[i + 2] + # local_sum[3] += res[i + 3] + # local_sum[0] += local_sum[1] + # local_sum[2] += local_sum[3] + # local_sum[0] += local_sum[2] + # return local_sum[0] if cutlass.const_expr(init_val is None) else local_sum[0] + init_val + else: + res = cute.make_fragment(x.shape, Float32) + res.store(x) + local_sum_0 = ( + cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1])) + # cute.arch.add_packed_f32x2((init_val / 2, init_val / 2), (res[0], res[1])) + if cutlass.const_expr(init_val is not None) + else (res[0], res[1]) + ) + local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])] + for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): + local_sum[0] = cute.arch.add_packed_f32x2( + local_sum[0], (res[i + 0], res[i + 1]) + ) + local_sum[1] = cute.arch.add_packed_f32x2( + local_sum[1], (res[i + 2], res[i + 3]) + ) + local_sum[2] = cute.arch.add_packed_f32x2( + local_sum[2], (res[i + 4], res[i + 5]) + ) + local_sum[3] = cute.arch.add_packed_f32x2( + local_sum[3], (res[i + 6], res[i + 7]) + ) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1]) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3]) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2]) + return local_sum[0][0] + local_sum[0][1] + + +@dsl_user_op +def atomic_add_fp32( + a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None +) -> None: + # gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() + # # cache_hint = cutlass.Int64(0x12F0000000000000) + # llvm.inline_asm( + # None, + # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip)], + # # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()], + # "red.global.add.f32 [$0], $1;", + # # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;", + # # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;", + # "l,f", + # # "l,f,l", + # has_side_effects=True, + # is_align_stack=False, + # asm_dialect=llvm.AsmDialect.AD_ATT, + # ) + nvvm.atomicrmw( + res=T.f32(), + op=nvvm.AtomicOpKind.FADD, + ptr=gmem_ptr.llvm_ptr, + a=Float32(a).ir_value(), + ) + + +@dsl_user_op +def elem_pointer( + x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None +) -> cute.Pointer: + return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) + + +@cute.jit +def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: + # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if" + tApA = cute.make_fragment( + cute.make_layout( + ( + cute.size(tAcA, mode=[0, 1]), + cute.size(tAcA, mode=[1]), + cute.size(tAcA, mode=[2]), + ), + stride=(cute.size(tAcA, mode=[2]), 0, 1), + ), + cutlass.Boolean, + ) + for rest_v in cutlass.range_constexpr(tApA.shape[0]): + for rest_k in cutlass.range_constexpr(tApA.shape[2]): + tApA[rest_v, 0, rest_k] = cute.elem_less( + tAcA[(0, rest_v), 0, rest_k][1], limit + ) + return tApA + + +@dsl_user_op +def cp_async_mbarrier_arrive_shared( + mbar_ptr: cute.Pointer, noinc: bool = False, *, loc=None, ip=None +) -> None: + nvvm.cp_async_mbarrier_arrive_shared( + mbar_ptr.llvm_ptr, + noinc=noinc, + loc=loc, + ip=ip, + ) + + +def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32: + warp_group_idx = cute.arch.thread_idx()[0] // 128 + if cutlass.const_expr(sync): + warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx) + return warp_group_idx + + +# @dsl_user_op +# def warp_vote_any_lt(a: float | Float32, b: float | Float32, *, loc=None, ip=None) -> cutlass.Boolean: +# mask = cutlass.Int32(-1) +# return cutlass.Boolean( +# llvm.inline_asm( +# T.i32(), +# [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), mask.ir_value(loc=loc, ip=ip)], +# ".pred p1, p2;\n" +# "setp.lt.f32 p1, $1, $2;\n" +# "vote.sync.any.pred p2, p1, $3;\n" +# "selp.u32 $0, 1, 0, p2;", +# # "selp.u32 $0, 1, 0, p1;", +# "=r,f,f,r", +# has_side_effects=False, +# is_align_stack=False, +# asm_dialect=llvm.AsmDialect.AD_ATT, +# ) +# ) + + +@cute.jit +def shuffle_sync( + value: cute.Numeric, + offset: cute.typing.Int, + width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, +) -> cute.Numeric: + assert value.width % 32 == 0, "value type must be a multiple of 32 bits" + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + val = cute.make_fragment(1, type(value)) + val[0] = value + val_i32 = cute.recast_tensor(val, cutlass.Int32) + for i in cutlass.range_constexpr(cute.size(val_i32)): + val_i32[i] = cute.arch.shuffle_sync( + val_i32[i], offset, mask_and_clamp=mask_and_clamp + ) + return val[0] + + +@dsl_user_op +def shr_u32( + val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None +) -> cutlass.Uint32: + return cutlass.Uint32( + llvm.inline_asm( + T.i32(), + [ + cutlass.Uint32(val).ir_value(loc=loc, ip=ip), + cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), + ], + "shr.s32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@cute.jit +def warp_prefix_sum( + val: cutlass.Int32, lane: Optional[cutlass.Int32] = None +) -> cutlass.Int32: + if cutlass.const_expr(lane is None): + lane = cute.arch.lane_idx() + # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, val = %d", cute.arch.thread_idx()[0] % 32, val) + for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))): + offset = 1 << i + # Very important that we set mask_and_clamp to 0 + partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0) + if lane >= offset: + val += partial_sum + # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, partial_sum = %d, val = %d", cute.arch.thread_idx()[0] % 32, partial_sum, val) + return val + + +@dsl_user_op +def cvt_f16x2_f32( + a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None +) -> cutlass.Int32: + assert to_dtype in [ + cutlass.BFloat16, + cutlass.Float16, + ], "to_dtype must be BFloat16 or Float16" + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)], + f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;", + "=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@cute.jit +def cvt_f16(src: cute.Tensor, dst: cute.Tensor): + assert cute.size(dst.shape) == cute.size( + src.shape + ), "dst and src must have the same size" + assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements" + assert dst.element_type in [ + cutlass.BFloat16, + cutlass.Float16, + ], "dst must be BFloat16 or Float16" + assert src.element_type is Float32, "src must be Float32" + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape) + for i in cutlass.range_constexpr(cute.size(dst_i32)): + dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type) + + +@dsl_user_op +def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: + out_f32x2 = llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32()]), + [Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()], + "{\n\t" + ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t" + ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t" + ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t" + "max.ftz.f32 f1, $2, 0fC2FE0000;\n\t" + "max.ftz.f32 f2, $3, 0fC2FE0000;\n\t" + "mov.b64 l1, {f1, f2};\n\t" + "mov.f32 f3, 0f4B400000;\n\t" + "mov.b64 l2, {f3, f3};\n\t" + "add.rm.ftz.f32x2 l7, l1, l2;\n\t" + "sub.rn.ftz.f32x2 l8, l7, l2;\n\t" + "sub.rn.ftz.f32x2 l9, l1, l8;\n\t" + "mov.f32 f7, 0f3D9DF09D;\n\t" + "mov.b64 l6, {f7, f7};\n\t" + "mov.f32 f6, 0f3E6906A4;\n\t" + "mov.b64 l5, {f6, f6};\n\t" + "mov.f32 f5, 0f3F31F519;\n\t" + "mov.b64 l4, {f5, f5};\n\t" + "mov.f32 f4, 0f3F800000;\n\t" + "mov.b64 l3, {f4, f4};\n\t" + "fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t" + "fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t" + "fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t" + "mov.b64 {r1, r2}, l7;\n\t" + "mov.b64 {r3, r4}, l10;\n\t" + "shl.b32 r5, r1, 23;\n\t" + "add.s32 r7, r5, r3;\n\t" + "shl.b32 r6, r2, 23;\n\t" + "add.s32 r8, r6, r4;\n\t" + "mov.b32 $0, r7;\n\t" + "mov.b32 $1, r8;\n\t" + "}\n", + "=r,=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip)) + out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip)) + return out0, out1 diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5222bff0a4a..5aa49e0f51c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1537,6 +1537,16 @@ def _get_attention_backend_from_str(self, backend_str: str): ) return FlashAttentionBackend(self) + elif backend_str == "fa-cute": + assert torch.cuda.get_device_capability()[0] == 10, ( + "FlashAttention v4 Backend requires SM>=100" + "With your setup, please use `--attention-backend flashinfer` instead." + ) + from sglang.srt.layers.attention.blackwell_prefill_attention_backend import ( + BlackwellPrefillAttentionBackend, + ) + + return BlackwellPrefillAttentionBackend(self) elif backend_str == "cutlass_mla": from sglang.srt.layers.attention.cutlass_mla_backend import ( CutlassMLABackend, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0d6c794e6f4..d56ba56dd79 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -487,6 +487,15 @@ def print_deprecated_warning(message: str): self.disable_cuda_graph = True self.disable_radix_cache = True + if ( + self.attention_backend == "fa-cute" + or self.prefill_attention_backend == "fa-cute" + ): + logger.warning( + "fa-cute only supports a page_size of 128, change page_size to 128." + ) + self.page_size = 128 + # Set page size if self.page_size is None: self.page_size = 1 @@ -1330,6 +1339,7 @@ def add_cli_args(parser: argparse.ArgumentParser): # NVIDIA specific "cutlass_mla", "fa3", + "fa-cute", "flashinfer", "flashmla", "trtllm_mla", @@ -2031,7 +2041,7 @@ def check_server_args(self): if self.chunked_prefill_size > 0: assert ( self.chunked_prefill_size % self.page_size == 0 - ), "chunked_prefill_size must be divisible by page_size" + ), f"{chunked_prefill_size=} must be divisible by {page_size=}" def check_lora_server_args(self): assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive" @@ -2123,7 +2133,7 @@ def model_specific_adjustments(self): self.attention_backend = "fa3" else: self.attention_backend = "triton" - supported_backends = ["triton", "trtllm_mha", "fa3"] + supported_backends = ["triton", "trtllm_mha", "fa3", "fa-cute"] logger.info( f"Use {self.attention_backend} as attention backend for GptOssForCausalLM" )