I’m working on implementing variational sequential monte carlo as per this paper
A simplified version of the set of functions I’m using are:
using
Optimisers, Flux, ForwardDiff, Distributions, ...;
# samples particles
function sim_state(..., rng=rng)
# random sample using rng: e.g. rand(rng, MvNormal(...))
# outputs a tuple
#return (samples = ...,
# logprob = ...
# )
end
# forward model whose parameters I want to optimize
ffnn = Chain(Dense(5 => 20, relu), Dense(20, 5))
# additional helper functions here
# ....
# finally the function I am optimizing
function sim_vsmc(
forward_model,
y::Vector,
X::Matrix,
n_particles::Int64,
sig2w::Vector{Float64},
sig2l::Float64;
rng_num = 1,
lower_bound_only=false
)
rng = Xoshiro(rng_num)
T = length(y)
u = sim_state(forward_model, sig2w, sig2l, n_particles; rng=rng)
logz1 = samp_logprob(1, u.state_samp, y[1], X[1, :], sig2w, sig2l, n_particles)
logpw = zeros(T, n_particles)
pw = zeros(T, n_particles)
traj_u = zeros(T, length([sig2w; sig2l]), n_particles)
logpw[1, :] = (sum(logz1; dims=1) |> drop_dim) .- u.state_logprob
pw[1, :] = logpw[1, :] |> safe_weights
traj_u[1, :, :] = u.state_samp
for t in 2:T
ancestor_idxs = resample_stratified(rng, norm_weights(pw[t-1, :]), n_particles)
u = sim_state(forward_model, sig2w, sig2l, n_particles; prev_states=traj_u[t-1, :, ancestor_idxs], rng=rng)
traj_u[t, :, :] = u.state_samp
traj_u[1:t-1, :, :] = traj_u[1:t-1, :, ancestor_idxs]
logzt = samp_logprob(t, u.state_samp, y[t], X[t, :], sig2w, sig2l, n_particles; prev_states=traj_u[t-1, :, :])
logpw[t, :] = (sum(logzt; dims=1) |> drop_dim) .- u.state_logprob
pw[t, :] = logpw[t, :] |> safe_weights
end
if lower_bound_only
elbo_hat = log.(mean(pw; dims=2)) |> sum
return elbo_hat
end
b_t = resample_stratified(rng, norm_weights(pw[end, :]), 1)[1]
return (
pw = pw,
logpw = logpw,
traj_u = traj_u,
b_t = b_t,
u = traj_u[:, :, b_t]
)
end
The following fails:
flat, re = destructure(ffnn)
st = Optimisers.setup(ADAM(), flat)
grads = ForwardDiff.gradient(flat) do v
m = re(v)
sim_vsmc(m, y, X, 10, ones(4), 0.1; lower_bound_only=true)
end
with the following error:
ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12})
Closest candidates are:
(::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
@ Base rounding.jl:207
(::Type{T})(::T) where T<:Number
@ Core boot.jl:792
(::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number}
@ Base char.jl:50
...
Stacktrace:
[1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12})
@ Base ./number.jl:7
[2] setindex!
@ ./array.jl:971 [inlined]
[3] macro expansion
@ ./multidimensional.jl:932 [inlined]
[4] macro expansion
@ ./cartesian.jl:64 [inlined]
[5] _unsafe_setindex!
@ ./multidimensional.jl:927 [inlined]
[6] _setindex!
@ ./multidimensional.jl:916 [inlined]
[7] setindex!
@ ./abstractarray.jl:1399 [inlined]
[8] sim_state(forward_model::Chain{Tuple{Dense{typeof(relu), Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}}, Dense{typeof(identity), Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}}}}, sig2w::Vector{Float64}, sig2l::Float64, n_particles::Int64; prev_states::Nothing, rng::Xoshiro)
@ Main ./REPL[11]:23
[9] sim_state
@ ./REPL[11]:1 [inlined]
[10] sim_vsmc(forward_model::Chain{Tuple{Dense{typeof(relu), Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}}, Dense{typeof(identity), Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}}}}, y::Vector{Float64}, X::Matrix{Float64}, n_particles::Int64, sig2w::Vector{Float64}, sig2l::Float64; rng_num::Int64, lower_bound_only::Bool)
@ Main ./REPL[70]:14
[11] (::var"#35#36")(v::Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}})
@ Main ./REPL[71]:3
[12] chunk_mode_gradient(f::var"#35#36", x::Vector{Float32}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}})
@ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:123
[13] gradient(f::Function, x::Vector{Float32}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}}, ::Val{true})
@ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:21
[14] gradient(f::Function, x::Vector{Float32}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#35#36", Float32}, Float32, 12}}})
@ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:17
[15] gradient(f::Function, x::Vector{Float32})
@ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:17
[16] top-level scope
@ REPL[71]:1
Not entirely sure what to make of this error message. Does anyone know how to resolve this?
Apologies for missing bits of code, trying to keep this as minimal as need be.