mike94025

mike94025 t1_je5nrdi wrote

You’re looking in the wrong place. What you’re looking at is the BT gen1 fastpath, not the BT gern 2 custom kernels.

You need to look at F.multi_head_attention_forward().

The fastpath still services inference until a full rewrite of activation.py for now that will hopefully be refactored in a future release. (There’s always a tension between refactoring and introducing new features under a tone and staffing constrained problem formulation.)

1

mike94025 t1_je5mfa8 wrote

This doesn't force it. It says that flash is enabled, and stone others. To force it, you have to disable all other kernels. Then it’s flash or bust.

You can find more in our blog which got published today and the SDPA tutorial. Both are linked here https://www.linkedin.com/posts/michael-gschwind-3704222_pytorch-activity-7046773418288955393-gOSh

PS: the context manager can be used anywhere outside the call as well, including around the call to model.forward.

2

mike94025 t1_jcx5xvg wrote

Data type?

SDPA currently has 3 kernels implemented by a kernel picker.

  • sdpa_math
  • sdpa_flash
  • sdpa_mem_eff

A kernel picker picks the best given your constraints

  • Math is the trusted kernel from the equation in the paper.
  • Flash only works for FP16 and BF16, and on SM80 (e.g., A100).
  • mem_efficient kernel works on older architecture levels, and supports FP32, but the upside is limited due to lack of compute capacity for FP32. FP16 or BF16 should help. Also, there are requirements on alignment, dropout values etc to qualify for the high-perf SDPA implementations. Dropout required to be 0 @ PT2.0

Also, different kernels parallelize across different dimensions, so B=1 will not work with all of those kernels.

In a nutshell, performance comes at the price of generality, and GPUs are finnecky to get the performance, so our inputs must adhere to those, and parallelization strategies matter for different combinations of dimensions.

5

mike94025 t1_jcv94un wrote

SDPA is used by F.multi_head_attention_forward (if need_weights=False) which is used by nn.MHA and nn.Transformer* as well as other libraries. (source)

Public service announcement: need_weights defaults to True, and guts performance. (Because allocating and writing the attention weight tensor defeats the memory BW advantages of flash attention.)

Also, if `key_padding_mask is not None` performance will suffer (because this is converted into an attention mask, and only the causal attention mask is suppprted by Flash Attention). Use Nested Tensors for variable sequence length batches.

1

mike94025 t1_jcv83hu wrote

Yes - use the backend context manager to disable all other backends to see that you're running the one you want. (Otherwise, since all other backends are disabled, you'll get an error.)

SDPA context manager is intended to facilitate debug (for perf or correctness), and is not (and should not be) required for normal operational usage.

Check out the SPDA tutorial at https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#explicit-dispatcher-control

1

mike94025 t1_jcmho8t wrote

Don't call flash_sdp directly. That way you're locked into particular hardware and create non-portable models. You can either use F.scaled_dot_product_attention() , or you use nn.MultiHeadAttention. In either case it will pick the right implementation based on the hardware you have, and the constraints. Ideally, the constraints would be weakened in the future, and/or new kernels might support other operating points in an optimized manner, and then the kernel picker can dispatch to that implementation.

See the kernel-picker logic that dispatches based on input characteristics in the source code, and/or the SDPA tutorial here => https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html

2