Adapting automatic differentiation to functions with user-defined type argument

I’m using my own module where models can be summarized to be functions y = f(θ) depending on parameters θ, the latter being defined as a specific type P. I can’t change that without rewriting the whole module, which I can’t afford at the moment. I’ve written my own Jacobian function but I now want to use automatic differentiation. The problem I’m facing is that the latter only works with functions depending on Vectors. I’ve tried to adapt things as in the following MWE:

using ForwardDiff

type P
    a::Float64
    b::Float64
end

## convert parameters to vector
make_vector(θ::P) = map(z -> getfield(θ, z), fieldnames(θ))

## toy model
function f(θ::P)
    x = make_vector(θ)
    return sum(x) > 0.0 ? sin.(x) : cos.(x)
end

g(x::Vector) = f(P(x...))
df(θ::P) = ForwardDiff.jacobian(g, make_vector(θ))
x = [1.0, 2.0]
θ = P(x...)
dfdθ = df(θ)

The problem is that doing so fails because of the way jacobian works in the ForwardDiff module, I therefore get this error:

ERROR: LoadError: MethodError: Cannot `convert` an object of type ForwardDiff.Dual{ForwardDiff.Tag{#g,Float64},Float64,2} to an object of type Float64

Is there a way to overcome this? By keeping the signature of my models with the P parameters of course. Many thanks,

Don’t hard-code Float64 in your struct. Parameterize on the type instead.

3 Likes

This indeed solved my current problem, so I at least need to remove these constraints from my models. I’ll see if this is sufficient, I may encounter other problems if some of my parameters are, say, Array{Float64} instead, even if they’re not hard-coded within the struct…? Anyway, thank you very much, I’ll keep you posted if I encounter new problems in the more general case.

The important point is to keep your code generic so that it works with all subtypes of <: Real (also as element types). This can frequently be accomplished by

  1. simply letting Julia figure out the type (for mapping, collections, etc),
  2. if you can’t do that, using constructs like similar (for creating arrays), zero, one, oftype (just look at the manual).