It’s just a mathematical fact about ODEs that if f(u,p,t)
is not deterministic then u' = f(u,p,t)
does not necessarily have a well-defined solution. For example, think about a case with an adaptive dt
where:
function f(u,p,t)
p[] = p[] + 1
u + p[]
end
i.e. each time you call f
you increment p
. If you have a fixed time step integrator with dt
, then this solves the ODE f(u,p,t) = u + i
for t in (t,t+dt)
. Notice then that changing dt
actually changes the ODE, and thus it changes not just what the ODE solver spits out, but “the analytical solution” of the ODE. In other words, there is no convergent definition to the solution to this ODE.
This then comes into play with adaptive time stepping. If you have an implicit state of this sort inside of an ODE definition, changing the tolerance may not just change the solution, but also change the analytical solution of the ODE that is being defined. Such a case does not have a well-defined convergent solution. In a mathematical sense, such a process is not even an ODE, which is why ODE solvers will many times fail to handle such a process correctly.
I mentioned this (and some other related cases) in the JuliaCon talk on debugging ODE solutions:
If you think about how ODE adaptivity works (as I show an animation in there), you’ll see that rejection sampling will break many assumptions of a network state and can throw things into the undefined area.
I know some people have thrown a stateful RNN into the RHS of an ODE solver and claimed it worked, but… I have no idea how that go threw reviews because it would have precisely this problem. If you train the ODE and then change the tolerance, such a “network” or “layer” would actually change its solution, not converge to something, as it’s not an ODE. There was a follow-up paper that showed this occurs on a lot of neural ODEs used in ML, and with this background it shouldn’t be surprising why that would be the case.
For implicit layer things, I want to find a better way to throw an appropriate error for this case, though it’s a little bit difficult in general since there are some correct ways to do it. For example, norm layers with a state are fine, if they use the same state throughout the whole ODE and only update after the ODE is solved.
Yes, and in the sense of Lux/Optimisers.jl I don’t think the “look” will change. You do have to note that, like ODE solvers, not all optimizers are essentially “one-step”. Optimisers.jl sticks to optimization algorithms like ADAM that require one call to the objective function before moving forwards: this is kind of a standard assumption of machine learning libraries.
Some optimization algorithms are not like this. For example, BFGS with line searching (which is what Optim.jl uses) will have multiple objective calls for a step. This is the reason why some algorithms (like BFGS) are not stable with stochastic loss functions, and changing state within the objective presents itself as a stochastic loss because you cannot guarantee the loss value at a given point is always the same (in fact, if there’s state that generally won’t be true). There are variations of BFGS which are robust to stochastic loss functions and such, but that’s a whole topic.
So tl;dr, for optimizers that support state of this form it will be fine, but not all optimizers actually support that kind of behavior. We need to setup a trait system to make it easier to query such details.