Viewing a single comment thread. View all comments

elbiot t1_j0k91nm wrote

I think unravel is a tuple so you can just star unpack it to use it as indices without having to do anything else with it

1

Logon1028 OP t1_j0lxkae wrote

That's what I am doing currently. But I have to unpack it in a triple nested for loop because numpy doesn't accept tuples. So I don't gain the benefits of numpy's parallelization. Which is why I was searching for a possible alternative. I am not trying to like super optimize this function, but I want all the low hanging fruit I can get. I want people to be able to use the library to train small models for learning purposes.

1

elbiot t1_j0m3a67 wrote

Huh?

idx = unravel_indices(indices, shape) Values=arr[*idx]

No loop required. If you're referring to the same loop you were using to get the argmax, you can just adjust your indices first so they apply to the unstrided array

1

Logon1028 OP t1_j0mbuxc wrote

Yes, but that unravel_indices has to be applied to EVERY SINGLE ELEMENT of the last axis independently. i.e.

        for depth in range(strided_result.shape[0]):
            for x in range(strided_result.shape[1]):
                for y in range(strided_result.shape[2]):
                    local_stride_index = np.unravel_index(argmax_arr[depth][x][y], strided_result[depth][x][y].shape)

unravel_indices only takes a 1d array as input. In order to apply it to only the last axis of the 4D array you have to use a 4 loop. unravel_indices has no axis parameter.

1

Logon1028 OP t1_j0mcyle wrote

What I ended up doing is using np.indices (multiplied by the stride) to apply a mask to the x and y argmax arrays using an elementwise multiplication. Then I used elementwise division and modulus to calculate the input indexes myself. The only for loop I have in the forward pass now is a simple one for the depth of the input. The backward pass still uses a triple for loop, but I can live with that.

The model I showed in the previous comment now trains in just under 4 minutes. So now I have a roughly 3x performance increase from my original implementation. And I think that is where I am going to leave it.

Thank you for your help. Even though I didn't use all your suggestions directly, it definitely guided me in the right direction. My current implementation is FAR more efficient than any examples I could find online unfortunately.

1