Monte Carlo problem (distributed computing, DifferentialEquations pkg) with ForwardDiff AD for parameter estimation

I am trying to use a Monte Carlo problem (distributed computing) from the Differential Equations package for parameter estimation using AD from ForwardDiff.

It is working when I define the objective function directly in my code, but wrapping it in another function, that returns the objective function does not work, see the attached minimal example. Any ideas where the problem lies and how to solve it?

using Distributed
addprocs([("dhcp229", 2)])
@everywhere using DifferentialEquations, LinearAlgebra
tspan = (0.0,1.0)
@everywhere ode_f = (u,p,t)->p[1]*u
prob = ODEProblem(ode_f,0.5,tspan)


@everywhere function prob_func(prob,i,repeat,x) # x is the new parameter vector


    p_new = [x...] 
    x0 = prob.u0
    tspan = prob.tspan
    f = ode_f

    return ODEProblem(f, convert.(eltype(p_new), x0),convert.(eltype(p_new), tspan),convert.(eltype(p_new), p_new))

end

monte_prob = MonteCarloProblem(prob,prob_func=(prob,i,repeat) -> prob_func(prob,i,repeat,[1.]))
ndata = 10
xdata = range(tspan[1], stop=tspan[2],length=ndata)
ydata = ones(ndata)

@everywhere function f(x)
    MCprob = MonteCarloProblem(monte_prob.prob, output_func=monte_prob.output_func, prob_func = (prob,i,repeat) -> prob_func(prob,i,repeat,x), reduction = monte_prob.reduction, u_init = monte_prob.u_init)

    sol = solve(MCprob; num_monte=1, saveat=xdata)

    norm(ydata-sol[1].u)

end

function wrapper_fcn()
    return function f(x)
        MCprob = MonteCarloProblem(monte_prob.prob, output_func=monte_prob.output_func, prob_func = (prob,i,repeat) -> prob_func(prob,i,repeat,x), reduction = monte_prob.reduction, u_init = monte_prob.u_init) 

        sol = solve(MCprob; num_monte=1, saveat=xdata)

        norm(ydata-sol[1].u)

    end
end

##

f2 = wrapper_fcn()
x0 = [1.]
f2(x0)

using ForwardDiff
g = x->ForwardDiff.gradient(f, x)
g2 = x->ForwardDiff.gradient(f2, x)
g(x0)  # THIS WORKS!
g2(x0) # THIS DOES NOT WORK!

This is a tough one. The problem is that ForwardDiff uses a tag that is dependent on the function that is being called. In the normal case, it’s called Main.f2, but after it gets serialized and sent to the other computer for distributed parallelism, it’s in the Serialization.__deserialized_types__ module, so it gets a different tag.

2 Likes