Gradient of neural network parameters used inside of SMC simulation

I’m working on implementing variational sequential monte carlo as per this paper

A simplified version of the set of functions I’m using are:

using 
    Optimisers, Flux, ForwardDiff, Distributions, ...; 


# samples particles
function sim_state(..., rng=rng)
    # random sample using rng: e.g. rand(rng, MvNormal(...))
    # outputs a tuple 
    #return (samples = ..., 
    #        logprob = ...
    #    )
end

# forward model whose parameters I want to optimize
ffnn = Chain(Dense(5 => 20, relu), Dense(20, 5))

# additional helper functions here
# ....

# finally the function I am optimizing
function sim_vsmc(
    forward_model,
    y::Vector,
    X::Matrix,
    n_particles::Int64,
    sig2w::Vector{Float64},
    sig2l::Float64;
    rng_num = 1,
    lower_bound_only=false
)

    rng = Xoshiro(rng_num)
    T = length(y)
    u = sim_state(forward_model, sig2w, sig2l, n_particles; rng=rng)
    logz1 = samp_logprob(1, u.state_samp, y[1], X[1, :], sig2w, sig2l, n_particles)

    logpw = zeros(T, n_particles)
    pw = zeros(T, n_particles)
    traj_u = zeros(T, length([sig2w; sig2l]), n_particles)

    logpw[1, :] = (sum(logz1; dims=1) |> drop_dim) .- u.state_logprob
    pw[1, :] = logpw[1, :] |> safe_weights
    traj_u[1, :, :] = u.state_samp

    for t in 2:T
        ancestor_idxs = resample_stratified(rng, norm_weights(pw[t-1, :]), n_particles)
        u = sim_state(forward_model, sig2w, sig2l, n_particles; prev_states=traj_u[t-1, :, ancestor_idxs], rng=rng)
        traj_u[t, :, :] = u.state_samp
        traj_u[1:t-1, :, :] = traj_u[1:t-1, :, ancestor_idxs]
        logzt = samp_logprob(t, u.state_samp, y[t], X[t, :], sig2w, sig2l, n_particles; prev_states=traj_u[t-1, :, :])
        logpw[t, :] = (sum(logzt; dims=1) |> drop_dim) .- u.state_logprob
        pw[t, :] = logpw[t, :] |> safe_weights
    end

    if lower_bound_only
        elbo_hat = log.(mean(pw; dims=2)) |> sum
        return elbo_hat
    end

    b_t = resample_stratified(rng, norm_weights(pw[end, :]), 1)[1]

    return (
        pw = pw, 
        logpw = logpw,
        traj_u = traj_u,
        b_t = b_t,
        u = traj_u[:, :, b_t]
    )
end

The following fails:

flat, re = destructure(ffnn)
st = Optimisers.setup(ADAM(), flat)
grads = ForwardDiff.gradient(flat) do v
    m = re(v)
    sim_vsmc(m, y,  X, 10, ones(4), 0.1; lower_bound_only=true)
end

with the following error:

ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12})

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:207
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:792
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number}
   @ Base char.jl:50
  ...

Stacktrace:
  [1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12})
    @ Base ./number.jl:7
  [2] setindex!
    @ ./array.jl:971 [inlined]
  [3] macro expansion
    @ ./multidimensional.jl:932 [inlined]
  [4] macro expansion
    @ ./cartesian.jl:64 [inlined]
  [5] _unsafe_setindex!
    @ ./multidimensional.jl:927 [inlined]
  [6] _setindex!
    @ ./multidimensional.jl:916 [inlined]
  [7] setindex!
    @ ./abstractarray.jl:1399 [inlined]
  [8] sim_state(forward_model::Chain{Tuple{Dense{typeof(relu), Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}}, Dense{typeof(identity), Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}}}}, sig2w::Vector{Float64}, sig2l::Float64, n_particles::Int64; prev_states::Nothing, rng::Xoshiro)
    @ Main ./REPL[11]:23
  [9] sim_state
    @ ./REPL[11]:1 [inlined]
 [10] sim_vsmc(forward_model::Chain{Tuple{Dense{typeof(relu), Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}}, Dense{typeof(identity), Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}}}}, y::Vector{Float64}, X::Matrix{Float64}, n_particles::Int64, sig2w::Vector{Float64}, sig2l::Float64; rng_num::Int64, lower_bound_only::Bool)
    @ Main ./REPL[70]:14
 [11] (::var"#35#36")(v::Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}})
    @ Main ./REPL[71]:3
 [12] chunk_mode_gradient(f::var"#35#36", x::Vector{Float32}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:123
 [13] gradient(f::Function, x::Vector{Float32}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}}, ::Val{true})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:21
 [14] gradient(f::Function, x::Vector{Float32}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:17
 [15] gradient(f::Function, x::Vector{Float32})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:17
 [16] top-level scope
    @ REPL[71]:1

Not entirely sure what to make of this error message. Does anyone know how to resolve this?

Apologies for missing bits of code, trying to keep this as minimal as need be.

The issue is that differentiation with ForwardDiff works with numbers of a special type called Dual, which are not just Float64. As a consequence, every container created in your function should be able to accommodate them, at least if its contents are involved in the differentiation.
Not sure if that is the only issue cause I can’t run your code (and the error seems to happen in sim_state anyway), but whenever you create arrays with zero(n, m), their default element type is Float64, which prevents them from hosting dual numbers. It is better to parameterize your array types using promotion of the input number types.

As a side note, differentiating through a random function is not as straightforward as it sounds. You might be interested in StochasticAD.jl if your program has discrete randomness for example.
This paper is also a great resource [1906.10652] Monte Carlo Gradient Estimation in Machine Learning

1 Like

@gdalle Thank you for the response and information. I’ve narrowed down the source of the error. Here’s a version of the sim_state function that I now use:

function sim_state(forward_model, sig2w, sig2l, n_particles; prev_states=nothing, rng=nothing)
    m = length(sig2w) + length(sig2l)
    D =  Diagonal([sig2w; sig2l])
    state_samp = zeros(m, n_particles)
 
    for n in 1:n_particles
        state_samp[:, n] = forward_model(prev_states[:, n]) .+ rand(rng, MvNormal(zeros(m), D))
    end

    return  state_samp
end

If I comment out the section involving generating new samples:

# for loop
state_samp[:, n] = forward_model(prev_states[:, n]) .+ rand(rng, MvNormal(zeros(m), D))

then I’m able to take gradients.

and in particular the offending part which causes the error is the forward pass of the network forward_model(prev_states[:, n])

If you post a complete MWE I might be able to take a look. As a side note, any reason to use ForwardDiff here? In a neural network setting you usually have many parameters and only one output, so reverse mode autodiff seems more appropriate.

@gdalle thank again! I don’t have any good reason to use ForwardDiff, I was simply trying to get something up and running. I changed over to ReverseDiff based on your recommendation, for now doing the simplest (and most inefficient thing). I also get an error with this AD type, but it tells me how to fix it ReverseDiff.value(forward_model(prev_states[:, n])) allows for

p, re = Flux.destructure(ffnn);
f = (w) -> sim_vsmc(w, re, y, X, 10, ones(4), 0.1; lower_bound_only=true)
ReverseDiff.gradient(f, p)

to run without errors. Many thanks!