Hi-
I am interested in using normalizing flows to approximate intractable likelihood functions. I found the following example using DiffEqFlux.jl: Continuous Normalizing Flows · DiffEqFlux.jl
Given that this is such a simple example, I anticipated that the performance would be high. However, the scatter plot shows extremely poor performance. What went wrong?
Version Info
Julia 1.8.4
(normalizing_flow) pkg> st
Status `~/.julia/dev/normalizing_flow/Project.toml`
[aae7a2af] DiffEqFlux v1.53.0
[0c46a032] DifferentialEquations v7.6.0
[b4f34e82] Distances v0.10.7
[31c24e10] Distributions v0.25.79
[587475ba] Flux v0.13.10
[7f7a1694] Optimization v3.10.0
[253f991c] OptimizationFlux v0.1.2
[36348300] OptimizationOptimJL v0.1.5
Code
Summary
###########################################################################################################
# load packages
###########################################################################################################
cd(@__DIR__)
using Pkg
Pkg.activate("")
using Flux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationFlux
using OptimizationOptimJL, Distributions
using Random
Random.seed!(3411)
###########################################################################################################
# setup network
###########################################################################################################
nn = Flux.Chain(
Flux.Dense(1, 3, tanh),
Flux.Dense(3, 1, tanh),
) |> f32
tspan = (0.0f0, 1.0f0)
ffjord_mdl = FFJORD(nn, tspan, Tsit5())
# Training
data_dist = Normal(6.0f0, 0.7f0)
train_data = Float32.(rand(data_dist, 1, 100))
function loss(θ)
logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ)
-mean(logpx)
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p)
res1 = Optimization.solve(optprob,
ADAM(0.1),
maxiters = 100)
optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(optprob2,
Optim.LBFGS(),
allow_f_increases=false)
###########################################################################################################
# evaluate and plot
###########################################################################################################
using Distances
actual_pdf = pdf.(data_dist, train_data)
learned_pdf = exp.(ffjord_mdl(train_data, res2.u)[1])
train_dis = totalvariation(learned_pdf, actual_pdf) / size(train_data, 2)
# Data Generation
ffjord_dist = FFJORDDistribution(FFJORD(nn, tspan, Tsit5(); p=res2.u))
new_data = rand(ffjord_dist, 100)
using Plots
scatter(actual_pdf', learned_pdf', xlabel="true density", ylabel="estimated density",
leg=false, grid=false)