# Differentiating Jacobian-vector product for sliced score matching?

Hi, I’m trying to implement score matching in Julia. Essentially, one needs to differentiate through the expression:

\mathbb{E}_{v \sim \mathcal{N}(0,\mathbf{I})} \mathbb{E}_{x \sim p(x)} \left( v^{\top} \nabla_x f(x; \theta) v + \frac{1}{2} {\lVert v^{\top} f(x; \theta) \rVert}^2 \right)

or equivalently:

\mathbb{E}_{x \sim p(x)} \left( \mathrm{tr}\left( \nabla_x f(x; \theta)\right) + \frac{1}{2} {\rVert f(x; \theta) \rVert}^2 \right)

It appears that people do the \nabla_x f(x; \theta) v using forward mode Jacobian-vector products and differentiate everything using reverse-mode AD.

Is there a way to do something similar in Julia?

If you’re using Zygote for reverse mode you can tell it to differentiate a specific function with ForwardDiff: see Utilities · Zygote

Not sure if it would work for your problem but you could check out FastDifferentiation.jl. It’s a package I’m developing which is in the process of being registered.

4 Likes

Hi all,

Thanks for the replies. I’m expecting to use Flux, so I think I’m tied to Zygote and co. for now. But it seems that Jacobians currently don’t place very nicely with Zygote. For example, the following examples don’t work:

using Zygote
using Flux

function main()
model = Chain(
Dense(2  => 20, tanh),
Dense(20 => 2))

θ, restructure = Flux.destructure(model)

x = randn(Float32, 2)
model′ = restructure(θ′)
_, J = Zygote.forward_jacobian(x) do x′
model′(x′)
end
tr(J) + sum(J.^2)/2
end
end

using Zygote
using Flux
using ForwardDiff

function main()
model = Chain(
Dense(2  => 20, tanh),
Dense(20 => 2))

θ, restructure = Flux.destructure(model)

x = randn(Float32, 2)
model′ = restructure(θ′)
J = ForwardDiff.jacobian(x) do x′
model′(x′)
end
tr(J) + sum(J.^2)/2
end
end


I wonder if anybody has successfully used nested differentiation with Jocobians?

What fails in the examples?

For Zygote.forward_jacobian:

ERROR: Mutating arrays is not supported -- called setindex!(Matrix{Float32}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
 error(s::String)
@ Base ./error.jl:35
 _throw_mutation_error(f::Function, args::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/HTsWj/src/lib/array.jl:88
 (::Zygote.var"#551#552"{Matrix{Float32}})(#unused#::Nothing)
@ Zygote ~/.julia/packages/Zygote/HTsWj/src/lib/array.jl:100
 (::Zygote.var"#2643#back#553"{Zygote.var"#551#552"{Matrix{Float32}}})(Δ::Nothing)
 Pullback
@ ~/.julia/packages/Zygote/HTsWj/src/lib/forward.jl:31 [inlined]
 (::Zygote.Pullback{Tuple{typeof(Zygote.forward_jacobian), var"#156#158"{Vector{Float32}, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}}, Vector{Float32}, Val{2}}, Any})(Δ::Tuple{Nothing, Matrix{Float32}})
@ Zygote ~/.julia/packages/Zygote/HTsWj/src/compiler/interface2.jl:0
 Pullback
@ ~/.julia/packages/Zygote/HTsWj/src/lib/forward.jl:44 [inlined]
 Pullback
@ ~/.julia/packages/Zygote/HTsWj/src/lib/forward.jl:42 [inlined]
 Pullback


For ForwardDiff:

┌ Warning: ForwardDiff.jacobian(f, x) within Zygote cannot track gradients with respect to f,
│ and f appears to be a closure, or a struct with fields (according to issingletontype(typeof(f))).
│ typeof(f) = var"#168#170"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}
└ @ Zygote ~/.julia/packages/Zygote/HTsWj/src/lib/forward.jl:150
(nothing,)


For the ForwardDiff case, it seems to be a known issue and obviously given the warning message.

I’m not sure but the first error looks like something you could overcome by using Lux instead, where parameters are never implicitly mutated?

Interesting. I’ll give it a shot.

It’s either that or an internal limitation of Zygote, in which case God have mercy on your soul. Everytime there’s an issue with higher-order derivatives my policy is to run for the hills

Unrelated but sum(abs2, J) / 2 will be faster since it avoids allocating. In this case it’s probably peanuts compared to the (nested) autodiff call, but in other situations it might help

1 Like

Oops seems like the worst scenario!

using Zygote
using Lux
using ForwardDiff
using Random
using LinearAlgebra

function main()
rng   = Random.default_rng()
model = Chain(
Dense(2, 20, tanh),
Dense(20, 2))

θ, st = Lux.setup(rng, model)

x = randn(rng, Float32, 2)
_, J = Zygote.forward_jacobian(x) do x′
y, _ = Lux.apply(model, x′, θ′, st)
y
end
tr(J) + sum(J.^2)/2
end
end


output:

ERROR: Mutating arrays is not supported -- called setindex!(Matrix{Float32}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
 error(s::String)
@ Base ./error.jl:35
 _throw_mutation_error(f::Function, args::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/HTsWj/src/lib/array.jl:88
 (::Zygote.var"#551#552"{Matrix{Float32}})(#unused#::Nothing)
@ Zygote ~/.julia/packages/Zygote/HTsWj/src/lib/array.jl:100
 (::Zygote.var"#2643#back#553"{Zygote.var"#551#552"{Matrix{Float32}}})(Δ::Nothing)
 Pullback
@ ~/.julia/packages/Zygote/HTsWj/src/lib/forward.jl:31 [inlined]


You need a non-mutating implementation of Jacobian. Zygote’s Jacobian implementation (and I think ForwardDiff may be doing that as well) creates a matrix and then populates it. You need to use a mapreduce + hcat/vcat implementation which is not mutating. AbstractDifferentiation.jacobian with a Zygote backend for example is non-mutating so it should be more amenable to higher order AD.

2 Likes

Typo: in the equivalent expression, there is no E_{v∼N(0,I)}.

If your x is large (seems so) and you only want a stochastic estimate of the gradient, you are probably better off using the first form with a trace estimator. This will be cheaper than computing the entire n \times n Jacobian and then taking its trace. If you even use Rademacher distributed vectors v, you will also get a lower variance in the trace estimator (ref). Rademacher v also has mean 0 and std of 1 so your second term’s expectation should be unaffected.

2 Likes

Hi,

Thanks, that’s very useful! However, it seems that not all combinations of forward-mode and reverse-mode works:

using Zygote
using Flux

tr(J) + sum(abs2, model(x))/2
end

function main()

model = Chain(
Dense(2, 20, softplus),
Dense(20, 2))

x = randn(2)
θ, restructure = Flux.destructure(model)

end


This works. But (forward, zygote) doesn’t work, with the same mutation error. Is this exected behavior?

Don’t use HigherOrderBackend for this, just pass in the backends you want.

function main()
model = Chain(
Dense(2, 20, softplus),
Dense(20, 2),
)
x = randn(2)
θ, restructure = Flux.destructure(model)
end


On my machine, FD over Zygote and FD over FD work.

I see. But excuse me if I am wrong about this (I don’t know much about AD in general), but isn’t Zygote over FD the thing that is supposed to be faster here?

I think so but time it. Zygote over Zygote and Zygote over FD are failing for 2 different Zygote bugs. ReverseDiff over FD works though.

using ReverseDiff

function main()
model = Chain(
Dense(2, 20, softplus),
Dense(20, 2),
)
x = randn(2)
θ, restructure = Flux.destructure(model)
end

1 Like

I think you will get more of a speedup by not computing the entire Jacobian though compared to anything you will get from using different modes of AD. A full dense Jacobian is expensive to compute no matter how you choose to do so. RD and FD should be asymptotically comparable when computing a full n x n Jacobian but FD will probably be faster for small n.

2 Likes

Hello,

The term v^\top \nabla_x f(x) v can be rewritten as a directional derivative \frac{\mathrm{d}}{\mathrm{d} \alpha}(v^\top f(x+\alpha v))\Big|_{\alpha = 0}: we only need the derivative of the scalar function \alpha \mapsto v^\top f(x + \alpha v) for \alpha = 0.
I don’t know how to implement it properly with Zygote though.

1 Like