How to estimate likelihood functions with normalizing flows in Julia

Hi-
I am interested in using normalizing flows in Julia to approximate intractable likelihood functions. The basic idea is to train a neural network to learn the relationship between parameters and samples from a model. Once the neural network is trained, it can be saved, and used later for inference, such as with Turing.jl. I created this simple example in Python to demonstrate how to learn the likelihood function for a lognormal distribution. In this example, I generate 10,000 parameter vectors, and for each parameter vector, I draw a single random sample from the lognormal distribution. Here is a comparison of the true and estimated densities for two different parameter vectors.

Screenshot from 2023-01-02 08-10-20

My question is how can this be done in Julia? I think this would be of broad interest to the community given that models in many domains do not have tractible likelihood functions. I tried modifying the example from DiffEqFlux.jl, but I am not really sure how to properly set up the neural network. Below, I pasted a draft of the code. Any help would be greatly appreciated.

code

Summary
###########################################################################################################
#                                           load packages
###########################################################################################################
cd(@__DIR__)
using Pkg 
Pkg.activate("")
using Flux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationFlux
using OptimizationOptimJL, Distributions
using Random
Random.seed!(3411)
###########################################################################################################
#                                           generate training data
###########################################################################################################
n_parms = 10_000
# training parameters
train_parms = map(x -> Float32.(rand(Gamma(1.0, .5), 2)), 1:n_parms)
# training samples
samples = map(p -> Float32(rand(LogNormal(p...))), train_parms)
training_data = [hcat(train_parms...); samples']
###########################################################################################################
#                                           setup network
###########################################################################################################
nn = Flux.Chain(
    # inputs are parameters μ and σ, and distribution sample
    Flux.Dense(3, 10, tanh),
    Flux.Dense(10, 10, tanh),
    Flux.Dense(10, 1, tanh),
) |> f32
tspan = (0.0f0, 50.0f0)

ffjord_mdl = DiffEqFlux.FFJORD(nn, tspan, Tsit5())

function loss(θ)
    logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ)
    -mean(logpx)
end

function cb(p, l)::Bool
    vl = loss(p)
    @info "Training" loss = vl
    false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p)

res1 = Optimization.solve(optprob,
                          ADAM(0.1),
                          maxiters = 100, callback=cb)

optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(optprob2,
                          Optim.LBFGS(),
                          allow_f_increases=false, callback=cb)
###########################################################################################################
#                                           evaluate and plot
###########################################################################################################
using Plots

test_parms = [1.5, .5]
xs = [.05:.05:20;]
true_density = pdf.(LogNormal(test_parms...), xs)
# is there a better way to get the estimated density?
est_density = map(x -> exp(ffjord_mdl([x], res2.u, monte_carlo=false)[1]), xs)
est_density = vcat(est_density...)
# plot the true and estimated densities
plot(xs, true_density)
plot!(xs, est_density)

What did you get from the code you have? Is it the same neural network architecture with the same loss and same hyper parameters?

Thanks for your reply. Here are the changes I made: I changed the training data from a 1 X 10,000 array to a 3 X 10,000 array where each column corresponds to parameter mu, sigma, and x, which is a sample from the lognormal distribution given the parameters. I figured that I needed to change the input layer from one to three nodes:

nn = Flux.Chain(
    # inputs are parameters μ and σ, and distribution sample
    Flux.Dense(3, 10, tanh),
    Flux.Dense(10, 10, tanh),
    Flux.Dense(10, 1, tanh),
) |> f32

I increased tspan to accommodate the larger range of x. I did not make any other changes. Here is the error I get:

DimensionMismatch: A has dimensions (10,3) but B has dimensions (1,1000)

Full Error
ERROR: DimensionMismatch: A has dimensions (10,3) but B has dimensions (1,1000)
Stacktrace:
  [1] gemm_wrapper!(C::Matrix{Float32}, tA::Char, tB::Char, A::Matrix{Float32}, B::Matrix{Float32}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
    @ LinearAlgebra ~/julia-1.8.3/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:646
  [2] mul!
    @ ~/julia-1.8.3/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:161 [inlined]
  [3] mul!
    @ ~/julia-1.8.3/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:276 [inlined]
  [4] *
    @ ~/julia-1.8.3/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:148 [inlined]
  [5] rrule
    @ ~/.julia/packages/ChainRules/ajkp7/src/rulesets/Base/arraymath.jl:64 [inlined]
  [6] rrule
    @ ~/.julia/packages/ChainRulesCore/C73ay/src/rules.jl:134 [inlined]
  [7] chain_rrule
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:218 [inlined]
  [8] macro expansion
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0 [inlined]
  [9] _pullback
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:9 [inlined]
 [10] _pullback
    @ ~/.julia/packages/Flux/kq9Et/src/layers/basic.jl:172 [inlined]
 [11] macro expansion
    @ ~/.julia/packages/Flux/kq9Et/src/layers/basic.jl:53 [inlined]
 [12] _pullback
    @ ~/.julia/packages/Flux/kq9Et/src/layers/basic.jl:53 [inlined]
 [13] _pullback(::Zygote.Context{false}, ::typeof(Flux._applychain), ::Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}, ::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [14] _pullback
    @ ~/.julia/packages/Flux/kq9Et/src/layers/basic.jl:51 [inlined]
 [15] pullback(f::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, cx::Zygote.Context{false}, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:44
 [16] pullback
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:42 [inlined]
 [17] ffjord(u::Matrix{Float32}, p::Vector{Float32}, t::Float32, re::Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, e::Matrix{Float32}, st::Nothing; regularize::Bool, monte_carlo::Bool)
    @ DiffEqFlux ~/.julia/packages/DiffEqFlux/2IJEZ/src/ffjord.jl:237
 [18] ffjord_
    @ ~/.julia/packages/DiffEqFlux/2IJEZ/src/ffjord.jl:282 [inlined]
 [19] ODEFunction
    @ ~/.julia/packages/SciMLBase/QqtZA/src/scimlfunctions.jl:2096 [inlined]
 [20] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, false, Matrix{Float32}, Nothing, Float32, Vector{Float32}, Float32, Float32, Float32, Float32, Vector{Matrix{Float32}}, ODESolution{Float32, 3, Vector{Matrix{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Matrix{Float32}}}, ODEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, SciMLBase.AutoSpecialize, DiffEqFlux.var"#ffjord_#72"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, SciMLBase.AutoSpecialize, DiffEqFlux.var"#ffjord_#72"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Matrix{Float32}}, Vector{Float32}, Vector{Vector{Matrix{Float32}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}}, DiffEqBase.DEStats, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, DiffEqFlux.var"#ffjord_#72"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, OrdinaryDiffEq.DEOptions{Float32, Float32, Float32, Float32, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Bool, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, Tuple{}, Tuple{}}, Matrix{Float32}, Float32, Nothing, OrdinaryDiffEq.DefaultInit}, cache::OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/CFzwI/src/perform_step/low_order_rk_perform_step.jl:672
 [21] __init(prob::ODEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, SciMLBase.AutoSpecialize, DiffEqFlux.var"#ffjord_#72"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, alg::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{Val{true}}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Bool, callback::Nothing, dense::Bool, calck::Bool, dt::Float32, dtmin::Nothing, dtmax::Float32, force_dtmin::Bool, adaptive::Bool, gamma::Rational{Int64}, abstol::Nothing, reltol::Nothing, qmin::Rational{Int64}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{Int64}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::Base.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:save_noise,), Tuple{Bool}}})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/CFzwI/src/solve.jl:493
 [22] #__solve#566
    @ ~/.julia/packages/OrdinaryDiffEq/CFzwI/src/solve.jl:5 [inlined]
 [23] #solve_call#21
    @ ~/.julia/packages/DiffEqBase/wXM5P/src/solve.jl:494 [inlined]
 [24] solve_up(prob::ODEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, SciMLBase.AutoSpecialize, DiffEqFlux.var"#ffjord_#72"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, sensealg::Nothing, u0::Matrix{Float32}, p::Vector{Float32}, args::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}; kwargs::Base.Pairs{Symbol, Bool, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:save_noise, :save_start, :save_end), Tuple{Bool, Bool, Bool}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/wXM5P/src/solve.jl:856
 [25] #solve#26
    @ ~/.julia/packages/DiffEqBase/wXM5P/src/solve.jl:823 [inlined]
 [26] _concrete_solve_adjoint(::ODEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, SciMLBase.AutoSpecialize, DiffEqFlux.var"#ffjord_#72"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::InterpolatingAdjoint{0, true, Val{:central}, Nothing}, ::Matrix{Float32}, ::Vector{Float32}, ::SciMLBase.ChainRulesOriginator; save_start::Bool, save_end::Bool, saveat::Vector{Float32}, save_idxs::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/DInxI/src/concrete_solve.jl:279
 [27] _concrete_solve_adjoint
    @ ~/.julia/packages/SciMLSensitivity/DInxI/src/concrete_solve.jl:231 [inlined]
 [28] #_solve_adjoint#50
    @ ~/.julia/packages/DiffEqBase/wXM5P/src/solve.jl:1263 [inlined]
 [29] _solve_adjoint
    @ ~/.julia/packages/DiffEqBase/wXM5P/src/solve.jl:1232 [inlined]
 [30] #rrule#48
    @ ~/.julia/packages/DiffEqBase/wXM5P/src/solve.jl:1216 [inlined]
 [31] rrule
    @ ~/.julia/packages/DiffEqBase/wXM5P/src/solve.jl:1212 [inlined]
 [32] rrule
    @ ~/.julia/packages/ChainRulesCore/C73ay/src/rules.jl:134 [inlined]
 [33] chain_rrule
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:218 [inlined]
 [34] macro expansion
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0 [inlined]
 [35] _pullback
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:9 [inlined]
 [36] _apply
    @ ./boot.jl:816 [inlined]
 [37] adjoint
    @ ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:203 [inlined]
 [38] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [39] _pullback
    @ ~/.julia/packages/DiffEqBase/wXM5P/src/solve.jl:823 [inlined]
 [40] _pullback(::Zygote.Context{false}, ::DiffEqBase.var"##solve#26", ::InterpolatingAdjoint{0, true, Val{:central}, Nothing}, ::Nothing, ::Nothing, ::Val{true}, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(solve), ::ODEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, SciMLBase.AutoSpecialize, DiffEqFlux.var"#ffjord_#72"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [41] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:816
 [42] adjoint
    @ ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:203 [inlined]
 [43] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [44] _pullback
    @ ~/.julia/packages/DiffEqBase/wXM5P/src/solve.jl:813 [inlined]
 [45] _pullback(::Zygote.Context{false}, ::CommonSolve.var"#solve##kw", ::NamedTuple{(:sensealg,), Tuple{InterpolatingAdjoint{0, true, Val{:central}, Nothing}}}, ::typeof(solve), ::ODEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, SciMLBase.AutoSpecialize, DiffEqFlux.var"#ffjord_#72"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [46] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:816
 [47] adjoint
    @ ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:203 [inlined]
 [48] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [49] _pullback
    @ ~/.julia/packages/DiffEqFlux/2IJEZ/src/ffjord.jl:297 [inlined]
 [50] _pullback(::Zygote.Context{false}, ::DiffEqFlux.var"##forward_ffjord#67", ::Bool, ::Bool, ::typeof(DiffEqFlux.forward_ffjord), ::FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::Matrix{Float32}, ::Vector{Float32}, ::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [51] _pullback (repeats 2 times)
    @ ~/.julia/packages/DiffEqFlux/2IJEZ/src/ffjord.jl:278 [inlined]
 [52] _pullback(::Zygote.Context{false}, ::typeof(DiffEqFlux.forward_ffjord), ::FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::Matrix{Float32}, ::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [53] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:816
 [54] adjoint
    @ ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:203 [inlined]
 [55] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [56] _pullback
    @ ~/.julia/packages/DiffEqFlux/2IJEZ/src/ffjord.jl:276 [inlined]
--- the last 5 lines are repeated 1 more time ---
 [62] _pullback(::Zygote.Context{false}, ::FFJORD{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Vector{Float32}, Nothing, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::Matrix{Float32}, ::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [63] _pullback
    @ ~/.julia/dev/normalizing_flow/sandbox.jl:35 [inlined]
 [64] _pullback(ctx::Zygote.Context{false}, f::typeof(loss), args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [65] _pullback
    @ ~/.julia/dev/normalizing_flow/sandbox.jl:46 [inlined]
 [66] _pullback(::Zygote.Context{false}, ::var"#73#74", ::Vector{Float32}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [67] _apply
    @ ./boot.jl:816 [inlined]
 [68] adjoint
    @ ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:203 [inlined]
 [69] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [70] _pullback
    @ ~/.julia/packages/SciMLBase/QqtZA/src/scimlfunctions.jl:3580 [inlined]
 [71] _pullback(::Zygote.Context{false}, ::OptimizationFunction{true, Optimization.AutoZygote, var"#73#74", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ::Vector{Float32}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [72] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:816
 [73] adjoint
    @ ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:203 [inlined]
 [74] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [75] _pullback
    @ ~/.julia/packages/Optimization/o00ZS/src/function/zygote.jl:30 [inlined]
 [76] _pullback(ctx::Zygote.Context{false}, f::Optimization.var"#156#165"{OptimizationFunction{true, Optimization.AutoZygote, var"#73#74", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}, args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [77] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:816
 [78] adjoint
    @ ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:203 [inlined]
 [79] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [80] _pullback
    @ ~/.julia/packages/Optimization/o00ZS/src/function/zygote.jl:34 [inlined]
 [81] _pullback(ctx::Zygote.Context{false}, f::Optimization.var"#158#167"{Tuple{}, Optimization.var"#156#165"{OptimizationFunction{true, Optimization.AutoZygote, var"#73#74", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}}, args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [82] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:44
 [83] pullback
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:42 [inlined]
 [84] gradient(f::Function, args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:96
 [85] (::Optimization.var"#157#166"{Optimization.var"#156#165"{OptimizationFunction{true, Optimization.AutoZygote, var"#73#74", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}})(::Vector{Float32}, ::Vector{Float32})
    @ Optimization ~/.julia/packages/Optimization/o00ZS/src/function/zygote.jl:32
 [86] macro expansion
    @ ~/.julia/packages/OptimizationFlux/zHcx5/src/OptimizationFlux.jl:32 [inlined]
 [87] macro expansion
    @ ~/.julia/packages/Optimization/o00ZS/src/utils.jl:37 [inlined]
 [88] __solve(prob::OptimizationProblem{true, OptimizationFunction{true, Optimization.AutoZygote, var"#73#74", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Vector{Float32}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, opt::Adam, data::Base.Iterators.Cycle{Tuple{Optimization.NullData}}; maxiters::Int64, callback::Function, progress::Bool, save_best::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ OptimizationFlux ~/.julia/packages/OptimizationFlux/zHcx5/src/OptimizationFlux.jl:31
 [89] #solve#540
    @ ~/.julia/packages/SciMLBase/QqtZA/src/solve.jl:84 [inlined]
````Preformatted text``Preformatted text`

My guess is that either the training data or the architecture are not correctly specified. Or both. Another guess is that the optimizer might need to be changed to accommodate the change in the architecture, but honestly I don’t have much of an idea.

Open an issue, that’ll make it easier to pull in the other contributors who have been working on CNFs who don’t tend to post here.

I just looked at the Python script and I see it’s not a one-to-one comparison, i.e. that script isn’t a CNF from what I can tell, so there’s no known hyperparameters etc. known to work on this case, so it’ll take some exploring.

Thanks. I here is the issue for reference.

You are right: the details of the algorithm are hidden from the user, but it should be using a normalizing flow of some variety. I included some details in the github issue. Thanks again for looking into this.

The SimulatedNeuralMoments.jl package does something similar. This fits a net using a vector of statistics as the inputs, and the parameters that generated the sample used to compute the statistics as the output. With a real sample, you compute the same statistics, feed it into the net, and you obtain an estimate. The package uses Turing to do MCMC. The SV example shows how an auxiliary model can be used to provide statistics.

I believe that the main difference with respect to the article mentioned in your link is in the choice of statistics. It might be reasonably easy to use the statistics in the referenced paper. KDEs and histograms could be used, if simpler statistics are not informative enough, though it might become pretty intensive, computationally.

1 Like

Thank you for the suggestion. From what I can tell, there are two major differences: (1) normalizing flows use raw samples rather than summary statistics, KDEs etc., and (2) the output of the trained neural network is a probability density: density = trained_nn([μ,σ,x]) for μ,σ and x within the training domain.

Thanks. I will look into it more carefully. This paper seem helpful to understand the nuts and bolts behind this: http://proceedings.mlr.press/v89/papamakarios19a/papamakarios19a.pdf

1 Like