Error in training Deep Equilibrium Models (DEQ)

I read the Blog about “Composability in Julia: Implementing Deep Equilibrium Models via Neural ODEs”.

I tried to reproduce the code to train the MNIST dataset, but it returns an error during training:

ERROR: MethodError: no method matching needs_concrete_A(::Nothing)
Closest candidates are:
  needs_concrete_A(::LinearSolve.AbstractFactorization) at julia\packages\LinearSolve\tyXtB\src\LinearSolve.jl:34        
  needs_concrete_A(::LinearSolve.AbstractKrylovSubspaceMethod) at julia\packages\LinearSolve\tyXtB\src\LinearSolve.jl:35 
  needs_concrete_A(::LinearSolve.AbstractSolveFunction) at julia\packages\LinearSolve\tyXtB\src\LinearSolve.jl:36     

I simplified the original code to the following one:

using Flux, Statistics
using DiffEqSensitivity, OrdinaryDiffEq, SteadyStateDiffEq

struct DeepEquilibriumNetwork{M,P,RE,A,K}
    model::M
    p::P
    re::RE
    args::A
    kwargs::K
end

Flux.@functor DeepEquilibriumNetwork

function DeepEquilibriumNetwork(model, args...; kwargs...)
    p, re = Flux.destructure(model)
    return DeepEquilibriumNetwork(model, p, re, args, kwargs)
end

Flux.trainable(deq::DeepEquilibriumNetwork) = (deq.p,)

function (deq::DeepEquilibriumNetwork)(x::AbstractArray{T},
                                       p = deq.p) where {T}
    z = deq.re(p)(x)
    # Solving the equation f(u) - u = du = 0
    dudt(u, _p, t) = deq.re(_p)(u .+ x) .- u
    ssprob = SteadyStateProblem(ODEProblem(dudt, z, (zero(T), one(T)), p))
    return solve(ssprob, deq.args...; u0 = z, deq.kwargs...).u
end

function Net()
    return Chain(
        DeepEquilibriumNetwork(Chain(Dense(100, 500, tanh), Dense(500, 100)),
                               DynamicSS(Tsit5(), abstol = 1f-2, reltol = 1f-2)),
    )
end

function loss_function()
    return mean(abs, deq(rand(100,1)) .- rand(100,1))
end

deq = Net();
data = Iterators.repeated((), 10);
Flux.train!(loss_function, Flux.params(deq), data, ADAM(0.01))

I just found if I modify the DeepEquilibriumNetwork size to 10 times smaller, it will train and won’t give an error…

function Net()
    return Chain(
        DeepEquilibriumNetwork(Chain(Dense(10, 50, tanh), Dense(50, 10)),
                               DynamicSS(Tsit5(), abstol = 1f-2, reltol = 1f-2)),
    )
end

I am using Julia v1.7. Thank you for any help!

You should provide minimal working example. Without it no-one can help you.

You should use GitHub - SciML/DeepEquilibriumNetworks.jl: Deep Equilibrium Networks (but faster!!!) instead of the code in the blog post. Updating the blog post has been on my todo list for quite some time.

Thank you for the advice! I just modified it.

Thank you!

I was able to run the code (with the package) now.