Avoid getting a ReshapedArray when calling reshape?

Flux has the clever array type Zeros for which arithmetic operations are implemented as basically noops (well, I guess you get the idea).

The main use case is to turn of the bias term but the problem is that some ops need to reshape it and then the magic disappears:

julia> using Flux

julia> z = zeros(1000,1000);

julia> zz = Flux.Zeros(1000,1000);

julia> using BenchmarkTools
julia> @btime z+z;
  3.947 ms (2 allocations: 7.63 MiB)

julia> @btime z+zz;
  64.287 ns (1 allocation: 32 bytes)

julia> @btime z+reshape(zz,1000,:);
  3.987 ms (11 allocations: 7.63 MiB)

julia> typeof(reshape(zz,1000,:))
Base.ReshapedArray{Bool,2,Flux.Zeros{Bool,2},Tuple{}}

What is the best way to prevent this from happening?

Implementing reshape for Zeros is an option of course, but since it has more than a handful of methods which all specialize on the second argument one must implement all of them or else there will be ambiguity and this feels a bit brute-force-ish. I guess another option is to implement methods for ReshapedArray of Zeros but this is probably even worse in this aspect.

Is there a more elegant way?

2 Likes