mike94025
mike94025 t1_je5oe88 wrote
mike94025 t1_je5nrdi wrote
Reply to comment by oathbreakerkeeper in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
You’re looking in the wrong place. What you’re looking at is the BT gen1 fastpath, not the BT gern 2 custom kernels.
You need to look at F.multi_head_attention_forward().
The fastpath still services inference until a full rewrite of activation.py for now that will hopefully be refactored in a future release. (There’s always a tension between refactoring and introducing new features under a tone and staffing constrained problem formulation.)
mike94025 t1_je5mfa8 wrote
Reply to comment by Competitive-Rub-1958 in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
This doesn't force it. It says that flash is enabled, and stone others. To force it, you have to disable all other kernels. Then it’s flash or bust.
You can find more in our blog which got published today and the SDPA tutorial. Both are linked here https://www.linkedin.com/posts/michael-gschwind-3704222_pytorch-activity-7046773418288955393-gOSh
PS: the context manager can be used anywhere outside the call as well, including around the call to model.forward.
mike94025 t1_jcx5xvg wrote
Reply to comment by crude2refined in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
Data type?
SDPA currently has 3 kernels implemented by a kernel picker.
- sdpa_math
- sdpa_flash
- sdpa_mem_eff
A kernel picker picks the best given your constraints
- Math is the trusted kernel from the equation in the paper.
- Flash only works for FP16 and BF16, and on SM80 (e.g., A100).
- mem_efficient kernel works on older architecture levels, and supports FP32, but the upside is limited due to lack of compute capacity for FP32. FP16 or BF16 should help. Also, there are requirements on alignment, dropout values etc to qualify for the high-perf SDPA implementations. Dropout required to be 0 @ PT2.0
Also, different kernels parallelize across different dimensions, so B=1 will not work with all of those kernels.
In a nutshell, performance comes at the price of generality, and GPUs are finnecky to get the performance, so our inputs must adhere to those, and parallelization strategies matter for different combinations of dimensions.
mike94025 t1_jcv94un wrote
Reply to comment by programmerChilli in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
SDPA is used by F.multi_head_attention_forward (if need_weights=False) which is used by nn.MHA and nn.Transformer* as well as other libraries. (source)
Public service announcement: need_weights defaults to True, and guts performance. (Because allocating and writing the attention weight tensor defeats the memory BW advantages of flash attention.)
Also, if `key_padding_mask is not None` performance will suffer (because this is converted into an attention mask, and only the causal attention mask is suppprted by Flash Attention). Use Nested Tensors for variable sequence length batches.
mike94025 t1_jcv83hu wrote
Reply to comment by Competitive-Rub-1958 in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
Yes - use the backend context manager to disable all other backends to see that you're running the one you want. (Otherwise, since all other backends are disabled, you'll get an error.)
SDPA context manager is intended to facilitate debug (for perf or correctness), and is not (and should not be) required for normal operational usage.
Check out the SPDA tutorial at https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#explicit-dispatcher-control
mike94025 t1_jcv7ltl wrote
Reply to comment by royalemate357 in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
You might look into https://github.com/pytorch/pytorch/pull/95793.
mike94025 t1_jcn7ksu wrote
Reply to comment by royalemate357 in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
Works for all. You need a compiler backend that can code-gen for your target, and need a frontend for the optimizer that can process the IR.
Alternatively, you need a backend for Triton (or another already supported optimizer) that can codegen for your target architecture.
mike94025 t1_jcmlddm wrote
Reply to comment by cthorrez in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
Better Transformer supports both, today. Some optimizations are still inference-only (and in particular support for variable-sequence length Nested Tensor) and the inference fastpath is a bit silo'ed, but nothing that future PyTorch update could not fix.
mike94025 t1_jcmho8t wrote
Reply to comment by Competitive-Rub-1958 in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
Don't call flash_sdp directly. That way you're locked into particular hardware and create non-portable models. You can either use F.scaled_dot_product_attention() , or you use nn.MultiHeadAttention. In either case it will pick the right implementation based on the hardware you have, and the constraints. Ideally, the constraints would be weakened in the future, and/or new kernels might support other operating points in an optimized manner, and then the kernel picker can dispatch to that implementation.
See the kernel-picker logic that dispatches based on input characteristics in the source code, and/or the SDPA tutorial here => https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
mike94025 t1_jcmepd6 wrote
Reply to comment by cthorrez in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
With Better Transformer, ¿Por que no los dos?
mike94025 t1_jcly3mi wrote
Reply to comment by Competitive-Rub-1958 in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
Documentation was not updated. Yes, you can use flash attention for training.
The first version included only forward() as we were resolving some issues with backward(). Docstring will be updated.
mike94025 t1_je5ojaw wrote
Reply to comment by tripple13 in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
It is. Follow the call tree into F.multi_head_attention_forward