diff --git a/sgl-kernel/include/utils.h b/sgl-kernel/include/utils.h index 56f32276426..5cab0786c4d 100644 --- a/sgl-kernel/include/utils.h +++ b/sgl-kernel/include/utils.h @@ -331,9 +331,11 @@ inline bool getEnvEnablePDL() { #ifndef USE_ROCM #define WARP_SIZE 32 #else -#include -#include -#define WARP_SIZE C10_WARP_SIZE +#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) +#define WARP_SIZE 64 +#else +#define WARP_SIZE 32 +#endif #endif #ifdef USE_ROCM