I’m not any sort of advanced Julia user myself, but I think this can be greatly improved. The following is loosely based on the design on Distributions.jl.
First, functions are almost never stored in structs. You can do it, but usually it’s more idiomatic to define a function with different methods for each distribution:
import Statistics
abstract type Distribution end
struct Normal{T<:Real} <: Distribution
loc::T
scale::T
end
family(::Normal) = "gaussian1"
struct Uniform{T<:Real} <: Distribution
lo::T
hi::T
end
family(::Uniform) = "uniform"
# === Interface ===
varlower(::Normal) = -Inf
varupper(::Normal) = Inf
# One function can have multiple methods
varlower(d::Uniform) = d.lo
varupper(d::Uniform) = d.hi
lpdf(d::Normal, x::Real) = -0.5 * (x - d.loc)^2 / d.scale^2 - log(d.scale) - 0.5 * log(2*pi)
lpdf(d::Uniform, x::Real) = (d.lo <= x <= d.hi) ? -log(d.hi - d.lo) : -Inf
# This function works for ALL Distributions that have a `lpdf` function
pdf(d::Distribution, x) = exp(lpdf(d, x))
Statistics.mean(d::Normal) = d.loc
Statistics.var(d::Normal) = d.scale^2
Statistics.skewness(d::Normal) = 0.0
To compute gradients and Hessians, we need to extract parameters from our distribution:
# Distribution to params
to_params(d::Normal) = [d.loc, d.scale]
to_params(d::Uniform) = [d.lo, d.hi]
# Params to distribution
from_params(::Normal, par::AbstractVector{<:Real}) = Normal(par...)
from_params(::Uniform, par::AbstractVector{<:Real}) = Uniform(par...)
Now computing gradients and Hessians is easy:
function grad_loglik(_d::Distribution, data::AbstractVector, par::AbstractVector{<:Real})
full_loglik(par::AbstractVector{<:Real}) = begin
d = from_params(_d, par)
sum(x -> lpdf(d, x), data)
# same as:
# sum(logpdf(d, x) for x in data)
end
gradient(full_loglik, par)
end
Note that the above function automatically works for all Distribution
s that implement from_params(::Distribution, ::AbstractVector{<:Real})
and logpdf(::Distribution, ::Any)
, so you can leave it as-is, define new distributions and you’ll automatically gain the ability to autodiff them.
To compute per-sample gradients, compute the Jacobian of a function that returns per-sample log-PDFs:
julia> let
data = randn(10)
ForwardDiff.jacobian(
par -> lpdf.(Ref(Normal(par...)), data),
[0.3, 1.4] # params
)
end
10×2 Matrix{Float64}:
0.500023 -0.364254
-0.537884 -0.309238
-0.469201 -0.406076
-0.697264 -0.0336369
0.289679 -0.596806
-0.222195 -0.645167
-0.262489 -0.617825
0.318568 -0.572206
-0.610699 -0.192151
-0.0406918 -0.711968
Of course, this is just a rough idea and can be improved in various ways.
Personally, I don’t like having this from_params(_d, par)
function. I think I need it because I can only differentiate with respect to vectors, but I actually want to differentiate wrt a Distribution
. I want to do this:
julia> ForwardDiff.gradient(d->loss(d, randn(5)), Normal(0.6, 0.9))
ERROR: MethodError: no method matching gradient(::var"#3#4", ::Normal{Float64})
The function `gradient` exists, but no method is defined for this combination of argument types.
Closest candidates are:
gradient(::Any, ::Real)
@ ForwardDiff ~/.julia/packages/ForwardDiff/Wq9Wb/src/gradient.jl:46
gradient(::F, ::AbstractArray, ::ForwardDiff.GradientConfig{T}, ::Val{CHK}) where {F, T, CHK}
@ ForwardDiff ~/.julia/packages/ForwardDiff/Wq9Wb/src/gradient.jl:16
gradient(::F, ::AbstractArray, ::ForwardDiff.GradientConfig{T}) where {F, T}
@ ForwardDiff ~/.julia/packages/ForwardDiff/Wq9Wb/src/gradient.jl:16
...
This works in JAX, by the way, because it has a concept of PyTrees and can automatically (!) extract vectors of parameters from structs/classes and rebuild objects from such vectors. It’s extremely easy to use because this lets you differentiate through dict
s, neural networks and basically anything you register as a PyTree. AFAIK this is how GitHub - patrick-kidger/equinox: Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/ works: you make your class a subclass of equinox.Module
- and boom, now you can autodiff wrt its parameters, no extra setup required!
In Julia, you can do something like How to optimize a neural network with a "traditional" optimizer?, but I basically never got it to work and stick to from_params
, as shown here…