I read the Blog about “Composability in Julia: Implementing Deep Equilibrium Models via Neural ODEs”.
I tried to reproduce the code to train the MNIST dataset, but it returns an error during training:
ERROR: MethodError: no method matching needs_concrete_A(::Nothing)
Closest candidates are:
  needs_concrete_A(::LinearSolve.AbstractFactorization) at julia\packages\LinearSolve\tyXtB\src\LinearSolve.jl:34        
  needs_concrete_A(::LinearSolve.AbstractKrylovSubspaceMethod) at julia\packages\LinearSolve\tyXtB\src\LinearSolve.jl:35 
  needs_concrete_A(::LinearSolve.AbstractSolveFunction) at julia\packages\LinearSolve\tyXtB\src\LinearSolve.jl:36     
I simplified the original code to the following one:
using Flux, Statistics
using DiffEqSensitivity, OrdinaryDiffEq, SteadyStateDiffEq
struct DeepEquilibriumNetwork{M,P,RE,A,K}
    model::M
    p::P
    re::RE
    args::A
    kwargs::K
end
Flux.@functor DeepEquilibriumNetwork
function DeepEquilibriumNetwork(model, args...; kwargs...)
    p, re = Flux.destructure(model)
    return DeepEquilibriumNetwork(model, p, re, args, kwargs)
end
Flux.trainable(deq::DeepEquilibriumNetwork) = (deq.p,)
function (deq::DeepEquilibriumNetwork)(x::AbstractArray{T},
                                       p = deq.p) where {T}
    z = deq.re(p)(x)
    # Solving the equation f(u) - u = du = 0
    dudt(u, _p, t) = deq.re(_p)(u .+ x) .- u
    ssprob = SteadyStateProblem(ODEProblem(dudt, z, (zero(T), one(T)), p))
    return solve(ssprob, deq.args...; u0 = z, deq.kwargs...).u
end
function Net()
    return Chain(
        DeepEquilibriumNetwork(Chain(Dense(100, 500, tanh), Dense(500, 100)),
                               DynamicSS(Tsit5(), abstol = 1f-2, reltol = 1f-2)),
    )
end
function loss_function()
    return mean(abs, deq(rand(100,1)) .- rand(100,1))
end
deq = Net();
data = Iterators.repeated((), 10);
Flux.train!(loss_function, Flux.params(deq), data, ADAM(0.01))
I just found if I modify the DeepEquilibriumNetwork size to 10 times smaller, it will train and won’t give an error…
function Net()
    return Chain(
        DeepEquilibriumNetwork(Chain(Dense(10, 50, tanh), Dense(50, 10)),
                               DynamicSS(Tsit5(), abstol = 1f-2, reltol = 1f-2)),
    )
end
I am using Julia v1.7. Thank you for any help!