I read the Blog about “Composability in Julia: Implementing Deep Equilibrium Models via Neural ODEs”.
I tried to reproduce the code to train the MNIST dataset, but it returns an error during training:
ERROR: MethodError: no method matching needs_concrete_A(::Nothing)
Closest candidates are:
needs_concrete_A(::LinearSolve.AbstractFactorization) at julia\packages\LinearSolve\tyXtB\src\LinearSolve.jl:34
needs_concrete_A(::LinearSolve.AbstractKrylovSubspaceMethod) at julia\packages\LinearSolve\tyXtB\src\LinearSolve.jl:35
needs_concrete_A(::LinearSolve.AbstractSolveFunction) at julia\packages\LinearSolve\tyXtB\src\LinearSolve.jl:36
I simplified the original code to the following one:
using Flux, Statistics
using DiffEqSensitivity, OrdinaryDiffEq, SteadyStateDiffEq
struct DeepEquilibriumNetwork{M,P,RE,A,K}
model::M
p::P
re::RE
args::A
kwargs::K
end
Flux.@functor DeepEquilibriumNetwork
function DeepEquilibriumNetwork(model, args...; kwargs...)
p, re = Flux.destructure(model)
return DeepEquilibriumNetwork(model, p, re, args, kwargs)
end
Flux.trainable(deq::DeepEquilibriumNetwork) = (deq.p,)
function (deq::DeepEquilibriumNetwork)(x::AbstractArray{T},
p = deq.p) where {T}
z = deq.re(p)(x)
# Solving the equation f(u) - u = du = 0
dudt(u, _p, t) = deq.re(_p)(u .+ x) .- u
ssprob = SteadyStateProblem(ODEProblem(dudt, z, (zero(T), one(T)), p))
return solve(ssprob, deq.args...; u0 = z, deq.kwargs...).u
end
function Net()
return Chain(
DeepEquilibriumNetwork(Chain(Dense(100, 500, tanh), Dense(500, 100)),
DynamicSS(Tsit5(), abstol = 1f-2, reltol = 1f-2)),
)
end
function loss_function()
return mean(abs, deq(rand(100,1)) .- rand(100,1))
end
deq = Net();
data = Iterators.repeated((), 10);
Flux.train!(loss_function, Flux.params(deq), data, ADAM(0.01))
I just found if I modify the DeepEquilibriumNetwork size to 10 times smaller, it will train and won’t give an error…
function Net()
return Chain(
DeepEquilibriumNetwork(Chain(Dense(10, 50, tanh), Dense(50, 10)),
DynamicSS(Tsit5(), abstol = 1f-2, reltol = 1f-2)),
)
end
I am using Julia v1.7. Thank you for any help!