Skip to content

Crash with fp16 and bf16 #12

@JohnAlphaIII

Description

@JohnAlphaIII

Warp kernel crashes for some input data in fp16 and bf16. E.g.

[B  C    T  ]
[2, 2, 32768] -- works
[4, 2, 32768] -- doesn't
[2, 4, 32768] -- doesn't

[4096, 2048, 256] -- works
[4096, 2048, 512] -- doesn't
[4096, 1024, 512] -- works

The error:

RuntimeError: CUDA error: an illegal memory access was encountered

Happens during forward pass. Pytorch 2.4.1 (also tried 2.3.1), cuda 12.5, NVIDIA A100 and H100. Tests are crashing as well. FP32 is ok. Would appreciate your help.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions