Quick epilogue if anyone comes across this. Writing this adjoint turned into nice learning experience about AD and CUDA. As shown below, the gpu version runs about 200x faster than the vectorized cpu version.
using Revise, CUDAapi, Flux, BenchmarkTools, CUDAnative, Distributions, Zygote, Flux, DiffRules
using SpecialFunctions
import CuArrays
CuArrays.allowscalar(false)
#GPU vectorized log gamma pdf
function vecgammalogpdf(k::TV, θ::TV, x::TV)::TV where TV<:CuArrays.CuVector
Γs_part = cudagammalogpdf_part.(k,θ,x)
return Γs_part .- cudaveclgamma(k)
end
#see https://math.stackexchange.com/questions/1441753/approximating-the-digamma-function
#and https://en.wikipedia.org/wiki/Digamma_function
function cudaapproxdigamma(x::Real)
adj = 0.0f0
ψ = x
#the polynomial at the end is only accurate for large x, but fortunately we can transform the problem as follows
for i ∈ 1:3
adj -= 1f0/ψ
ψ += 1f0
end
ψ = (CUDAnative.log(ψ) - 1f0 / 2f0 / ψ - 1f0 / 12f0 / (ψ * ψ) +
1f0/120f0 * CUDAnative.pow(ψ, -4) - 1f0/252f0 * CUDAnative.pow(ψ, -6) +
1f0/240f0 * CUDAnative.pow(ψ, -8) - 5f0/660f0 * CUDAnative.pow(ψ, -10) +
691f0/32760f0 * CUDAnative.pow(ψ, -12) - 1f0/12f0 * CUDAnative.pow(ψ, -14)) + adj
return ψ
end
#does everything w.r.t computing the gamma pdf except compute the log gamma function
#adapted from StatsFuns
#https://github.com/JuliaStats/StatsFuns.jl/blob/master/src/distrs/gamma.jl
cudagammalogpdf_part(k,θ,x) = - k * CUDAnative.log(θ) + (k - 1f0) * CUDAnative.log(x) - x / θ
#vectorized log gamma function
cudaveclgamma(k) = (CUDAnative.lgamma).(k)
#define the adjoint
Zygote.@adjoint cudaveclgamma(k) = cudaveclgamma(k), y->(y .* cudaapproxdigamma.(k), )
Zygote.refresh()
#lightly optimized cpu version for testing
function vecgammalogpdf(k::TV, θ::TV, x::TV)::TV where TV<:AbstractVector
Γs::TV = -(SpecialFunctions.loggamma).(k) .- k .* log.(θ) .+ (k .- 1f0) .* log.(x) .- x ./ θ
return Γs
end
slowvecgammalogpdf(k, θ, x) = ((kᵢ, θᵢ, xᵢ)->logpdf(Gamma(kᵢ,θᵢ), xᵢ)).(k, θ, x)
#test it
function testgrad(N::Int)
τ = abs.(randn(Float32, N)) .+ .001f0
absμ = abs.(randn(Float32, N)) .+ .001f0
θ = 1.0f0 ./ τ
k = absμ .* τ
v = rand(Float32, N)
cuθ = θ |> gpu
cuk = k |> gpu
cuv = v |> gpu
@info "testing function values (N=$N) cpu: $(sum(vecgammalogpdf(k, θ, v))) " *
"gpu: $(sum(vecgammalogpdf(cuk, cuθ, cuv)))"
@info "Additional check of value: $(sum(slowvecgammalogpdf(k, θ, v)))"
@info "cpu version gradient N=$N"
if N ≤ 10
gs = gradient((k,θ)->sum(vecgammalogpdf(k, θ, v)),k,θ)
@info "gs: $gs"
else
gs = @btime gradient((k,θ)->sum(vecgammalogpdf(k, θ, $v)),$k,$θ)
end
@info "gpu version gradient N=$N"
if N ≤ 10
cugs = gradient((cuk,cuθ)->sum(vecgammalogpdf(cuk, cuθ, cuv)),cuk,cuθ)
@info "cugs: $cugs"
else
cugs = @btime gradient((cuk,cuθ)->sum(vecgammalogpdf(cuk, cuθ, $cuv)),$cuk,$cuθ)
end
end
testgrad(5)
testgrad(10^6)
Output:
[ Info: testing function values (N=5) cpu: -9.063408 gpu: -9.063408
[ Info: Additional check of value: -9.063408
[ Info: cpu version gradient N=5
[ Info: gs: (Float32[12.996181, 4.6498237, 6.403746, -1.7591481, 1.6755874], Float32[1.1360388, -0.0055119246, -0.0006913543, -7.3534327, 0.007893473])
[ Info: gpu version gradient N=5
[ Info: cugs: (Float32[12.9961815, 4.649824, 6.403746, -1.7591481, 1.6755874], Float32[1.1360388, -0.0055119228, -0.0006913617, -7.3534327, 0.007893473])
[ Info: testing function values (N=1000000) cpu: -1.3431859e6 gpu: -1.3431859e6
[ Info: Additional check of value: -1.3431859e6
[ Info: cpu version gradient N=1000000
255.517 ms (9000248 allocations: 274.67 MiB)
[ Info: gpu version gradient N=1000000
1.136 ms (1246 allocations: 41.80 KiB)
You can see a nice alternate solution posted here by @xukai92 in reference to https://github.com/FluxML/Flux.jl/issues/383, which ran a bit faster but overtly digs into ForwardDiff and DiffRules.