Hey everyone,
I want to use AD to get gradients of a likelihood function. Inside this likelihood function, I need to use a root solver to get the quantiles of an univariate Mixture distribution, as there is no closed form solution available. I can readily call this function and evaluate it, but I wont get gradients from it, neither via ForwardDiff, Tracker or Zygote.
I assume the problem is the root solver and the custom functions within the likelihood function - there are no methods implemented for these for Dual Numbers/TrackedArrays/…, and this is were I am currently stuck. I also dont find any other solutions online.The whole code is stated at the end of this post, here is the likelihood function:
function set_lik_log(obs_unif::Array{T}) where {T<:Real}
function get_lik_log(θ)
μ = get_μ( θ[1] )
Σ = get_Σ( θ[2:3] )
π1 = θ[4:5]
distr = [MvNormal( μ[1,:], Σ[1] ), MvNormal(μ[2,:], Σ[2] ) ]
MixModel_univariate = get_univariate_mixture(distr, π1)
MixModel_bivariate = Distributions.MixtureModel( distr, π1 )
obs = get_inv_CDF(MixModel_univariate, obs_unif )
lik_log = 0.0
for time in 1:size(obs,1)
lik_log += logpdf(MixModel_bivariate , obs[time,:] )
for state in eachindex(MixModel_univariate)
lik_log -= logpdf(MixModel_univariate[state] , obs[time,state] )
end
end
return lik_log
end
end
I can get gradients if I state μ, Σ, π1
without the functions via
μ = [ 1. θ[1] ; -1. -θ[1] ]
Σ = [ [ 1. θ[2] ; θ[2] 1.], [ 1. θ[3] ; θ[3] 1.] ]
π1 = θ[4:5]
and if I push the line
obs = get_inv_CDF(MixModel_univariate, obs_unif )
between set_lik_log
and get_lik_log
, but this would not lead to the correct solution and is also not a solution for models with a larger parameter vector. In the current case, I get either errors that mutation of arrays is not allowed (Zygote) or that there are no methods implemented for Duals (ForwardDiff) or TrackedArrays (Tracker), and I dont know how to go onwards from here. In case someone might have the time to quickly look at it, here is a working example of the code, bar the gradient evaluations:
################################################################################
# Discourse example
using Distributions
using DistributionsAD
using LinearAlgebra
using PDMats
using Roots
using Tracker
using ForwardDiff
using Zygote
################################################################################
#Functions for θ
function get_μ(θ::Number)
return [ 1. θ ; -1. -θ ]
end
function get_Σ(ρ::AbstractVector)
Σ = Matrix{Float64}[] #Vector of Matrices
for iter in eachindex(ρ)
Σ_temp = zeros(2,2) #Initiate Matrix
Σ_temp[tril!(trues(size(Σ_temp)), -1)] .= ρ[iter] #Assign covariance
Σ_temp += I #Assign unit variance
Σ_temp = Symmetric( Σ_temp, :L) #Make Matrix symmetric
push!(Σ, Σ_temp)
end
return Σ
end
function get_univariate_mixture(distr::Vector{<:Distribution}, π)
μ = first.( Distributions.params.( (distr) ) )
σ = 1.0 #For all states, each variance set to 1 ?
return [ Distributions.MixtureModel(map(x -> Normal(x, σ), μ[state]), π ) for state in eachindex(μ) ]
end
################################################################################
# Functions to find roots
function set_quantile(distr::MixtureModel{Univariate}, prob::T) where {T<:Real}
function find_quantile(x::T) where {T<:AbstractFloat}
return Distributions.cdf(distr, x) - prob
end
end
function get_inv_CDF(distr::MixtureModel{Univariate}, prob::T) where {T<:Real}
#Assign initial bracket
μ = first.( Distributions.params.( Distributions.components(distr) ) )
bracket = (μ[1] - 30.0, μ[2] + 30.0 ) ##REDO properly
#Assign root closure
roots = set_quantile(distr, prob)
#Find root
return find_zero(roots, bracket, Bisection() )
end
#Dispatch method for use in MCMC code with bivariate uniform data
function get_inv_CDF(distr::Vector{<:MixtureModel{Univariate}}, prob::Array{T}) where {T<:Real}
obs = [ get_inv_CDF.(distr[dimension], prob[: , dimension] ) for dimension in eachindex(distr) ]
return hcat(obs...)
end
################################################################################[ 1. θ[1] ; -1. -θ[1] ] #
################################################################################[ [ 1. θ[2] ; θ[2] 1.], [ 1. θ[3] ; θ[3] 1.] ] #
function set_lik_log(obs_unif::Array{T}) where {T<:Real}
function get_lik_log(θ)
μ = get_μ( θ[1] )
Σ = get_Σ( θ[2:3] )
π1 = θ[4:5]
distr = [MvNormal( μ[1,:], Σ[1] ), MvNormal(μ[2,:], Σ[2] ) ]
MixModel_univariate = get_univariate_mixture(distr, π1)
MixModel_bivariate = Distributions.MixtureModel( distr, π1 )
obs = get_inv_CDF(MixModel_univariate, obs_unif )
lik_log = 0.0
for time in 1:size(obs,1)
lik_log += logpdf(MixModel_bivariate , obs[time,:] )
for state in eachindex(MixModel_univariate)
lik_log -= logpdf(MixModel_univariate[state] , obs[time,state] )
end
end
return lik_log
end
end
#working
obs_unif = rand(Uniform(0,1), (100,2) )
ll = set_lik_log(obs_unif)
θ = [.1, .2, .3, .4, .6]
ll(θ)
#neither is working
#1
grad = ForwardDiff.gradient( ll, θ )
#2
a2,b2 = Tracker.forward( ll, θ )
#3
value, back = Zygote.pullback(ll, θ)
grad = back(1)[1]