Will Flux/Zygote compute gradients sparsely?

So in short I want to train an MLP by sparsely updating it’s parameters. I do this on the forward pass by randomly subsampling the columns of my parameter matrices. E.g. for a simple 2 layer model on MINST with input data as (batch_size, 784), my layers are matrices of 784x512 and 512x10. But I will (sort of) randomly subsample a small proportion of the columns e.g. for the first layer compution, let’s say I subsample 50 columns, then my matrix multiply results in an output (batch_size, 50) instead of (batch_size, 512), however, in order to pass this to the next layer, I initialize a vector of zeros(512) and set the 50 subsampled indices to match the output, then pass this to the next layer. So the goal is to only update those 50 subsampled columns and not the entire original matrix.

This results in a massive savings in computation since the matrix multiplications are much smaller. Ok, so my question is, will Flux/Zygote on the backward pass also do this subsampling and therefore save me the computations I saved on the forward pass?

Given that Zygote isn’t spectacular with lots of indexing and BLAS really wants contiguous arrays, my hunch would be that any performance gain will be dwarfed by increased allocation or hitting suboptimal code paths.

However, this is a question best answered by a benchmark: give both a try and see what happens! @showgrad and @code_adjoint will be useful for interrogating what the AD system is doing.

1 Like

I implemented it but I discovered that Zygote cannot handle mutating arrays which is what I need to do for this sub-sampling routine. So I need to calculate the gradients manually.

I assumed the code was Zygote-friendly already, but if you still want to try there are a few ways to get around the array mutation limitation:

  1. Use Zygote.Buffer. This will allow direct mutation, but may be a little awkward depending on your use case.
  2. Write a custom AD rule that handles the mutation. This is the most involved approach, but allows for full control over the operations performed and thus may deliver the best performance. See https://juliadiff.org/ChainRulesCore.jl/stable/writing_good_rules.html#Patterns-that-need-rules-in-[Zygote.jl](https://github.com/FluxML/Zygote.jl) for more.
  3. See if there are existing AD-compatible functions that might fit.

Here, it seems 3. might be the way to go. NNlib exposes a scatter function that basically does all of this

in one line. Here’s a quick code snippet to demonstrate:

# given size(matmul_output) == (50, N) and size(sampled_cols) == (50,), where N is the batch size
NNlib.scatter(+, matmul_output, CartesianIndices((sampled_cols, 1:N)), dstsize=(512, N))

For more on CartesianIndices, see the REPL help or the manual.

1 Like

This sounds like channel pruning. In Flux the dimensions are the other way around, so this is removing all but 50 of the rows of the weight matrix. My guess is that just slicing these matrices will work fine here.

If I understand right this is the step where you wanted to mutate an array, to get back to full size for the next layer. But doing so means you lose half the benefit, since the very next matrix multiplication then involves a lot of zeros. If you keep 10% of each layer like this, you save 90%, but if you trim the columns of the next layer too, you save (in a long chain) 99%.

Maybe someone has a nicer library for this, but my quick function would look something like this. It’s important that trim happens inside the gradient call, to get gradients for the original weights.

julia> trim(d::Dense, pre, post) = Dense(d.weight[post, pre], d.bias isa AbstractVector && d.bias[post], d.σ);

julia> trim(Dense(10,20), 1:3, 11:15)
Dense(3, 5)         # 20 parameters

julia> function trim(c::Chain, mids...)
        l = 0  # leaves input size unchanged
        t = map(c.layers) do d
          trim(d, get(mids, l, :), get(mids, l+=1, :))
        end
        Chain(t...)
       end;

julia> m = Chain(Dense(784, 512, tanh), Dense(512, 10));

julia> trim(m, 1:50)
Chain(
  Dense(784, 50, tanh),                 # 39_250 parameters
  Dense(50, 10),                        # 510 parameters
)                   # Total: 4 arrays, 39_760 parameters, 155.562 KiB.

julia> x = rand(Float32, 784, 32);

julia> g = gradient(params(m)) do
         m2 = trim(m, 1:3:512)  # keep every 3rd row of m[1].weight
         y = m2(x)
         sum(abs2, y)
       end
Grads(...)

julia> g[m[2].weight]
10×512 Matrix{Float32}:
 -32.1707   0.0  0.0  -5.32201   0.0  0.0  …  0.0  0.0  -0.0301256  0.0  0.0   7.36754  0.0
   5.38624  0.0  0.0   0.589724  0.0  0.0     0.0  0.0   0.63017    0.0  0.0  -1.70225  0.0
   9.69729  0.0  0.0   1.40005   0.0  0.0     0.0  0.0  -1.59133    0.0  0.0  -2.6516   0.0
...

julia> g[m[1].weight]
512×784 Matrix{Float32}:
 -0.13919    -0.133159   -0.125593    …  -0.0448132  -0.12223    -0.0714781  -0.124712
  0.0         0.0         0.0             0.0         0.0         0.0         0.0
  0.0         0.0         0.0             0.0         0.0         0.0         0.0
  0.379954    0.316367    0.297065        0.533378    0.312041    0.392742    0.608988
  0.0         0.0         0.0             0.0         0.0         0.0         0.0
2 Likes

I think Zygote.buffer solves my issues actually.
@mcabbott Yes, subsampling the subsequent layer to match the output of the first layer would result in more savings but in this case I only want to subsample the columns of each parameter matrix, so since the last layer only has 10 columns, it doesn’t save much. I’m testing on just a 2-layer MLP for now but once I get it working will use more layers. For each layer’s parameter matrix I will sub-sample a different set of columns (adaptively based on the data), that’s why I need to map the output vector back to the expected length.

So for example with a 3-layer MLP for MNIST:
Layer 1 parameter matrix W1: 784 x 512
Layer 2 W2: 512 x 250
Layer 3 W3: 250 x 10 (10 output classes, softmax)

For each input, I have an algorithm generate a set of indices of columns S which are the best columns to sample from each matrix. So input X, I decide to sub-sample e.g. 50 columns from W1, then run this smaller mat-mul using the 784x50 subset of W1, getting a length 50 output vector. The next layer expects a 512 length vector, so I generate a length 512 zero vector, and since each element of my 50-element vector represents the inner product between the input and one of the sub-sampled columns, I map that value to the corresponding index in my zero’d 512 vector using the indices in S.

To be concrete, let’s say instead of 50 subsampled columns it’s only 3, and the indices sampled are S=[5,90,498], so my output vector is e.g. l2 = [0.4, -0.1, -.9] and I need to map this back into a 512-dimensional vector. So I generate a vector z=zeros(512) and then set z[5] = l2[1] and z[90] = l2[2] and z[498] = l2[3].

Now I have a sparse 512 output vector from layer 1, and I again use my algorithm to generate a list of columns S to subsample from layer 2 W2 and do the same thing.

This is based on this paper: [1903.03129] SLIDE : In Defense of Smart Algorithms over Hardware Acceleration for Large-Scale Deep Learning Systems which shows this kind of sparse approximate training can be make training an MLP orders of magnitude faster to the point that training on a CPU becomes comparable to conventional training on a GPU.

But why must you do this?

I can predict exactly what matmul with all those zeros will do, and not ever do them. This step helps even in my 2-layer example, without changing the output at all.

If there are 3 layers, this step does not interfere with the choices you make for which rows to ignore in the 2nd layer (hence columns of the 3rd). That’s the next argument in mids... in my code.

I think you still have rows & columns backwards here, for Flux’s conventions.

1 Like

Oh you’re totally right, of course, in the next layer I can subsample both the columns and rows. And yes I may be mixing columns and rows I’m coming from Python / PyTorch. Thanks for the help!

2 Likes