Skip to content

【New Feature】集中式支持w4afp8 量化推理 #3456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
const paddle::Tensor &input, const paddle::Tensor &gating_output,
const paddle::optional<paddle::Tensor> &gating_correction_bias,
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
const bool group_moe, const bool topk_only_mode);
const bool group_moe, const std::string &moe_quant_type, const bool topk_only_mode);

std::vector<paddle::Tensor>
MoETopKSelectKernel(const paddle::Tensor &gating_logits,
Expand Down Expand Up @@ -192,6 +192,7 @@ paddle::Tensor MoeExpertFFNFunc(
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const paddle::optional<paddle::Tensor>& input_row_sum,
const std::string& quant_method, const bool used_in_ep_low_latency);

paddle::Tensor MoeExpertFFNWint2Func(
Expand Down Expand Up @@ -877,7 +878,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("moe_expert_dispatch", &MoeExpertDispatch, py::arg("input"),
py::arg("gating_output"), py::arg("gating_correction_bias"),
py::arg("w4a8_in_scale"), py::arg("moe_topk"), py::arg("group_moe"),
py::arg("topk_only_mode"), "moe export dispatch function");
py::arg("moe_quant_type"), py::arg("topk_only_mode"),
"moe export dispatch function");

/**
* moe/fused_moe/ep_moe_prefill_func.cu
Expand Down
16 changes: 16 additions & 0 deletions custom_ops/gpu_ops/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,19 @@ inline int GetSMVersion() {
return sm_version;

}

template <typename T>
struct SumOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
};

template<typename T, typename ReductionOp, int block_size>
__inline__ __device__ T BlockAllReduce(T val) {
typedef cub::BlockReduce<T, block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T result_broadcast;
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp());
if (threadIdx.x == 0) { result_broadcast = result; }
__syncthreads();
return result_broadcast;
}
2 changes: 1 addition & 1 deletion custom_ops/gpu_ops/moe/fused_moe_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ template <typename T, typename NvType> class MoeHelper {

initialize_moe_routing_kernelLauncher<T>::run(
input_activations, permuted_data_, permuted_rows_, nullptr, nullptr,
expanded_source_row_to_expanded_dest_row, num_rows, num_rows,
expanded_source_row_to_expanded_dest_row, nullptr, num_rows, num_rows,
hidden_size, k, stream);

const int64_t expanded_active_expert_rows = k * num_rows;
Expand Down
134 changes: 87 additions & 47 deletions custom_ops/gpu_ops/moe/fused_moe_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -959,67 +959,88 @@ static void run(const T* input,
// to row 0 in the original matrix. Thus, to know where to read in the source
// matrix, we simply take the modulus of the expanded index.

template <typename T, int VecSize, typename OutT=T>
template <typename T, int VecSize, typename OutT=T, int kNumThreads=256>
__global__ void initialize_moe_routing_kernel(
const T* unpermuted_input,
OutT* permuted_output,
const int* expanded_dest_row_to_expanded_source_row,
const int *expert_idx_per_token,
const float *w4a8_in_scale,
int* expanded_source_row_to_expanded_dest_row,
float *input_row_sum,
const int64_t num_rows,
const int64_t active_rows,
const int64_t cols,
const int64_t num_rows_k) {
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;

// Reverse permutation map.
// I do this so that later, we can use the source -> dest map to do the k-way
// reduction and unpermuting. I need the reverse map for that reduction to
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
// thread block will be responsible for all k summations.
const int expanded_dest_row = blockIdx.x + blockIdx.y * gridDim.x;
if (expanded_dest_row >= num_rows_k) return;
const int expanded_source_row =
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
if (threadIdx.x == 0) {
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
expanded_dest_row;
}
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;

// Reverse permutation map.
// I do this so that later, we can use the source -> dest map to do the k-way
// reduction and unpermuting. I need the reverse map for that reduction to
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
// thread block will be responsible for all k summations.
const int expanded_dest_row = blockIdx.x + blockIdx.y * gridDim.x;
if (expanded_dest_row >= num_rows_k) return;
const int expanded_source_row =
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
if (threadIdx.x == 0) {
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
expanded_dest_row;
}

if (expanded_dest_row < active_rows) {

const int expert_idx = expert_idx_per_token[expanded_dest_row];
const float scale = w4a8_in_scale ? w4a8_in_scale[expert_idx] : -1;
const int source_row = expanded_source_row % num_rows;

const T* source_row_ptr = unpermuted_input + source_row * cols;
OutT *dest_row_ptr = permuted_output + expanded_dest_row * cols;

for (int tid = threadIdx.x * VecSize; tid < cols;
tid += blockDim.x * VecSize) {
// dest_row_ptr[tid] = source_row_ptr[tid];
Load<T, VecSize>(&source_row_ptr[tid], &src_vec);

if constexpr (std::is_same<OutT, int8_t>::value) {
using StoreT = AlignedVector<OutT, VecSize>;
StoreT dest_vec;
const float max_bound = 127.f;
const float min_bound = -127.f;
for (int j = 0; j < VecSize; j++) {
float quant_value =
max_bound * scale * static_cast<float>(src_vec[j]);
quant_value = quant_value > max_bound ? max_bound : quant_value;
quant_value = quant_value < min_bound ? min_bound : quant_value;
dest_vec[j] = static_cast<int8_t>(round(quant_value));
if (expanded_dest_row < active_rows) {
float local_sum = 0.0f;
const int expert_idx = expert_idx_per_token[expanded_dest_row];
const float scale = w4a8_in_scale ? w4a8_in_scale[expert_idx] : -1.0f;
const int source_row = expanded_source_row % num_rows;

const T* source_row_ptr = unpermuted_input + source_row * cols;
OutT *dest_row_ptr = permuted_output + expanded_dest_row * cols;

for (int tid = threadIdx.x * VecSize; tid < cols; tid += blockDim.x * VecSize) {
// dest_row_ptr[tid] = source_row_ptr[tid];
Load<T, VecSize>(&source_row_ptr[tid], &src_vec);

if constexpr (std::is_same<OutT, int8_t>::value) {
using StoreT = AlignedVector<OutT, VecSize>;
StoreT dest_vec;
const float max_bound = 127.f;
const float min_bound = -127.f;
for (int j = 0; j < VecSize; j++) {
float quant_value = max_bound * scale * static_cast<float>(src_vec[j]);
quant_value = quant_value > max_bound ? max_bound : quant_value;
quant_value = quant_value < min_bound ? min_bound : quant_value;
dest_vec[j] = static_cast<int8_t>(round(quant_value));
}
Store<OutT, VecSize>(dest_vec, &dest_row_ptr[tid]);
} else if constexpr (std::is_same<OutT, phi::dtype::float8_e4m3fn>::value) {
using StoreT = AlignedVector<OutT, VecSize>;
StoreT dest_vec;
const float max_bound = 448.f;
const float min_bound = -448.f;
for (int j = 0; j < VecSize; j++) {
float quant_value =
max_bound * scale * static_cast<float>(src_vec[j]);
quant_value = quant_value > max_bound ? max_bound : quant_value;
quant_value = quant_value < min_bound ? min_bound : quant_value;
dest_vec[j] = static_cast<phi::dtype::float8_e4m3fn>(quant_value);
local_sum += static_cast<float>(dest_vec[j]);
}
Store<phi::dtype::float8_e4m3fn, VecSize>(dest_vec, &dest_row_ptr[tid]);
} else {
Store<T, VecSize>(src_vec, &dest_row_ptr[tid]);
}
}

if constexpr (std::is_same<OutT, phi::dtype::float8_e4m3fn>::value) {
local_sum = BlockAllReduce<float, SumOp<float>, kNumThreads>(local_sum);
constexpr float sum_scale = -7.0f / 512.0f;
if (threadIdx.x == 0) {
input_row_sum[expanded_dest_row] = local_sum * sum_scale;
}
}
Store<OutT, VecSize>(dest_vec, &dest_row_ptr[tid]);
} else {
Store<T, VecSize>(src_vec, &dest_row_ptr[tid]);
}
}
}
}

template <typename T, typename OutT = T>
Expand All @@ -1032,6 +1053,7 @@ static void run(
const int *expert_idx_per_token,
const float *w4a8_in_scale,
int* expanded_source_row_to_expanded_dest_row,
float * input_row_sum,
const int64_t num_rows,
const int64_t active_rows,
const int64_t cols,
Expand All @@ -1040,6 +1062,22 @@ static void run(
const int threads = std::min(cols, int64_t(1024));
constexpr int max_pack_size = 16 / sizeof(T);
const auto config_initialize = Get1DBlocksAnd2DGridsMoe(num_rows * k);
if (std::is_same<OutT, phi::dtype::float8_e4m3fn>::value) {
initialize_moe_routing_kernel<T, max_pack_size>
<<<config_initialize.block_per_grid, 256, 0, stream>>>(
unpermuted_input,
permuted_output,
expanded_dest_row_to_expanded_source_row,
expert_idx_per_token,
w4a8_in_scale,
expanded_source_row_to_expanded_dest_row,
input_row_sum,
num_rows,
k * active_rows,
cols,
num_rows * k);
return;
}
if (cols % max_pack_size == 0) {
initialize_moe_routing_kernel<T, max_pack_size>
<<<config_initialize.block_per_grid, threads, 0, stream>>>(
Expand All @@ -1049,6 +1087,7 @@ static void run(
expert_idx_per_token,
w4a8_in_scale,
expanded_source_row_to_expanded_dest_row,
input_row_sum,
num_rows,
k * active_rows,
cols,
Expand All @@ -1062,6 +1101,7 @@ static void run(
expert_idx_per_token,
w4a8_in_scale,
expanded_source_row_to_expanded_dest_row,
input_row_sum,
num_rows,
k * active_rows,
cols,
Expand Down
46 changes: 34 additions & 12 deletions custom_ops/gpu_ops/moe/moe_dispatch.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ void MoeDispatchKernel(
const int hidden_size, const int expert_num, paddle::Tensor *permute_input,
paddle::Tensor *tokens_expert_prefix_sum,
paddle::Tensor *permute_indices_per_token, paddle::Tensor *topk_weight,
paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token) {
paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token,
paddle::Tensor *input_row_sum) {
using namespace phi;

typedef PDTraits<T> traits_;
Expand Down Expand Up @@ -113,16 +114,27 @@ void MoeDispatchKernel(
permuted_rows_, moe_topk * num_rows, false, stream);

if (w4a8_in_scale) {
initialize_moe_routing_kernelLauncher<data_t, int8_t>::run(
if (permute_input->dtype() == paddle::DataType::INT8) {
initialize_moe_routing_kernelLauncher<data_t, int8_t>::run(
input.data<data_t>(), permute_input->data<int8_t>(), permuted_rows_,
expert_idx_per_token->data<int32_t>(), w4a8_in_scale->data<float>(),
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
permute_indices_per_token->data<int32_t>(), input_row_sum->data<float>(),
num_rows, num_rows,
hidden_size, moe_topk, stream);
} else {
initialize_moe_routing_kernelLauncher<data_t, phi::dtype::float8_e4m3fn>::run(
input.data<data_t>(), permute_input->data<phi::dtype::float8_e4m3fn>(), permuted_rows_,
expert_idx_per_token->data<int32_t>(), w4a8_in_scale->data<float>(),
permute_indices_per_token->data<int32_t>(), input_row_sum->data<float>(),
num_rows, num_rows,
hidden_size, moe_topk, stream);
}
} else {
initialize_moe_routing_kernelLauncher<data_t>::run(
input.data<data_t>(), permute_input->data<data_t>(), permuted_rows_,
expert_idx_per_token->data<int32_t>(), nullptr,
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
permute_indices_per_token->data<int32_t>(), input_row_sum->data<float>(),
num_rows, num_rows,
hidden_size, moe_topk, stream);
}

Expand All @@ -135,7 +147,7 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
const paddle::Tensor &input, const paddle::Tensor &gating_output,
const paddle::optional<paddle::Tensor> &gating_correction_bias,
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
const bool group_moe, const bool topk_only_mode) {
const bool group_moe, const std::string &moe_quant_type, const bool topk_only_mode) {
const auto input_type = input.dtype();
auto place = input.place();
int token_rows = 0;
Expand All @@ -151,8 +163,16 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
const int num_rows = token_rows;
const int hidden_size = input.dims()[input_dims.size() - 1];

auto permute_input_dtype =
w4a8_in_scale ? paddle::DataType::INT8 : input_type;
auto input_row_sum = GetEmptyTensor({moe_topk * num_rows}, paddle::DataType::FLOAT32, place);

auto permute_input_dtype = input_type;
if (w4a8_in_scale) {
if (moe_quant_type == "w4a8") {
permute_input_dtype = paddle::DataType::INT8;
} else {
permute_input_dtype = paddle::DataType::FLOAT8_E4M3FN;
}
}

auto permute_input = GetEmptyTensor({moe_topk * num_rows, hidden_size},
permute_input_dtype, place);
Expand All @@ -176,19 +196,20 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
input, gating_output, gating_correction_bias, w4a8_in_scale, moe_topk,
group_moe, topk_only_mode, num_rows, hidden_size, expert_num,
&permute_input, &tokens_expert_prefix_sum, &permute_indices_per_token,
&topk_weight, &topk_idx, &expert_idx_per_token);
&topk_weight, &topk_idx, &expert_idx_per_token, &input_row_sum);
break;
case paddle::DataType::FLOAT16:
MoeDispatchKernel<paddle::DataType::FLOAT16>(
input, gating_output, gating_correction_bias, w4a8_in_scale, moe_topk,
group_moe, topk_only_mode, num_rows, hidden_size, expert_num,
&permute_input, &tokens_expert_prefix_sum, &permute_indices_per_token,
&topk_weight, &topk_idx, &expert_idx_per_token);
&topk_weight, &topk_idx, &expert_idx_per_token, &input_row_sum);
break;
default:
PD_THROW("Unsupported data type for MoeDispatchKernel");
}
return {permute_input,
input_row_sum,
tokens_expert_prefix_sum,
permute_indices_per_token,
topk_weight,
Expand All @@ -214,6 +235,7 @@ std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
const int permuted_rows = num_rows == -1 ? -1 : moe_topk * num_rows;

return {{permuted_rows, hidden_size},
{permuted_rows},
{expert_num},
{moe_topk, num_rows},
{num_rows, moe_topk},
Expand All @@ -226,7 +248,7 @@ MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype,
const paddle::DataType &gating_output_dtype,
const paddle::optional<paddle::DataType> &bias_type,
const int moe_topk) {
return {input_dtype, paddle::DataType::INT64, paddle::DataType::INT32,
return {input_dtype, paddle::DataType::FLOAT32, paddle::DataType::INT64, paddle::DataType::INT32,
paddle::DataType::FLOAT32, paddle::DataType::INT32, paddle::DataType::INT32};
}

Expand Down Expand Up @@ -282,10 +304,10 @@ PD_BUILD_STATIC_OP(moe_expert_dispatch)
.Inputs({"input", "gating_output",
paddle::Optional("gating_correction_bias"),
paddle::Optional("w4a8_in_scale")})
.Outputs({"permute_input", "tokens_expert_prefix_sum",
.Outputs({"permute_input", "input_row_sum", "tokens_expert_prefix_sum",
"permute_indices_per_token", "topk_weight", "topk_idx",
"expert_idx_per_token"})
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
.Attrs({"moe_topk:int", "group_moe:bool", "moe_quant_type:std::string", "topk_only_mode:bool"})
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));
Loading
Loading