Submitted by Logon1028 t3_zn1f3j in deeplearning
elbiot t1_j0k91nm wrote
Reply to comment by Logon1028 in Efficient Max Pooling Implementation by Logon1028
I think unravel is a tuple so you can just star unpack it to use it as indices without having to do anything else with it
Logon1028 OP t1_j0lxkae wrote
That's what I am doing currently. But I have to unpack it in a triple nested for loop because numpy doesn't accept tuples. So I don't gain the benefits of numpy's parallelization. Which is why I was searching for a possible alternative. I am not trying to like super optimize this function, but I want all the low hanging fruit I can get. I want people to be able to use the library to train small models for learning purposes.
elbiot t1_j0m3a67 wrote
Huh?
idx = unravel_indices(indices, shape) Values=arr[*idx]
No loop required. If you're referring to the same loop you were using to get the argmax, you can just adjust your indices first so they apply to the unstrided array
Logon1028 OP t1_j0mbuxc wrote
Yes, but that unravel_indices has to be applied to EVERY SINGLE ELEMENT of the last axis independently. i.e.
for depth in range(strided_result.shape[0]):
for x in range(strided_result.shape[1]):
for y in range(strided_result.shape[2]):
local_stride_index = np.unravel_index(argmax_arr[depth][x][y], strided_result[depth][x][y].shape)
unravel_indices only takes a 1d array as input. In order to apply it to only the last axis of the 4D array you have to use a 4 loop. unravel_indices has no axis parameter.
Logon1028 OP t1_j0mcyle wrote
What I ended up doing is using np.indices (multiplied by the stride) to apply a mask to the x and y argmax arrays using an elementwise multiplication. Then I used elementwise division and modulus to calculate the input indexes myself. The only for loop I have in the forward pass now is a simple one for the depth of the input. The backward pass still uses a triple for loop, but I can live with that.
The model I showed in the previous comment now trains in just under 4 minutes. So now I have a roughly 3x performance increase from my original implementation. And I think that is where I am going to leave it.
Thank you for your help. Even though I didn't use all your suggestions directly, it definitely guided me in the right direction. My current implementation is FAR more efficient than any examples I could find online unfortunately.
Viewing a single comment thread. View all comments