What are the differences between allowed parameter types in DiffEqFlux.jl versus DifferentialEquations.jl?
The documentation for DifferentialEquations states the following:
Note that the type for the parameters
p
can be anything: you can use arrays, static arrays, named tuples, etc. to enclose your parameters in a way that is sensible for your problem.
For DiffEqFlux though, this doesn’t seem to be the case. Based on this issue (and my own testing), it appears that automatic differentiation with DiffEqFlux only works when p
is an N-dimensional array. In particular, making p
an array of arrays or an array of mutable structs seems to cause errors when trying to compute gradients via Zygote or ForwardDiff.
I couldn’t find anything in the DiffEqFlux documentation directly stating that p
was restricted to being an N-dimensional array. I was wondering if someone could confirm this or clarify how to make additional types of p
work with DiffEqFlux.
If it’s helpful, here’s a MWE (adapted from this issue post):
Software versions:
Julia: v1.7.2
DifferentialEquations: v7.1.0
DiffEqFlux: v1.45.3
Zygote: v0.6.38
ForwardDiff: v0.10.25
using Flux, DiffEqFlux, OrdinaryDiffEq, ForwardDiff
# ODE functions
function f1!(dx, x, p, t)
dx[1] = p[1, 1]
dx[2] = p[2, 1]
end
p1 = [1. 5.; 5. 1.]
function f2!(dx, x, p, t)
dx[1] = p[1][1]
dx[2] = p[2][1]
end
p2 = [[1.], [5.]]
function f3!(dx, x, p, t)
dx[1] = p.one
dx[2] = p.two
end
mutable struct two_p
one::Float64
two::Float64
end
# Define length() and iterate() to remove (some) Zygote errors
Base.length(X::two_p) = 2
Base.getindex(X::two_p, i::Int) = begin
if i == 1
return X.one
elseif i == 2
return X.two
else
BoundsError()
end
end
Base.iterate(X::two_p, state=1) = begin
if 1 <= state <= 2
return (X[state], state+1)
elseif state > 2
return nothing
else
error()
end
end
p3 = two_p(1.,5.)
# ODE Problem setup
x0 = [1., 2.]
tspan = (0., 2.)
prob1 = ODEProblem(f1!, x0, tspan, p1)
prob2 = ODEProblem(f2!, x0, tspan, p2)
prob3 = ODEProblem(f3!, x0, tspan, p3)
function predict_adjoint1(p)
Array(concrete_solve(prob1, Tsit5(), x0, p))
end
function predict_adjoint2(p)
Array(concrete_solve(prob2, Tsit5(), x0, p))
end
function predict_adjoint3(p)
Array(concrete_solve(prob3, Tsit5(), x0, p))
end
# Loss functions
function loss_adjoint1(p)
prediction = predict_adjoint1(p)
loss = sum(abs2, prediction[:,end] .-1)
loss
end
function loss_adjoint2(p)
prediction = predict_adjoint2(p)
loss = sum(abs2, prediction[:,end] .-1)
loss
end
function loss_adjoint3(p)
prediction = predict_adjoint3(p)
loss = sum(abs2, prediction[:,end] .-1)
loss
end
# Compute Gradients
Zygote.gradient(loss_adjoint1,p1) # Works
ForwardDiff.gradient(loss_adjoint1,p1) # Works
Zygote.gradient(loss_adjoint2,p2) # ERROR: MethodError: no method matching Float64(::Vector{Float64})
ForwardDiff.gradient(loss_adjoint2,p2) # ERROR: MethodError: no method matching one(::Type{Vector{Float64}})
Zygote.gradient(loss_adjoint3,p3) # ERROR: type Array has no field one
ForwardDiff.gradient(loss_adjoint3,p3) # ERROR: MethodError: no method matching gradient(::typeof(loss_adjoint3), ::two_p)