Submitted by thomasahle t3_118gie9 in MachineLearning
Cross entropy on logits is a normal simplification that fuses softmax + cross entropy loss to something like:
def label_cross_entropy_on_logits(x, labels):
return (-x.select(labels) + x.logsumexp(axis=1)).sum(axis=0)
where x.select(labels) = x[range(batch_size), labels]
.
I was thinking about how the logsumexp
term looks like a regularization term, and wondered what would happen if I just replaced it by x.norm(axis=1)
instead. It seemed to work just as well as the original, so I thought, why not just enforce unit norm?
I changed my code to
def label_cross_entropy_on_logits(x, labels):
return -(x.select(labels) / x.norm(axis=1)).sum(axis=0)
and my training sped up dramatically, and my test loss decreased.
I'm sure this is a standard approach to categorical loss, but I haven't seen it before, and would love to get some references.
I found this old post: https://www.reddit.com/r/MachineLearning/comments/k6ff4w/unit_normalization_crossentropy_loss_outperforms/ which references LogitNormalization: https://arxiv.org/pdf/2205.09310.pdf However, it seems those papers all apply layer normalization and then softmax+CE. What seems to work for me is simply replacing softmax+CE by normalization.
ChuckSeven t1_j9iyuc2 wrote
hmm not sure, but I think if you don't exponentiate you cannot fit n targets into a d-dimensional space if n > d and you want there to exist a vector v for each target such that the outcome is a one-hot distribution (or 0 loss).
Basically, if you have 10 targets but only a 2-dimensional space you need to have enough non-linearity in the projection to your target space such that there exists a 2d vector which gives 0 loss for each target.
edit: MNIST only has 10 classes so you are probably fine. Furthermore, softmax of the dot product "care exponentially more" about the angle of the prediction vector than the scale. If you use norm, I'd think that you only care about angle which likely leads to different representations. The fact that those may improve performance highly depends how your model may rely on scale to learn certain predictions. Maybe in case of mnist, relying on scale worsens performance (e.g. if you want a wild guess, because it maybe makes "predictions more certain" simply if it has more pixels set to 1).