Submitted by super_deap t3_11tmpc5 in MachineLearning
Competitive-Rub-1958 t1_jd40cwb wrote
Reply to comment by mike94025 in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
would that mean for forcing MHA to use it, I should wrap the ctxmanager around the line where I forward through it?
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True):
x = x + self.attn_head(x, x, x, need_weights=False)[0]
because that doesn't really seem to work :(
mike94025 t1_je5mfa8 wrote
This doesn't force it. It says that flash is enabled, and stone others. To force it, you have to disable all other kernels. Then it’s flash or bust.
You can find more in our blog which got published today and the SDPA tutorial. Both are linked here https://www.linkedin.com/posts/michael-gschwind-3704222_pytorch-activity-7046773418288955393-gOSh
PS: the context manager can be used anywhere outside the call as well, including around the call to model.forward.
Viewing a single comment thread. View all comments