Using nlsolve inside Turing.jl- ForwardDiff.Dual type error

Hi, I’m trying to assess if Turing is a good fit for a particular dynamic pricing-type problem I have,
and I’m wondering if it’s possible to use nlsolve inside. I’m sure it is, but I can’t figure it out.
The code bellow returns a ERROR: TypeError: in typeassert, expected Float64, got a value of type ForwardDiff.Dual{Nothing, Float64, 3}
from inside the nlsolve call.

I found the following example, but frankly didn’t understand it.
Also, I tried to add the @adjoint definition from this issue,
but don’t think that’s the problem.
I did notice that this question returns the exact same error type, but it’s not obvious to me how to fix it (perhaps by moving the inside nlsolve as a separate function and adding some type annotations?).
Here’s the attempt:

using Random
using Zygote
using NLsolve
using ForwardDiff
using Distributions
using StatsPlots
using Turing
Random.seed!(42);
# dummy supply function
function supplyFunc(p,baseSupply) 
	if p>-1 baseSupply+(p+1)^2 else 0 end
end
# dummy demand function
function demandFunc(p,baseDemand) 
	if p<20 baseDemand-(p+1)^2 else 0 
	end
end
# estimate supply and demand from prices, single datapoint to start with
@model supplyDemandFromPrices(price)=begin
    
	sigma ~ truncated(Normal(0, 100), 0, Inf)
    supply ~ truncated(Normal(0, 100), 0, Inf)
	demand ~ truncated(Normal(0, 100), 0, Inf)
        # The offending call- try to estimate clearing price
	clearingPrice=nlsolve(x->[supplyFunc(x[1],supply)-demandFunc(x[1],demand)],zeros(1),method=:anderson, m=10,autodiff=:forward).zero[1]
      #  assume observed price is normally distributed around the 'market' price 
	price~Normal(clearingPrice,sigma)
end;
modelP = supplyDemandFromPrices(6.0)
chainP = sample(modelP, NUTS(0.65), 100)# throws an error

Not related but I think that Turing’s TruncatedNormal(...) might be faster than truncated(Normal(...)...).

1 Like

There are a couple of issues with the above:

  1. When calling nlsolve you provide zeros(1) as the initial point. This will create a Vector{Float64} which nlsolve will then re-use. nlsolve will also likely use this to infer the eltype of whatever they do internally because type-information means it will be faster. Then when you call sample with NUTS Turing will try to compute the gradient of the model. There are different approaches in Julia and ForwardDiff.jl is one of them. ForwarDiff.jl uses a special Dual number to perform keep track of the derivatives wrt. each input, hence when when we try to differentiate the model wrt. its parameters supply and demand will be Dual{..., Float64, ...} rather than simply a Float64. This in turn means that you will eventually call nlsolve giving it an init value of type Vector{Float64} but the output type will be of type Dual since supply and demand are of this type which is not what nlsolve expects. Hence errors. Solution: use zeros(eltype(supply, 1), 1) to construct an initial value with the correct element type.
  2. It doesn’t seem like nlsolve supports ForwardDiff.jl-differentiation through it, and so even if we address (1) we’ll still run into errors. This is an issue for NLSolve.jl though, not Turing.jl.

With that being said, there’s a way to address (2) using reverse-mode AD: https://github.com/JuliaNLSolvers/NLsolve.jl/issues/205#issuecomment-865856826. It’s easy to define adjoints (rules for computing gradients for reverse-mode) using ChainRulesCore.jl, and this is used by Zygote.jl, i.e. rules defined in ChainRulesCore.jl propagate to Zygote.jl. This you can also do for other reverse-mode AD-backends, so you’d have to look to the specific one you want to use. I’m going to use Zygote.jl though as it’s the go-to for reverse-mode these days.

module TuringTest

using Random
using Zygote
using NLsolve
using ForwardDiff
using Distributions
using Turing

# New dependenciens.
using ChainRulesCore
using IterativeSolvers
using LinearMaps

Random.seed!(42);

# https://github.com/JuliaNLSolvers/NLsolve.jl/issues/205#issuecomment-865856826
# Required to make it possible to differentiate through `nlsolve`.
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(nlsolve), f, x0; kwargs...)
    result = nlsolve(f, x0; kwargs...)
    function nlsolve_pullback(Δresult)
        Δx = Δresult[].zero
        x = result.zero
        _, f_pullback = rrule_via_ad(config, f, x)
        JT(v) = f_pullback(v)[2] # w.r.t. x
        # solve JT*Δfx = -Δx
        L = LinearMap(JT, length(x0))
        Δfx = gmres(L, -Δx)
        ∂f = f_pullback(Δfx)[1] # w.r.t. f itself (implicitly closed-over variables)
        return (NoTangent(), ∂f, ZeroTangent())
    end
    return result, nlsolve_pullback
end

# Switch to Zygote.jl because this supports differentiation through `nlsolve`
# thanks to the `rrule` we defined above.
Turing.setadbackend(:zygote)

# dummy supply function
supplyFunc(p,baseSupply) = p>-1 ? baseSupply+(p+1)^2 : 0

# dummy demand function
demandFunc(p,baseDemand) = p<20 ? baseDemand-(p+1)^2 : 0

# estimate supply and demand from prices, single datapoint to start with
@model supplyDemandFromPrices(price)=begin
    sigma ~ truncated(Normal(0, 100), 0, Inf)
    supply ~ truncated(Normal(0, 100), 0, Inf)
    demand ~ truncated(Normal(0, 100), 0, Inf)
    # The offending call- try to estimate clearing price
    clearingPrice = nlsolve(
        x->[supplyFunc(x[1],supply)-demandFunc(x[1],demand)],
        zeros(1),
        method=:anderson,
        m=10,
        autodiff=:forward
    ).zero[1]
    #  assume observed price is normally distributed around the 'market' price 
    price ~ Normal(clearingPrice,sigma)

    # Useful for debugging.
    return (; sigma, supply, demand, clearingPrice, price)
end;

model = supplyDemandFromPrices(6.0)

# Execute the model once without any differentiation just to make
# sure it works.
sigma, supply, demand, clearingPrice, price = model()

# Now we sample.
chains = sample(model, NUTS(), 1000)

Now, this make take a while to sample from because

  1. nlsolve and the computations required for its adjoint aren’t necessarily cheap. Of course not prohibitively expensive, but we’re going to have to call them a lot. NUTS has a max tree depth setting which effectively specifices the maximum number of steps (and thus gradient-calls) we’ll do. By default this is 10 which in turn means that we’re never going to take more than 2^10 = 1024 steps per iteration. This is a lot of steps per iteration though.
  2. We’re using an iterative solver for a non-linear problem and so the answer is not exact. Inexact gradients will confuse the sampler, likely leading to saturation of the max tree depth, i.e. we’ll probably be taking as many steps as allowed (1024, as mentioned above) on every iteration.
2 Likes