Hi everyone! I am learning Deep equilibrium (DEQ) models. The following code is copied from @ChrisRackauckas 's reply in a topic (sorry forgot which topic). Its purpose is to train a DEQ model to fit y=2x.
using Flux
using DiffEqSensitivity
using SteadyStateDiffEq
using DiffEqFlux
using OrdinaryDiffEq
using CUDA
using NonlinearSolve
CUDA.allowscalar(false)
# FastChain with initial_params should also work
# But it's more stable to use Chain and destructure
# ann = FastChain(
#   FastDense(1, 2, relu),
#   FastDense(2, 1, tanh))
# p1 = initial_params(ann)
ann = Chain(Dense(1, 2), Dense(2, 1)) |> gpu
p,re = Flux.destructure(ann)
tspan = (0.0f0, 1.0f0)
function solve_ss(x)
    xg = gpu(x)
    z = re(p)(xg) |> gpu
    function dudt_(u, _p, t)
        # Solving the equation f(u) - u = du = 0
        # Key question: Is there any difference between
        # re(_p)(x) and re(_p)(u+x)?
        re(_p)(u+xg) - u
    end
    ss = SteadyStateProblem(ODEProblem(dudt_, gpu(z), tspan, p))
    x = solve(ss, DynamicSS(Tsit5()), u0 = z, abstol = 1e-2, reltol = 1e-2).u
  
    # ss = NonlinearProblem(dudt_, gpu(z), p)
    # x = solve(ss, NewtonRaphson(), tol = 1e-6).u
    
end
# Let's run a DEQ model on linear regression for y = 2x
X = [1;2;3;4;5;6;7;8;9;10]
Y = [2;4;6;8;10;12;14;16;18;20]
data = Flux.Data.DataLoader(gpu.(collect.((X, Y))), batchsize=1,shuffle=true)
opt = ADAM(0.05)
function loss(x, y)
  ŷ = solve_ss(x)
  sum(abs2,y .- ŷ)
end
epochs = 100
for i in 1:epochs
    Flux.train!(loss, Flux.params(p), data, opt)
    println(solve_ss([-5])) # Print model prediction
end
I tried to run this code, but found the following warning
Warning: Instability detected. Aborting
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/6gTvi/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/6gTvi/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/7GnZA/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
If I use the method of finding the roots of nonlinear equations to obtain the equilibrium, I will find the following errors
ERROR: LoadError: MethodError: no method matching (::var"#dudt_#8"{CuArray{Int64, 1, CUDA.Mem.DeviceBuffer}})(::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
Closest candidates are:(::var"#dudt_#8")(::Any, ::Any, ::Any)
How to adjust to make it work well?
In addition, I would like to ask the Key question raised by @ChrisRackauckas: Is there any difference between re(_p)(x) and re(_p)(u+x)?
Thanks. 