Submitted by super_deap t3_11tmpc5 in MachineLearning
mike94025 t1_jcmho8t wrote
Reply to comment by Competitive-Rub-1958 in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
Don't call flash_sdp directly. That way you're locked into particular hardware and create non-portable models. You can either use F.scaled_dot_product_attention() , or you use nn.MultiHeadAttention. In either case it will pick the right implementation based on the hardware you have, and the constraints. Ideally, the constraints would be weakened in the future, and/or new kernels might support other operating points in an optimized manner, and then the kernel picker can dispatch to that implementation.
See the kernel-picker logic that dispatches based on input characteristics in the source code, and/or the SDPA tutorial here => https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
Competitive-Rub-1958 t1_jcn8bti wrote
cool. I just wanted to make it explicit to make sure I'm running `FlashAttention`. Perhaps there's an easy way to check that?
mike94025 t1_jcv83hu wrote
Yes - use the backend context manager to disable all other backends to see that you're running the one you want. (Otherwise, since all other backends are disabled, you'll get an error.)
SDPA context manager is intended to facilitate debug (for perf or correctness), and is not (and should not be) required for normal operational usage.
Check out the SPDA tutorial at https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#explicit-dispatcher-control
Competitive-Rub-1958 t1_jd40cwb wrote
would that mean for forcing MHA to use it, I should wrap the ctxmanager around the line where I forward through it?
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True):
x = x + self.attn_head(x, x, x, need_weights=False)[0]
because that doesn't really seem to work :(
mike94025 t1_je5mfa8 wrote
This doesn't force it. It says that flash is enabled, and stone others. To force it, you have to disable all other kernels. Then it’s flash or bust.
You can find more in our blog which got published today and the SDPA tutorial. Both are linked here https://www.linkedin.com/posts/michael-gschwind-3704222_pytorch-activity-7046773418288955393-gOSh
PS: the context manager can be used anywhere outside the call as well, including around the call to model.forward.
Viewing a single comment thread. View all comments