Skip to content

Commit 278d3bd

Browse files
committed
Merge branch 'develop' into mm_structred_output
2 parents 2557839 + ce1f353 commit 278d3bd

File tree

107 files changed

+5504
-3372
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

107 files changed

+5504
-3372
lines changed

.github/workflows/approve.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ on:
66
- develop
77
- 'release/*'
88

9+
env:
10+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
11+
912
jobs:
1013
Approval:
1114
name: Approval

.github/workflows/ci_gcu.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ concurrency:
1313

1414
jobs:
1515
CI_GCU:
16-
runs-on: [self-hosted, GCU-S60-8Card]
16+
runs-on:
17+
group: GCU
1718
steps:
1819
- name: Print current runner name
1920
run: |

.github/workflows/ci_iluvatar.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ concurrency:
1111

1212
jobs:
1313
CI_ILUVATAR:
14-
runs-on: [self-hosted, IXUCA]
14+
runs-on:
15+
group: IXUCA
1516
steps:
1617
- name: Print current runner name
1718
run: |

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,6 @@ build
167167
.ccls-cache
168168

169169
third_party
170+
171+
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_*.cu
172+
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_template.h

custom_ops/gpu_ops/sample_kernels/rejection_top_p_sampling.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@ std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
2929

3030
// need_batch_random
3131
if (seed == -1) {
32+
#ifdef PADDLE_WITH_COREX
33+
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(probs.place()));
34+
#else
3235
phi::GPUContext* dev_ctx = static_cast<phi::GPUContext*>(phi::DeviceContextPool::Instance().Get(probs.place()));
36+
#endif
3337
auto gen_cuda = dev_ctx->GetGenerator();
3438
auto seed_offset = gen_cuda->IncrementOffset(32 * batch_size);
3539
philox_seed = seed_offset.first;

custom_ops/gpu_ops/sample_kernels/sampling.cuh

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,15 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
212212
prob_greater_than_threshold[j] = pred(prob_vec[j]) ? prob_vec[j] : 0;
213213
valid[j] = pred(prob_vec[j]) && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
214214
}
215+
#ifdef PADDLE_WITH_COREX
216+
float aggregate_local =
217+
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
218+
.Sum(prob_greater_than_threshold);
219+
#else
215220
float aggregate_local =
216221
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
217222
.Sum<VEC_SIZE>(prob_greater_than_threshold);
223+
#endif
218224
if (tx == 0) {
219225
temp_storage->block_aggregate.value = aggregate_local;
220226
}
@@ -226,8 +232,13 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
226232
DeterministicInclusiveSum<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>(
227233
prob_greater_than_threshold, inclusive_cdf, temp_storage);
228234
} else {
235+
#ifdef PADDLE_WITH_COREX
236+
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
237+
.InclusiveSum(prob_greater_than_threshold, inclusive_cdf);
238+
#else
229239
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
230240
.InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);
241+
#endif
231242

232243
__syncthreads();
233244
}
@@ -239,11 +250,21 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
239250

240251
bool greater_than_u_diff[VEC_SIZE];
241252
#ifdef SAMPLING_CUB_SUBTRACTLEFT_DEFINED
242-
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
243-
.SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff, BoolDiffOp());
253+
#ifdef PADDLE_WITH_COREX
254+
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
255+
.SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp());
256+
#else
257+
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
258+
.SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff, BoolDiffOp());
259+
#endif
244260
#else
245-
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
246-
.FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
261+
#ifdef PADDLE_WITH_COREX
262+
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
263+
.FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
264+
#else
265+
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
266+
.FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
267+
#endif
247268
#endif
248269
__syncthreads();
249270

@@ -355,18 +376,30 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
355376
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
356377
}
357378

379+
#ifdef PADDLE_WITH_COREX
380+
aggregate_gt_pivot_0 +=
381+
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
382+
.Sum(probs_gt_pivot_0);
383+
#else
358384
aggregate_gt_pivot_0 +=
359385
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
360386
.Sum<VEC_SIZE>(probs_gt_pivot_0);
387+
#endif
361388
if (tx == 0) {
362389
temp_storage.block_aggregate.pair = aggregate_gt_pivot_0;
363390
}
364391
__syncthreads();
365392
aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair;
366393

394+
#ifdef PADDLE_WITH_COREX
395+
aggregate_gt_pivot_1 +=
396+
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
397+
.Sum(probs_gt_pivot_1);
398+
#else
367399
aggregate_gt_pivot_1 +=
368400
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
369401
.Sum<VEC_SIZE>(probs_gt_pivot_1);
402+
#endif
370403
if (tx == 0) {
371404
temp_storage.block_aggregate.pair = aggregate_gt_pivot_1;
372405
}
@@ -466,16 +499,26 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
466499
probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0;
467500
}
468501

502+
#ifdef PADDLE_WITH_COREX
503+
aggregate_gt_pivot_0 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
504+
.Sum(probs_gt_pivot_0);
505+
#else
469506
aggregate_gt_pivot_0 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
470507
.Sum<VEC_SIZE>(probs_gt_pivot_0);
508+
#endif
471509
if (tx == 0) {
472510
temp_storage.block_aggregate.value = aggregate_gt_pivot_0;
473511
}
474512
__syncthreads();
475513
aggregate_gt_pivot_0 = temp_storage.block_aggregate.value;
476514

515+
#ifdef PADDLE_WITH_COREX
516+
aggregate_gt_pivot_1 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
517+
.Sum(probs_gt_pivot_1);
518+
#else
477519
aggregate_gt_pivot_1 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
478520
.Sum<VEC_SIZE>(probs_gt_pivot_1);
521+
#endif
479522
if (tx == 0) {
480523
temp_storage.block_aggregate.value = aggregate_gt_pivot_1;
481524
}
@@ -521,9 +564,15 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u
521564
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
522565
in_data_[j] = in_data_vec[j];
523566
}
567+
#ifdef PADDLE_WITH_COREX
568+
max_val = max(
569+
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
570+
.Reduce(in_data_, cub::Max()));
571+
#else
524572
max_val = max(
525573
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
526574
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
575+
#endif
527576
__syncthreads();
528577
}
529578
if (tx == 0) {
@@ -605,7 +654,11 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
605654
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
606655
const uint32_t row_idx = bx;
607656
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
657+
#ifdef PADDLE_WITH_COREX
658+
double pivot = std::numeric_limits<float>::infinity(), normalizer = 1;
659+
#else
608660
double pivot = -cuda::std::numeric_limits<float>::infinity(), normalizer = 1;
661+
#endif
609662
vec_t<float, VEC_SIZE> probs_vec;
610663
if (k < d) {
611664
extern __shared__ __align__(alignof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>))
@@ -659,14 +712,26 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
659712
}
660713
}
661714

715+
#ifdef PADDLE_WITH_COREX
716+
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
717+
temp_storage.block_prim.reduce_value_count)
718+
.Sum(probs_gt_pivot_0_pair);
719+
#else
662720
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
663721
temp_storage.block_prim.reduce_value_count)
664722
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
723+
#endif
665724
__syncthreads();
666725

726+
#ifdef PADDLE_WITH_COREX
727+
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
728+
temp_storage.block_prim.reduce_value_count)
729+
.Sum(probs_gt_pivot_1_pair);
730+
#else
667731
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
668732
temp_storage.block_prim.reduce_value_count)
669733
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
734+
#endif
670735
__syncthreads();
671736
}
672737
min_gt_low =

custom_ops/gpu_ops/sample_kernels/utils.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,13 @@ inline std::pair<int, int> GetCudaComputeCapability() {
258258

259259
/******************* math *******************/
260260
__forceinline__ __device__ float ptx_rcp(float x) {
261+
#ifdef PADDLE_WITH_COREX
262+
return __ivcorex_rcpf(x);
263+
#else
261264
float y;
262265
asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
263266
return y;
267+
#endif
264268
}
265269

266270
template <typename T1, typename T2>
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "cute/algorithm/copy.hpp"
16+
#include "cute/atom/mma_atom.hpp"
17+
#include "cutlass/gemm/collective/collective_builder.hpp"
18+
19+
#include "cutlass/cutlass.h"
20+
#include "cutlass/layout/layout.h"
21+
#include "cutlass/numeric_types.h"
22+
#include "cutlass/pipeline/pipeline.hpp"
23+
24+
using namespace cute;
25+
26+
template <int kStages, class GemmType, class OutputType, class SmemLayoutA,
27+
class SmemLayoutB, class SmemLayoutC>
28+
struct SharedStorage {
29+
union {
30+
struct {
31+
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutA>> smem_a;
32+
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutB>> smem_b;
33+
};
34+
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutC>> smem_c;
35+
};
36+
37+
struct {
38+
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline;
39+
};
40+
};
41+
42+
template<int kBlockM_, int kBlockN_, int kBlockK_,
43+
int kNWarps_, int kStages_,
44+
int kTiles_, int M_,
45+
int TokenPackSize_,
46+
int TAIL_N_ = 0,
47+
int kClusterM_ = 1,
48+
typename elem_type=cutlass::float_e4m3_t,
49+
typename OutputType = cutlass::bfloat16_t>
50+
struct Kernel_traits {
51+
using Element = elem_type;
52+
using ElementAccum = float;
53+
using ElementOutput = OutputType;
54+
static_assert(cutlass::sizeof_bits_v<Element> == 8);
55+
56+
static constexpr int kNWarps = kNWarps_;
57+
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
58+
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
59+
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
60+
61+
static_assert(kNWarps_ == 12 || kNWarps_ == 16);
62+
63+
static constexpr int kBlockM = kBlockM_;
64+
static constexpr int kBlockN = kBlockN_;
65+
static constexpr int kBlockK = kBlockK_;
66+
static constexpr int kTiles = kTiles_;
67+
static constexpr int TokenPackSize = TokenPackSize_;
68+
static constexpr int M = M_;
69+
static constexpr int TAIL_N = TAIL_N_;
70+
71+
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kBlockK>>;
72+
using TileShape_MNK_TAIL = Shape<Int<kBlockM>, Int<TAIL_N>, Int<kBlockK>>;
73+
74+
static constexpr int kClusterM = kClusterM_;
75+
using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
76+
77+
static constexpr int kStages = kStages_;
78+
static_assert(kStages > 1);
79+
80+
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
81+
82+
using TiledMma = decltype(cute::make_tiled_mma(
83+
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
84+
AtomLayoutMNK{}));
85+
86+
using TiledMma_TAIL = decltype(cute::make_tiled_mma(
87+
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK_TAIL>(),
88+
AtomLayoutMNK{}));
89+
90+
using SmemLayoutAtomA = decltype(
91+
cutlass::gemm::collective::detail::rs_smem_selector<
92+
GMMA::Major::K, Element, Int<kBlockM>, Int<kBlockK / 2>>());
93+
94+
using SmemLayoutA = decltype(
95+
tile_to_shape(SmemLayoutAtomA{},
96+
make_shape(Int<kBlockM>{}, Int<kBlockK / 2>{}, Int<kStages>{})));
97+
98+
using SmemLayoutAtomB = decltype(
99+
cutlass::gemm::collective::detail::rs_smem_selector<
100+
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK{})),
101+
decltype(cute::get<2>(TileShape_MNK{}))>());
102+
103+
using SmemLayoutB = decltype(
104+
tile_to_shape(SmemLayoutAtomB{},
105+
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
106+
107+
using SmemLayoutAtomB_TAIL = decltype(
108+
cutlass::gemm::collective::detail::rs_smem_selector<
109+
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK_TAIL{})),
110+
decltype(cute::get<2>(TileShape_MNK_TAIL{}))>());
111+
112+
using SmemLayoutB_TAIL = decltype(
113+
tile_to_shape(SmemLayoutAtomB_TAIL{},
114+
make_shape(
115+
shape<1>(TileShape_MNK_TAIL{}),
116+
shape<2>(TileShape_MNK_TAIL{}),
117+
Int<kStages>{})
118+
));
119+
120+
using SmemLayoutAtomC = decltype(
121+
cutlass::gemm::collective::detail::rs_smem_selector<
122+
GMMA::Major::K, ElementOutput,
123+
decltype(cute::get<0>(TileShape_MNK{})),
124+
decltype(cute::get<1>(TileShape_MNK{}))>());
125+
126+
using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{})));
127+
128+
using SmemCopyAtomAB = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
129+
using SmemCopyAtomC = Copy_Atom<cute::SM90_U32x4_STSM_N, ElementOutput>;
130+
131+
using SharedStorage = SharedStorage<
132+
kStages, Element, ElementOutput, SmemLayoutA, SmemLayoutB, SmemLayoutC>;
133+
134+
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
135+
using PipelineState = typename cutlass::PipelineState<kStages>;
136+
137+
138+
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<OutputType>);
139+
static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem;
140+
// static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
141+
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
142+
using TiledCopyCAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, OutputType>;
143+
using TiledCopyCThrLayout = decltype(cute::make_layout(
144+
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
145+
LayoutRight{}));
146+
using TiledCopyCValLayout = decltype(cute::make_layout(
147+
cute::make_shape(_1{}, Int<kNumVecElem>{}),
148+
LayoutRight{}));
149+
using TiledCopyC = decltype(make_tiled_copy(
150+
TiledCopyCAtom{},
151+
TiledCopyCThrLayout{}, // Thr layout
152+
TiledCopyCValLayout{} // Val layout
153+
));
154+
};

0 commit comments

Comments
 (0)