CUDA gradient of gamma PDF

I’m having trouble computing an AD gradient of the gamma PDF using a gpu array. The goal is to use this as part of a Flux project. I am happy to open an issue, but I am new enough to gpu coding that I want to make sure I am not doing something silly.

Here is an MWE:

using Revise, Flux, CuArrays, BenchmarkTools, CUDAnative, CUDAapi
CuArrays.allowscalar(false)

function vecgammalogpdf(k::TV, θ::TV, x::TV)::TV where TV<:CuArrays.CuVector

  Γs = -(CUDAnative.lgamma).(k) .- k .* log.(θ) .+ (k .- 1f0) .* log.(x) .- x ./ θ

  return Γs
end

function testvecgammalogpdf(N)
  τ = abs.(randn(Float32, N)) .+ .001f0
  absμ = abs.(randn(Float32, N)) .+ .001f0

  θ = 1.0f0 ./ τ
  k = absμ .* τ
  v = rand(Float32, N)

  cuθ = θ |> gpu
  cuk = k |> gpu
  cuv = v |> gpu

  @info "Make sure the function works:"
  println(sum(vecgammalogpdf(cuk, cuθ, cuv)))

  @info "compute the gradient"
  gs = gradient(()->sum(vecgammalogpdf(cuk, cuθ, cuv)), Flux.params(cuk, cuθ, cuv))
  println(gs)
end

N= 10^5
sleep(0.5)
testvecgammalogpdf(N)

When I run this, I get:

[ Info: Make sure the function works:
-134438.14
[ Info: compute the gradient
ERROR: LoadError: scalar getindex is disallowed
Stacktrace:
 [1] error(::String) at .\error.jl:33
 [2] assertscalar(::String) at C:\Users\Clinton\.julia\packages\GPUArrays\QDGmr\src\host\indexing.jl:41
 [3] getindex(::CuArray{Float32,1,Nothing}, ::Int64) at C:\Users\Clinton\.julia\packages\GPUArrays\QDGmr\src\host\indexing.jl:86
 [4] _broadcast_getindex at .\broadcast.jl:597 [inlined]
 [5] _getindex at .\broadcast.jl:628 [inlined]
 [6] _broadcast_getindex at .\broadcast.jl:603 [inlined]
 [7] getindex at .\broadcast.jl:564 [inlined]
 [8] copy at .\broadcast.jl:854 [inlined]
 [9] materialize at .\broadcast.jl:820 [inlined]
 [10] broadcast_forward at C:\Users\Clinton\.julia\packages\Zygote\wkc82\src\lib\broadcast.jl:173 [inlined]
 [11] adjoint at C:\Users\Clinton\.julia\packages\Zygote\wkc82\src\lib\broadcast.jl:189 [inlined]
 [12] _pullback at C:\Users\Clinton\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
 [13] adjoint at C:\Users\Clinton\.julia\packages\Zygote\wkc82\src\lib\lib.jl:167 [inlined]
 [14] _pullback at C:\Users\Clinton\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
 [15] broadcasted at .\broadcast.jl:1232 [inlined]
 [16] vecgammalogpdf at C:\Users\Clinton\Dropbox\Projects\Capacity\mwe.jl:693 [inlined]
 [17] _pullback(::Zygote.Context, ::typeof(vecgammalogpdf), ::CuArray{Float32,1,Nothing}, ::CuArray{Float32,1,Nothing}, ::CuArray{Float32,1,Nothing}) at C:\Users\Clinton\.julia\packages\Zygote\wkc82\src\compiler\interface2.jl:0
 [18] #62 at C:\Users\Clinton\Dropbox\Projects\Capacity\mwe.jl:716 [inlined]
 [19] _pullback(::Zygote.Context, ::var"#62#63"{CuArray{Float32,1,Nothing},CuArray{Float32,1,Nothing},CuArray{Float32,1,Nothing}}) at C:\Users\Clinton\.julia\packages\Zygote\wkc82\src\compiler\interface2.jl:0
 [20] pullback(::Function, ::Zygote.Params) at C:\Users\Clinton\.julia\packages\Zygote\wkc82\src\compiler\interface.jl:103
 [21] gradient(::Function, ::Zygote.Params) at C:\Users\Clinton\.julia\packages\Zygote\wkc82\src\compiler\interface.jl:44
 [22] testvecgammalogpdf(::Int64; tol::Float64) at C:\Users\Clinton\Dropbox\Projects\Capacity\mwe.jl:716
 [23] testvecgammalogpdf(::Int64) at C:\Users\Clinton\Dropbox\Projects\Capacity\mwe.jl:701
 [24] top-level scope at C:\Users\Clinton\Dropbox\Projects\Capacity\mwe.jl:722
in expression starting at C:\Users\Clinton\Dropbox\Projects\Capacity\mwe.jl:722

Without the call to CUDAnative.lgamma everything works swimmingly, so its just that call that seems to mess things up. Also, even if I allowscalar, I get MethodError: no method matching lgamma(::ForwardDiff.Dual{Nothing,Float32,1}).

Any help getting this to work would be appreciated!

1 Like

I am guessing that the problem is that it doesn’t know how to take the adjoint of the CUDAnative.lgamma function?

2 Likes

Maybe? I’m only familiar with the term in this context from a quick scan of the Zygote documentation- is this a scenario where a custom adjoint would be needed?

Yes I think so. Basically “adjoint” corresponds (somewhat) with the derivative. I don’t think Zygote knows how to differentiate CUDAnative.lgamma.

1 Like

Quick epilogue if anyone comes across this. Writing this adjoint turned into nice learning experience about AD and CUDA. As shown below, the gpu version runs about 200x faster than the vectorized cpu version.


using Revise, CUDAapi, Flux, BenchmarkTools, CUDAnative, Distributions, Zygote, Flux, DiffRules
using SpecialFunctions
import CuArrays
CuArrays.allowscalar(false)

#GPU vectorized log gamma pdf
function vecgammalogpdf(k::TV, θ::TV, x::TV)::TV where TV<:CuArrays.CuVector
  Γs_part = cudagammalogpdf_part.(k,θ,x)

  return Γs_part .- cudaveclgamma(k)
end

#see https://math.stackexchange.com/questions/1441753/approximating-the-digamma-function
#and https://en.wikipedia.org/wiki/Digamma_function
function cudaapproxdigamma(x::Real)
  adj = 0.0f0
  ψ = x
  #the polynomial at the end is only accurate for large x, but fortunately we can transform the problem as follows
  for i ∈ 1:3
    adj -= 1f0/ψ
    ψ += 1f0
  end

  ψ = (CUDAnative.log(ψ) - 1f0 / 2f0 / ψ - 1f0 / 12f0 / (ψ * ψ) +
    1f0/120f0 * CUDAnative.pow(ψ, -4) - 1f0/252f0 * CUDAnative.pow(ψ, -6) +
    1f0/240f0 * CUDAnative.pow(ψ, -8) - 5f0/660f0 * CUDAnative.pow(ψ, -10) +
    691f0/32760f0 * CUDAnative.pow(ψ, -12) - 1f0/12f0 * CUDAnative.pow(ψ, -14)) + adj

  return ψ
end

#does everything w.r.t computing the gamma pdf except compute the log gamma function
#adapted from StatsFuns
#https://github.com/JuliaStats/StatsFuns.jl/blob/master/src/distrs/gamma.jl
cudagammalogpdf_part(k,θ,x) = - k * CUDAnative.log(θ) + (k - 1f0) * CUDAnative.log(x) - x / θ

#vectorized log gamma function
cudaveclgamma(k) = (CUDAnative.lgamma).(k)

#define the adjoint
Zygote.@adjoint cudaveclgamma(k) = cudaveclgamma(k), y->(y .* cudaapproxdigamma.(k), )
Zygote.refresh()

#lightly optimized cpu version for testing
function vecgammalogpdf(k::TV, θ::TV, x::TV)::TV where TV<:AbstractVector

  Γs::TV = -(SpecialFunctions.loggamma).(k) .- k .* log.(θ) .+ (k .- 1f0) .* log.(x) .- x ./ θ

  return Γs
end

slowvecgammalogpdf(k, θ, x) = ((kᵢ, θᵢ, xᵢ)->logpdf(Gamma(kᵢ,θᵢ), xᵢ)).(k, θ, x)

#test it
function testgrad(N::Int)
  τ = abs.(randn(Float32, N)) .+ .001f0
  absμ = abs.(randn(Float32, N)) .+ .001f0

  θ = 1.0f0 ./ τ
  k = absμ .* τ
  v = rand(Float32, N)

  cuθ = θ |> gpu
  cuk = k |> gpu
  cuv = v |> gpu

  @info "testing function values (N=$N) cpu: $(sum(vecgammalogpdf(k, θ, v))) " *
    "gpu: $(sum(vecgammalogpdf(cuk, cuθ, cuv)))"
  @info "Additional check of value: $(sum(slowvecgammalogpdf(k, θ, v)))"

  @info "cpu version gradient N=$N"
  if N ≤ 10
    gs = gradient((k,θ)->sum(vecgammalogpdf(k, θ, v)),k,θ)
    @info "gs: $gs"
  else
    gs = @btime gradient((k,θ)->sum(vecgammalogpdf(k, θ, $v)),$k,$θ)
  end

  @info "gpu version gradient N=$N"
  if N ≤ 10
    cugs = gradient((cuk,cuθ)->sum(vecgammalogpdf(cuk, cuθ, cuv)),cuk,cuθ)
    @info "cugs: $cugs"
  else
    cugs = @btime gradient((cuk,cuθ)->sum(vecgammalogpdf(cuk, cuθ, $cuv)),$cuk,$cuθ)
  end
end

testgrad(5)
testgrad(10^6)

Output:

[ Info: testing function values (N=5) cpu: -9.063408 gpu: -9.063408
[ Info: Additional check of value: -9.063408
[ Info: cpu version gradient N=5
[ Info: gs: (Float32[12.996181, 4.6498237, 6.403746, -1.7591481, 1.6755874], Float32[1.1360388, -0.0055119246, -0.0006913543, -7.3534327, 0.007893473])
[ Info: gpu version gradient N=5
[ Info: cugs: (Float32[12.9961815, 4.649824, 6.403746, -1.7591481, 1.6755874], Float32[1.1360388, -0.0055119228, -0.0006913617, -7.3534327, 0.007893473])
[ Info: testing function values (N=1000000) cpu: -1.3431859e6 gpu: -1.3431859e6
[ Info: Additional check of value: -1.3431859e6
[ Info: cpu version gradient N=1000000
  255.517 ms (9000248 allocations: 274.67 MiB)
[ Info: gpu version gradient N=1000000
  1.136 ms (1246 allocations: 41.80 KiB)

You can see a nice alternate solution posted here by @xukai92 in reference to https://github.com/FluxML/Flux.jl/issues/383, which ran a bit faster but overtly digs into ForwardDiff and DiffRules.

4 Likes