MrAcurite t1_j2h2ei1 wrote
Reply to comment by currentscurrents in [D] Is there any research into using neural networks to discover classical algorithms? by currentscurrents
You can teach a neural network to solve, say, mazes in a 10x10 grid, but then you'd need to train it again to solve them in a 20x20 grid, and there would be a size at which the same model would simply cease to work. Whereas Dijkstra's, even if it slows down, would never fail to find the exit if the exit exists.
You might be able to train a model to find new strategies in a specific case, analyze it, and then code your understanding of it yourself, kinda like using a Monte Carlo approach to find a numerical answer to a problem before trying an analytic one. But you're not going to be able to pull an algorithm out of the parameters directly.
currentscurrents OP t1_j2hdsvv wrote
Someone else posted this example, which is kind of what I was interested in. They trained a neural network to do a toy problem, addition mod 113, and then were able to determine the algorithm it used to compute it.
>The algorithm learned to do modular addition can be fully reverse engineered. The algorithm is roughly:
>Map inputs x,y→ cos(wx),cos(wy),sin(wx),sin(wy) with a Discrete Fourier Transform, for some frequency w.
>Multiply and rearrange to get cos(w(x+y))=cos(wx)cos(wy)−sin(wx)sin(wy) and sin(w(x+y))=cos(wx)sin(wy)+sin(wx)cos(wy)
>By choosing a frequency w=2πnk we get period dividing n, so this is a function of x + y (mod n)
>Map to the output logits z with cos(w(x+y))cos(wz)+sin(w(x+y))sin(wz)=cos(w(x+y−z)) - this has the highest logit at z≡x+y(mod n), so softmax gives the right answer.
>To emphasise, this algorithm was purely learned by gradient descent! I did not predict or understand this algorithm in advance and did nothing to encourage the model to learn this way of doing modular addition. I only discovered it by reverse engineering the weights.
This is a very different way to do modular addition, but it makes sense for the network. Sine/cosine functions represent waves that repeat every frequency, so if you choose the right frequency you can implement the non-differentiable modular addition function just working with differentiable functions.
Extracting this algorithm is useful for generalization; while the original network only worked for mod 113, with the algorithm we can plug in any value for the frequency. Of course this is a toy example and there are much faster ways to do modular addition, but maybe it could work for more complex problems too.
Competitive-Rub-1958 t1_j2hy4kg wrote
Incidentally, that task has already been solved (https://twitter.com/arpitbansal297/status/1580922302543167488?cxt=HHwWgICgyaidyPArAAAA) They can OOD generalize to novel, unseen mazes of arbitrary sizes as long as they compute for more iterations at test-time!
MrAcurite t1_j2igfva wrote
Interesting. I'll have to add that paper to my reading list.
Viewing a single comment thread. View all comments