So in short I want to train an MLP by sparsely updating it’s parameters. I do this on the forward pass by randomly subsampling the columns of my parameter matrices. E.g. for a simple 2 layer model on MINST with input data as (batch_size, 784), my layers are matrices of 784x512 and 512x10. But I will (sort of) randomly subsample a small proportion of the columns e.g. for the first layer compution, let’s say I subsample 50 columns, then my matrix multiply results in an output (batch_size, 50) instead of (batch_size, 512), however, in order to pass this to the next layer, I initialize a vector of zeros(512) and set the 50 subsampled indices to match the output, then pass this to the next layer. So the goal is to only update those 50 subsampled columns and not the entire original matrix.
This results in a massive savings in computation since the matrix multiplications are much smaller. Ok, so my question is, will Flux/Zygote on the backward pass also do this subsampling and therefore save me the computations I saved on the forward pass?
Given that Zygote isn’t spectacular with lots of indexing and BLAS really wants contiguous arrays, my hunch would be that any performance gain will be dwarfed by increased allocation or hitting suboptimal code paths.
However, this is a question best answered by a benchmark: give both a try and see what happens! @showgrad and @code_adjoint will be useful for interrogating what the AD system is doing.
I assumed the code was Zygote-friendly already, but if you still want to try there are a few ways to get around the array mutation limitation:
Use Zygote.Buffer. This will allow direct mutation, but may be a little awkward depending on your use case.
Write a custom AD rule that handles the mutation. This is the most involved approach, but allows for full control over the operations performed and thus may deliver the best performance. See Writing Good Rules · ChainRulesZygote.jl for more.
See if there are existing AD-compatible functions that might fit.
Here, it seems 3. might be the way to go. NNlib exposes a scatter function that basically does all of this
in one line. Here’s a quick code snippet to demonstrate:
# given size(matmul_output) == (50, N) and size(sampled_cols) == (50,), where N is the batch size
NNlib.scatter(+, matmul_output, CartesianIndices((sampled_cols, 1:N)), dstsize=(512, N))
For more on CartesianIndices, see the REPL help or the manual.
This sounds like channel pruning. In Flux the dimensions are the other way around, so this is removing all but 50 of the rows of the weight matrix. My guess is that just slicing these matrices will work fine here.
If I understand right this is the step where you wanted to mutate an array, to get back to full size for the next layer. But doing so means you lose half the benefit, since the very next matrix multiplication then involves a lot of zeros. If you keep 10% of each layer like this, you save 90%, but if you trim the columns of the next layer too, you save (in a long chain) 99%.
Maybe someone has a nicer library for this, but my quick function would look something like this. It’s important that trim happens inside the gradient call, to get gradients for the original weights.
I think Zygote.buffer solves my issues actually. @mcabbott Yes, subsampling the subsequent layer to match the output of the first layer would result in more savings but in this case I only want to subsample the columns of each parameter matrix, so since the last layer only has 10 columns, it doesn’t save much. I’m testing on just a 2-layer MLP for now but once I get it working will use more layers. For each layer’s parameter matrix I will sub-sample a different set of columns (adaptively based on the data), that’s why I need to map the output vector back to the expected length.
So for example with a 3-layer MLP for MNIST:
Layer 1 parameter matrix W1: 784 x 512
Layer 2 W2: 512 x 250
Layer 3 W3: 250 x 10 (10 output classes, softmax)
For each input, I have an algorithm generate a set of indices of columns S which are the best columns to sample from each matrix. So input X, I decide to sub-sample e.g. 50 columns from W1, then run this smaller mat-mul using the 784x50 subset of W1, getting a length 50 output vector. The next layer expects a 512 length vector, so I generate a length 512 zero vector, and since each element of my 50-element vector represents the inner product between the input and one of the sub-sampled columns, I map that value to the corresponding index in my zero’d 512 vector using the indices in S.
To be concrete, let’s say instead of 50 subsampled columns it’s only 3, and the indices sampled are S=[5,90,498], so my output vector is e.g. l2 = [0.4, -0.1, -.9] and I need to map this back into a 512-dimensional vector. So I generate a vector z=zeros(512) and then set z = l2 and z = l2 and z = l2.
Now I have a sparse 512 output vector from layer 1, and I again use my algorithm to generate a list of columns S to subsample from layer 2 W2 and do the same thing.