We are happy to announce the support of OpenAI Whisper model (ASR task) on Kernl. 

We focused on high quality transcription in a latency sensitive scenario, meaning:

  • whisper-large-v2 weights
  • beam search 5 (as recommended in the related paper)

We measured a 2.3x speedup on Nvidia A100 GPU (2.4x on 3090 RTX) compared to Hugging Face implementation using FP16 mixed precision on transcribing librispeech test set (over 2600 examples). For now, OpenAI implementation is not yet PyTorch 2.0 compliant.

In the post below, we discuss what worked (CUDA Graph), our tricks (to significantly reduce memory footprint), and what did not pay off (Flash attention and some other custom Triton kernels).

Unsung hero: CUDA graphs

CUDA graphs technology provides most of the speed up. Compared to vanilla PyTorch 2.0 (“reduce-overhead mode”), we provide a limited memory footprint when vanilla PyTorch 2.0 may raise OOM exception.

memory footprint

Experiments have been run on a 3090 RTX with 24 Gb DDR. A reminder that PyTorch 2.0 focuses on training, not inference, which may explain why it OOMs rapidly in this case.

At its beginning, many partitioners were surprised by PyTorch eager mode performances, when compared to TensorFlow 1.x compiled models: they were on par! Python brought its flexibility and ease of debugging without implying any significant performance cost.

This is mostly because GPUs are latency hiding hardware: when PyTorch launches an operation on GPU, it sends instructions from host (CPU) to a queue (the CUDA stream), which allows PyTorch to continue Python script execution without having to wait for CUDA kernel to finish its work. This strategy effectively hides most of the Python overhead, in particular when there are some computation costly operations like convolutions or matrix multiplications.

Each new generation of GPUs being much faster than its predecessor, this strategy could not last forever, according to one PyTorch maintainer, it is an “existential problem” (dev podcast, around 8mn30).

In inference mode, especially in latency-sensitive scenarios where batch size tends to be low, there is often little computation to perform (regarding what modern GPUs can do), making it even harder to hide effectively Python overhead. It’s accentuated in the case of generative models like Whisper, because each decoder call focuses on generating a single token, and a part of the computation is cached for the next token.

This is a typical situation where CUDA graph is very helpful.

The main idea behind CUDA graph is that we can replace a series of instructions sent from host (CPU) to device (GPU) by one call referring to a graph of instruction stored in GPU. Check also this twitter thread for more explanations.

First it will observe the inference of a model for specific input shapes and then replay it without going through most of the Python code.

One constraint is that it will replay the exact same operations with the exact same arguments.

For instance, memory addresses used by kernels are captured and therefore need to be static. For input tensors, it means that for each inference, we need to allocate some GPU memory and copy them there before the capture and copy all the following input tensors at the very same place.

The second constraint is that dynamic shapes are not supported by CUDA graph because it captures everything. We could have our own machinery in front of the model, but PyTorch 2.0 offers the right tooling to manage that point out of the box.

Basically, dynamo offers a mechanism which checks if the model has already been captured for specific input shapes and some other states and capture it if not yet the case. You just have to provide a function which converts to CUDA graphs and you are done.

Out of the box, PyTorch 2.0 provides a “reduce-overhead” mode which applies CUDA graph to the model. Unfortunately, for now, it will raise an OOM with Whisper large or medium because it reserves some CUDA space for each input shape. Therefore, for a generative model it rapidly fulfills the GPU memory, in particular because of the K/V cache which can be huge.

We have worked around this constrain by building our own layer on top of the memory pool of PyTorch. 

Basically, a PyTorch tensor is made of 2 parts, a CUDA allocated memory represented by PyTorch as a “storage”, and a bunch of metadata associated with it. Among the metadata there is a CUDA memory address, the tensor shape plus its strides, its dtype and... a memory offset.

Our idea is to create a very large tensor and share its storage between several input tensors, using offset metadata. With this solution, we avoid specializing in input tensor shapes and share the reserved memory for different input shapes related to several CUDA graphs.

As shown in the table above, it significantly reduces the memory overhead.

What about custom (Triton) kernels for attention?

TL; DR: we tried, they work, we got up to 2 times faster than eager PyTorch for cross attention and they bring close to nothing in e2e latency mostly because the improvement is not big enough to matter 🙁

Below, we follow the convention of naming Q, K and V, the 3 tensors used in the attention of transformer models.

Whisper is based on a classic transformer architecture, with an encoder and a decoder.

Two characteristics of this model are of interest:

  • The shape of Q tensor used in cross-attention is always [batch, #heads, 1, 1500].
  • Model has been trained on 30-second audio files and their associated transcript. Because audio files are short, the sequence to generate is usually short, fewer than 50 tokens most of the time.

Because of these characteristics, optimizing attention has a low reward. In particular, the now common trick “replace attention with flash attention” is counterproductive:

  • self-attention: sequences are very short, so quadratic complexity is less of an issue;
  • cross-attention: using flash-attention leads to a 2 times slower inference on this part of the model.

We have tried to work on the second point and thought we could make cross attention faster.

Usual attention implementation (self and cross) relies on a series of operations: matmul (Q x K^t) -> rescale -> SoftMax -> matmul (SoftMax output x V). Intermediate output tensors have a shape which usually scales quadratically with input sequence length. They will be saved and reloaded from DDR, and memory bandwidth is a very scarce resource in GPUs.

To optimize speed, flash attention fuses operations, so basically first matmul will work on a small part of Q and K, and directly apply SoftMax to it without saving intermediate results to DDR. Same for second matmul. Because we don't go and back through GPU main memory, flash attention usually runs much faster than naïve implementation of attention.

The parallelization of the jobs is done on different axes: batch and attention head for the original flash attention, and Triton author added a third one, tokens, aka third dimension of Q (this important trick is now also part of flash attention CUDA implementation).

In the Whisper latency sensitive case, this doesn’t work well. The size of batches is low and sequence length (third dimension of Q tensor) is... 1! So, even if each job is done very efficiently, our GPU occupancy is low, and basically most of its streaming processors are idle. At the end of the day, the FA kernel is up to 2 times slower than eager PyTorch implementation (depending on batch size and model size).

Try 1: the very simple kernel

We noted that there is little computation to do and that we were memory bandwidth bounded. It means that most of the time we wait for data to be transferred from main memory to shared memory. 

We leveraged that fact in a very simple kernel with 2 optimizations:

  • after having finishing the rescale of the QK^t matmul, we perform the SoftMax computation in parallel of loading V tensor for the final matmul. The SoftMax computation finishes before the end of the V loading, so basically it costs us nothing;
  • to achieve best performances, we also changed the memory layout of V tensor in a way where we get a coalesced access, so we lowered the pressure on the memory bandwidth and increased instruction throughput (coalesced access let you load up to 128 bytes in a single instruction so you need less of them, which lets you perform more other things)

Altogether this cross attention was up to 2x faster compared to eager. It appeared to bring between 5 to 20% in end-to-end benchmark depending on model size and batch size. Cool but far from being a game changer, it requires a modification specific to Whisper model (memory layout of V) which is not in the spirit of the Kernl library. We decided to search for another way of doing things (we kept the code in the library for possible future use case).

Try 2: Skinny Flash Attention

Our second try is based on the very same trick as Flash Attention (parallel SoftMax) but is designed for tall and skinny tensors, which is inspired by split-k strategy in GEMM (a close cousin of the matmul). The main idea is to add a new parallelization axis over the 3rd dimension of K tensor. The next steps are in the same spirit as flash attention with a difference that we need a new reduction operation between the different jobs' outputs. It provides 5-10% speedup compared to eager implementation on this setup at kernel level. We kept that kernel to ease the next feature we are working on (quantization) but the effect in end-to-end latency is inferior to 5% (still it exists 😅).

Some thoughts about PyTorch 2.0, Triton and making things much faster

Playing with PyTorch 1.14 2.0 since this summer made us quite convinced that the major update to be released very soon will be a game changer for the ML field.

For inference (but also for training), the parallel with PyTorch vs TensorFlow is obvious to our eyes. 

The traditional way to deploy a model is to export it to Onnx, then to TensorRT plan format. Each step requires its own tooling, its own mental model, and may raise some issues. The most annoying thing is that you need Microsoft or Nvidia support to get the best performances, and sometimes model support takes time. For instance, T5, a model released in 2019, is not yet correctly supported on TensorRT, in particular K/V cache is missing (soon it will be according to TensorRT maintainers, but I wrote the very same thing almost 1 year ago and then 4 months ago so… I don’t know).

PyTorch 2.0 makes the graph capture step easy, it has been designed to work even if not everything is PyTorch compliant. With its Python first philosophy, it provides flexibility and debuggability. 

Several years ago, some said that by design PyTorch can’t be as performant than Tensorflow because of its eager execution model, compilation has to be faster. The same thing could be said for OnnxRuntime or TensorRT, they are C++ stuff, they have less overhead, etc. But at the end of the day, it's always the “ease of use” which is decisive. Ease of use because of Python, but also because of the transparency in the process, Triton makes understanding and debugging kernels much easier than closed source TensorRT Myelin engine calling closed source cuBlas library.

And of course, like TensorFlow, there will be many use cases where dedicated tools will be best choices, starting with situations where you can’t deploy a Python interpreter.

The second lesson, Triton is easier to start with than CUDA, but you probably can’t write or debug highly performant code without being able to, at least, read and debug PTX/SASS instructions. We realized that when we had some performance issues... The good news is that PTX is understandable, and you will probably spot unexpected generated code with some effort if there is any. Moreover, CUDA probably requires the same care when you really focus on performances.

We had plenty of issues with Triton, for example, cosmetics change in code may raise segfault. At some point you finish by having an intuition of what kinds of patterns to follow to make things work, in particular when there are for loops and dot operations. A new version of Triton has recently been released after a full rewrite of its backend, our little tests showed some improvement on stability but we have not yet fully switched.

As in my previous post, I highly recommend that readers start playing with Triton library, I rewrite it here: it’s fun (at least when it doesn’t segfault) and helps you to make sense of a large part of what is happening in ML engineering. I am quite convinced many flash attention like kernels are still to be written. 


Two important things to note about the project described here:

  • CUDA graphs require us to capture a graph per input tensor shape, there is a non-negligible warmup time. We measure around 10mn on 2 different machines / GPUs (down from 50mn in our previous Kernl version). One user reported with the new version a bit more than 20mn of warmup time. We are aware of obvious ways to decrease it significantly.
  • The context here is latency sensitive optimization. In throughput sensitive one, just increasing batch size will bring you most of the speedup. Otherwise, more aggressive optimizations like quantization are required (not yet released on Kernl).


JackDT t1_j7toqp1 wrote

This is amazing. The warmup time is annoying but 2.5X is actually fast enough to use for a live stream now. I'll wait 10 minutes if I need to!


netw0rkf10w t1_j7u86nu wrote

The work is amazing and the post is very informative. Thanks!


zzzthelastuser t1_j7ulu8h wrote

> CUDA graphs require us to capture a graph per input tensor shape, there is a non-negligible warmup time. We measure around 10mn on 2 different machines / GPUs (down from 50mn in our previous Kernl version). One user reported with the new version a bit more than 20mn of warmup time. We are aware of obvious ways to decrease it significantly.

Dumb question, but what's mn? millineconds?


pommedeterresautee OP t1_j7uml76 wrote

lol unfortunately no, minutes :(


mLalush t1_j7vcqph wrote

Love your write ups /u/pommedeterresautee . Especially the fact that they're written with human beings in mind. I mean that as a compliment, seeing as the vast majority of stuff concerning cuda and low level optimization is impenetrable.

I periodically check to see whether the documentation and tutorial sections have been expanded. My advice is put some real effort and focus in to examples and tutorials. It is key for an optimization/acceleration library. 10x-ing the users of a library like this is much more likely to come from spending 10 out of every 100 developer hours writing tutorials, as opposed to spending 8 or 9 of those tutorial-writing hours on developing new features that only a small minority understand how to use and apply.


blackkettle t1_j7tyq1r wrote

This is very interesting, thanks for sharing! Do you have any more detail on RTF vs Accuracy curves? Also did you run this on any other data sets? Librispeech - even the “other” pieces is very clean, simple data from an acoustic and linguistic standpoint.

It would be really interesting to see how well this holds on noisy spontaneous speech like conversations.


pommedeterresautee OP t1_j7u0p8z wrote

Using CG doesn't affect the output quality.

What works with Whisper will still work with CG+Whisper.


blackkettle t1_j7u2kd0 wrote

Probably my question was not well-formulated. I'm just curious about what the RTF vs Accuracy tradeoff looks like. I'm not questioning whether it works, I'm just curious what the actual performance looks like.

You report on memory usage and beam sizes, as well as relative speedup, but it would be interesting to also see WER performance, as well as the actual absolute RTFs.


whata_wonderful_day t1_j7ubutx wrote

His point is that it's identical. They didn't use quantization or anything that would hurt performance. The whisper paper has a lot of the details you're asking for


blackkettle t1_j7ud34i wrote

Are you talking about this paper:


maybe I missed it but I can't find any place in that paper where they talk about the trade-offs with respect to real time factor and decoding strategies. RTF vs acc curves for CPU vs GPU for STT typically vary not in terms of absolute performance but in terms of where along the RTF curve you achieve a particular accuracy. That impacts what kinds of tasks you can expect to use the model for, and how you can expect to scale it to real world applications. So far this has been the weakest point for all the Whisper related work (still better off with espnet, k2, speechbrain, etc). This information would be interesting to see if they have it.


uzibart t1_j7uiryq wrote

whisper-cpp comparison?


pommedeterresautee OP t1_j7uk761 wrote

I just discovered the project

As written in another comment, there is no way for (recent) CPU (even ARM ones) to be as fast as (recent) GPU on such big model (the list no GPU support in limitations).

That being said, the project looks super cool, tks for the pointer (I ordered a M2 Max, lots of fun to come :-) )


stevevaius t1_ja1vftm wrote

Very interesting. For a noob is there any simple notebook that shows how to load a sound file and run model on it at Google colab?


pommedeterresautee OP t1_ja26tgi wrote

Our work is for GPU with capacity >= 80 (A10, A100, 3090RTX, etc.) . On Colab you will likely get a T4, etc. (75). Your best bet is to copy paste related to CUDA graph from Kernl library and use with PyTorch 2.0 nightly.


stevevaius t1_ja27q2v wrote

Thanks. For a simple uploading a wav file and transcribe it, is there any implementation on colab? Sorry to bother you. I am working on whisper.cpp but large model is not fast on streaming. Looking to solve this issue by faster methods.


lpatks t1_j7u536l wrote

Interesting. Do you have any good resources for learning about PTX/SASS instructions? I've played around with Triton a bit, but it isn't even clear to me where I would see this output.


master3243 t1_j8bg2wl wrote

It would be amazing if this supports the whisper package directly instead of whisper in transformers from huggingface. (I know this isn't just for whisper but I do really need to speedup whisper)


pommedeterresautee OP t1_j8ciq4d wrote

As written at top of the post, unfortunately, the way openAI designed Whisper makes it non compliant with PyTorch 2.0

People at OpenAI said they will rework the package when PyTorch 2.0 is released. Then we will be able to optimize it.


SnooHesitations8849 t1_j7tjdla wrote

A C++ implementation on CPU would be on par with python's implementation on GPU. Just mind-blowing how much you can gain from using C++. But for sure, C++ is way harder to code.


pommedeterresautee OP t1_j7tk4fx wrote

On large DL models like Whisper large, CPU is never on par with GPUs because CPU is latency oriented hardware and GPU is throughput oriented. The only ways large models are run on CPUs is by reducing the number of operations to perform like by sparsification or pruning.

Moreover, PyTorch is mostly C++ with a Python layer over it (for now at least, PyTorch 2.0 may be a start of change in this architecture). The Python layer brings most of the PyTorch latency.

And then, even C++ engine launching operations on GPU can not be on par with CUDA graphs (most of the time at least), because you have still to send instruction at a time, and there is still some latency overhead associated in running things that way, just much less than Python. With CUDA graphs there is almost none at all.There is a second thing not discussed here, it's that the graph of instructions is optimized.

Main drawback of CG is the memory overhead, you need at least to double the space taken for input tensors. On generative models with K/V cache, it matters as explained in this post. Plus you need to copy input tensors, which offsets a -very-small part of the gains (at least that s what we saw in our tests on Whisper and Bert / Roberta).

That is why TensorRT (a big C++ piece) for instance supports CUDA graphs.

Still, TBH, as you pointed out, the most important thing is that ... it's easier to build and run :-)


programmerChilli t1_j7toust wrote

> The Python layer brings most of the PyTorch latency.

This actually isn't true - I believe most of the per-operator latency come from C++.


pommedeterresautee OP t1_j7tp663 wrote

I guess you better know than me :-)

Which part? The dispatcher thing or it's spread on several steps?


clauwen t1_j7txja0 wrote

You really have to wonder why everybody uses torch and tf2 then. Stupid trillion dollar companies, could just run everything on cpu, if they could only hire C++ devs. Billions of dollars down the drain, really unlucky.


Wrandraall t1_j7u76m1 wrote

Training =/= inference anyway. Whatever can be reached with CPU inference time, training still benefit by using GPUs from parallelization and caching