pommedeterresautee OP t1_itu4mr6 wrote
Reply to comment by ganzzahl in [P] Up to 12X faster GPU inference on Bert, T5 and other transformers with OpenAI Triton kernels by pommedeterresautee
>Why is it that we don't see any projects with similar speedups using custom CUDA kernels or custom ONNX operators?
To be honest, we had the very same question :-)
CUDA is powerful... and verbose. To target several generations of hardware you need some deep knowledge of their characteristics. I have many times followed people from Microsoft on a PR implementing some new model, it takes them often 1 month or more. On TensorRT I suppose it's even harder as they generate code but hey, it's a black box. For best perf, CUDA code could be good, but you need nvcc to generate the right set of PTX instructions to reach peak perf which is not always the case from what I saw.
Hopefully, people of Nvidia working on Cutlass try to make those things easier by taking care of the lowest level of Cuda implementations. The lib is not, right now, what you would call, easy to grasp but you really learn a lot by working with it (much more than starting from scratch as you see what is the right way to implement stuff).
There are several reasons why you don't see more Triton:
- many people work with it but not in OSS (Anthropic, OpenAI, etc.). You can guess through issues and repo stars that the language is growing faster and faster since a few months
- educative material ... could be more smooth, it's a bit first tuto (add 2 vecs) is boringly simple, on matmul one there is a block you need to look during long minutes to understand what it does, and fused attention, it took us days to understand each line... and realize that it was not really the Flash Attention paper (like one of us implemented the paper, the other worked on Triton example and we were arguing during days about everything until we realized that it was not parallelized at the same level...).
Things will change, PyTorch has choose Triton language as their default one to compile GPU models for future PyTorch version (I guess version 1.14, not sure). More about it here -> https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747
There are certainly other reasons (like big corps can't rely on other big corps techno without some guarantees, etc.) but I think those above are very important explanations.
To be honest, we have been very surprised by the speedups ourselves, beating TensorRT on long sequences was definitely far above our objectives. Even more crazy when you think we have still margins for more speedups... (like we don't yet tuned blocks sizes on some kernels, etc.)
Let's see where it brings us...
ganzzahl t1_itwsfkh wrote
This is perhaps an entirely dumb question that I will be able to answer for myself after I read through the Triton docs, but I'll ask anyway: Could one implement custom ONNX operators using Triton, or can it only be used in a Python environment?
Viewing a single comment thread. View all comments