Is there a way to teach Zygote to derive Diagonal * Vector more efficiently?

I’m still getting my head around reverse-mode diff, but is there a way to teach Zygote to do the derivative of Diagonal * Vector without allocating N^2 memory?

using LinearAlgebra, BenchmarkTools, Zygote

v = rand(4096)
D = Diagonal(v)

@btime gradient(α -> norm((α * D) * v), 1)
# 53.308 ms (32915 allocations: 129.41 MiB)

The 129.29MiB is basically the size of v*v' which appears to get computed into a dense matrix in one of the adjoints. However, if I rewrite the exact same operation slightly differently I can get:

@btime gradient(α -> norm((α * D).diag .* v), 1)
# 871.463 μs (32919 allocations: 1.41 MiB)

So in theory it appears possible. I tried adding something inspired by the adjoint rule for Vector .* Vector,

@adjoint *(x::Diagonal, y::Vector) = x.diag .* y,
  z̄ -> (unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x)))

but this does not work (yields wrong answer, and memory consumption is the same).

Does anyone have any suggestions on if this is possible (seems it must be?), and if so, how to do it? Many thanks.

That’s probably because z̄ .* conj.(x) creates a matrix? This seems to work:

@adjoint *(x::Diagonal, y::Vector) = x.diag .* y,
    z̄ -> (Diagonal(unbroadcast(x.diag, z̄ .* conj.(y))), unbroadcast(y, z̄ .* conj.(x.diag)))
1 Like

Ah, I had messed it up a bit, your solutions makes sense, thanks!

1 Like

Please consider contributing this to Zygote.jl.