What to do when the DEQ model does not converge?

Hi everyone! I am learning Deep equilibrium (DEQ) models. The following code is copied from @ChrisRackauckas 's reply in a topic (sorry forgot which topic). Its purpose is to train a DEQ model to fit y=2x.

using Flux
using DiffEqSensitivity
using SteadyStateDiffEq
using DiffEqFlux
using OrdinaryDiffEq
using CUDA
using NonlinearSolve
CUDA.allowscalar(false)

# FastChain with initial_params should also work
# But it's more stable to use Chain and destructure
# ann = FastChain(
#   FastDense(1, 2, relu),
#   FastDense(2, 1, tanh))
# p1 = initial_params(ann)

ann = Chain(Dense(1, 2), Dense(2, 1)) |> gpu
p,re = Flux.destructure(ann)
tspan = (0.0f0, 1.0f0)

function solve_ss(x)
    xg = gpu(x)
    z = re(p)(xg) |> gpu
    function dudt_(u, _p, t)
        # Solving the equation f(u) - u = du = 0
        # Key question: Is there any difference between
        # re(_p)(x) and re(_p)(u+x)?
        re(_p)(u+xg) - u
    end
    ss = SteadyStateProblem(ODEProblem(dudt_, gpu(z), tspan, p))
    x = solve(ss, DynamicSS(Tsit5()), u0 = z, abstol = 1e-2, reltol = 1e-2).u
  
    # ss = NonlinearProblem(dudt_, gpu(z), p)
    # x = solve(ss, NewtonRaphson(), tol = 1e-6).u
    
end

# Let's run a DEQ model on linear regression for y = 2x
X = [1;2;3;4;5;6;7;8;9;10]
Y = [2;4;6;8;10;12;14;16;18;20]
data = Flux.Data.DataLoader(gpu.(collect.((X, Y))), batchsize=1,shuffle=true)
opt = ADAM(0.05)

function loss(x, y)
  ŷ = solve_ss(x)
  sum(abs2,y .- ŷ)
end

epochs = 100
for i in 1:epochs
    Flux.train!(loss, Flux.params(p), data, opt)
    println(solve_ss([-5])) # Print model prediction
end

I tried to run this code, but found the following warning

Warning: Instability detected. Aborting
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/6gTvi/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/6gTvi/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/7GnZA/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.

If I use the method of finding the roots of nonlinear equations to obtain the equilibrium, I will find the following errors

ERROR: LoadError: MethodError: no method matching (::var"#dudt_#8"{CuArray{Int64, 1, CUDA.Mem.DeviceBuffer}})(::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
Closest candidates are:(::var"#dudt_#8")(::Any, ::Any, ::Any)

How to adjust to make it work well?

In addition, I would like to ask the Key question raised by @ChrisRackauckas: Is there any difference between re(_p)(x) and re(_p)(u+x)?

Thanks. :grinning:

1 Like

Yeah… we mentioned in the blog post that one issue with the DEQ approach is that not all dynamical systems have steady states. x_{n+1} = f(x_n) or u' = f(u) could just blow up to infinity as n->inf or t->inf, and that is the case for you here. What you need to do is ensure that the start of your system has a steady state to begin with. The easy way to do that is to force some sufficient conditions for steady state behavior. With ODEs, if f'(u_ss) < 0 near some u_ss for which f(u_ss) = 0 (there’s similar contraction results that hold on discrete systems as well). Thus what you can do is try to make the RHS have a smaller derivative to start with, say:

   function dudt_(u, _p, t)
        re(_p)(u)/100 - u
    end

would have a higher chance of having steady state behavior for a random initialized neural network than not having the normalization. Then of course the training may compensate by multiplying by 100, but basically all you need is to make sure it starts. If the first solve is unstable, then it can never train, so it’s really just forcing there to be a starting point.

What do you mean by that?

If you don’t have the u in there, then the steady state is just u = NN(x) so you don’t need the numerical solver and can just call the NN.

1 Like

Even if f'(u_{ss}) < 0 and f(u_{ss}) = 0, there is no guarantee that the solution of the IVP

u'=f(u); u(0) = u_0

will converge to u_{ss } as t \to \infty. So you may need to pay some attention to the initial data u_0.

The condition f'(u_{ss}) < 0 is a big deal. There is no reason to think that a solution of f(u) = 0 is a steady-state solution without checking f'.

Yes indeed, that’s why I said “higher chance” of going to a steady state. There are still many other reasons why it might not, but tempering the derivative of the f part (since it’s f(u) - u is one way to make it more likely to have a large basin of attraction. Let me expand for the OP.

One simple way to see this is to take the simplest nonlinear ODE: u' = u^2 - u. If u(0) < 1, then u^2 < u and thus u → 0. If u(0) > 1, then u' > 0 and u'' > 0, and so you get blowup to infinity really quickly. My suggestion would be to temper the nonlinearity, which in this case it would do u' = u^2 / 100 - u, which would increase the basin of attraction (the starting values which go to zero) to be any u(0) < 10

2 Likes

Thank you very much for your explanation! The initial value does have an effect on convergence, temper the nonlinearity is intuitively feasible. On the other hand, is it worth discussing to give a better initial parameter to the model? Because from the perspective of the iterative method, to see if the model has a fixed point, it refers to whether the spectral radius of the iterative matrix corresponding to the model is less than 1. We only need to select one as the initial parameter from the set of model parameters that satisfy the constraint. Of course, how to find this parameter is not trivial, but some techniques like the Homotopy algorithm may be helpful.

What do you mean by that?

I mean if I change

    ss = SteadyStateProblem(ODEProblem(dudt_, gpu(z), tspan, p))
    x = solve(ss, DynamicSS(Tsit5()), u0 = z, abstol = 1f-2, reltol = 1f-2).u

to

    ss = NonlinearProblem(dudt_, gpu(z), p)
    x = solve(ss, NewtonRaphson(), maxiters=1e6).u

There will be the matching error.

Yeah, I think there’s more work to find better initializations for this kind of thing.

NonlinearProblem wants a function out-of-place f(u,p) or in-place f(du,u,pt).

Okay, let me try!
Thanks