Error while reproducing example of continuous normalizing flows documentation

I am executing an exact copy of the code in the documentation about continuous normalizing flow (CNF): Continuous Normalizing Flows · DiffEqFlux.jl

However, I get an error in the line res1 = Optimization.solve(optprob, OptimizationOptimisers.Adam(0.01); maxiters = 20, callback = cb), namely “ERROR: type OptimizationState has no field layer_1”.

Julia version is 1.10.5. Operating system is Linux, Debian. x86-64 architecture.

Are you on the latest release for all the packages? This was temporarily broken Forward Mode for Neural ODEs / Structured Parameters broken · Issue #1099 · SciML/SciMLSensitivity.jl · GitHub

Thanks a lot for the answer! That seemed to be part of the problem. After import Pkg; Pkg.update(), I now only get the slightly different error “ERROR: Exception while generating log record in module Main at …
exception = type OptimizationState has no field layer_1
Stacktrace: …”

And when replacing the line @info "FFJORD Training" loss=loss(p) with println(l), it executes correctly, which was not the case before. Not sure why @info ... produces the error though?

Another tangential question, now that I am writing with you: Is there a reason why ps,st = Lux.setup(...) can not directly return ps as a ComponentArray, or why one would not want to make that work? In the documentation, the code

ps, st = Lux.setup(Xoshiro(0), model)
ps = ComponentArray(ps)

appears very often and the line ps = ComponentArray(ps) seems a bit like boilerplate code. (But that is just a naive question / suggestion. Let me emphasize that the whole Lux library is really great work! = )

Try using

function cb(p, l)
    @info "FFJORD Training" loss=loss(p.u)
    return false
end

I have a PR for updating the docs but it is held up by a patch that needs to be merged in an upstream library, so it will take some time for my updates to be merged.

appears very often and the line ps = ComponentArray(ps) seems a bit like boilerplate code.

That is a good question, the main answer is to handle the general case. NamedTuples are a strict superset in terms of representing what CAs can represent. While CA is great for representing a nested structure as a flat vector, but that has a few tradeoffs.

The biggest one is that all the elements are promoted to a uniform type (you could store union type but that is a performance disaster). A recent issue came up here FourierNeuralOperator and ComponentArrays error · Issue #29 · LuxDL/NeuralOperators.jl · GitHub, where some of the layers store Complex Number as parameters while others store Reals. Now that is not a problem for NamedTuples but when you construct a CompoenentArray it makes everything a Complex Number and then the layers that expect real parameters stop working.

The other common issue is that CAs cannot have shared parameters. For example if we have nt = (; a = arr1, b = arr2, c = arr1) where nt.a === nt.c. Try converting it to a CA, then you won’t get ca.a === ca.c.

So if we return a NT, and a user converts it to a CA then we expect the user to be aware of the shortcomings. But if we were to return a CA, then the user cannot handle the former 2 cases without messing with internal Lux code.

1 Like

Thanks a lot! Using @info "FFJORD Training" loss=loss(p.u) made it work. Great that you are planning to update this in the documentation already.

And thanks a lot for the detailed answer regarding ComponentArrays. Especially the necessity for shared parameters in some applications seems to be a very good point. In that case, another possibility to remove the ps = ComponentArray(ps) code, would be to make Optimization.OptimizationProblem(optf, ps) accept ps as a NamedTuple. For instance in

optprob = Optimization.OptimizationProblem(optf, ps)
Optimization.solve(optprob, ...)

and other contexts that require that ps is a CA, the functions could simply internally first perform a check (or use multiple dispatch) and if ps is still an NT, then just internally perform ps = ComponentArray(ps) and after that run the function as usual. (It seems like one could comparatively quickly implement this but of course this is then something that goes beyond Lux alone.)