Gradient through categorical distribution slow and has many allocations

I am trying have performant gradients computations for a function the defines a conditional distribution over a discrete space. More specifically, I want the gradient of the following function:

import Flux
import Zygote
using Distributions
using BenchmarkTools

function logprob(W, x, a)
    probs = Flux.softmax(W' * x))
    d = Categorical(probs)
    return logpdf(d,a)

x = collect(Float64, 1:100)
w = zeros(Float64, (100,2))
@btime (w->Zygote.gradient(logprob(w,x,1),w)
# running stats: 150.143 μs (608 allocations: 29.55 KiB)

By removing the categorical distribution I tend to see about two orders of magnitude decrease in computation time.

function logprob2(W, x, a)
    probs = Flux.softmax(W' * x))
    return log(probs[a])
@btime (w->Zygote.gradient(logprob2(w,x,1),w)
# running stats: 1.031 μs (26 allocations: 4.53 KiB)

My questions are: what is causing the slow down and is it possible to speed up the process while still leveraging Distributions.jl functions?