Viewing a single comment thread. View all comments

tripple13 t1_jck9593 wrote

Does anyone know why they didn't add the flashattention directly into the MultiheadAttention-modules? Seems to be integrated, awesome!

7

programmerChilli t1_jcnydmw wrote

I think it is used in Pytorch’s nn.transformerencoder but a lot of people like implementing their own.

2

mike94025 t1_jcv94un wrote

SDPA is used by F.multi_head_attention_forward (if need_weights=False) which is used by nn.MHA and nn.Transformer* as well as other libraries. (source)

Public service announcement: need_weights defaults to True, and guts performance. (Because allocating and writing the attention weight tensor defeats the memory BW advantages of flash attention.)

Also, if `key_padding_mask is not None` performance will suffer (because this is converted into an attention mask, and only the causal attention mask is suppprted by Flash Attention). Use Nested Tensors for variable sequence length batches.

1

mike94025 t1_je5ojaw wrote

It is. Follow the call tree into F.multi_head_attention_forward

1

tripple13 t1_je5seed wrote

Is that right? I some how end up here when trying to assess what the F.multi_head_attention call does in the Class definition.

But I trust you're right, it would only make sense, I just couldn't identify the calls myself.

1