Using ForwardDiff with Custom Struct

Hi guys,

I have a function, loglikelihood ll, of a vector of parameters that I need to AutoDiff. I also need to apply some transformations within ll, and as long as I keep the Vector form of the input, that works fine.

I would like to improve this a little to make it more general by filling the input into a custom mutable struct, my model mod. That way I could easily use AD for different model distributions. I try to define that model struct as general as possible, but continue to receive errors - the problem here is that I cannot put a ForwardDiff.Duals in an array of Floats, that is automatically assigned to my mutable struct. I tried to extend the function that fills the model but I believe that is not the issue here, as everything works fine if I use the function without filling.

I provided a short example here and would be very grateful if someone could provide any advice:

using Parameters, Distributions
using ForwardDiff, Zygote

mutable struct MyDistribution
    distribution    :: Any
###!! CAUSE OF THE ERROR !! :
    param           :: OrderedDict{Symbol, Any } #Parameter of Distribution
###!!
    prior           :: OrderedDict{Symbol, Vector{<:Distribution} } #Prior of Parameter of Distribution
end

function shape_param(θₜ::Vector{R}, prior) where {R<:Real}
    return length(θₜ) == 1 ? θₜ : reshape( θₜ, size(prior) ) #θₜ[1]
end

#shape_param(θₜ::Vector{ ForwardDiff.Dual{T,V,N} }, param)  where {T,V,N} = shape_param(ForwardDiff.value(θₜ), param)

#Functor to fill this struct - just as example to receive the same error as in original code
function (distr::MyDistribution)(θ_new::AbstractVector)
    @unpack param, prior = distr
    for state in 1:2
    param[:λ][state] = shape_param(θ_new[state], prior[:λ][state] ) #HERE IS MY ERROR LINE ORIGIN
    end
    @pack! distr = param
    return nothing
end

#Likelihood function of interest
function get_likelihood(distr::MyDistribution, data)
    function ll(θ_new::AbstractVector)
        #update MyDistribution
        distr(θ_new) #FIRST ERROR Line in AD
        #now call the logpdf
        lp = logpdf( distr.distribution(θ_new[1]), data)
        return lp
    end
end

#Calling these functions works as intended
distr = MyDistribution(Poisson,
                      OrderedDict(:λ => [ 1., 2. ] ),
                      OrderedDict(:λ => [ Gamma(2,2), Gamma(2,2) ] )
                      )
θ_new = [5., 6.]
data = 10
#Check - working
ll = get_likelihood(distr, data )
ll(θ_new)

#Unfortunately, cannot use AD
ForwardDiff.gradient(ll, θ_new) #TypeError: in typeassert, expected Float64, got ForwardDiff.Dual{Nothing,Float64,2}
Zygote.gradient(ll, θ_new) #Mutating arrays is not supported

Calling these functions works as intended, but unfortunately I cannot use autodiff. due to line

param[:λ][state] = shape_param(θ_new[state], prior[:λ][state] )

I can autodiff through the function shape_param, but cannot assign Duals to the struct. Does anyone have an idea how I could change that?

BR,

1 Like

Does the type assert occur in your code? This prevents ForwardDiff from using dual numbers. In Zygote’s case, it might be a missing gradient definition (see this discussion) for some data structure. For example, you can’t diff through code using FieldVectors, maybe this is the case for one of your data structures too. You could try replacing your OrderedDicts with ordinary ones.

1 Like

Also, let me point your attention to https://github.com/TuringLang/DistributionsAD.jl.

1 Like

Thanks for both of your answers! DistributionsAD is great, but it unfortunately does not make my code work here, as the functions work fine with ForwardDiff, I just dont find a way to fill the model with Dual Numbers, the rest works just fine actually.

Regarding Zygote, I have given up using it in case I have to use logpdfs with various probabilities, as I usually get the Mutating Arrays not allowed error in some or most cases here. I would very much like to use Reverse Mode AD as I have more and more parameter to differentiate, but did not find a workaround here.

1 Like