-
Notifications
You must be signed in to change notification settings - Fork 44
Open
Description
- compressed attention needs to be overlapping segments
- make sure compress and fine block sizes can be different, deal with the importance score as explained in paper
- handle < block size sequence lengths, make sure it doesn't break
- handle any block size < num selected blocks
- handle no blocks
- flex attention for starters
- add for sliding windows
- fine attention
- add the mask function for compressed attention assuming at some future date, attention logits can be extracted
- replace einx get at
-
allow for ablating this extra block causal diagonal in fine attentionnevermind, it is necessary for the first block - make it possible to customize the MLP for compressing key / value
- add attention pool as a type of compression module, even if they said "MLP"
-
figure out relative positions from query to each compressed key. perhaps use the midpoint of the set of keysgiven some recent literature, prob best not to have relative positions for the compressed pathway - gqa
- in gqa, allow each query head to select different sets of key / values. averaging the importance score across grouped query heads seems not for the best, but will allow for both to see if some tradeoff can be made
- experiments - https://api.wandb.ai/links/lucidrains/d7or7n5n
- few tests to make sure flex vs cpu version lines up
- inference pathways
- sliding windows
- block causal
- make rotary embed torch lib support gqa
- rotary for fine key / values
- some kv cache management
- keep cache on cpu and load only kv segments into gpu for selective attn
- memmap
- assert cache and without yield same result
- fix grouped each query head seeing different kv segments
- replace slow
get_at
in fine attn inference - running seq to be compressed, as well as all compressed - think about compressed sliding windows for even longer context
- computing the importance score + loading sparse kv blocks into mem
- wire up flex fine mask, see which one is faster, play around with
BlockMask
if slow, then move on - offer another type of gating by importance score, but on the fine attention output - default to this one as flex fine mask is not compatible
-
figure out whether they used some soft topk or gating with the importance scoresjust use one hot straight through on compress attention probs and gate the selected keys and values - revise the importance score for different compress vs fine block sizes based on dialogue
- build out the triton kernel
- figure out how best to deal with skinny matrix tl.dot - at this point in time, i'm expanding the dims to the minimum dim of 16 to carry out tl.dot with two 3ds. if any triton / cuda kernel expert has a better suggestion, let me know
- parallelize across sequence for backwards pass, make it optional
- forwards
- backwards
- dv
- dk
- dq
- swap q and kv loops
- figure out why dk is intermittently failing
- take care of gqa
- forwards
- backwards
- fix nan issue at higher batch sizes when fine selection is turned on
- allow for
query_heads_share_selected_kv
- autodetect indices and mask having query number of heads and pass in a flag - make the block causal diagonal optional and prep an encoder nsa variant
- take care of block sizes for both m and n less than fine block size
- flag in function that deduplicates selected indices
- debug nan issue with grouped query heads (4 / 2) and 4 selected fine blocks
- make sure triton nsa tolerates any seq length, for generation
- just move head to always the first dimension
-
seek a code review from triton expertscan't find any, expertise is still too rare
- improvisations
- generalize to multi-level hierarchical sparse attention
- add an encoder variant for long context video
- try adding a fused talking heads on the gqa, since they are all loaded in
- deduplicate some of the computation between fine attention block causal + sliding window causal, or try to carry them out in parallel
- offer a variant where one does attn softclamping to 40-50, and remove lse / maximum altogether, saving a ton of complexity
- 2d / 3d versions
LZ-QWQ, richardburleigh, Wesley-Jzy, Lucas9909, gucasbrg and 3 moreLZ-QWQ, rese1f, mikegreen7892003 and richardburleigh
Metadata
Metadata
Assignees
Labels
No labels