Network ODE Model; Type error with ForwardDiff Dual

Hello,

I am new to Turing and trying to use it for my research on network ODE models.
I have started with a simple model (network diffusion)

\frac{du}{dt} = - \rho \mathbf{L} \mathbf{u}

and am trying to infer the value of \rho.

Below is the code for the ODE problem:

NetworkDiffusion(u, p, t) = -p * L * u
u0 = [0.9,0.1,0.1,0.1,0.1];
p = 1.0;
prob = ODEProblem(NetworkDiffusion, u0, (0.0,2.0), p);
sol = solve(prob, saveat=0.1);
data = Array(sol)

(L = 5×5 SparseArrays.SparseMatrixCSC{Int64,Int64} with 25 stored entries)

and for the Turing model:

Turing.setadbackend(:forwarddiff)
@model function fit(data, problem)
	σ ~ InverseGamma(2, 3) # ~ is the tilde character
	p ~ truncated(Normal(1.0,1.0),0.0,2.5)

	prob = remake(problem, p=p)
	predicted = solve(prob, saveat=0.1)

	for i = 1:length(predicted)
		data[:,i] ~ MvNormal(predicted[i], σ)
	end
end

model = fit(data, problem);
chain = sample(model, NUTS(0.65), 1000)

With this code, I receive the following error:
TypeError: in typeassert, expected Float64, got a value of type ForwardDiff.Dual{Nothing,Float64,2}

(I can share more of the error message if needed).

I have been unable to identify what the problem is. There was a similar post in Oct 2020, but I doesn’t seem that the solutions presented there would apply in this case (I may be wrong). I would be very grateful for any help on how to fix the error!

Thanks,
Pavan

I am not sure if it would help but can you maybe try it by passing a solver as an argument to solve function?

1 Like

@ChrisRackauckas or @mohamed82008 might be able to help.

1 Like

Thanks both. Passing a solver argument did not fix it.

However, I found a solution from @ChrisRackauckas last year. This seems to have fixed it:

Turing.setadbackend(:forwarddiff)
@model function fit(data, func)
	σ ~ InverseGamma(2, 3) # ~ is the tilde character
	p ~ truncated(Normal(1.25,1.0),0.0,2.5)

	prob = ODEProblem(func,eltype(p).(u0),(0.0,2.0),p)
	predicted = solve(prob, Tsit5(),saveat=0.1)

	for i = 1:length(predicted)
		data[:,i] ~ MvNormal(predicted[i], σ)
	end
end

It’s unclear to me why this works, but the following does not:

problem = ODEProblem(NetworkDiffusion, eltype(p).(u0), (0.0,2.0), p);

Turing.setadbackend(:forwarddiff)
@model function fit(data, problem)
	σ ~ InverseGamma(2, 3)
	p ~ truncated(Normal(1.25,1.0),0.0,2.5)

	prob = remake(problem, p=p)
	predicted = solve(prob, Tsit5(),saveat=0.1)

	for i = 1:length(predicted)
		data[:,i] ~ MvNormal(predicted[i], σ)
	end
end

Practically, is there any difference in using remake vs recalling ODEProblem?

Thanks,
Pavan

1 Like

Which line causes this error? Because the error seems familiar to me, I believe I might have solved it in my case that time by doing ::Type{T} = Float64 in the model definition so in your case:

@model function fit(data, problem, ::Type{T} = Float64) where {T} 

And then you need to add ::T to the variable that seems to be the cause of the problem. So that in your case might be the “prob” variable since that seems to be the main difference between the two versions of your code that make it fail.
So then you would get:

prob::T = remake(problem, p=p)

It’s kind of a wild guess, but easy enough to try so it can’t harm I’d say.

Fixed in DiffEqBase v6.57.1. Update in about 2 hours and it should just work.

1 Like

Oh wow, thanks! And this addresses the remake vs ODEProblem issue?
Presumably updating the environment with the package manager will suffice?

Yes, yes

1 Like

Does eltype(p).(u0) work because p must be type Float64 or because it must be the same type as u0? Is there a way to invert models with non Float64 parameters in p (e.g. nonhomogenous equations)?

The initial condition must be a dual number in order to push forward the derivatives with respect to parameters on the state. But if you differentiate by the parameters, then p is Dual while u0 isn’t. So we have a consistency check to upconvert u0 in this case. We missed the specific case though where p is a scalar (always assumed it was an abstractarray, tuple, or named tuple).