Flux has the clever array type Zeros
for which arithmetic operations are implemented as basically noops (well, I guess you get the idea).
The main use case is to turn of the bias term but the problem is that some ops need to reshape it and then the magic disappears:
julia> using Flux
julia> z = zeros(1000,1000);
julia> zz = Flux.Zeros(1000,1000);
julia> using BenchmarkTools
julia> @btime z+z;
3.947 ms (2 allocations: 7.63 MiB)
julia> @btime z+zz;
64.287 ns (1 allocation: 32 bytes)
julia> @btime z+reshape(zz,1000,:);
3.987 ms (11 allocations: 7.63 MiB)
julia> typeof(reshape(zz,1000,:))
Base.ReshapedArray{Bool,2,Flux.Zeros{Bool,2},Tuple{}}
What is the best way to prevent this from happening?
Implementing reshape
for Zeros
is an option of course, but since it has more than a handful of methods which all specialize on the second argument one must implement all of them or else there will be ambiguity and this feels a bit brute-force-ish. I guess another option is to implement methods for ReshapedArray
of Zeros
but this is probably even worse in this aspect.
Is there a more elegant way?