Automatic differentiation of loglikelihood function. Am I doing it right?

Hi everyone!

I’m a statistician and R user taking my first steps with Julia. I have extensive experience with R, and this summer I spent my free time reimplementing the R package gamlss as a personal exercise. Julia has really caught my interest, and I’d like to learn the language by working on a project to reimplement gamlss in Julia.

I’m just getting started, and one of the first things I’m trying to build is an interface that somehow mirrors R’s family concept, but immediately generalizable to distributions with multiple parameters modeled through regression models.

The idea is to create a Family struct that takes various arguments, including a series of functions like lpdf(). This lpdf() function needs to be differentiated with respect to the model parameters contained in an array par. Currently, I’ve conceived par in the most general terms possible (given my current knowledge) as an “array of arrays,” where each element contains a vector of dimension n (where n is the sample size). This allows me to compute the log-likelihood for the i-th observation x_i ​ at its corresponding parameter value \theta_i​. It seemed natural to design an lpdf() function within Family that uses broadcasting . to perform these operations.

The trickiest part, however, is figuring out how to optimally compute the gradient and Hessian of the log-likelihood function for the i-th observation. I need the gradient and Hessian at the observation level to later build the chain rule for computing gradients and Hessians with respect to the regression model parameters, but that’s a problem I’ll tackle later.

The solution I’ve arrived at uses a double for loop and Zygote.jl, as you can see in the code below. It takes the data x, the parameter array par (structured as described above), and the Family object from which I extract the lpdf() expression that I differentiate for each observation.

Since I believe the best way to learn is by doing, making mistakes, and being corrected by people who know more than you do, I was wondering if you could share your thoughts on this implementation. Is this the best way to approach this kind of implementation in Julia?

Thanks!

using SpecialFunctions
using Zygote
using BenchmarkTools

struct Family
    family::String              # name of family
    vartype::String             # "continuous" or "discrete"?
    varlower                    # lower bound for variable
    varupper                    # upper bound for variable
    npar::Int                   # number of parameters
    parnames::Vector{String}    # names of parameters
    pardescr::Vector{String}    # description of parameters
    parlower::Vector{Float64}   # lower bounds of parameters
    parupper::Vector{Float64}   # upper bounds of parameters
    pdf::Function               # probability density (mass) function
    lpdf::Function              # log pdf/pmf
    cdf::Function               # cumulative distribution functionm
    lcdf::Function              # log cdf
    qf::Function                # quantile function
    E::Function                 # Expected value
    V::Function                 # Variance
    A::Function                 # Asymmetry
    K::Function                 # Kurtosis 
end

gaussian1 = Family(
    "gaussian1",
    "continuous",
    -Inf,
    Inf,
    2,
    ["μ", "σ²"],
    ["mean", "variance"],
    [-Inf, 0.0],
    [Inf, Inf],
    (x, par) -> exp.(-0.5 .* log.(2π*par[2]) .- 0.5 .* (x .- par[1]).^2 ./ par[2]),
    (x, par) -> -0.5 .* log.(2π*par[2]) .- 0.5 .* (x .- par[1]).^2 ./ par[2],
    (x, par) -> 0.5 .* (1 .+ erf.((x .- par[1]) ./ sqrt.(2*par[2]))),
    (x, par) -> log.(0.5 .* (1 .+ erf.((x .- par[1]) ./ sqrt.(2*par[2])))),
    (prob, par) -> par[1] .+ sqrt.(2*par[2]) .* erfinv.(2*prob .- 1),
    par -> par[1],
    par -> par[2],
    par -> 0,
    par -> 3
)


function grad_loglik(family::Family, x, par)
    n = size(x, 1)
    k = size(par, 1)
    g = zeros(n, k)
    par_vec = zeros(k)
    # Use Zygote to compute gradient
    for i in 1:n
        for j in 1:k
            par_vec[j] = par[j][i]                 
        end
        g[i, :] = gradient(p -> family.lpdf(x[i], p), par_vec)[1]
    end
    return g
end

function hess_loglik(family::Family, x, par)
    n = size(x, 1)
    k = size(par, 1)
    h = zeros(k, k, n)
    par_vec = zeros(k)
    # Use Zygote to compute hessian
    for i in 1:n
        for j in 1:k
            par_vec[j] = par[j][i]                 
        end
        h[:, :, i] = hessian(p -> family.lpdf(x[i], p), par_vec)
    end
    return h
end


n = 2
x = [1.0, 2.0]
mu = x
s2 = abs.(x)
par = [mu, s2]

@btime grad_loglik(gaussian1, x, par)
@btime hess_loglik(gaussian1, x, par)

I’m not any sort of advanced Julia user myself, but I think this can be greatly improved. The following is loosely based on the design on Distributions.jl.

First, functions are almost never stored in structs. You can do it, but usually it’s more idiomatic to define a function with different methods for each distribution:

import Statistics
abstract type Distribution end

struct Normal{T<:Real} <: Distribution
   loc::T
   scale::T
end

family(::Normal) = "gaussian1"

struct Uniform{T<:Real} <: Distribution
   lo::T
   hi::T
end

family(::Uniform) = "uniform"

# === Interface ===

varlower(::Normal) = -Inf
varupper(::Normal) = Inf

# One function can have multiple methods
varlower(d::Uniform) = d.lo
varupper(d::Uniform) = d.hi

lpdf(d::Normal, x::Real) = -0.5 * (x - d.loc)^2 / d.scale^2 - log(d.scale) - 0.5 * log(2*pi)
lpdf(d::Uniform, x::Real) = (d.lo <= x <= d.hi) ? -log(d.hi - d.lo) : -Inf

# This function works for ALL Distributions that have a `lpdf` function
pdf(d::Distribution, x) = exp(lpdf(d, x))

Statistics.mean(d::Normal) = d.loc
Statistics.var(d::Normal) = d.scale^2
Statistics.skewness(d::Normal) = 0.0

To compute gradients and Hessians, we need to extract parameters from our distribution:

# Distribution to params
to_params(d::Normal) = [d.loc, d.scale]
to_params(d::Uniform) = [d.lo, d.hi]

# Params to distribution
from_params(::Normal, par::AbstractVector{<:Real}) = Normal(par...)
from_params(::Uniform, par::AbstractVector{<:Real}) = Uniform(par...)

Now computing gradients and Hessians is easy:

function grad_loglik(_d::Distribution, data::AbstractVector, par::AbstractVector{<:Real})
   full_loglik(par::AbstractVector{<:Real}) = begin
      d = from_params(_d, par)
      sum(x -> lpdf(d, x), data)
      # same as:
      # sum(logpdf(d, x) for x in data)
   end
   gradient(full_loglik, par)
end

Note that the above function automatically works for all Distributions that implement from_params(::Distribution, ::AbstractVector{<:Real}) and logpdf(::Distribution, ::Any), so you can leave it as-is, define new distributions and you’ll automatically gain the ability to autodiff them.

To compute per-sample gradients, compute the Jacobian of a function that returns per-sample log-PDFs:

julia> let
        data = randn(10)
        ForwardDiff.jacobian(
          par -> lpdf.(Ref(Normal(par...)), data),
          [0.3, 1.4] # params
        )
       end
10×2 Matrix{Float64}:
  0.500023   -0.364254
 -0.537884   -0.309238
 -0.469201   -0.406076
 -0.697264   -0.0336369
  0.289679   -0.596806
 -0.222195   -0.645167
 -0.262489   -0.617825
  0.318568   -0.572206
 -0.610699   -0.192151
 -0.0406918  -0.711968

Of course, this is just a rough idea and can be improved in various ways.


Personally, I don’t like having this from_params(_d, par) function. I think I need it because I can only differentiate with respect to vectors, but I actually want to differentiate wrt a Distribution. I want to do this:

julia> ForwardDiff.gradient(d->loss(d, randn(5)), Normal(0.6, 0.9))
ERROR: MethodError: no method matching gradient(::var"#3#4", ::Normal{Float64})
The function `gradient` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  gradient(::Any, ::Real)
   @ ForwardDiff ~/.julia/packages/ForwardDiff/Wq9Wb/src/gradient.jl:46
  gradient(::F, ::AbstractArray, ::ForwardDiff.GradientConfig{T}, ::Val{CHK}) where {F, T, CHK}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/Wq9Wb/src/gradient.jl:16
  gradient(::F, ::AbstractArray, ::ForwardDiff.GradientConfig{T}) where {F, T}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/Wq9Wb/src/gradient.jl:16
  ...

This works in JAX, by the way, because it has a concept of PyTrees and can automatically (!) extract vectors of parameters from structs/classes and rebuild objects from such vectors. It’s extremely easy to use because this lets you differentiate through dicts, neural networks and basically anything you register as a PyTree. AFAIK this is how GitHub - patrick-kidger/equinox: Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/ works: you make your class a subclass of equinox.Module - and boom, now you can autodiff wrt its parameters, no extra setup required!

In Julia, you can do something like How to optimize a neural network with a "traditional" optimizer?, but I basically never got it to work and stick to from_params, as shown here…

2 Likes

I agree with @ForceBru’s general recommendation to use methods instead of function fields.

I’d also like to point out that the obligation to differentiate arrays is specific to forward-mode AD packages, in this case ForwardDiff.jl. If you switch to a reverse-mode AD package (Enzyme.jl or Mooncake.jl for example) you’ll be able to compute gradients of any structured object (see below). However, you will still need arrays if you want to compute Jacobian or Hessians, because AD needs a way to list all the parameters involved in the input/output (which will make up the columns/rows of the resulting matrix).

using Enzyme

struct Normal
    μ::Float64
    σ::Float64
end

logpdf(x::Number, n::Normal) = -log(n.σ) - (x - n.μ)^2 / 2n.σ^2
logpdf(x::AbstractVector, n::AbstractVector{Normal}) = mapreduce(logpdf, +, x, n)
julia> distributions = [Normal(i, i / 10) for i in 1:5];

julia> observations = float.(2:10);

julia> logpdf(observations, distributions)
-66.45512183336737

julia> Enzyme.gradient(Enzyme.Reverse, logpdf, observations, distributions)
([-99.99999999999999, -24.999999999999996, -11.11111111111111, -6.249999999999999, -4.0, 0.0, 0.0, 0.0, 0.0], Normal[Normal(99.99999999999999, 989.9999999999995), Normal(24.999999999999996, 119.99999999999994), Normal(11.11111111111111, 33.7037037037037), Normal(6.249999999999999, 13.124999999999993), Normal(4.0, 6.0)])

Regarding the original question of @giovannitinervia9, the double loop approach works but is very inefficient, especially with Zygote.jl which doesn’t like for loops. A vectorized formulation would be much better. I don’t think the Jacobian suggested above is quite appropriate though, because it seems to me like each sample computation is independent from the others? In that case, taking the gradient of the full likelihood as I did above automatically yields sample-wise gradients, since \nabla_{\theta_i} \sum_j f(\theta_j, x_j) = \nabla_1 f(\theta_i, x_i).

The Hessian is a bit more tricky due to the vectorization constraint, but you can use ComponentArrays.jl to convert between both representations.

Here’s a generic and efficient way to differentiate wrt arbitrary structs in Julia:

julia> using StaticArrays

julia> using ForwardDiff: gradient

julia> using AccessorsExtra

# if you want a vector back:
julia> gradient_vec(f, x, o=RecursiveOfType(Number)) =
               gradient(SVector(getall(x, o))) do vec
                       f(setall(x, o, vec))
               end

# if you want an object back:
julia> gradient_obj(f, x, o=RecursiveOfType(Number)) =
               setall(x, o, gradient_vec(f, x, o))

Applying to lpdf(::Normal, x) defined by @ForceBru above:

julia> gradient_vec(d -> lpdf(d, 0.5), Normal(0., 1.))
2-element SVector{2, Float64} with indices SOneTo(2):
  0.5
 -0.75

julia> gradient_obj(d -> lpdf(d, 0.5), Normal(0., 1.))
Normal{Float64}(0.5, -0.75)

It’s generally zero-cost:

julia> using BenchmarkTools

julia> @btime gradient_vec(d -> lpdf(d, 0.5), Normal(0., 1.))
  1.042 ns (0 allocations: 0 bytes)

julia> @btime gradient_obj(d -> lpdf(d, 0.5), Normal(0., 1.))
  1.083 ns (0 allocations: 0 bytes)

Also discussed in my recent JuliaCon’25 talk.

2 Likes

Thank you very much for your answers! I’m trying to understand how the Distributions.jl module works and I’ve been reading the code for the implementation of the Normal and Poisson distributions. However, I can’t see where the PDF or its logarithm is actually implemented in either of them.

If you run @edit logpdf(Normal(0, 1), 0.5) it will open up the line where this is implemented. However, in this case it is a bit obscure: it is a macro that delegates to the corresponding function to StatsFuns.jl. e.g. for the Normal distribution the delegation occurs here, which causes logpdf(d::Normal, x) to call StatsFuns.normlogpdf(d.μ, d.σ, x).

1 Like