Submitted by beautyofdeduction t3_10uuslf in deeplearning

Hey Y'all, the transformer model that I'm training has:

  • Keras param count is 22 million: 6 encoder blocks, each of which has 8 head with 64 head_size each
  • sequence of length 6250
  • batch size 1

It consistently OOMs on any GPU with less than 40G of vram (Rtx A6000 for example). I've tried on both Google Colab and Lambda Labs.

22M params plus activations, the most expensive of which has size 6250 * 6250. So that comes down to ~62M floats, i.e. 500 MB. I cannot wrap my head around what caused the vram OOM!

I must be missing something but I can't see it. Please help me out. How much vram usage did you see with your transformers? Any thoughts are appreciated!

6

Comments

You must log in or register to comment.

neuralbeans t1_j7f68rv wrote

Parameters are a tiny portion of the values in GPU. The number of activations grows quadratically with sequence size.

7

beautyofdeduction OP t1_j7hkb7q wrote

Yes, that's true. But even adding that in (6250*6250 ~= 40 mil floats), we are still nowhere near 40G.

1

neuralbeans t1_j7jdiqz wrote

A sequence length of 6250 is massive! It's not just 6250*6250 since you're not multiplying one float per pair of sequence items. You're multiplying the key and value vectors together per pair of sequence items, and this is done for every attention head (in parallel). I think you're seriously under estimating the problem.

What transformer is this which accepts a sequence length of 6250?

1

beautyofdeduction OP t1_j7jqohn wrote

I wish I can send you my Github. But the original Attention is All You Need paper trained on sequences of length 25000 on multiple K80's (stated by the authors), which has only 12GB vram. Yes they used multiple GPUs, but afaik each GPU needs to be able to handle its own batch. Or maybe not? Again I wish I could show you my code.

1

BellyDancerUrgot t1_j7ec7oj wrote

~ 83gb I think, not 500mb

1

beautyofdeduction OP t1_j7epm3u wrote

Can you elaborate?

1

BellyDancerUrgot t1_j7eq93o wrote

Each Float64 is 4 bytes. U said u have 22M parameters.

Also besides ur params and activations u still have gradients + sequences are mapped for each attention head so multiply that by 8 as well.

For context I think deeplabv3 which iirc is a model with 58mil parameters was trained on 8 V100s.

Edit : I clearly had a brain stroke while writing the first part so ignore

1

beautyofdeduction OP t1_j7eqr8c wrote

8 Bytes * 22M = 0.176 GB?

1

BellyDancerUrgot t1_j7f0u7u wrote

Okay yeah Idk wtf I was typing. Yes 0.176gb for just the parameters. U still have to account for dense representations of long sequences, that too 8 times, activations, gradients and all these multiplied by the number of layers. There was a formula to approximate the value I read somewhere online. Activations I think take up way more memory than the model itself.

The memory requirement is roughly inline with most mid size transformer models I think.

3

beautyofdeduction OP t1_j7hkq74 wrote

That context of how much memory other models use up is helpful. Thanks for taking the time to respond.

2

Long_Two_6176 t1_j7gc9rz wrote

Remember also that computations, not just parameter count, cost GPU memory. Check your intermediate tensor sizes

1

ia3leonid t1_j7hgcoq wrote

Gradients are also stored and take as much memory as weights + activations, or more for some optimisers (Adam also tracks statistics for each weight, for example )

1