Parameter Types in DiffEqFlux.jl versus DifferentialEquations.jl

What are the differences between allowed parameter types in DiffEqFlux.jl versus DifferentialEquations.jl?

The documentation for DifferentialEquations states the following:

Note that the type for the parameters p can be anything: you can use arrays, static arrays, named tuples, etc. to enclose your parameters in a way that is sensible for your problem.

For DiffEqFlux though, this doesn’t seem to be the case. Based on this issue (and my own testing), it appears that automatic differentiation with DiffEqFlux only works when p is an N-dimensional array. In particular, making p an array of arrays or an array of mutable structs seems to cause errors when trying to compute gradients via Zygote or ForwardDiff.

I couldn’t find anything in the DiffEqFlux documentation directly stating that p was restricted to being an N-dimensional array. I was wondering if someone could confirm this or clarify how to make additional types of p work with DiffEqFlux.

If it’s helpful, here’s a MWE (adapted from this issue post):

Software versions:

Julia: v1.7.2
DifferentialEquations: v7.1.0
DiffEqFlux: v1.45.3
Zygote: v0.6.38
ForwardDiff: v0.10.25

using Flux, DiffEqFlux, OrdinaryDiffEq, ForwardDiff

# ODE functions

function f1!(dx, x, p, t)
  dx[1] = p[1, 1]
  dx[2] = p[2, 1]
end
p1 = [1. 5.; 5. 1.]


function f2!(dx, x, p, t)
  dx[1] = p[1][1]
  dx[2] = p[2][1]
end
p2 = [[1.], [5.]]


function f3!(dx, x, p, t)
  dx[1] = p.one
  dx[2] = p.two
end

mutable struct two_p
  one::Float64
  two::Float64
end

# Define length() and iterate() to remove (some) Zygote errors
Base.length(X::two_p) = 2
Base.getindex(X::two_p, i::Int) = begin
    if i == 1
        return X.one
    elseif i == 2
        return X.two
    else
        BoundsError()
    end
end
Base.iterate(X::two_p, state=1) = begin
    if 1 <= state <= 2
        return (X[state], state+1)
    elseif state > 2
        return nothing
    else
        error()
    end
end

p3 = two_p(1.,5.)


# ODE Problem setup

x0 = [1., 2.]
tspan = (0., 2.)
prob1 = ODEProblem(f1!, x0, tspan, p1)
prob2 = ODEProblem(f2!, x0, tspan, p2)
prob3 = ODEProblem(f3!, x0, tspan, p3)


function predict_adjoint1(p) 
  Array(concrete_solve(prob1, Tsit5(), x0, p))
end

function predict_adjoint2(p) 
  Array(concrete_solve(prob2, Tsit5(), x0, p))
end

function predict_adjoint3(p) 
  Array(concrete_solve(prob3, Tsit5(), x0, p))
end

# Loss functions

function loss_adjoint1(p)
  prediction = predict_adjoint1(p)
  loss = sum(abs2, prediction[:,end] .-1)
  loss
end

function loss_adjoint2(p)
  prediction = predict_adjoint2(p)
  loss = sum(abs2, prediction[:,end] .-1)
  loss
end

function loss_adjoint3(p)
  prediction = predict_adjoint3(p)
  loss = sum(abs2, prediction[:,end] .-1)
  loss
end


# Compute Gradients

Zygote.gradient(loss_adjoint1,p1) # Works
ForwardDiff.gradient(loss_adjoint1,p1) # Works

Zygote.gradient(loss_adjoint2,p2) # ERROR: MethodError: no method matching Float64(::Vector{Float64})
ForwardDiff.gradient(loss_adjoint2,p2) # ERROR: MethodError: no method matching one(::Type{Vector{Float64}})

Zygote.gradient(loss_adjoint3,p3) # ERROR: type Array has no field one 
ForwardDiff.gradient(loss_adjoint3,p3) # ERROR: MethodError: no method matching gradient(::typeof(loss_adjoint3), ::two_p)

Related?

1 Like

If I remember correctly, I think that DiffEqFlux works with ComponentArrays for the parameters. I don’t think that a struct would work, but I am also pretty sure that a struct won’t work in DifferentialEquations.jl for the parameters. I would have to think about it, but I believe the gradient tape has difficulty tracing back through operations on the struct.

2 Likes

Yep, ComponentArrays should work here. If it doesn’t, please open an issue!

1 Like

Sorry to insist, is this the proposed solution for the issue referenced above too? Then I’ll try that. Thanks!

Thanks for this information–this is helpful to know! That makes sense that it’s not possible to set p equal to a struct.

I guess what’s confused me is that DifferentialEquations.jl works perfectly fine with p being a Vector{Any} containing a mix of scalars, arrays, structs, and even functions. But taking the gradient when p is a Vector{Any} doesn’t seem to work.

As an example, let’s say we have the following code:

using DifferentialEquations, DiffEqFlux

function f!(du,u,p,t)

    du[1] = -p[1]'*u
    du[2] = (p[2].a + p[2].b)u[2]
    du[3] = p[3](u,t)
    return nothing
end

struct mystruct
    a
    b
end

function control(u,t)
    return -exp(-t)*u[3]
end


u0 = [10,15,20]
p = [[1;2;3], mystruct(-1,-2), control]
tspan = (0.0,10.0)

prob = ODEProblem(f!,u0, tspan, p)

sol = solve(prob, Tsit5()) # Solves without errors

This code runs without errors, which is great! The parameter vector p contains an array, a struct, and even a function, and everything works perfectly fine.

However, let’s say we want to define a loss function with respect to the first entry of p (e.g. the entry [1;2;3]) and take the gradient:

function loss(p1)
    sol = solve(prob, Tsit5(), p=[p1, mystruct(-1,-2), control])
    return sum(abs2, sol)
end

grad(p) = Zygote.gradient(loss, p)

p2 = [4;5;6]
grad(p2) # ERROR: MethodError: no method matching Int64(::Vector{Int64})

Even though DifferentialEquations handled the ODE solving just fine, Zygote crashes when taking the gradient for the first entry of p. Using ForwardDiff also results in an error:

gradF(p) = ForwardDiff.gradient(loss,p)
gradF(p2) # ERROR: TypeError: in typeassert, expected Float64, got a value of type ForwardDiff.Dual{Nothing, Float64, 3}

Since DifferentialEquations works with a p of type Vector{Any}, it’s difficult to tell whether the Zygote/ForwardDiff errors are due to the type of p or some other problem with the function f!.

(Granted, I could be doing something wrong–I’d love to know if I’m missing something here)

I added much better error messages in this PR:

That should give a lot more clarity here.

Note that the current interface is kind of “requires AbstractArray”, but in reality there’s a bit more generality than it could have with a SciMLParameters Interface, which just hasn’t been written down and fully described. I plan to create this interface package rather soon though.

1 Like