Submitted by super_deap t3_11tmpc5 in MachineLearning
tripple13 t1_jck9593 wrote
Does anyone know why they didn't add the flashattention directly into the Seems to be integrated, awesome!MultiheadAttention
-modules?
programmerChilli t1_jcnydmw wrote
I think it is used in Pytorch’s nn.transformerencoder but a lot of people like implementing their own.
mike94025 t1_jcv94un wrote
SDPA is used by F.multi_head_attention_forward (if need_weights=False) which is used by nn.MHA and nn.Transformer* as well as other libraries. (source)
Public service announcement: need_weights defaults to True, and guts performance. (Because allocating and writing the attention weight tensor defeats the memory BW advantages of flash attention.)
Also, if `key_padding_mask is not None` performance will suffer (because this is converted into an attention mask, and only the causal attention mask is suppprted by Flash Attention). Use Nested Tensors for variable sequence length batches.
mike94025 t1_je5ojaw wrote
It is. Follow the call tree into F.multi_head_attention_forward
Viewing a single comment thread. View all comments