Submitted by super_deap t3_11tmpc5 in MachineLearning
oathbreakerkeeper t1_jd0lu2p wrote
Reply to comment by mike94025 in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
Am I looking in the wrong place? It seems like the torch 2.0 code still requires training==False in order to use FlashAttention:
Dependent_Ad5120 t1_jd3m0ce wrote
try fp16, that doesn't require training=False apparently.
oathbreakerkeeper t1_jd43931 wrote
I'm using amp mixed precision which should be using fp16. It still requires training==false.
But the torch code also disables flash attention if autocast is enabled I'm not sure how to deal with that one.
Dependent_Ad5120 t1_jdec7kx wrote
I don't know. I was using pure fp16, no autocast and it works.
oathbreakerkeeper t1_jdgjte0 wrote
How do you use pure fp16 out of curiosity? I've only ever trained with mixed precision, letting pytorch handle the fp16 stuff from there.
Do you have an example of a github repo that does it?
Dependent_Ad5120 t1_je5qfmp wrote
I don't have a github repo for this, but it is pretty simple:
```
model = nn.Transformer().cuda().half
input = torch.rand(..).cuda().half
with sdp_kernel(...enable only flash attn):
output = model(input)
```
These 4 lines should be enough.
mike94025 t1_je5nrdi wrote
You’re looking in the wrong place. What you’re looking at is the BT gen1 fastpath, not the BT gern 2 custom kernels.
You need to look at F.multi_head_attention_forward().
The fastpath still services inference until a full rewrite of activation.py for now that will hopefully be refactored in a future release. (There’s always a tension between refactoring and introducing new features under a tone and staffing constrained problem formulation.)
Viewing a single comment thread. View all comments