Skip to content

[sgl-kernel] feat: Support sm120 cutlass fp8 gemm kernel #9403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
347 changes: 346 additions & 1 deletion sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,344 @@ void sm100_fp8_dispatch_shape(
const c10::optional<torch::Tensor>& bias) {
return sm100_fp8_dispatch_bias<OutType>(out, a, b, scales_a, scales_b, bias);
}

template <
typename ElementType,
typename OutElementType,
typename AccumElementType,
typename CTAShape,
typename ClusterShape,
typename MainloopScheduleType,
typename EpilogueScheduleType,
typename TileSchedulerType = void,
bool WithBias = false>
struct DeviceGemmFp8RowwiseSm120 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This struct DeviceGemmFp8RowwiseSm120 and its helper functions prepare_sm120_fp8_args and launch_sm120_fp8_scaled_mm are almost identical to their sm100 counterparts. This introduces significant code duplication.

To improve maintainability, consider refactoring this into a generic template that is parameterized by the architecture tag (e.g., cutlass::arch::Sm120).

For example, for the struct:

template <
    typename Arch,
    typename ElementType,
    // ... other template params
>
struct DeviceGemmFp8Rowwise {
  // ... generic implementation using Arch in CollectiveEpilogue and CollectiveMainloop
};

// Then define type aliases for each architecture
using DeviceGemmFp8RowwiseSm120 = DeviceGemmFp8Rowwise<cutlass::arch::Sm120, ...>;

The helper functions prepare_*_fp8_args and launch_*_fp8_scaled_mm are already generic enough and don't need to be duplicated for each architecture. You could rename them to remove the sm120 prefix and use them for all compatible architectures.

static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");
using TileShape = CTAShape;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;

using ElementComputeEpilogue = float;
using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcast<
0,
TileShape,
ElementComputeEpilogue,
ElementComputeEpilogue,
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;

using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcast<
0,
TileShape,
ElementComputeEpilogue,
ElementComputeEpilogue,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;

using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
0,
TileShape,
OutElementType,
OutElementType,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;

using Compute0 = cutlass::epilogue::fusion::
Sm90Compute<cutlass::multiplies, float, float, cutlass::FloatRoundStyle::round_to_nearest>;

using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;

using LayoutA = cutlass::layout::RowMajor;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementType>::value;

using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementType>::value;

using ElementC = void;
using LayoutC = cutlass::layout::RowMajor;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<OutElementType>::value;

using LayoutD = cutlass::layout::RowMajor;
static constexpr int AlignmentD = AlignmentC;

using Compute1MulAdd = cutlass::epilogue::fusion::
Sm90Compute<cutlass::multiply_add, OutElementType, float, cutlass::FloatRoundStyle::round_to_nearest>;
using Compute1Mul = cutlass::epilogue::fusion::
Sm90Compute<cutlass::multiplies, OutElementType, float, cutlass::FloatRoundStyle::round_to_nearest>;

using EVTCompute = typename std::conditional_t<
WithBias,
cutlass::epilogue::fusion::Sm90EVT<Compute1MulAdd, ScaleA, EVTCompute0, Bias>,
cutlass::epilogue::fusion::Sm90EVT<Compute1Mul, ScaleA, EVTCompute0>>;
using ArgumentType = typename EVTCompute::Arguments;
// MMA type
using ElementAccumulator = AccumElementType;

// Epilogue types
using ElementCompute = float;

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm120,
cutlass::arch::OpClassTensorOp,
TileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementCompute,
ElementC,
LayoutC,
AlignmentC,
OutElementType,
LayoutD,
AlignmentD,
EpilogueScheduleType,
EVTCompute>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm120,
cutlass::arch::OpClassTensorOp,
ElementType,
LayoutA,
AlignmentA,
ElementType,
LayoutB,
AlignmentB,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduleType>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
static_assert(
std::is_same_v<Descriptor, ScaleA> || std::is_same_v<Descriptor, ScaleB> || std::is_same_v<Descriptor, Bias>);
return Arguments{data_ptr};
}

public:
static ArgumentType prepare_args(
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias = std::nullopt) {
auto a_args = args_from_tensor<ScaleA, float>(a_scales);
auto b_args = args_from_tensor<ScaleB, float>(b_scales);

typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};

if constexpr (WithBias) {
auto bias_args = args_from_tensor<Bias, OutElementType>(bias.value());
return ArgumentType{a_args, evt0_args, bias_args, {}};
} else {
return ArgumentType{a_args, evt0_args, {}};
}
}
};

template <typename GemmType, bool WithBias>
typename GemmType::Gemm::Arguments prepare_sm120_fp8_args(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using Gemm = typename GemmType::Gemm;
using ElementT = typename Gemm::ElementA;
using ElementC = typename Gemm::ElementC;
using ElementOutput = typename Gemm::ElementD;
using ElementComputeEpilogue = float;
using GemmKernel = typename Gemm::GemmKernel;

using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = StrideC;
using StrideAux = StrideC;

int32_t m = a.size(0);
int32_t n = b.size(1);
int32_t k = a.size(1);

ElementT const* ptr_a = reinterpret_cast<ElementT const*>(a.data_ptr());
ElementT const* ptr_b = reinterpret_cast<ElementT const*>(b.data_ptr());

StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
StrideC stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1));
StrideAux aux_stride = stride_d;

typename GemmKernel::MainloopArguments mainloop_args{ptr_a, stride_a, ptr_b, stride_b};

typename GemmKernel::ProblemShape prob_shape = {m, n, k, 1};
cutlass::KernelHardwareInfo hw_info;
typename GemmKernel::TileSchedulerArguments scheduler = {};

auto ptr_c = static_cast<ElementOutput*>(out.data_ptr());

auto prepare_epilogue_args = [&](const c10::optional<torch::Tensor>& bias = c10::nullopt) {
if constexpr (WithBias) {
TORCH_CHECK(bias.has_value(), "Bias tensor is required but not provided.");
return typename GemmKernel::EpilogueArguments{
GemmType::prepare_args(scales_a, scales_b, bias.value()), ptr_c, stride_c, ptr_c, stride_d};
} else {
return typename GemmKernel::EpilogueArguments{
GemmType::prepare_args(scales_a, scales_b), ptr_c, stride_c, ptr_c, stride_d};
}
};

typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGemm,
prob_shape,
mainloop_args,
prepare_epilogue_args(bias),
hw_info,
scheduler};
return args;
}

template <typename Gemm, bool WithBias>
void launch_sm120_fp8_scaled_mm(
torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
auto args = prepare_sm120_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);

typename Gemm::Gemm gemm_op;
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto can_implement = gemm_op.can_implement(args);
TORCH_CHECK(can_implement == cutlass::Status::kSuccess)
auto status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess)
}

template <typename OutType>
void sm120_fp8_dispatch_bias(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using CTAShapeDefault = Shape<_256, _128, _64>;
using ClusterShapeDefault = Shape<_1, _1, _1>;

using CTAShape256 = Shape<_128, _128, _128>;
using ClusterShape256 = Shape<_1, _1, _1>;

using CTAShape64 = Shape<_128, _64, _128>;
using ClusterShape64 = Shape<_1, _1, _1>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The ClusterShape for all configurations is set to Shape<_1, _1, _1>. This effectively disables thread block clustering, which is a key performance feature on Hopper and newer architectures. This is likely to lead to suboptimal performance.

For sm100, cluster shapes like Shape<_2, _2, _1> and Shape<_2, _1, _1> are used. For sm90, shapes like Shape<_1, _8, _1> are used.

I recommend using more appropriate cluster shapes for sm120 and tuning them for optimal performance. You could start with values similar to those for sm100 as a baseline.

  using CTAShapeDefault = Shape<_256, _128, _64>;
  using ClusterShapeDefault = Shape<_2, _2, _1>;

  using CTAShape256 = Shape<_128, _128, _128>;
  using ClusterShape256 = Shape<_2, _1, _1>;

  using CTAShape64 = Shape<_128, _64, _128>;
  using ClusterShape64 = Shape<_1, _1, _1>;

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no programmatic multicast on this arch sm120. only support cluster shape <_1, _1, _1>


using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileSchedulerType = void;

using ElementInput = cutlass::float_e4m3_t;
using ElementOutput = OutType;
using AccumElementType = float;

using BiasGemmDefault = DeviceGemmFp8RowwiseSm120<
ElementInput,
ElementOutput,
AccumElementType,
CTAShapeDefault,
ClusterShapeDefault,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>;
using BiasGemm256 = DeviceGemmFp8RowwiseSm120<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape256,
ClusterShape256,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>;
using BiasGemm64 = DeviceGemmFp8RowwiseSm120<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape64,
ClusterShape64,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>;

using GemmDefault = DeviceGemmFp8RowwiseSm120<
ElementInput,
ElementOutput,
AccumElementType,
CTAShapeDefault,
ClusterShapeDefault,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>;
using Gemm256 = DeviceGemmFp8RowwiseSm120<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape256,
ClusterShape256,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>;
using Gemm64 = DeviceGemmFp8RowwiseSm120<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape64,
ClusterShape64,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>;

uint32_t const m = a.size(0);
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));

if (bias) {
if (mp2 <= 64) {
return launch_sm120_fp8_scaled_mm<BiasGemm64, true>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 256) {
return launch_sm120_fp8_scaled_mm<BiasGemm256, true>(out, a, b, scales_a, scales_b, bias);
} else {
return launch_sm120_fp8_scaled_mm<BiasGemmDefault, true>(out, a, b, scales_a, scales_b, bias);
}
} else {
if (mp2 <= 64) {
return launch_sm120_fp8_scaled_mm<Gemm64, false>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 256) {
return launch_sm120_fp8_scaled_mm<Gemm256, false>(out, a, b, scales_a, scales_b, bias);
} else {
return launch_sm120_fp8_scaled_mm<GemmDefault, false>(out, a, b, scales_a, scales_b, bias);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The dispatch logic here is less granular than for other architectures like sm100, which has a special case for mp2 <= 16. This implementation starts with mp2 <= 64. For small batch sizes (e.g., M=1), this could lead to using a suboptimal kernel configuration. Consider adding a more fine-grained dispatch for smaller M values to improve performance, similar to what's done for sm100 and sm90.

}

template <typename OutType>
void sm120_fp8_dispatch_shape(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
return sm120_fp8_dispatch_bias<OutType>(out, a, b, scales_a, scales_b, bias);
}
#endif

torch::Tensor fp8_scaled_mm(
Expand Down Expand Up @@ -1212,7 +1550,14 @@ torch::Tensor fp8_scaled_mm(
auto sm_version = getSMVersion();

#if defined CUDA_VERSION && CUDA_VERSION >= 12080
if (sm_version >= 100) {
if (sm_version >= 120) {
if (out_dtype == torch::kBFloat16) {
sm120_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
sm120_fp8_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
return out;
} else if (sm_version >= 100) {
if (out_dtype == torch::kBFloat16) {
sm100_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
Expand Down