Will Flux/Zygote compute gradients sparsely?

This sounds like channel pruning. In Flux the dimensions are the other way around, so this is removing all but 50 of the rows of the weight matrix. My guess is that just slicing these matrices will work fine here.

If I understand right this is the step where you wanted to mutate an array, to get back to full size for the next layer. But doing so means you lose half the benefit, since the very next matrix multiplication then involves a lot of zeros. If you keep 10% of each layer like this, you save 90%, but if you trim the columns of the next layer too, you save (in a long chain) 99%.

Maybe someone has a nicer library for this, but my quick function would look something like this. It’s important that trim happens inside the gradient call, to get gradients for the original weights.

julia> trim(d::Dense, pre, post) = Dense(d.weight[post, pre], d.bias isa AbstractVector && d.bias[post], d.σ);

julia> trim(Dense(10,20), 1:3, 11:15)
Dense(3, 5)         # 20 parameters

julia> function trim(c::Chain, mids...)
        l = 0  # leaves input size unchanged
        t = map(c.layers) do d
          trim(d, get(mids, l, :), get(mids, l+=1, :))
        end
        Chain(t...)
       end;

julia> m = Chain(Dense(784, 512, tanh), Dense(512, 10));

julia> trim(m, 1:50)
Chain(
  Dense(784, 50, tanh),                 # 39_250 parameters
  Dense(50, 10),                        # 510 parameters
)                   # Total: 4 arrays, 39_760 parameters, 155.562 KiB.

julia> x = rand(Float32, 784, 32);

julia> g = gradient(params(m)) do
         m2 = trim(m, 1:3:512)  # keep every 3rd row of m[1].weight
         y = m2(x)
         sum(abs2, y)
       end
Grads(...)

julia> g[m[2].weight]
10×512 Matrix{Float32}:
 -32.1707   0.0  0.0  -5.32201   0.0  0.0  …  0.0  0.0  -0.0301256  0.0  0.0   7.36754  0.0
   5.38624  0.0  0.0   0.589724  0.0  0.0     0.0  0.0   0.63017    0.0  0.0  -1.70225  0.0
   9.69729  0.0  0.0   1.40005   0.0  0.0     0.0  0.0  -1.59133    0.0  0.0  -2.6516   0.0
...

julia> g[m[1].weight]
512×784 Matrix{Float32}:
 -0.13919    -0.133159   -0.125593    …  -0.0448132  -0.12223    -0.0714781  -0.124712
  0.0         0.0         0.0             0.0         0.0         0.0         0.0
  0.0         0.0         0.0             0.0         0.0         0.0         0.0
  0.379954    0.316367    0.297065        0.533378    0.312041    0.392742    0.608988
  0.0         0.0         0.0             0.0         0.0         0.0         0.0
2 Likes