Submitted by super_deap t3_11tmpc5 in MachineLearning
Dependent_Ad5120 t1_jdec7kx wrote
Reply to comment by oathbreakerkeeper in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
I don't know. I was using pure fp16, no autocast and it works.
oathbreakerkeeper t1_jdgjte0 wrote
How do you use pure fp16 out of curiosity? I've only ever trained with mixed precision, letting pytorch handle the fp16 stuff from there.
Do you have an example of a github repo that does it?
Dependent_Ad5120 t1_je5qfmp wrote
I don't have a github repo for this, but it is pretty simple:
```
model = nn.Transformer().cuda().half
input = torch.rand(..).cuda().half
with sdp_kernel(...enable only flash attn):
output = model(input)
```
These 4 lines should be enough.
Viewing a single comment thread. View all comments