 # Help with Jacobian vector product to get natural gradient

Hi I am trying to reproduce an algorithm based on Natural gradient computations (Natural Gradients in Practice - Salimbeni et al 2018).
The key computation is that you can get rid of the inverse Fisher Information matrix by replacing it with transformation derivatives, here is the relevant extract ;

The whole talk about reverse-mode relevant, as forward mode is available in Julia. However I was wondering if there was a way to compute this quantity in one pass instead of computing first \frac{\partial \xi}{\partial \theta} and then \frac{\partial \mathcal{L}}{\partial \eta}.

For precision here \xi are an arbitrary representation of the variational parameters, \theta are the natural parameters and \eta are the expectation parameters (first and second moment) However I was wondering if there was a way to compute this quantity in one pass instead of computing first ∂ξ∂θ and then ∂L∂η .

I’m a little confused by this statement, since at no point do you ever actually instantiate ∂ξ∂θ. So could you elaborate a little on what you mean by this? Well if we take a concrete example where \xi = (\mu, L) where q = \mathcal{N}(\mu, LL^\top), you still need to compute \frac{\partial (\mu, L)}{\partial \theta} right ?

I actually wrote a basic script for it to make more clear what I am doing

    using Flux: destructure
using LinearAlgebra
E, to_expec = destructure(meanvar_to_expec(μ, L))
μ, L = expec_to_meanvar(to_expec(E)...)
θ = L * randn(length(μ), nSamples) .+ μ
sum(logπ, eachcol(θ)) / nSamples + logdet(L)
end

η, to_nat = destructure(meanvar_to_nat(μ, L))
dξ_dη = jacobian(η) do η
vec(nat_to_meanvar(to_nat(η)...))
end

ξ, to_meanvar = destructure((μ, L))


I used destructure out of laziness

Right, so the point here is that you don’t ever compute the jacobian(η) explicitly – instead, to compute the natural gradient w.r.t. the parameters of your preferred parametrisation ξ you do something like

foward_mode_AD(natural_to_meanvar, theta, dL_dexpec)


which is equivalent to the jacobian-vector product in your code.

Does this clear things up, or am I missing the point?

Ah i think I get it better, you don’t compute the jacobian explicitly, instead you pass dL_dexpec as your final value and compute the result directly from there (as a jacobian vector product).
I cannot find such a function in ForwardDiff.jl, do you know what to use?

Exactly.

I cannot find such a function in ForwardDiff.jl, do you know what to use?

I’m actually not entirely sure either – I agree that I can’t immediately see it in the public API. Probably best to ask about ForwardDiff + jvps in #autodiff on Slack for a quick answer.

Note that if that doesn’t work you can presumably just use the reverse-mode trick from the paper.

Forward-mode automatic differentiation libraries are perhaps less common than reverse-mode, but fortunately there is an elegant way to achieve forward-mode automatic differentiation using reverse-mode differentiation twice

Seriously?! Isn’t forward-mode AD orders of magnitude simpler to implement than reverse-mode?!

For jvp f’(x) * h, can’t you just differentiate f(x+t h) wrt t?

1 Like

Yeah, it’s more of a comment on AD tools that the ML community tend to use, rather than AD generally – JAX is probably the first major bit of ML tooling that has forwards mode.

But yes, I agree that it’s a bit strange.

From @oxinabox on slack:

SHould be easy enough if just construct the dual numbers youself right?

duals = map(Dual, x, dx)

So the idea would be to write something like

out_duals = natural_to_meanvar(map(Dual, theta, dL_dexpec))


and then the desired natural gradient should be contained within the dual bits of out_duals.

1 Like

The derivative of f(t * a + c) wrt t at 0 is the Jacobian vector product J(f, c) * a.

julia> using ForwardDiff

julia> foo(x) = [x, x*x, x^2]
foo (generic function with 1 method)

julia> ForwardDiff.derivative(t->foo([1,2,3] * t + [3, 4, 5]), 0)
3-element Vector{Int64}:
1
14
16

julia> ForwardDiff.jacobian(foo, [3,4,5]) * [1,2,3]
3-element Vector{Int64}:
1
14
16

2 Likes

For jvp f’(x) * h, can’t you just differentiate f(x+t h) wrt t?

Sorry, didn’t appreciated this properly when you first wrote it!

@willtebbutt to answer your concern with the problems of such a method, here is the kind of optimiser you need to use to make sure everything is okay :

struct IncreasingRate
α::Float64 # Maximum learning rate
γ::Float64 # Convergence rate to the maximum
state
end

IncreasingRate(α=1.0, γ=1e-8) = IncreasingRate(α, γ, IdDict())

function Optimise.apply!(opt::IncreasingRate, x, g)
t = get!(()->0, opt.state, x)
opt.state[x] += 1
return g .* opt.α * (1 - exp(-opt.γ * t))
end

1 Like