Zygote/Flux not computing gradient with view/slice of tensor properly for MLP

I am working on an implementation of an ML algorithm called SLIDE (sub-linear deep learning engine) from an academic paper. I’ve included a link to a gist of the code: https://gist.github.com/outlace/c81a15ab77c1c9cebb56003545ee5512

Basically the algorithm tries to make a deep learning algorithm that scales sub-linearly with respect to the size of the parameter matrices of each layer. Normally a matrix multiplication of two n x n matrices has a time complexity of O(n^2), but SLIDE proposes using a node sub-sampling scheme, so that given an n x m matrix, where each row (or column, depending on how you setup the matmul) is a ‘node’ in the neural network, you sub-sample a small proportion of nodes (rows) s and then do a matrix multiplication using this much smaller matrix s x m where s << n. If you increase the size of the matrix the sub-sampled number of rows s will grow sub-linearly, so that your overall processing time grows sublinearly as you scale up the matrix.

Ok, so basically to implement this I just generate a list of indices of the nodes (rows) I want to use and then do the forward pass of my simple 2-layer fully connected neural network using these smaller matrices by doing a view/slice of the original parameter matrices during training. But it appears when Zygote/Flux is doing the backward pass to compute gradients it is using the full n x m matrix and not my view/slice of it because the wall time is growing linearly as I increase the n dimension, and it should be relatively constant as I am only sampling about 50 nodes regardless of the size of the original parameter matrix.

Is there anyway to get Zygote/Flux to do the backward pass using a view/slice so that I can reap the performance benefit? Or perhaps there’s an issue with my implementation. Any help is greatly appreciated.


Here’s actually the minimally working example that illustrates my problem:

function test1(x,W)
    return sum(x * W)

function test2(x,W,S)
    return sum(x * W[:,S])

dim1,dim2 = 784,3000 #vary dim2 from 1000 to 2000 to 3000 to est. time scaling
x = randn(1,dim1);
W1 = randn(dim1,dim2)
W2 = randn(dim1,dim2);
S = rand(1:dim2,50); #pick 50 random indices to subsample the parameter matrix


@btime test1(x,W1) #scales approximately linearly with increasing dim2
@btime test2(x,W2,S) #runs in constant time with increasing dim2
@btime Zygote.gradient(w -> test1(x,w),W1) # 1.554 ms -2x-> 4.073 ms -3x-> 7.634 ms  :scales approx linearly
@btime Zygote.gradient(w -> test2(x,w,S),W2) # 321.709 μs -2x-> 488.125 μs -3x-> 810.333 μs  :scales approx linearly

So the gradient for test2 which uses a view/slice of the original parameter matrix does run much faster than the full matrix version, but it is still scaling approximately linearly with increasing size of dim2 when I expect it to be approximately constant time due to the forward pass being constant time (forward and backpropagation should have the same time complexity). What’s going on?

1 Like

It looks like it is a memory bottleneck. The forward pass for test2 has constant time and has constant memory allocation, whereas for the backward pass with Zygote the memory allocation scales linearly for test2.

This is not just a Zygote issue, I tried this with PyTorch and it’s worse performance in absolute terms and worse scaling.

1 Like

Zygote seems to materialize a dense gradient for W, so the best way to reduce linear growth here is to only pass the region of interest. Viz.

using Flux
using BenchmarkTools

function grad_update_full!(W, x, S, opt)
    ∂f∂W, = gradient(w -> sum(view(w, S, :) * x), W)
    # @show ∂f∂W
    Flux.Optimise.update!(opt, W, ∂f∂W)
    return W

function grad_update_partial!(W, x, S, opt)
    Wₛ = view(W, S, :)
    ∂f∂Wₛ, = gradient(w -> sum(w * x), Wₛ)
    # @show ∂f∂Wₛ
    Flux.Optimise.update!(opt, Wₛ, ∂f∂Wₛ)
    return W

## Sanity check
let indim = 5, outdim = 10, sample_ncol = 3
    W = randn(outdim, indim)
    x = randn(indim, 1)
    S = rand(1:outdim, sample_ncol)
    opt = Descent(0.1)

    w1 = grad_update_full!(copy(W), x, S, opt)
    w2 = grad_update_partial!(copy(W), x, S, opt)
    @assert w1 == w2

## Benchmark comparison
let indim = 512, sample_ncol = 64
    x = randn(indim, 1)
    opt = Descent(0.1)

    for i in 8:16
        outdim = 2^i
        W = randn(outdim, indim)
        S = rand(1:outdim, sample_ncol)

        @show outdim
        @btime grad_update_full!(w, $x, $S, $opt) setup=(w = copy($W))
        @btime grad_update_partial!(w, $x, $S, $opt) setup=(w = copy($W))

Which results in:

outdim = 256
  466.603 μs (27 allocations: 1.26 MiB)
  246.108 μs (30 allocations: 518.19 KiB)

outdim = 512
  848.467 μs (27 allocations: 2.26 MiB)
  298.342 μs (30 allocations: 518.19 KiB)

outdim = 1024
  1.487 ms (27 allocations: 4.26 MiB)
  370.419 μs (30 allocations: 518.19 KiB)

outdim = 2048
  2.803 ms (27 allocations: 8.26 MiB)
  554.691 μs (30 allocations: 518.19 KiB)

outdim = 4096
  5.356 ms (27 allocations: 16.26 MiB)
  696.186 μs (30 allocations: 518.19 KiB)

outdim = 8192
  15.307 ms (27 allocations: 32.26 MiB)
  752.438 μs (30 allocations: 518.19 KiB)

outdim = 16384
  30.165 ms (27 allocations: 64.26 MiB)
  799.389 μs (30 allocations: 518.19 KiB)

outdim = 32768
  59.796 ms (27 allocations: 128.26 MiB)
  854.441 μs (30 allocations: 518.19 KiB)

outdim = 65536
  119.296 ms (27 allocations: 256.26 MiB)
  850.507 μs (30 allocations: 518.19 KiB)

While runtime and memory usage increase linearly with the full gradient, taking the partial gradient (whether with a view or a materialized slice) does not.
There is some runtime overhead, but that’s more likely due to update! being O(n) as well as system-related factors like cache.
Hope this helps!


Very neat, thanks a lot!

So I can replicate this by directly passing in a view to gradient but for some reason it doesn’t seem to work when I use the Flux.params(...) way - do you know how to get it to work with Params?

The model you’re calling params on needs to have views as its parameters instead of the full arrays. Which likely means you’ll have to reconstruct the model before every call to gradient. Alternatively, you can define your own mutable (callable) layers and swap out the parameters every iteration. e.g. m = MyDense(view(X, S1)); ... m.weight = view(X, S2).

I think what I’m trying to do might be impossible in Zygote then, because I want to extend that minimal example to something like this

function test3(x,w1,w2,S)
    layer1 = x * view(W1,:,S) .|> relu
    S2 = rand(1:size(W2)[2],50)
    layer2 = layer2 * view(W2,S,S2) |> sum

The first S of indices I generate based on the input x, which I can do from outside the model call, so the solution you gave works. But for the second matrix, I want to do the same thing and generate a different list of indices to generate a view with but this S2 is generated based on the output of the first layer, so I cannot pass in both views of the parameters without running the model. I guess I could run the model forward, save S and S2 and then use those to pass views for the backward pass, but then I’d be doing 2 forward passes and 1 backward pass of the model instead of 1 forward and backward. I guess that might still be worth it compared Zygote allocating a full dense matrix each backward pass.

You can still create S2 and view(W2,S,S2) beforehand because they only depend on the size of the output. At least I assume that’s the intent, because rand(1:size(W2)[2],50) is even easier to compute beforehand because you know the size of W2 and doesn’t depend on output size at all.

The only scenario I can think of where you’d be forced to create the view during the forward pass is if the indices used depend on the values in the previous layer output. Since it seems you’d only look at the size and not the values, the most you’d have to do is calculate sizes up-front. Maybe even once, if the number of rows/columns sampled at each layer is constant. Utility Functions · Flux can help with this calculation for complex layers, but based on your previous gist even that may be overkill.

Yes, unfortunately my real algorithm needs to generate indices based on the value of the output of the previous layer. I made up the test1 and test2 functions trying to be the simplest analogy to my algorithm but I see I didn’t succeed fully. In my gist above I use the function sample_nodes to generate the list of indices, and it takes the input vector (either the input data or the output vector from previous layer), runs it through a randomized hash function, and then uses that to lookup set of indices stored in a hash table. The idea is that for a given input vector to a layer, sample_nodes finds the set of nodes (i.e. rows in the parameter matrix) that have a large inner product with the input vector and then sub-samples those rows, since nodes with small inner products don’t contribute much to the output. That way you have sparse activations and updates and the neural network runs in sub-linear time and it has been shown this doesn’t sacrifice accuracy for fully connected nets at least.

Makes sense, I wasn’t sure if you were implementing the full SLIDE algorithm or some adaptation of it.
In general, I wouldn’t expect an LSH style scheme like SLIDE uses to play very well with array-centric AD like Zygote. You could try writing custom rules, but there’s no guarantee that would solve the materialization issue.

A more straightforward but involved route would be to manually write the gradient calculation and backprop routines: this is how the original implementation does it, and the code doesn’t look too bad since layer computation outside of LSH is pretty simple. Ditching AD would also allow you to turn up the optimization dial. The SLIDE papers make liberal use of threading, SIMD and other techniques, all of which should be possible to replicate with Julia. See GitHub - JuliaSIMD/SimpleChains.jl: Simple chains for an example of how much of a difference clever use of SIMD alone can make for MLPs.

Thanks for all the help, I really appreciate it! I think for now I am just trying to get a prototype working so I care about the time complexity being sub-linear but I don’t need to spend time optimizing away the constant overhead right now. I will try seeing if some of the other AD packages happen to manage memory better but otherwise I think I will just use the 3-pass approach where I run the model forward to get the sampled nodes and then use that to pass in views of the parameter matrices as you showed above. It works and the forward model run is quite cheap so it’s not that much additional overhead and still runs in near constant time. Once I get a working prototype I can then start optimizing and part of that will probably be me just computing the gradients manually, which is fine since I don’t anticipate needing to change the architecture significantly once I get it working.

I figured out a solution, just in case you’re curious. Basically I store the parameters and a view into the parameters separately and update the view without recording gradients and it works. Something like this:

[incomplete code]

mutable struct layer

function test4(x,ms)
#ms is a vector of layers, x is some input vector
    Zygote.ignore() do
        S = rand(1:size(ms[1].W)[2],50)
        ms[1].Wview = view(ms[1].W,:,S)
        ms[2].Wview = view(ms[2].W,S,:)
    y = x * ms[1].Wview
    y = Flux.relu.(y)
    y = y * ms[2].Wview


Hi @outlace, I’m currently working on my own implementation of SLIDE in Julia. Have you been able to finish your implementation? It might be interesting to compare notes :smiley:


Very cool! I’ll throw mine up as a github gist when I get a chance. I have a working prototype, currently trying to implement the follow up paper MONGOOSE. I tried to get SLIDE to train a simple 2 layer neural network on MNIST. It works. It runs faster than normal dense training due to the subsampling of the matrix rows, but it takes more iterations to reach the same level of accuracy as a regular network; so kind of disappointing. But the MNIST problem is probably not a good use case for SLIDE. It probably just works for inherently sparse data/layers which is what they used it for with the massive classification task where the last layer is inherently sparse. How about you?

1 Like

I haven’t gotten around to reading the mongoose paper yet but it sounds promising. Although given your results on mnist, I’m a little less excited about this whole approach, to be honest. I expected slightly worse results when SLIDE is applied to smaller networks, since SLIDE scales better in terms of layer size, but the fact that it also seems to have worse convergence behaviour is slightly surprising. Mostly because the convergence behaviour on the Delicious-200K and Amazon-670K datasets was almost identical for SLIDE and tensorflow.

Anyway, my main motivation for implementing this is to try to extend the framework beyond fully connected layers. I figured that since layers like convolutional layers can be considered as fully connected layers with weight tying and a subset of the weights being permanently zero, this could be relatively easy. Although this is of course a very crude approach, so the implementation would serve as a proof of concept. Have you maybe tried this out yourself?

1 Like

Well there is still more fiddling to be done to maybe improve its performance. For example, I wonder if manually computing gradients would be faster as Zygote has a hard time with the matrix subsampling. I am also primarily interested in implementing SLIDE for computer vision applications, so I can do training and inference on a CPU. I was planning on implementing one of the new MLP-based vision networks like MLP-Mixer ( https://papers.nips.cc/paper/2021/file/cba0a4ee5ccd02fda0fe3f9a3e7b89fe-Paper.pdf ) using SLIDE to see if that would work.

I uploaded my experiments to this git repo: https://github.com/outlace/SLIDE-Pose-Estimation

So I was able to get Zygote to get gradients but Flux’s built in optimizer doesn’t seem to work with my SLIDE implementation, so I had to write an ADAM optimizer from scratch, that’s in the optim.jl file.

I’m not sure whether circumventing zygote would help much, but I have a feeling that it will boost performance noticeably. I read some of the documentation for creating custom auto differentiation rules and a basic setup doesn’t seem very complex. Of course it would have to be merged with a SLIDE implementation which is a lot more of a challenge :smiley:

I’d love to take a look at your implementation, but it seems like the repo you linked is set to private.

About having to write custom optimizers, there might be some trickery involved in handling sparse gradients but this shouldn’t be a problem, at least conceptually. When I’m going through your code I’ll try to find out if that could be prevented.

1 Like

Whoops - you’re right it was set to private, fixed.

It most certainly would, though given the relative simplicity of the network in SLIDE it may be easier to do manual gradient calculations like the reference implementation does.