Error in training Deep Equilibrium Models (DEQ)

I read the Blog about “Composability in Julia: Implementing Deep Equilibrium Models via Neural ODEs”.

I tried to reproduce the code to train the MNIST dataset, but it returns an error during training:

ERROR: MethodError: no method matching needs_concrete_A(::Nothing)
Closest candidates are:
  needs_concrete_A(::LinearSolve.AbstractFactorization) at julia\packages\LinearSolve\tyXtB\src\LinearSolve.jl:34        
  needs_concrete_A(::LinearSolve.AbstractKrylovSubspaceMethod) at julia\packages\LinearSolve\tyXtB\src\LinearSolve.jl:35 
  needs_concrete_A(::LinearSolve.AbstractSolveFunction) at julia\packages\LinearSolve\tyXtB\src\LinearSolve.jl:36     

I simplified the original code to the following one:

using Flux, Statistics
using DiffEqSensitivity, OrdinaryDiffEq, SteadyStateDiffEq

struct DeepEquilibriumNetwork{M,P,RE,A,K}
    model::M
    p::P
    re::RE
    args::A
    kwargs::K
end

Flux.@functor DeepEquilibriumNetwork

function DeepEquilibriumNetwork(model, args...; kwargs...)
    p, re = Flux.destructure(model)
    return DeepEquilibriumNetwork(model, p, re, args, kwargs)
end

Flux.trainable(deq::DeepEquilibriumNetwork) = (deq.p,)

function (deq::DeepEquilibriumNetwork)(x::AbstractArray{T},
                                       p = deq.p) where {T}
    z = deq.re(p)(x)
    # Solving the equation f(u) - u = du = 0
    dudt(u, _p, t) = deq.re(_p)(u .+ x) .- u
    ssprob = SteadyStateProblem(ODEProblem(dudt, z, (zero(T), one(T)), p))
    return solve(ssprob, deq.args...; u0 = z, deq.kwargs...).u
end

function Net()
    return Chain(
        DeepEquilibriumNetwork(Chain(Dense(100, 500, tanh), Dense(500, 100)),
                               DynamicSS(Tsit5(), abstol = 1f-2, reltol = 1f-2)),
    )
end

function loss_function()
    return mean(abs, deq(rand(100,1)) .- rand(100,1))
end

deq = Net();
data = Iterators.repeated((), 10);
Flux.train!(loss_function, Flux.params(deq), data, ADAM(0.01))

I just found if I modify the DeepEquilibriumNetwork size to 10 times smaller, it will train and won’t give an error…

function Net()
    return Chain(
        DeepEquilibriumNetwork(Chain(Dense(10, 50, tanh), Dense(50, 10)),
                               DynamicSS(Tsit5(), abstol = 1f-2, reltol = 1f-2)),
    )
end

I am using Julia v1.7. Thank you for any help!

You should provide minimal working example. Without it no-one can help you.

1 Like

You should use GitHub - SciML/DeepEquilibriumNetworks.jl: Deep Equilibrium Networks (but faster!!!) instead of the code in the blog post. Updating the blog post has been on my todo list for quite some time.

1 Like

Thank you for the advice! I just modified it.

Thank you!

I was able to run the code (with the package) now.