I am trying to implement the code given in Continuous Normalizing Flows · DiffEqFlux.jl, with mini batching.
using Flux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationFlux,
OptimizationOptimJL, Distributions, Optimisers
using IterTools: ncycle
nn = Flux.Chain(
Flux.Dense(1, 3, tanh),
Flux.Dense(3, 1, tanh),
) |> f32
tspan = (0.0f0, 10.0f0)
ffjord_mdl = FFJORD(nn, tspan, Tsit5())
data_dist = Normal(6.0f0, 0.7f0)
train_data = Float32.(rand(data_dist, 1, 100))
function loss(p, batch)
logpx, λ₁, λ₂ = ffjord_mdl(batch, p)
-mean(logpx)
end
train_loader = Flux.Data.DataLoader(train_data, batchsize=10)
adtype = Optimization.AutoZygote()
optfun = OptimizationFunction((x, p, batch) -> loss(x, batch), adtype)
optprob = Optimization.OptimizationProblem(optfun, ffjord_mdl.p)
res1 = Optimization.solve(optprob,
Optimisers.ADAM(0.1),
ncycle(train_loader, 10))
I am getting this error.
ERROR: MethodError: no method matching (::var"#112#113")(::Vector{Float32}, ::SciMLBase.NullParameters, ::Float32, ::Float32, ::Float32, ::Float32, ::Float32, ::Float32, ::Float32, ::Float32, ::Float32, ::Float32)
Closest candidates are:
(::var"#112#113")(::Any, ::Any, ::Any)
What am I doing wrong here?
Thank You