CUDA gradient of gamma PDF

Quick epilogue if anyone comes across this. Writing this adjoint turned into nice learning experience about AD and CUDA. As shown below, the gpu version runs about 200x faster than the vectorized cpu version.


using Revise, CUDAapi, Flux, BenchmarkTools, CUDAnative, Distributions, Zygote, Flux, DiffRules
using SpecialFunctions
import CuArrays
CuArrays.allowscalar(false)

#GPU vectorized log gamma pdf
function vecgammalogpdf(k::TV, θ::TV, x::TV)::TV where TV<:CuArrays.CuVector
  Γs_part = cudagammalogpdf_part.(k,θ,x)

  return Γs_part .- cudaveclgamma(k)
end

#see https://math.stackexchange.com/questions/1441753/approximating-the-digamma-function
#and https://en.wikipedia.org/wiki/Digamma_function
function cudaapproxdigamma(x::Real)
  adj = 0.0f0
  ψ = x
  #the polynomial at the end is only accurate for large x, but fortunately we can transform the problem as follows
  for i ∈ 1:3
    adj -= 1f0/ψ
    ψ += 1f0
  end

  ψ = (CUDAnative.log(ψ) - 1f0 / 2f0 / ψ - 1f0 / 12f0 / (ψ * ψ) +
    1f0/120f0 * CUDAnative.pow(ψ, -4) - 1f0/252f0 * CUDAnative.pow(ψ, -6) +
    1f0/240f0 * CUDAnative.pow(ψ, -8) - 5f0/660f0 * CUDAnative.pow(ψ, -10) +
    691f0/32760f0 * CUDAnative.pow(ψ, -12) - 1f0/12f0 * CUDAnative.pow(ψ, -14)) + adj

  return ψ
end

#does everything w.r.t computing the gamma pdf except compute the log gamma function
#adapted from StatsFuns
#https://github.com/JuliaStats/StatsFuns.jl/blob/master/src/distrs/gamma.jl
cudagammalogpdf_part(k,θ,x) = - k * CUDAnative.log(θ) + (k - 1f0) * CUDAnative.log(x) - x / θ

#vectorized log gamma function
cudaveclgamma(k) = (CUDAnative.lgamma).(k)

#define the adjoint
Zygote.@adjoint cudaveclgamma(k) = cudaveclgamma(k), y->(y .* cudaapproxdigamma.(k), )
Zygote.refresh()

#lightly optimized cpu version for testing
function vecgammalogpdf(k::TV, θ::TV, x::TV)::TV where TV<:AbstractVector

  Γs::TV = -(SpecialFunctions.loggamma).(k) .- k .* log.(θ) .+ (k .- 1f0) .* log.(x) .- x ./ θ

  return Γs
end

slowvecgammalogpdf(k, θ, x) = ((kᵢ, θᵢ, xᵢ)->logpdf(Gamma(kᵢ,θᵢ), xᵢ)).(k, θ, x)

#test it
function testgrad(N::Int)
  τ = abs.(randn(Float32, N)) .+ .001f0
  absμ = abs.(randn(Float32, N)) .+ .001f0

  θ = 1.0f0 ./ τ
  k = absμ .* τ
  v = rand(Float32, N)

  cuθ = θ |> gpu
  cuk = k |> gpu
  cuv = v |> gpu

  @info "testing function values (N=$N) cpu: $(sum(vecgammalogpdf(k, θ, v))) " *
    "gpu: $(sum(vecgammalogpdf(cuk, cuθ, cuv)))"
  @info "Additional check of value: $(sum(slowvecgammalogpdf(k, θ, v)))"

  @info "cpu version gradient N=$N"
  if N ≤ 10
    gs = gradient((k,θ)->sum(vecgammalogpdf(k, θ, v)),k,θ)
    @info "gs: $gs"
  else
    gs = @btime gradient((k,θ)->sum(vecgammalogpdf(k, θ, $v)),$k,$θ)
  end

  @info "gpu version gradient N=$N"
  if N ≤ 10
    cugs = gradient((cuk,cuθ)->sum(vecgammalogpdf(cuk, cuθ, cuv)),cuk,cuθ)
    @info "cugs: $cugs"
  else
    cugs = @btime gradient((cuk,cuθ)->sum(vecgammalogpdf(cuk, cuθ, $cuv)),$cuk,$cuθ)
  end
end

testgrad(5)
testgrad(10^6)

Output:

[ Info: testing function values (N=5) cpu: -9.063408 gpu: -9.063408
[ Info: Additional check of value: -9.063408
[ Info: cpu version gradient N=5
[ Info: gs: (Float32[12.996181, 4.6498237, 6.403746, -1.7591481, 1.6755874], Float32[1.1360388, -0.0055119246, -0.0006913543, -7.3534327, 0.007893473])
[ Info: gpu version gradient N=5
[ Info: cugs: (Float32[12.9961815, 4.649824, 6.403746, -1.7591481, 1.6755874], Float32[1.1360388, -0.0055119228, -0.0006913617, -7.3534327, 0.007893473])
[ Info: testing function values (N=1000000) cpu: -1.3431859e6 gpu: -1.3431859e6
[ Info: Additional check of value: -1.3431859e6
[ Info: cpu version gradient N=1000000
  255.517 ms (9000248 allocations: 274.67 MiB)
[ Info: gpu version gradient N=1000000
  1.136 ms (1246 allocations: 41.80 KiB)

You can see a nice alternate solution posted here by @xukai92 in reference to https://github.com/FluxML/Flux.jl/issues/383, which ran a bit faster but overtly digs into ForwardDiff and DiffRules.

4 Likes