Skip to content

todo #1

@lucidrains

Description

@lucidrains
  • compressed attention needs to be overlapping segments
  • make sure compress and fine block sizes can be different, deal with the importance score as explained in paper
  • handle < block size sequence lengths, make sure it doesn't break
    • handle any block size < num selected blocks
    • handle no blocks
  • flex attention for starters
    • add for sliding windows
    • fine attention
    • add the mask function for compressed attention assuming at some future date, attention logits can be extracted
  • replace einx get at
  • allow for ablating this extra block causal diagonal in fine attention nevermind, it is necessary for the first block
  • make it possible to customize the MLP for compressing key / value
  • add attention pool as a type of compression module, even if they said "MLP"
  • figure out relative positions from query to each compressed key. perhaps use the midpoint of the set of keys given some recent literature, prob best not to have relative positions for the compressed pathway
  • gqa
  • in gqa, allow each query head to select different sets of key / values. averaging the importance score across grouped query heads seems not for the best, but will allow for both to see if some tradeoff can be made
  • experiments - https://api.wandb.ai/links/lucidrains/d7or7n5n
  • few tests to make sure flex vs cpu version lines up
  • inference pathways
    • sliding windows
    • block causal
    • make rotary embed torch lib support gqa
    • rotary for fine key / values
    • some kv cache management
      • keep cache on cpu and load only kv segments into gpu for selective attn
      • memmap
    • assert cache and without yield same result
    • fix grouped each query head seeing different kv segments
    • replace slow get_at in fine attn inference
    • running seq to be compressed, as well as all compressed - think about compressed sliding windows for even longer context
    • computing the importance score + loading sparse kv blocks into mem
  • wire up flex fine mask, see which one is faster, play around with BlockMask if slow, then move on
  • offer another type of gating by importance score, but on the fine attention output - default to this one as flex fine mask is not compatible
  • figure out whether they used some soft topk or gating with the importance scores just use one hot straight through on compress attention probs and gate the selected keys and values
  • revise the importance score for different compress vs fine block sizes based on dialogue
  • build out the triton kernel
    • figure out how best to deal with skinny matrix tl.dot - at this point in time, i'm expanding the dims to the minimum dim of 16 to carry out tl.dot with two 3ds. if any triton / cuda kernel expert has a better suggestion, let me know
    • parallelize across sequence for backwards pass, make it optional
    • forwards
    • backwards
      • dv
      • dk
      • dq
      • swap q and kv loops
      • figure out why dk is intermittently failing
    • take care of gqa
      • forwards
      • backwards
    • fix nan issue at higher batch sizes when fine selection is turned on
    • allow for query_heads_share_selected_kv - autodetect indices and mask having query number of heads and pass in a flag
    • make the block causal diagonal optional and prep an encoder nsa variant
    • take care of block sizes for both m and n less than fine block size
    • flag in function that deduplicates selected indices
    • debug nan issue with grouped query heads (4 / 2) and 4 selected fine blocks
    • make sure triton nsa tolerates any seq length, for generation
    • just move head to always the first dimension
    • seek a code review from triton experts can't find any, expertise is still too rare
  • improvisations
    • generalize to multi-level hierarchical sparse attention
    • add an encoder variant for long context video
    • try adding a fused talking heads on the gqa, since they are all loaded in
    • deduplicate some of the computation between fine attention block causal + sliding window causal, or try to carry them out in parallel
    • offer a variant where one does attn softclamping to 40-50, and remove lse / maximum altogether, saving a ton of complexity
    • 2d / 3d versions

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions