Automatic Differentiation with custom functions within likelihood function

Hey everyone,

I want to use AD to get gradients of a likelihood function. Inside this likelihood function, I need to use a root solver to get the quantiles of an univariate Mixture distribution, as there is no closed form solution available. I can readily call this function and evaluate it, but I wont get gradients from it, neither via ForwardDiff, Tracker or Zygote.

I assume the problem is the root solver and the custom functions within the likelihood function - there are no methods implemented for these for Dual Numbers/TrackedArrays/…, and this is were I am currently stuck. I also dont find any other solutions online.The whole code is stated at the end of this post, here is the likelihood function:

function set_lik_log(obs_unif::Array{T}) where {T<:Real}
    function get_lik_log(θ)
        μ = get_μ( θ[1] )
        Σ = get_Σ( θ[2:3] )
        π1 = θ[4:5]

        distr = [MvNormal( μ[1,:], Σ[1] ), MvNormal(μ[2,:],  Σ[2] ) ]
        MixModel_univariate = get_univariate_mixture(distr, π1)
        MixModel_bivariate = Distributions.MixtureModel( distr, π1 )
        obs = get_inv_CDF(MixModel_univariate, obs_unif )

        lik_log = 0.0
        for time in 1:size(obs,1)
            lik_log += logpdf(MixModel_bivariate , obs[time,:] )
            for state in eachindex(MixModel_univariate)
                lik_log -= logpdf(MixModel_univariate[state] , obs[time,state] )
            end
        end
        return lik_log
    end
end

I can get gradients if I state μ, Σ, π1 without the functions via

        μ = [ 1. θ[1] ; -1. -θ[1] ]
        Σ = [ [ 1. θ[2] ; θ[2] 1.], [ 1. θ[3] ; θ[3] 1.] ]
        π1 = θ[4:5]

and if I push the line

obs = get_inv_CDF(MixModel_univariate, obs_unif )

between set_lik_log and get_lik_log, but this would not lead to the correct solution and is also not a solution for models with a larger parameter vector. In the current case, I get either errors that mutation of arrays is not allowed (Zygote) or that there are no methods implemented for Duals (ForwardDiff) or TrackedArrays (Tracker), and I dont know how to go onwards from here. In case someone might have the time to quickly look at it, here is a working example of the code, bar the gradient evaluations:

################################################################################
# Discourse example

using Distributions
using DistributionsAD
using LinearAlgebra
using PDMats
using Roots

using Tracker
using ForwardDiff
using Zygote

################################################################################
#Functions for θ
function get_μ(θ::Number)
    return [ 1. θ ; -1. -θ ]
end

function get_Σ(ρ::AbstractVector)
    Σ = Matrix{Float64}[] #Vector of Matrices
    for iter in eachindex(ρ)
        Σ_temp = zeros(2,2)                                     #Initiate Matrix
        Σ_temp[tril!(trues(size(Σ_temp)), -1)] .= ρ[iter]       #Assign covariance
        Σ_temp += I                                             #Assign unit variance
        Σ_temp = Symmetric( Σ_temp, :L)                         #Make Matrix symmetric
        push!(Σ, Σ_temp)
    end
    return Σ
end

function get_univariate_mixture(distr::Vector{<:Distribution}, π)
    μ = first.( Distributions.params.( (distr) ) )
    σ = 1.0 #For all states, each variance set to 1 ?
    return [ Distributions.MixtureModel(map(x -> Normal(x, σ), μ[state]), π ) for state in eachindex(μ) ]
end

################################################################################
# Functions to find roots
function set_quantile(distr::MixtureModel{Univariate}, prob::T) where {T<:Real}
    function find_quantile(x::T) where {T<:AbstractFloat}
        return Distributions.cdf(distr, x)  - prob
    end
end

function get_inv_CDF(distr::MixtureModel{Univariate}, prob::T) where {T<:Real}
    #Assign initial bracket
    μ = first.( Distributions.params.( Distributions.components(distr) ) )
    bracket = (μ[1] - 30.0, μ[2] + 30.0 ) ##REDO properly
    #Assign root closure
    roots = set_quantile(distr, prob)
    #Find root
    return find_zero(roots, bracket, Bisection() )

end
#Dispatch method for use in MCMC code with bivariate uniform data
function get_inv_CDF(distr::Vector{<:MixtureModel{Univariate}}, prob::Array{T}) where {T<:Real}
    obs = [ get_inv_CDF.(distr[dimension], prob[: , dimension] ) for dimension in eachindex(distr) ]
    return hcat(obs...)
end

################################################################################[ 1. θ[1] ; -1. -θ[1] ] #
################################################################################[ [ 1. θ[2] ; θ[2] 1.], [ 1. θ[3] ; θ[3] 1.] ] #
function set_lik_log(obs_unif::Array{T}) where {T<:Real}
    function get_lik_log(θ)
        μ = get_μ( θ[1] )
        Σ = get_Σ( θ[2:3] )
        π1 = θ[4:5]

        distr = [MvNormal( μ[1,:], Σ[1] ), MvNormal(μ[2,:],  Σ[2] ) ]
        MixModel_univariate = get_univariate_mixture(distr, π1)
        MixModel_bivariate = Distributions.MixtureModel( distr, π1 )
        obs = get_inv_CDF(MixModel_univariate, obs_unif )

        lik_log = 0.0
        for time in 1:size(obs,1)
            lik_log += logpdf(MixModel_bivariate , obs[time,:] )
            for state in eachindex(MixModel_univariate)
                lik_log -= logpdf(MixModel_univariate[state] , obs[time,state] )
            end
        end
        return lik_log
    end
end

#working
obs_unif = rand(Uniform(0,1), (100,2) )
ll = set_lik_log(obs_unif)
θ = [.1, .2, .3, .4, .6]
ll(θ)

#neither is working
#1
grad = ForwardDiff.gradient( ll, θ  )
#2
a2,b2 = Tracker.forward( ll, θ )
#3
value, back = Zygote.pullback(ll, θ)
grad = back(1)[1]

my changes are destined to make this work with ForwardDiff.
change get_Σ to this:

function get_Σ(ρ::AbstractVector)
    Σ = Matrix{eltype(ρ)}[] #Vector of Matrices, propagate duals by adding explicitly the type of ρ
    for iter in eachindex(ρ)
        Σ_temp = zeros(eltype(ρ),2,2)                                     #Initiate Matrix,same as Σ
        Σ_temp[tril!(trues(size(Σ_temp)), -1)] .= ρ[iter]       #Assign covariance
        Σ_temp += I                                             #Assign unit variance
        Σ_temp = Symmetric( Σ_temp, :L)                         #Make Matrix symmetric
        push!(Σ, Σ_temp)
    end
    return Σ
end

change set_quantile to this

function set_quantile(distr::MixtureModel{Univariate}, prob::T) where {T<:Real}
    function find_quantile(x::T) where {T<:Real} #dual propagation occurs only on functions that accept reals
        return Distributions.cdf(distr, x)  - prob
    end
end

ForwardDiff “worked” as it throws a result, but is a vector of NaNs

You may find this discussion relevant, has a lot of nice suggestions:

First of all, thank you both a lot for your comments, and your comments in the two functions! Is there maybe another change you made to make it “work” for ForwardDiff in the get_inv_CDF function? It still throws an error for me in the line:

    #Find root
    return find_zero(roots, bracket, Bisection() )

I am working under Julia 1.2 and updated the packages above before running the code.

Wow, thank you for the link! I am glad this seems to work with ForwardDiff! I will have a look at it today for my problem.

i implemented my custom bisection method, because the other one uses prevfloat somewhere ( i just copied that implementation from a matlab forum and can be improved. with that, a gradient is given, but maybe you should check with a finite differenciation package (DiffEqDiffTools.jl?)

using Distributions
using DistributionsAD
using LinearAlgebra
using PDMats
using Roots

#using Tracker
using ForwardDiff
#using Zygote

function simple_bisection(f,bracket)
    xl,xu = bracket
    tol=1e-08
    xr = NaN
    fxu = f(xu)
    xnew = Vector{typeof(fxu)}(undef,0)
    if fxu*f(xl)<0
    push!(xnew,zero(fxu))
    for i=2:1000
    xr=(xu+xl)/2;
        if f(xu)*f(xr)<0
            xl=xr;
        else
            xu=xr;
        end
     
        if f(xl)*f(xr)<0
            xu=xr;
        else
            xl=xr;
        end
    
    
        push!(xnew,xr)
        if abs((xnew[i]-xnew[i-1])/xnew[i])<tol
            break
        end
    end
    return xr
    else
    throw("bisection not possible, brackets are of same sign")
    end
end
################################################################################
#Functions for θ
function get_μ(θ::Number)
    return [ 1. θ ; -1. -θ ]
end

function get_Σ(ρ::AbstractVector)
    Σ = Matrix{eltype(ρ)}[] #Vector of Matrices
    for iter in eachindex(ρ)
        Σ_temp = zeros(eltype(ρ),2,2)                                     #Initiate Matrix
        Σ_temp[tril!(trues(size(Σ_temp)), -1)] .= ρ[iter]       #Assign covariance
        Σ_temp += I                                             #Assign unit variance
        Σ_temp = Symmetric( Σ_temp, :L)                         #Make Matrix symmetric
        push!(Σ, Σ_temp)
    end
    return Σ
end

function get_univariate_mixture(distr::Vector{<:Distribution}, π)
    μ = first.( Distributions.params.( (distr) ) )
    σ = 1.0 #For all states, each variance set to 1 ?
    return [ Distributions.MixtureModel(map(x -> Normal(x, σ), μ[state]), π ) for state in eachindex(μ) ]
end

################################################################################
# Functions to find roots
function set_quantile(distr::MixtureModel{Univariate}, prob::T) where {T<:Real}
    function find_quantile(x::T) where {T<:Real}
        return Distributions.cdf(distr, x)  - prob
    end
end

function get_inv_CDF(distr::MixtureModel{Univariate}, prob::T) where {T<:Real}
    #Assign initial bracket
    μ = first.( Distributions.params.( Distributions.components(distr) ) )
    bracket = (μ[1] - 30.0, μ[2] + 30.0 ) ##REDO properly
    #Assign root closure
    roots = set_quantile(distr, prob)
    #Find root
    return simple_bisection(roots,bracket)
    #res = Distributions.cdf(distr, x)(distr, prob) 
    #@show res
    #return res
    #return find_zero(roots, bracket)

end
#Dispatch method for use in MCMC code with bivariate uniform data
function get_inv_CDF(distr::Vector{<:MixtureModel{Univariate}}, prob::Array{T}) where {T<:Real}
    obs = [ get_inv_CDF.(distr[dimension], prob[: , dimension] ) for dimension in eachindex(distr) ]
    return hcat(obs...)
end

################################################################################[ 1. θ[1] ; -1. -θ[1] ] #
################################################################################[ [ 1. θ[2] ; θ[2] 1.], [ 1. θ[3] ; θ[3] 1.] ] #
function set_lik_log(obs_unif::Array{T}) where {T<:Real}
    function get_lik_log(θ)
        μ = get_μ( θ[1] )
        Σ = get_Σ( θ[2:3] )
        π1 = θ[4:5]

        distr = [MvNormal( μ[1,:], Σ[1] ), MvNormal(μ[2,:],  Σ[2] ) ]
        MixModel_univariate = get_univariate_mixture(distr, π1)
        MixModel_bivariate = Distributions.MixtureModel( distr, π1 )
        obs = get_inv_CDF(MixModel_univariate, obs_unif )

        lik_log = 0.0
        for time in 1:size(obs,1)
            lik_log += logpdf(MixModel_bivariate , obs[time,:] )
            for state in eachindex(MixModel_univariate)
                lik_log -= logpdf(MixModel_univariate[state] , obs[time,state] )
            end
        end
        return lik_log
    end
end

#working
obs_unif = rand(Uniform(0,1), (100,2) )
ll = set_lik_log(obs_unif)
θ = [.1, .2, .3, .4, .6]
grad = ForwardDiff.gradient( ll, θ  )

also, there exists Distributions.quantile_bisect, but it didn’t work for me when propagating duals.
The gradient in that case gives numbers around these values:

5-element Array{Float64,1}:
  -65.83409982000504
   -9.127450339548174
  -14.703667317669348
  -43.357436197191085
 -137.76170920187263

Thank you! Defining

#Add support for root function for Tracker and ForwardDiff
import Base: prevfloat, nextfloat
Base.prevfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = ForwardDiff.Dual{T}(prevfloat(d.value), prevfloat.(d.partials)...)
Base.nextfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = ForwardDiff.Dual{T}(nextfloat(d.value), nextfloat.(d.partials)...)

Base.prevfloat(d::Tracker.TrackedReal) = prevfloat(d.data) 
Base.nextfloat(d::Tracker.TrackedReal) = nextfloat(d.data) 

would make the code work for Tracker and ForwardDiff, but I think going forward your solution with a custom root solver might be preferable if this can easily be done.