Submitted by super_deap t3_11tmpc5 in MachineLearning
programmerChilli t1_jcnydmw wrote
Reply to comment by tripple13 in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
I think it is used in Pytorch’s nn.transformerencoder but a lot of people like implementing their own.
mike94025 t1_jcv94un wrote
SDPA is used by F.multi_head_attention_forward (if need_weights=False) which is used by nn.MHA and nn.Transformer* as well as other libraries. (source)
Public service announcement: need_weights defaults to True, and guts performance. (Because allocating and writing the attention weight tensor defeats the memory BW advantages of flash attention.)
Also, if `key_padding_mask is not None` performance will suffer (because this is converted into an attention mask, and only the causal attention mask is suppprted by Flash Attention). Use Nested Tensors for variable sequence length batches.
Viewing a single comment thread. View all comments