-
Notifications
You must be signed in to change notification settings - Fork 2.7k
[sgl-kernel] feat: Support sm120 cutlass fp8 gemm kernel #9403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1168,6 +1168,344 @@ void sm100_fp8_dispatch_shape( | |
const c10::optional<torch::Tensor>& bias) { | ||
return sm100_fp8_dispatch_bias<OutType>(out, a, b, scales_a, scales_b, bias); | ||
} | ||
|
||
template < | ||
typename ElementType, | ||
typename OutElementType, | ||
typename AccumElementType, | ||
typename CTAShape, | ||
typename ClusterShape, | ||
typename MainloopScheduleType, | ||
typename EpilogueScheduleType, | ||
typename TileSchedulerType = void, | ||
bool WithBias = false> | ||
struct DeviceGemmFp8RowwiseSm120 { | ||
static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)"); | ||
using TileShape = CTAShape; | ||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch; | ||
|
||
using ElementComputeEpilogue = float; | ||
using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcast< | ||
0, | ||
TileShape, | ||
ElementComputeEpilogue, | ||
ElementComputeEpilogue, | ||
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>; | ||
|
||
using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcast< | ||
0, | ||
TileShape, | ||
ElementComputeEpilogue, | ||
ElementComputeEpilogue, | ||
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>; | ||
|
||
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast< | ||
0, | ||
TileShape, | ||
OutElementType, | ||
OutElementType, | ||
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>; | ||
|
||
using Compute0 = cutlass::epilogue::fusion:: | ||
Sm90Compute<cutlass::multiplies, float, float, cutlass::FloatRoundStyle::round_to_nearest>; | ||
|
||
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>; | ||
|
||
using LayoutA = cutlass::layout::RowMajor; | ||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementType>::value; | ||
|
||
using LayoutB = cutlass::layout::ColumnMajor; | ||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementType>::value; | ||
|
||
using ElementC = void; | ||
using LayoutC = cutlass::layout::RowMajor; | ||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<OutElementType>::value; | ||
|
||
using LayoutD = cutlass::layout::RowMajor; | ||
static constexpr int AlignmentD = AlignmentC; | ||
|
||
using Compute1MulAdd = cutlass::epilogue::fusion:: | ||
Sm90Compute<cutlass::multiply_add, OutElementType, float, cutlass::FloatRoundStyle::round_to_nearest>; | ||
using Compute1Mul = cutlass::epilogue::fusion:: | ||
Sm90Compute<cutlass::multiplies, OutElementType, float, cutlass::FloatRoundStyle::round_to_nearest>; | ||
|
||
using EVTCompute = typename std::conditional_t< | ||
WithBias, | ||
cutlass::epilogue::fusion::Sm90EVT<Compute1MulAdd, ScaleA, EVTCompute0, Bias>, | ||
cutlass::epilogue::fusion::Sm90EVT<Compute1Mul, ScaleA, EVTCompute0>>; | ||
using ArgumentType = typename EVTCompute::Arguments; | ||
// MMA type | ||
using ElementAccumulator = AccumElementType; | ||
|
||
// Epilogue types | ||
using ElementCompute = float; | ||
|
||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< | ||
cutlass::arch::Sm120, | ||
cutlass::arch::OpClassTensorOp, | ||
TileShape, | ||
ClusterShape, | ||
cutlass::epilogue::collective::EpilogueTileAuto, | ||
ElementAccumulator, | ||
ElementCompute, | ||
ElementC, | ||
LayoutC, | ||
AlignmentC, | ||
OutElementType, | ||
LayoutD, | ||
AlignmentD, | ||
EpilogueScheduleType, | ||
EVTCompute>::CollectiveOp; | ||
|
||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< | ||
cutlass::arch::Sm120, | ||
cutlass::arch::OpClassTensorOp, | ||
ElementType, | ||
LayoutA, | ||
AlignmentA, | ||
ElementType, | ||
LayoutB, | ||
AlignmentB, | ||
ElementAccumulator, | ||
TileShape, | ||
ClusterShape, | ||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( | ||
sizeof(typename CollectiveEpilogue::SharedStorage))>, | ||
MainloopScheduleType>::CollectiveOp; | ||
using GemmKernel = | ||
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>; | ||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; | ||
template <typename Descriptor, typename T> | ||
static auto args_from_tensor(torch::Tensor const& tensor) { | ||
using Arguments = typename Descriptor::Arguments; | ||
auto* data_ptr = static_cast<T*>(tensor.data_ptr()); | ||
static_assert( | ||
std::is_same_v<Descriptor, ScaleA> || std::is_same_v<Descriptor, ScaleB> || std::is_same_v<Descriptor, Bias>); | ||
return Arguments{data_ptr}; | ||
} | ||
|
||
public: | ||
static ArgumentType prepare_args( | ||
torch::Tensor const& a_scales, | ||
torch::Tensor const& b_scales, | ||
std::optional<torch::Tensor> const& bias = std::nullopt) { | ||
auto a_args = args_from_tensor<ScaleA, float>(a_scales); | ||
auto b_args = args_from_tensor<ScaleB, float>(b_scales); | ||
|
||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; | ||
|
||
if constexpr (WithBias) { | ||
auto bias_args = args_from_tensor<Bias, OutElementType>(bias.value()); | ||
return ArgumentType{a_args, evt0_args, bias_args, {}}; | ||
} else { | ||
return ArgumentType{a_args, evt0_args, {}}; | ||
} | ||
} | ||
}; | ||
|
||
template <typename GemmType, bool WithBias> | ||
typename GemmType::Gemm::Arguments prepare_sm120_fp8_args( | ||
torch::Tensor& out, | ||
const torch::Tensor& a, | ||
const torch::Tensor& b, | ||
const torch::Tensor& scales_a, | ||
const torch::Tensor& scales_b, | ||
const c10::optional<torch::Tensor>& bias) { | ||
using Gemm = typename GemmType::Gemm; | ||
using ElementT = typename Gemm::ElementA; | ||
using ElementC = typename Gemm::ElementC; | ||
using ElementOutput = typename Gemm::ElementD; | ||
using ElementComputeEpilogue = float; | ||
using GemmKernel = typename Gemm::GemmKernel; | ||
|
||
using StrideA = typename Gemm::GemmKernel::StrideA; | ||
using StrideB = typename Gemm::GemmKernel::StrideB; | ||
using StrideC = typename Gemm::GemmKernel::StrideC; | ||
using StrideD = StrideC; | ||
using StrideAux = StrideC; | ||
|
||
int32_t m = a.size(0); | ||
int32_t n = b.size(1); | ||
int32_t k = a.size(1); | ||
|
||
ElementT const* ptr_a = reinterpret_cast<ElementT const*>(a.data_ptr()); | ||
ElementT const* ptr_b = reinterpret_cast<ElementT const*>(b.data_ptr()); | ||
|
||
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); | ||
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); | ||
StrideC stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); | ||
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1)); | ||
StrideAux aux_stride = stride_d; | ||
|
||
typename GemmKernel::MainloopArguments mainloop_args{ptr_a, stride_a, ptr_b, stride_b}; | ||
|
||
typename GemmKernel::ProblemShape prob_shape = {m, n, k, 1}; | ||
cutlass::KernelHardwareInfo hw_info; | ||
typename GemmKernel::TileSchedulerArguments scheduler = {}; | ||
|
||
auto ptr_c = static_cast<ElementOutput*>(out.data_ptr()); | ||
|
||
auto prepare_epilogue_args = [&](const c10::optional<torch::Tensor>& bias = c10::nullopt) { | ||
if constexpr (WithBias) { | ||
TORCH_CHECK(bias.has_value(), "Bias tensor is required but not provided."); | ||
return typename GemmKernel::EpilogueArguments{ | ||
GemmType::prepare_args(scales_a, scales_b, bias.value()), ptr_c, stride_c, ptr_c, stride_d}; | ||
} else { | ||
return typename GemmKernel::EpilogueArguments{ | ||
GemmType::prepare_args(scales_a, scales_b), ptr_c, stride_c, ptr_c, stride_d}; | ||
} | ||
}; | ||
|
||
typename GemmKernel::Arguments args{ | ||
cutlass::gemm::GemmUniversalMode::kGemm, | ||
prob_shape, | ||
mainloop_args, | ||
prepare_epilogue_args(bias), | ||
hw_info, | ||
scheduler}; | ||
return args; | ||
} | ||
|
||
template <typename Gemm, bool WithBias> | ||
void launch_sm120_fp8_scaled_mm( | ||
torch::Tensor& out, | ||
torch::Tensor const& a, | ||
torch::Tensor const& b, | ||
const torch::Tensor& scales_a, | ||
const torch::Tensor& scales_b, | ||
const c10::optional<torch::Tensor>& bias) { | ||
auto args = prepare_sm120_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias); | ||
|
||
typename Gemm::Gemm gemm_op; | ||
size_t workspace_size = gemm_op.get_workspace_size(args); | ||
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); | ||
auto workspace = torch::empty(workspace_size, workspace_options); | ||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); | ||
auto can_implement = gemm_op.can_implement(args); | ||
TORCH_CHECK(can_implement == cutlass::Status::kSuccess) | ||
auto status = gemm_op.run(args, workspace.data_ptr(), stream); | ||
TORCH_CHECK(status == cutlass::Status::kSuccess) | ||
} | ||
|
||
template <typename OutType> | ||
void sm120_fp8_dispatch_bias( | ||
torch::Tensor& out, | ||
const torch::Tensor& a, | ||
const torch::Tensor& b, | ||
const torch::Tensor& scales_a, | ||
const torch::Tensor& scales_b, | ||
const c10::optional<torch::Tensor>& bias) { | ||
using CTAShapeDefault = Shape<_256, _128, _64>; | ||
using ClusterShapeDefault = Shape<_1, _1, _1>; | ||
|
||
using CTAShape256 = Shape<_128, _128, _128>; | ||
using ClusterShape256 = Shape<_1, _1, _1>; | ||
|
||
using CTAShape64 = Shape<_128, _64, _128>; | ||
using ClusterShape64 = Shape<_1, _1, _1>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The For I recommend using more appropriate cluster shapes for
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no programmatic multicast on this arch sm120. only support cluster shape <_1, _1, _1> |
||
|
||
using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto; | ||
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto; | ||
using TileSchedulerType = void; | ||
|
||
using ElementInput = cutlass::float_e4m3_t; | ||
using ElementOutput = OutType; | ||
using AccumElementType = float; | ||
|
||
using BiasGemmDefault = DeviceGemmFp8RowwiseSm120< | ||
ElementInput, | ||
ElementOutput, | ||
AccumElementType, | ||
CTAShapeDefault, | ||
ClusterShapeDefault, | ||
MainloopScheduleType, | ||
EpilogueScheduleType, | ||
TileSchedulerType, | ||
true>; | ||
using BiasGemm256 = DeviceGemmFp8RowwiseSm120< | ||
ElementInput, | ||
ElementOutput, | ||
AccumElementType, | ||
CTAShape256, | ||
ClusterShape256, | ||
MainloopScheduleType, | ||
EpilogueScheduleType, | ||
TileSchedulerType, | ||
true>; | ||
using BiasGemm64 = DeviceGemmFp8RowwiseSm120< | ||
ElementInput, | ||
ElementOutput, | ||
AccumElementType, | ||
CTAShape64, | ||
ClusterShape64, | ||
MainloopScheduleType, | ||
EpilogueScheduleType, | ||
TileSchedulerType, | ||
true>; | ||
|
||
using GemmDefault = DeviceGemmFp8RowwiseSm120< | ||
ElementInput, | ||
ElementOutput, | ||
AccumElementType, | ||
CTAShapeDefault, | ||
ClusterShapeDefault, | ||
MainloopScheduleType, | ||
EpilogueScheduleType, | ||
TileSchedulerType, | ||
false>; | ||
using Gemm256 = DeviceGemmFp8RowwiseSm120< | ||
ElementInput, | ||
ElementOutput, | ||
AccumElementType, | ||
CTAShape256, | ||
ClusterShape256, | ||
MainloopScheduleType, | ||
EpilogueScheduleType, | ||
TileSchedulerType, | ||
false>; | ||
using Gemm64 = DeviceGemmFp8RowwiseSm120< | ||
ElementInput, | ||
ElementOutput, | ||
AccumElementType, | ||
CTAShape64, | ||
ClusterShape64, | ||
MainloopScheduleType, | ||
EpilogueScheduleType, | ||
TileSchedulerType, | ||
false>; | ||
|
||
uint32_t const m = a.size(0); | ||
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m)); | ||
|
||
if (bias) { | ||
if (mp2 <= 64) { | ||
return launch_sm120_fp8_scaled_mm<BiasGemm64, true>(out, a, b, scales_a, scales_b, bias); | ||
} else if (mp2 <= 256) { | ||
return launch_sm120_fp8_scaled_mm<BiasGemm256, true>(out, a, b, scales_a, scales_b, bias); | ||
} else { | ||
return launch_sm120_fp8_scaled_mm<BiasGemmDefault, true>(out, a, b, scales_a, scales_b, bias); | ||
} | ||
} else { | ||
if (mp2 <= 64) { | ||
return launch_sm120_fp8_scaled_mm<Gemm64, false>(out, a, b, scales_a, scales_b, bias); | ||
} else if (mp2 <= 256) { | ||
return launch_sm120_fp8_scaled_mm<Gemm256, false>(out, a, b, scales_a, scales_b, bias); | ||
} else { | ||
return launch_sm120_fp8_scaled_mm<GemmDefault, false>(out, a, b, scales_a, scales_b, bias); | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The dispatch logic here is less granular than for other architectures like |
||
} | ||
|
||
template <typename OutType> | ||
void sm120_fp8_dispatch_shape( | ||
torch::Tensor& out, | ||
const torch::Tensor& a, | ||
const torch::Tensor& b, | ||
const torch::Tensor& scales_a, | ||
const torch::Tensor& scales_b, | ||
const c10::optional<torch::Tensor>& bias) { | ||
return sm120_fp8_dispatch_bias<OutType>(out, a, b, scales_a, scales_b, bias); | ||
} | ||
#endif | ||
|
||
torch::Tensor fp8_scaled_mm( | ||
|
@@ -1212,7 +1550,14 @@ torch::Tensor fp8_scaled_mm( | |
auto sm_version = getSMVersion(); | ||
|
||
#if defined CUDA_VERSION && CUDA_VERSION >= 12080 | ||
if (sm_version >= 100) { | ||
if (sm_version >= 120) { | ||
if (out_dtype == torch::kBFloat16) { | ||
sm120_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias); | ||
} else { | ||
sm120_fp8_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias); | ||
} | ||
return out; | ||
} else if (sm_version >= 100) { | ||
if (out_dtype == torch::kBFloat16) { | ||
sm100_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias); | ||
} else { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This struct
DeviceGemmFp8RowwiseSm120
and its helper functionsprepare_sm120_fp8_args
andlaunch_sm120_fp8_scaled_mm
are almost identical to theirsm100
counterparts. This introduces significant code duplication.To improve maintainability, consider refactoring this into a generic template that is parameterized by the architecture tag (e.g.,
cutlass::arch::Sm120
).For example, for the struct:
The helper functions
prepare_*_fp8_args
andlaunch_*_fp8_scaled_mm
are already generic enough and don't need to be duplicated for each architecture. You could rename them to remove thesm120
prefix and use them for all compatible architectures.