-
Notifications
You must be signed in to change notification settings - Fork 8
Open
Description
Warp kernel crashes for some input data in fp16 and bf16. E.g.
[B C T ]
[2, 2, 32768] -- works
[4, 2, 32768] -- doesn't
[2, 4, 32768] -- doesn't
[4096, 2048, 256] -- works
[4096, 2048, 512] -- doesn't
[4096, 1024, 512] -- works
The error:
RuntimeError: CUDA error: an illegal memory access was encountered
Happens during forward pass. Pytorch 2.4.1 (also tried 2.3.1), cuda 12.5, NVIDIA A100 and H100. Tests are crashing as well. FP32 is ok. Would appreciate your help.
Metadata
Metadata
Assignees
Labels
No labels