I am trying have performant gradients computations for a function the defines a conditional distribution over a discrete space. More specifically, I want the gradient of the following function:
import Flux
import Zygote
using Distributions
using BenchmarkTools
function logprob(W, x, a)
probs = Flux.softmax(W' * x))
d = Categorical(probs)
return logpdf(d,a)
end
x = collect(Float64, 1:100)
w = zeros(Float64, (100,2))
@btime (w->Zygote.gradient(logprob(w,x,1),w)
# running stats: 150.143 μs (608 allocations: 29.55 KiB)
By removing the categorical distribution I tend to see about two orders of magnitude decrease in computation time.
function logprob2(W, x, a)
probs = Flux.softmax(W' * x))
return log(probs[a])
end
@btime (w->Zygote.gradient(logprob2(w,x,1),w)
# running stats: 1.031 μs (26 allocations: 4.53 KiB)
My questions are: what is causing the slow down and is it possible to speed up the process while still leveraging Distributions.jl functions?