Submitted by wangyi_fudan t3_y2w87i in MachineLearning
The proof is simple:
attention=softmax(QKt)V
=softmax(XWq (XWk)t)XWv
=softmax(XWqWktXt)XWv
let Wk'=WkWq'
attention=softmax(X(XWk')t)XWv
=softmax(XK')V
now we see that Q=XWq is replaced by X, reduced 1/4 paramters in attention module.
I did real experiment and found that with 3/4 parameters of original attention, the difference of loss is 0.01 during the training process and does not increase. Though Wq is not necessary, but with 1/4 more parameters it seems just slightly better.
But in multihead attention, Wq is necessary. However, research has shown that stacking many small single heads attention modules to form a very deep model is better than wider multi-head attention (single head is enough).
StellaAthena t1_is7iss2 wrote
The proof is even more simple: (xW_q)(xW_k)^T = x(W_qW_k^T )x^T = xWx
The problem is that W_q and W_k are not square matrices. They are d_model by d_head, and so their product is d_model x d_model. In practice d_model >> d_head (e.g., they’re 4096 and 256 respectively in GPT-J). Doing it your way uses a lot more memory and compute