Why doesn't this normalizing flow example work?

Hi-

I am interested in using normalizing flows to approximate intractable likelihood functions. I found the following example using DiffEqFlux.jl: Continuous Normalizing Flows · DiffEqFlux.jl

Given that this is such a simple example, I anticipated that the performance would be high. However, the scatter plot shows extremely poor performance. What went wrong?

scatter

Version Info

Julia 1.8.4

(normalizing_flow) pkg> st
Status `~/.julia/dev/normalizing_flow/Project.toml`
  [aae7a2af] DiffEqFlux v1.53.0
  [0c46a032] DifferentialEquations v7.6.0
  [b4f34e82] Distances v0.10.7
  [31c24e10] Distributions v0.25.79
  [587475ba] Flux v0.13.10
  [7f7a1694] Optimization v3.10.0
  [253f991c] OptimizationFlux v0.1.2
  [36348300] OptimizationOptimJL v0.1.5

Code

Summary
###########################################################################################################
#                                           load packages
###########################################################################################################
cd(@__DIR__)
using Pkg 
Pkg.activate("")
using Flux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationFlux
using OptimizationOptimJL, Distributions
using Random
Random.seed!(3411)
###########################################################################################################
#                                           setup network
###########################################################################################################
nn = Flux.Chain(
    Flux.Dense(1, 3, tanh),
    Flux.Dense(3, 1, tanh),
) |> f32
tspan = (0.0f0, 1.0f0)

ffjord_mdl = FFJORD(nn, tspan, Tsit5())

# Training
data_dist = Normal(6.0f0, 0.7f0)
train_data = Float32.(rand(data_dist, 1, 100))

function loss(θ)
    logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ)
    -mean(logpx)
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p)

res1 = Optimization.solve(optprob,
                          ADAM(0.1),
                          maxiters = 100)

optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(optprob2,
                          Optim.LBFGS(),
                          allow_f_increases=false)
###########################################################################################################
#                                           evaluate and plot
###########################################################################################################
using Distances

actual_pdf = pdf.(data_dist, train_data)
learned_pdf = exp.(ffjord_mdl(train_data, res2.u)[1])
train_dis = totalvariation(learned_pdf, actual_pdf) / size(train_data, 2)

# Data Generation
ffjord_dist = FFJORDDistribution(FFJORD(nn, tspan, Tsit5(); p=res2.u))
new_data = rand(ffjord_dist, 100)

using Plots
scatter(actual_pdf', learned_pdf', xlabel="true density", ylabel="estimated density",
    leg=false, grid=false)

It needs many more than 100 iterations to converge in practice IIRC. And probably a larger network. I think I have an issue open to go back and fix up that example: I forget who wrote it but it was a small test of the method and not actually a good example.

Increasing the training set and network size did not seem to help unfortunately. I opened an issue here