Here’s a quick attempt:
function outers1(states, tanh_term)
number_of_states = size(states, 1)
size(states, 1) == number_of_states || error()
dP_dw = zeros(Float64, (number_of_states, size(states,2), size(tanh_term,2)))
@inbounds for a = 1:number_of_states
@views dP_dw[a, :, :] .= states[a, :] .* tanh_term[a, :]'
end
dP_dw
end
using Einsum, Test
outers2(states, tanh_term) = @einsum dP_dw[a,s,t] := states[a, s] * tanh_term[a, t]
outers3(states, tanh_term) = @vielsum dP_dw[a,s,t] := states[a, s] * tanh_term[a, t]
outers4(states, tanh_term) = states .* reshape(tanh_term, size(states, 1), 1, :)
N = 50; states, tanh_term = randn(N,N), randn(N,N);
@test outers1(states, tanh_term) ≈ outers2(states, tanh_term) ≈ outers4(states, tanh_term)
#== results ==#
julia> @btime outers0($states, $tanh_term); # as in question
364.486 μs (202 allocations: 1.96 MiB)
julia> @btime outers1($states, $tanh_term); # with @inbounds, @views, and .=
189.673 μs (2 allocations: 976.64 KiB)
julia> @btime outers2($states, $tanh_term);
62.388 μs (2 allocations: 976.64 KiB)
julia> @btime outers3($states, $tanh_term);
52.462 μs (62 allocations: 984.86 KiB)
julia> @btime outers4($states, $tanh_term);
63.387 μs (4 allocations: 976.75 KiB)
If you look at @macroexpand1 @einsum dP_dw[a,s,t] := states[a, s] * tanh_term[a, t]
, the main change is that it orders the loops with a
innermost.