Combining Lux.jl with Optimization.jl

Hi there,
while trying to compile a state-of-the-art NeuralDifferentialEquations tutorial I stumbled upon the recent addition of Lux.jl and Optimization.jl

It seems they don’t really work together, because Optimization.jl requires the optimization function to be stateless, or have implicit states. (EDIT: The documentation mentions u as the state, however this is not generic state, but state which necessarily is to updated by gradient methods). The most widespread example of stateful computations are probably random numbers, where the random number generator is the state. There are other useful states as well, like in recurrent neural networks.

I am wondering what this means to the Optimization.jl interface. Are there plans for an interface version 2.0 which allows for explicit state?
Or will there be an Optimization2.jl which will truly be the single one package to go for optimization tasks (which already seems to be the goal for Optimization.jl)?

Practically, I am now falling back to Lux.Training which seems to be the most generic interface for training explicit-stateful-functions with autogradients. (Also AbstractDifferentiations.jl does not work with even simple Lux.jl examples, so this is also not option at this stage) If someone knows another optimization meta-package which allows for explicit-stateful-functions I would be very interested to know.

There’s also the p which is for non-derivative state and hyperparameters. Optimization.jl+Lux.jl is used for most of the neural differential equation stuff these days, see the tutorials for how state is handled.

1 Like

Hi Chris,
I am spending my second full day reading documentation and haven’t found the example of how to get back the updated Lux state st so that I can input it to the next Lux model call as the updated state.

I’ve found this recent update to the UDE Paper example universal_differential_equations/scenario_1.jl at master · ChrisRackauckas/universal_differential_equations · GitHub which uses both Lux and Optimization, but to the best of my understanding, it just ignores the Lux state st.

If there already is a tutorial how to correctly handle Lux-like state st (like for instance updating an explicit rng), I would be very thankful if you could point me to it.

What’s your use case for st there? It’s generally incompatible with accurate definitions of ODEs with adaptivity, so one should generally error if it’s used in a neural network ODE definition as then it’s not actually an ODE (or even an equation with a well-defined result)

1 Like

I have two usecases actually:

  1. I woud like to use Optimization.jl to optimize a Lux.jl model. This includes handling the lux state correctly. No ODEs involved.
  2. I would like to replicate the official Lux example of NeuralODE. It uses Zygote pullback for optimization as of now, which I would like to exchange with Optimization.jl

For this, use p to hold st in a ref?

In that case, st is properly ignored just by being discarded in dudt. It’s basically the same as what’s done in

https://diffeqflux.sciml.ai/dev/examples/neural_ode/

1 Like

Both don’t work as far as I see unfortunately

  1. if I make Optimizer p hold the Lux st, there is no way to update it, or to extract the updated st out of Optimization. Hence the state is not the generic Lux-like state (e.g. rng), but more like constants.

  2. I took a look and found that the decisive difference is the definition of the dudt:

The SciML example does ignore the state in the inner dudt function by explicitly adding st=st

The Lux.jl example does explicitly handle the state by not doing so:

function (n::NeuralODE)(x, ps, st)
    function dudt(u, p, t)
        u_, st = n.model(u, p, st)
        return u_
    end
[...]

It rather seems to confirm so far, that Zygote is indeed not replaceable with Optimization yet.

1 Like

No, you just make p a Ref as mentioned above. Did you give that a try?

Yes, but that’s potentially incorrect so that should be changed (@avikpal). Best would be to check and error if state updates.

1 Like

sorry, I somehow haven’t read correctly - writing it Ref triggered now :+1: . Sure, this restores state by making it implicit. Got it. Still no way to have explicit state, but really good to know. Do you know whether it is safe to use such implicit state within Optimization?

EDIT: Thank you very much Chris for all your immediate support, thank you for recommending this workaround.

Very impressive, so really ODEs are not compatible with custom updateable state. I’ve no clue what is going on behind the scene which breaks this. Is there an issue which describes the problem?

It’s just a mathematical fact about ODEs that if f(u,p,t) is not deterministic then u' = f(u,p,t) does not necessarily have a well-defined solution. For example, think about a case with an adaptive dt where:

function f(u,p,t)
  p[] = p[] + 1
  u + p[]
end

i.e. each time you call f you increment p. If you have a fixed time step integrator with dt, then this solves the ODE f(u,p,t) = u + i for t in (t,t+dt). Notice then that changing dt actually changes the ODE, and thus it changes not just what the ODE solver spits out, but “the analytical solution” of the ODE. In other words, there is no convergent definition to the solution to this ODE.

This then comes into play with adaptive time stepping. If you have an implicit state of this sort inside of an ODE definition, changing the tolerance may not just change the solution, but also change the analytical solution of the ODE that is being defined. Such a case does not have a well-defined convergent solution. In a mathematical sense, such a process is not even an ODE, which is why ODE solvers will many times fail to handle such a process correctly.

I mentioned this (and some other related cases) in the JuliaCon talk on debugging ODE solutions:

If you think about how ODE adaptivity works (as I show an animation in there), you’ll see that rejection sampling will break many assumptions of a network state and can throw things into the undefined area.

I know some people have thrown a stateful RNN into the RHS of an ODE solver and claimed it worked, but… I have no idea how that go threw reviews because it would have precisely this problem. If you train the ODE and then change the tolerance, such a “network” or “layer” would actually change its solution, not converge to something, as it’s not an ODE. There was a follow-up paper that showed this occurs on a lot of neural ODEs used in ML, and with this background it shouldn’t be surprising why that would be the case.

For implicit layer things, I want to find a better way to throw an appropriate error for this case, though it’s a little bit difficult in general since there are some correct ways to do it. For example, norm layers with a state are fine, if they use the same state throughout the whole ODE and only update after the ODE is solved.

Yes, and in the sense of Lux/Optimisers.jl I don’t think the “look” will change. You do have to note that, like ODE solvers, not all optimizers are essentially “one-step”. Optimisers.jl sticks to optimization algorithms like ADAM that require one call to the objective function before moving forwards: this is kind of a standard assumption of machine learning libraries.

Some optimization algorithms are not like this. For example, BFGS with line searching (which is what Optim.jl uses) will have multiple objective calls for a step. This is the reason why some algorithms (like BFGS) are not stable with stochastic loss functions, and changing state within the objective presents itself as a stochastic loss because you cannot guarantee the loss value at a given point is always the same (in fact, if there’s state that generally won’t be true). There are variations of BFGS which are robust to stochastic loss functions and such, but that’s a whole topic.

So tl;dr, for optimizers that support state of this form it will be fine, but not all optimizers actually support that kind of behavior. We need to setup a trait system to make it easier to query such details.

1 Like

the workaround with the Ref worked.

For others it might be helpful to have a snippet, what this means in practice (incomplete, won’t run, but probably still helpful)

ps, st = Lux.setup(rng, model)

function loss_function(x, y, ps, st)
    y_pred, st = Lux.apply(model, x, ps, st)
    sum(abs2, y .- y_pred), st
end

ps_trained, st_trained = let st=Ref(st), x=x, y=y
    
    optprob = Optimization.OptimizationProblem(
        Optimization.OptimizationFunction(
            function(ps, constants)
                loss, st[] = loss_function(x, y, ps, st[])
                loss
            end,
            Optimization.AutoZygote()
        ),
        ComponentArrays.ComponentVector{Float64}(ps),
    )
    
    solution = Optimization.solve(
        optprob,
        OptimizationOptimisers.ADAM(0.1),
        maxiters = 500,
    )
    
    solution.u, st[]
end
1 Like

In the Lux example, the state should get updated. It won’t be type-stable due to the closure, but that is a different issue. The DiffEqFlux code currently uses st=st, which means the local st gets updated, and nothing is reflected in the higher scope (Lux compatible layers by Abhishek-1Bhatt · Pull Request #750 · SciML/DiffEqFlux.jl · GitHub).

And to add to the point to st. In most cases when using with NeuralODEs or any differential equations, the state should not change (rather changing state will most likely cause undesirable dynamics). The only exception would be cases like VariationalDropout where the state is updated (exclusively) on the first call and preserved for the entire dynamics (to actually ensure “stability”).

1 Like