There are a couple of issues with the above:
- When calling
nlsolve
you provide zeros(1)
as the initial point. This will create a Vector{Float64}
which nlsolve
will then re-use. nlsolve
will also likely use this to infer the eltype
of whatever they do internally because type-information means it will be faster. Then when you call sample
with NUTS
Turing will try to compute the gradient of the model. There are different approaches in Julia and ForwardDiff.jl is one of them. ForwarDiff.jl uses a special Dual
number to perform keep track of the derivatives wrt. each input, hence when when we try to differentiate the model wrt. its parameters supply
and demand
will be Dual{..., Float64, ...}
rather than simply a Float64
. This in turn means that you will eventually call nlsolve
giving it an init value of type Vector{Float64}
but the output type will be of type Dual
since supply
and demand
are of this type which is not what nlsolve
expects. Hence errors. Solution: use zeros(eltype(supply, 1), 1)
to construct an initial value with the correct element type.
- It doesn’t seem like
nlsolve
supports ForwardDiff.jl-differentiation through it, and so even if we address (1) we’ll still run into errors. This is an issue for NLSolve.jl though, not Turing.jl.
With that being said, there’s a way to address (2) using reverse-mode AD: https://github.com/JuliaNLSolvers/NLsolve.jl/issues/205#issuecomment-865856826. It’s easy to define adjoints (rules for computing gradients for reverse-mode) using ChainRulesCore.jl, and this is used by Zygote.jl, i.e. rules defined in ChainRulesCore.jl propagate to Zygote.jl. This you can also do for other reverse-mode AD-backends, so you’d have to look to the specific one you want to use. I’m going to use Zygote.jl though as it’s the go-to for reverse-mode these days.
module TuringTest
using Random
using Zygote
using NLsolve
using ForwardDiff
using Distributions
using Turing
# New dependenciens.
using ChainRulesCore
using IterativeSolvers
using LinearMaps
Random.seed!(42);
# https://github.com/JuliaNLSolvers/NLsolve.jl/issues/205#issuecomment-865856826
# Required to make it possible to differentiate through `nlsolve`.
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(nlsolve), f, x0; kwargs...)
result = nlsolve(f, x0; kwargs...)
function nlsolve_pullback(Δresult)
Δx = Δresult[].zero
x = result.zero
_, f_pullback = rrule_via_ad(config, f, x)
JT(v) = f_pullback(v)[2] # w.r.t. x
# solve JT*Δfx = -Δx
L = LinearMap(JT, length(x0))
Δfx = gmres(L, -Δx)
∂f = f_pullback(Δfx)[1] # w.r.t. f itself (implicitly closed-over variables)
return (NoTangent(), ∂f, ZeroTangent())
end
return result, nlsolve_pullback
end
# Switch to Zygote.jl because this supports differentiation through `nlsolve`
# thanks to the `rrule` we defined above.
Turing.setadbackend(:zygote)
# dummy supply function
supplyFunc(p,baseSupply) = p>-1 ? baseSupply+(p+1)^2 : 0
# dummy demand function
demandFunc(p,baseDemand) = p<20 ? baseDemand-(p+1)^2 : 0
# estimate supply and demand from prices, single datapoint to start with
@model supplyDemandFromPrices(price)=begin
sigma ~ truncated(Normal(0, 100), 0, Inf)
supply ~ truncated(Normal(0, 100), 0, Inf)
demand ~ truncated(Normal(0, 100), 0, Inf)
# The offending call- try to estimate clearing price
clearingPrice = nlsolve(
x->[supplyFunc(x[1],supply)-demandFunc(x[1],demand)],
zeros(1),
method=:anderson,
m=10,
autodiff=:forward
).zero[1]
# assume observed price is normally distributed around the 'market' price
price ~ Normal(clearingPrice,sigma)
# Useful for debugging.
return (; sigma, supply, demand, clearingPrice, price)
end;
model = supplyDemandFromPrices(6.0)
# Execute the model once without any differentiation just to make
# sure it works.
sigma, supply, demand, clearingPrice, price = model()
# Now we sample.
chains = sample(model, NUTS(), 1000)
Now, this make take a while to sample from because
-
nlsolve
and the computations required for its adjoint aren’t necessarily cheap. Of course not prohibitively expensive, but we’re going to have to call them a lot. NUTS
has a max tree depth setting which effectively specifices the maximum number of steps (and thus gradient-calls) we’ll do. By default this is 10 which in turn means that we’re never going to take more than 2^10 = 1024 steps per iteration. This is a lot of steps per iteration though.
- We’re using an iterative solver for a non-linear problem and so the answer is not exact. Inexact gradients will confuse the sampler, likely leading to saturation of the max tree depth, i.e. we’ll probably be taking as many steps as allowed (1024, as mentioned above) on every iteration.