Any faster way of computing small gradients?

I’ve tested multiple packages to compute the gradient of a log-likelihood (huge sum of the same function applied to different data points):

const AV = AbstractVector{T} where T

normal_pdf(x::Real, mean::Real, var::Real) =
    exp(-(x - mean)^2 / (2var)) / sqrt(2π * var)

function mixture_loglikelihood(params::AV{<:Real}, data::AV{<:Real})::Real
    K = length(params) ÷ 3
    weights, means, stds = params[1:K], params[K+1:2K], params[2K+1:end]

    sum(
        sum(
            weight * normal_pdf(x, mean, std^2)
            for (weight, mean, std) in zip(weights, means, stds)
        ) |> log
        for x in data
    )
end

Note: I know I shouldn’t be using traditional gradient-based optimization for estimating mixture models - this is just an example.

Benchmark

In this benchmark, the parameter is a vector of 3N_COMPONENTS == 12 values, so the gradient is really small (12-dimensional).

I benchmarked ForwardDiff, ReverseDiff, Zygote and Enzyme like this:

import Pkg; Pkg.status()

import Random
import ForwardDiff, ReverseDiff, Zygote
using BenchmarkTools

const AV = AbstractVector{T} where T

normal_pdf(x::Real, mean::Real, var::Real) =
    exp(-(x - mean)^2 / (2var)) / sqrt(2π * var)

function mixture_loglikelihood(params::AV{<:Real}, data::AV{<:Real})::Real
    K = length(params) ÷ 3
    weights, means, stds = params[1:K], params[K+1:2K], params[2K+1:end]

    sum(
        sum(
            weight * normal_pdf(x, mean, std^2)
            for (weight, mean, std) in zip(weights, means, stds)
        ) |> log
        for x in data
    )
end

SEED = 42
N_SAMPLES = 500
N_COMPONENTS = 4

rnd = Random.MersenneTwister(SEED)
data = randn(rnd, N_SAMPLES)
params0 = [rand(rnd, N_COMPONENTS); randn(rnd, N_COMPONENTS); 2rand(rnd, N_COMPONENTS)]
objective = params -> mixture_loglikelihood(params, data)

@info "Settings" SEED N_SAMPLES N_COMPONENTS length(params0)

@info "Computing gradient w/ ForwardDiff"
let
    grad_storage = similar(params0)
    cfg_grad = ForwardDiff.GradientConfig(objective, params0, ForwardDiff.Chunk{length(params0)}())

    # 1. Compile
    ForwardDiff.gradient!(grad_storage, objective, params0, cfg_grad)
    # 2. Benchmark
    trial = @benchmark ForwardDiff.gradient!($grad_storage, $objective, $params0, $cfg_grad)
    show(stdout, MIME("text/plain"), trial)
    println()
    @show grad_storage
end

@info "Computing gradient w/ ReverseDiff"
let
    grad_storage = similar(params0)
    objective_tape = ReverseDiff.GradientTape(objective, params0) |> ReverseDiff.compile

    # 1. Compile
    ReverseDiff.gradient!(grad_storage, objective_tape, params0)
    # 2. Benchmark
    trial = @benchmark ReverseDiff.gradient!($grad_storage, $objective_tape, $params0)
    show(stdout, MIME("text/plain"), trial)
    println()
    @show grad_storage
end

@info "Computing gradient w/ Zygote reverse"
let
    # 1. Compile
    grad_storage = Zygote.gradient(objective, params0)
    # 2. Benchmark
    trial = @benchmark Zygote.gradient($objective, $params0)
    show(stdout, MIME("text/plain"), trial)
    println()
    @show grad_storage
end

Note: Enzyme didn’t work at all, so it’s not included here.

Benchmark results

  • ForwardDiff
    • min: 226.686 μs
    • mean: 243.161 μs ± 355.714 μs
    • max: 13.218 ms
  • ReverseDiff
    • min: 1.719 ms
    • mean: 1.770 ms ± 82.430 μs (7 times slower than ForwardDiff’s mean)
    • max: 2.440 ms
  • Zygote
    • min: 351.711 ms (1446 times slower than ForwardDiff!)
    • mean: 371.142 ms ± 15.917 ms
    • max: 405.815 ms

Best mean time: ForwardDiff (243.161 μs)

Benchmark results with matrix operations

Since Zygote (and presumably ReverseDiff) are optimized to work with matrices and vectors, I converted the sum calls in mixture_loglikelihood to matrix-vector operations involving broadcasting:

function mixture_loglikelihood(params::AV{<:Real}, data::AV{<:Real})::Real
    K = length(params) ÷ 3
    weights, means, stds = params[1:K], params[K+1:2K], params[2K+1:end]

    mat = normal_pdf.(data, means', stds' .^2) # (N, K)
    sum(
        mat .* weights', dims=2
    ) .|> log |> sum
end

Now Zygote is the fastest.

  • ForwardDiff
    • min: 249.830 μs
    • mean: 508.371 μs ± 861.326 μs
    • max: 13.453 ms
  • ReverseDiff
    • min: 1.636 ms
    • mean: 1.675 ms ± 60.849 μs (3 times slower than ForwardDiff)
    • max: 2.302 ms
  • Zygote
    • min: 170.113 μs
    • mean: 274.045 μs ± 782.079 μs (1.8 times faster than ForwardDiff)
    • max: 18.369 ms

Best mean time: Zygote (274.045 μs)

I also pre-generated gradient-computing functions with Symbolics.jl, but that turned out to be about 1.5 times slower than ForwardDiff: Benchmark of Julia autodiff · GitHub (also full reproducible code is there).


What else can I do to speed up computation of small (much less than 30-dimensional) gradients? I thought Symbolics would completely destroy the competitors, but it came third, so maybe I’m doing something wrong?

Direct Symbolics will unroll loops into a scalar form which can reduce the amount of SIMD, along with the fact that ForwardDiff.jl has some manual SIMD. For small dense forward mode, it’s really hard to beat ForwardDiff. Symbolics will only win when you start getting to sparse cases, where sparse automatic differentiation can require some redundant computations depending on the sparsity pattern but sparse AD can hit the optimal growth scaling.

To beat it, you’d need to get Enzyme working. To get Enzyme working, you’d need to do some program finagling. Specifically, I think the issue is that it’s allocating the view, and so you’d have to avoid that or wait for better GC support from Enzyme. Enzyme is continuing to improve its support of the Julia runtime, but for now the most clear way to use it is to just write code that completely avoids the runtime.

Wow, I was assuming that, since Symbolics’ generated functions already “know” how to compute the gradient, they should always automatically be faster… I mean, ForwardDiff has to run the function with dual numbers, so it’s figuring out the derivative while it’s executing the function’s operations, but Symbolics produces functions that calculate the derivative directly.

Well, benchmarks show that, indeed, ForwardDiff is fast anyway. However…


…I wrote the same benchmark using JAX, and it blew all Julia code away (see updated gist):

Setup:

  • 500 data samples
  • 4 mixture components: 4 weights, 4 means, 4 standard deviations => 12 parameters total

Mean time (10000 samples with 1 evaluation):

  1. JAX (Python): 119.072 μs ± 16.123 μs (benchmarked the wrong thing, ooops; this one is correct) 61.308 μs ± 5.368 μs
  2. Zygote.jl: 271.229 μs ± 714.442 μs
  3. ForwardDiff.jl: 515.443 μs ± 665.059 μs
  4. Symbolics.jl (precomputed gradient, mutating function): 809.821 μs ± 273.313 μs
  5. ReverseDiff.jl: 1.633 ms ± 74.832 μs

It seems like JAX is at least 2 4.4 (!) times faster than anything Julia and is more consistent (smaller standard deviation).


JAX wins. Fatality.

Isn’t Jax using multithreading by default? If so, compare with that in Julia as well. What data type is Jax using? Could they be defaulting to Float32?

See, e.g thus thread for similar benchmarks

Yes, it’s defaulting to Float32, but I specifically used jax.config.update("jax_enable_x64", True) to enable Float64. The gradients are Float64, indeed:

In [40]: jax_code.the_grad(jax_code.params0)
Out[40]: 
DeviceArray([289.73084956, 199.27559525, 236.68945778, 292.06123402,
              -9.42979939,  26.72229565,  -1.91803555,  37.9874909 ,
             -24.09562015, -13.93568733, -38.00044666,  12.87712892],            dtype=float64)

I ran Julia code with julia-1.8 --threads=4 code.jl, but didn’t see any difference whatsoever: CPU usage didn’t spike and the timings remained the same as in single-threaded mode.

I guess I should manually enable threading somewhere?

When I benchmark JAX:

In [48]: %timeit jax_code.the_grad(jax_code.params0)
60.3 µs ± 464 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

…I see that only 1 CPU core is being used. I’m sampling CPU usage at 0.5s frequency, the benchmark takes several seconds, so I’m probably not missing a sudden usage spike.

In Julia, sum is not threaded by default, you’d need to use something like ThreadsX.sum or write a threaded loop.

2 Likes

I used both just for fun:

function mixture_loglikelihood(params::AV{<:Real}, data::AV{<:Real})::Real
    K = length(params) ÷ 3
    weights, means, stds = @views params[1:K], params[K+1:2K], params[2K+1:end]

    mat = normal_pdf.(data, means', stds' .^2) # (N, K)
    mat .= mat .* weights'

    mat_summed = zeros(eltype(mat), size(mat, 1))
    Threads.@threads for n in eachindex(mat_summed)
        for k in axes(mat, 2)
            mat_summed[n] += mat[n, k]
        end
    end

    ThreadsX.sum(
        log(one_sum) for one_sum in mat_summed
    )
    
    # ThreadsX.sum doesn't support `dims`
    # ThreadsX.sum(mat .* weights', dims=2) .|> log |> ThreadsX.sum
end

Now ForwardDiff takes 481.755 μs, which isn’t that much faster, but ReverseDiff errors out:

[ Info: Computing gradient w/ ReverseDiff
ERROR: LoadError: UndefRefError: access to undefined reference
Stacktrace:
  [1] getindex
    @ ./array.jl:924 [inlined]
  [2] iterate
    @ ./array.jl:898 [inlined]
  [3] iterate
    @ ./generator.jl:44 [inlined]
  [4] collect_to!(dest::Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{}}}, itr::Base.Generator{Vector{ReverseDiff.AbstractInstruction}, ReverseDiff.var"#665#667"}, offs::Int64, st::Int64)
    @ Base ./array.jl:845
  [5] collect_to_with_first!
    @ ./array.jl:823 [inlined]
  [6] collect(itr::Base.Generator{Vector{ReverseDiff.AbstractInstruction}, ReverseDiff.var"#665#667"})
    @ Base ./array.jl:797
  [7] CompiledTape
    @ ~/.julia/packages/ReverseDiff/5MMPp/src/api/tape.jl:102 [inlined]
  [8] compile
    @ ~/.julia/packages/ReverseDiff/5MMPp/src/api/tape.jl:147 [inlined]
  [9] |>(x::ReverseDiff.GradientTape{var"#38#39", ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, f::typeof(ReverseDiff.compile))
    @ Base ./operators.jl:911
 [10] top-level scope
    @ ~/test/autodiff_bench/code.jl:148
 [11] include(fname::String)
    @ Base.MainInclude ./client.jl:476
 [12] top-level scope
    @ REPL[1]:1
in expression starting at /Users/forcebru/test/autodiff_bench/code.jl:146

Looks like gradient tapes and threads don’t mix.

Zygote can’t do it either:

[ Info: Computing gradient w/ Zygote reverse
ERROR: LoadError: Compiling Tuple{typeof(Base.Threads.threading_run), var"#506#threadsfor_fun#44"{var"#506#threadsfor_fun#42#45"{Vector{Float64}, Matrix{Float64}, Base.OneTo{Int64}}}, Bool}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] instrument(ir::IRTools.Inner.IR)
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/reverse.jl:121
  [3] #Primal#23
    @ ~/.julia/packages/Zygote/D7j8v/src/compiler/reverse.jl:205 [inlined]
  [4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/reverse.jl:322
  [5] _generate_pullback_via_decomposition(T::Type)
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/emit.jl:101
  [6] #s2812#1078
    @ ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:28 [inlined]
  [7] var"#s2812#1078"(::Any, ctx::Any, f::Any, args::Any)
    @ Zygote ./none:0
  [8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:582
  [9] macro expansion
    @ ./threadingconstructs.jl:89 [inlined]
 [10] _pullback
    @ ~/test/autodiff_bench/code.jl:37 [inlined]
 [11] _pullback(::Zygote.Context, ::typeof(mixture_loglikelihood), ::Vector{Float64}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0
 [12] _pullback
    @ ~/test/autodiff_bench/code.jl:108 [inlined]
 [13] _pullback(ctx::Zygote.Context, f::var"#47#48", args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0
 [14] _pullback(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:34
 [15] pullback(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:40
 [16] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:75
 [17] top-level scope
    @ ~/test/autodiff_bench/code.jl:162
 [18] include(fname::String)
    @ Base.MainInclude ./client.jl:476
 [19] top-level scope
    @ REPL[1]:1
in expression starting at /Users/forcebru/test/autodiff_bench/code.jl:160

Also, Zygote doesn’t support mutation like mat_summed[n] += mat[n, k], but how else can I sum along the 2nd dimension if ThreadsX doesn’t support sum(..., dims=2)?

So it doesn’t look like adding threading does anything, unfortunately.

Tullio ought to be good at this, and will be multi-threaded by default.

(Note that I named your function with broadcasting mixture_loglikelihood_array, and avoid non-const global data in objective = params -> mixture_loglikelihood(params, data).)

julia> g1 = @btime Zygote.gradient(mixture_loglikelihood, $params0, $data);  # original, but no globals
  min 425.496 ms, mean 470.545 ms (1273985 allocations, 55.66 MiB)

julia> g2 = @btime Zygote.gradient(mixture_loglikelihood_array, $params0, $data);  # broadcasting
  min 78.583 μs, mean 707.912 μs (169 allocations, 225.77 KiB)

julia> using Tullio, ForwardDiff

julia> function mixture_loglikelihood_tullio(params::AV{<:Real}, data::AV{<:Real})
           K = length(params) ÷ 3
           weights, means, stds = params[1:K], params[K+1:2K], params[2K+1:end]

           @tullio tmp[r] := normal_pdf(data[r], means[c], stds[c]^2) * weights[c] grad=Dual    
           sum(log, tmp)
       end;

julia> g3 = @btime Zygote.gradient(mixture_loglikelihood_tullio, $params0, $data);
  min 87.166 μs, mean 131.886 μs (31 allocations, 15.83 KiB)

julia> g1[2] ≈ g2[2] ≈ g3[2]
true

# ForwardDiff instead:

julia> @btime ForwardDiff.gradient(p -> mixture_loglikelihood(p, $data), $params0);
  min 162.584 μs, mean 396.510 μs (1514 allocations, 83.59 KiB)

julia> @btime ForwardDiff.gradient(p -> mixture_loglikelihood_array(p, $data), $params0);
  min 118.833 μs, mean 1.543 ms (21 allocations, 513.39 KiB)

julia> @btime ForwardDiff.gradient(p -> mixture_loglikelihood_tullio(p, $data), $params0);
  min 122.875 μs, mean 278.139 μs (11 allocations, 56.11 KiB)

# Time just the forward pass, too:

julia> @btime mixture_loglikelihood($params0, $data);
  min 76.750 μs, mean 187.267 μs (1508 allocations, 31.75 KiB)

julia> @btime mixture_loglikelihood_array($params0, $data);
  min 21.708 μs, mean 105.647 μs (11 allocations, 39.98 KiB)

julia> @btime mixture_loglikelihood_tullio($params0, $data); 
  min 27.292 μs, mean 28.266 μs (4 allocations, 4.34 KiB)

It’s a little faster if you add LoopVectorization, but how much may depend a lot on your computer. (These times are on an M1 mac.)

julia> using LoopVectorization

julia> function mixture_loglikelihood_tullio_lv(params::AV{<:Real}, data::AV{<:Real})
           K = length(params) ÷ 3
           weights, means, stds = params[1:K], params[K+1:2K], params[2K+1:end]
           # NB we are running this macro after loading LV so that it sees it
           @tullio tmp[r] := normal_pdf(data[r], means[c], stds[c]^2) * weights[c] grad=Dual    
           sum(log, tmp)
       end;

julia> @btime mixture_loglikelihood_tullio_lv($params0, $data);  # just the forward pass, twice as quick
  min 11.625 μs, mean 12.167 μs (4 allocations, 4.34 KiB)

julia> g4 = @btime Zygote.gradient(mixture_loglikelihood_tullio_lv, $params0, $data);
  min 67.916 μs, mean 114.190 μs (31 allocations, 15.83 KiB)

julia> @btime ForwardDiff.gradient(p -> mixture_loglikelihood_tullio_lv(p, $data), $params0);
  min 126.458 μs, mean 298.009 μs (11 allocations, 56.11 KiB)
2 Likes

Indeed, I’m now getting much better performance:

Autodiff basic matmul Tullio + LoopVectorization
JAX 61.308 μs ± 5.368 μs N/A
Zygote 271.229 μs ± 714.442 μs 149.678 μs ± 258.898 μs
ForwardDiff 515.443 μs ± 665.059 μs 278.379 μs ± 334.067 μs

Code:

normal_pdf(x::Real, mean::Real, var::Real) =
    exp(-(x - mean)^2 / (2var)) / sqrt(2π * var)

function mixture_loglikelihood(params::AV{<:Real}, data::AV{<:Real})::Real
    K = length(params) ÷ 3
    weights, means, stds = @views params[1:K], params[K+1:2K], params[2K+1:end]

    Tullio.@tullio tmp[n] := weights[k] * normal_pdf(data[n], means[k], stds[k]^2) grad=Dual

    sum(log, tmp)
end

…even though Tullio still uses one CPU code (I think I talked about this elsewhere, and it turned out I needed to tune some settings or something like that).

Still a far cry from JAX’s 61.308 μs, though. I must’ve benchmarked JAX incorrectly: maybe it’s caching something, so it doesn’t actually recompute the gradient every time??? It can’t be that fast! I don’t even know any JAX: I just read some docs, slapped on a bunch of @jax.jit - and suddenly I’m beating Julia code that’s equipped with a whole bunch of packages and that I’ve spent the entire day optimizing.

1 Like

I get another factor 2 or 3 by writing normal_pdf out where @tullio can see it & hence derive a symbolic gradient. (Above, grad=Dual uses ForwardDiff within the loop to compute gradients.)

julia> function mixture_loglikelihood_tullio_2(params::AV{<:Real}, data::AV{<:Real})
           K = length(params) ÷ 3
           weights, means, stds = params[1:K], params[K+1:2K], params[2K+1:end]  # could use eachcol but not the bottleneck
           @tullio tmp[r] :=  exp(-(data[r] - means[c])^2 / (2*stds[c]^2)) / (sqrt(2π) * stds[c]) * weights[c]  # allowing symbolic grad of normal_pdf 
           sum(log, tmp)
       end;

julia> g5 = @btime Zygote.gradient(mixture_loglikelihood_tullio_2, $params0, $data);
  min 26.500 μs, mean 68.806 μs (31 allocations, 15.83 KiB)

julia> g5[2] ≈ g1[2]
true

Tullio will use Threads.nthreads() threads if the arrays are big enough. You do need to start Julia with multiple threads, e.g. julia -t4.

5 Likes

jax.jit is doing a lot of heavy lifting here, as evidenced by a multi-order of magnitude decrease after removing it. For the interested, here’s a gist with all the dumped compiler output from JAX and XLA. The LLVM bitcode (pre and post-optimization) will likely be of most interest for this crowd, but the HLO is also nice because it’s higher level.

1 Like

It would be great to add the benchmark for other packages too.
For example

So the reason Enzyme doesn’t work on that fully yet isn’t GC related, but that some part of it is type instable and there’s a part of the julia generics calling convention that isn’t yet handled (hence that error, which is at least nicer than a segfault).

Would you consider making a type stable version of the code?

Given this function:

function mixture_loglikelihood(params::AV{<:Real}, data::AV{<:Real})::Real
    K = length(params) ÷ 3
    weights, means, stds = @views params[1:K], params[K+1:2K], params[2K+1:end]

    mat = normal_pdf.(data, means', stds' .^2) # (N, K)
    sum(
        mat .* weights', dims=2
    ) .|> log |> sum
end

…and differentiating like this:

objective = params -> mixture_loglikelihood(params, data)
_, (_, grad_storage) = Yota.grad(objective, params0)

…Yota produces this error:

[ Info: Computing gradient w/ Yota
ERROR: LoadError: MethodError: no method matching length(::Type{Val{2}})
Closest candidates are:
  length(::Union{Base.KeySet, Base.ValueIterator}) at abstractdict.jl:58
  length(::Union{ArrayInterfaceCore.BidiagonalIndex, ArrayInterfaceCore.TridiagonalIndex}) at ~/.julia/packages/ArrayInterfaceCore/7kMjZ/src/ArrayInterfaceCore.jl:594
  length(::Union{LinearAlgebra.Adjoint{T, <:Union{StaticArraysCore.StaticArray{Tuple{var"#s2"}, T, 1} where var"#s2", StaticArraysCore.StaticArray{Tuple{var"#s3", var"#s4"}, T, 2} where {var"#s3", var"#s4"}}}, LinearAlgebra.Diagonal{T, <:StaticArraysCore.StaticArray{Tuple{var"#s13"}, T, 1} where var"#s13"}, LinearAlgebra.Hermitian{T, <:StaticArraysCore.StaticArray{Tuple{var"#s10", var"#s11"}, T, 2} where {var"#s10", var"#s11"}}, LinearAlgebra.LowerTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s18", var"#s19"}, T, 2} where {var"#s18", var"#s19"}}, LinearAlgebra.Symmetric{T, <:StaticArraysCore.StaticArray{Tuple{var"#s7", var"#s8"}, T, 2} where {var"#s7", var"#s8"}}, LinearAlgebra.Transpose{T, <:Union{StaticArraysCore.StaticArray{Tuple{var"#s2"}, T, 1} where var"#s2", StaticArraysCore.StaticArray{Tuple{var"#s3", var"#s4"}, T, 2} where {var"#s3", var"#s4"}}}, LinearAlgebra.UnitLowerTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s24", var"#s25"}, T, 2} where {var"#s24", var"#s25"}}, LinearAlgebra.UnitUpperTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s21", var"#s22"}, T, 2} where {var"#s21", var"#s22"}}, LinearAlgebra.UpperTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s15", var"#s16"}, T, 2} where {var"#s15", var"#s16"}}, StaticArraysCore.StaticArray{Tuple{var"#s25"}, T, 1} where var"#s25", StaticArraysCore.StaticArray{Tuple{var"#s1", var"#s3"}, T, 2} where {var"#s1", var"#s3"}, StaticArraysCore.StaticArray{<:Tuple, T}} where T) at ~/.julia/packages/StaticArrays/8Dz3j/src/abstractarray.jl:1
  ...
Stacktrace:
  [1] unzip(tuples::Tuple{DataType, ChainRules.var"#apply_type_pullback#42"{Tuple{Int64}}})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/rulesets.jl:92
  [2] bcast_rrule(::Yota.YotaRuleConfig, ::typeof(Base.Broadcast.broadcasted), ::typeof(Core.apply_type), ::Type, ::Vararg{Any}; kw::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/cr_api.jl:49
  [3] bcast_rrule(::Yota.YotaRuleConfig, ::typeof(Base.Broadcast.broadcasted), ::typeof(Core.apply_type), ::Type, ::Vararg{Any})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/cr_api.jl:48
  [4] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any}; val::Missing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:194
  [5] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:179
  [6] record_or_recurse!(::Umlaut.Tracer{Yota.BcastGradCtx}, ::Function, ::Vararg{Any})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:85
  [7] trace!(::Umlaut.Tracer{Yota.BcastGradCtx}, ::Core.CodeInfo, ::Umlaut.Variable, ::Vararg{Umlaut.Variable})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:220
  [8] trace(::Function, ::Vector{Float64}, ::Vararg{Any}; ctx::Yota.BcastGradCtx, fargtypes::Tuple{typeof(normal_pdf), Tuple{DataType, DataType, DataType}}, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:346
  [9] make_rrule(::typeof(Base.Broadcast.broadcasted), ::Function, ::Vector{Float64}, ::Vararg{Any})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/cr_api.jl:136
 [10] rrule_via_ad(::Yota.YotaRuleConfig, ::Function, ::Function, ::Vararg{Any})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/cr_api.jl:170
 [11] rrule(::Yota.YotaRuleConfig, ::typeof(Base.Broadcast.broadcasted), ::typeof(normal_pdf), ::Vector{Float64}, ::LinearAlgebra.Adjoint{Float64, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}}, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(Base.literal_pow), Tuple{Base.RefValue{typeof(^)}, LinearAlgebra.Adjoint{Float64, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}}, Base.RefValue{Val{2}}}})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/rulesets.jl:98
 [12] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any}; val::Missing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:194
 [13] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:179
 [14] record_primitive!(::Umlaut.Tape{Yota.GradCtx}, ::Function, ::Vararg{Any})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:49
 [15] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:193
 [16] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Function, ::Vararg{Any})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:220
 [17] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:202
 [18] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Umlaut.Variable, ::Vararg{Umlaut.Variable})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:220
 [19] trace(f::Function, args::Vector{Float64}; ctx::Yota.GradCtx, fargtypes::Nothing, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:346
 [20] #gradtape#90
    @ ~/.julia/packages/Yota/VCIzN/src/grad.jl:243 [inlined]
 [21] grad(f::var"#12#13", args::Vector{Float64}; seed::Int64)
    @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:314
 [22] grad(f::var"#12#13", args::Vector{Float64})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:306
 [23] top-level scope
    @ ~/test/autodiff_bench/code.jl:118
 [24] include(fname::String)
    @ Base.MainInclude ./client.jl:476
 [25] top-level scope
    @ REPL[2]:1
in expression starting at /Users/forcebru/test/autodiff_bench/code.jl:116

When I use the version with Tullio:

function mixture_loglikelihood(params::AV{<:Real}, data::AV{<:Real})::Real
    K = length(params) ÷ 3
    weights, means, stds = @views params[1:K], params[K+1:2K], params[2K+1:end]

    Tullio.@tullio tmp[n] := weights[k] * normal_pdf(data[n], means[k], stds[k]^2) grad=Dual

    sum(log, tmp)
end

…Yota produces this error:

ERROR: LoadError: No deriative rule found for op %78 = convert(%3, %76)::Float64, try defining it using 

	ChainRulesCore.rrule(::typeof(convert), ::DataType, ::Float64) = ...

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
    @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:170
  [3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
    @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:211
  [4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Int64)
    @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:222
  [5] #gradtape#90
    @ ~/.julia/packages/Yota/VCIzN/src/grad.jl:244 [inlined]
  [6] grad(f::var"#21#22", args::Vector{Float64}; seed::Int64)
    @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:314
  [7] grad(f::var"#21#22", args::Vector{Float64})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:306
  [8] top-level scope
    @ ~/test/autodiff_bench/code.jl:114
  [9] include(fname::String)
    @ Base.MainInclude ./client.jl:476
 [10] top-level scope
    @ REPL[2]:1
in expression starting at /Users/forcebru/test/autodiff_bench/code.jl:112

Both errors point to code within Yota, so it seems like it doesn’t work…


Looks like Yota has trouble figuring out broadcasting. Here Zygote works fine:

julia> xs = randn(200);

julia> Zygote.gradient(mu->sum(log, normal_pdf.(xs, mu, 1.0)), 1.0)
(-169.05854117272128,)

But Yota fails:

julia> Yota.grad(mu->sum(log, normal_pdf.(xs, mu, 1.0)), 1.0)
ERROR: MethodError: no method matching length(::Type{Val{2}})
Closest candidates are:
  length(::Union{Base.KeySet, Base.ValueIterator}) at abstractdict.jl:58
  length(::Union{ArrayInterfaceCore.BidiagonalIndex, ArrayInterfaceCore.TridiagonalIndex}) at ~/.julia/packages/ArrayInterfaceCore/7kMjZ/src/ArrayInterfaceCore.jl:594
  length(::Union{LinearAlgebra.Adjoint{T, <:Union{StaticArraysCore.StaticArray{Tuple{var"#s2"}, T, 1} where var"#s2", StaticArraysCore.StaticArray{Tuple{var"#s3", var"#s4"}, T, 2} where {var"#s3", var"#s4"}}}, LinearAlgebra.Diagonal{T, <:StaticArraysCore.StaticArray{Tuple{var"#s13"}, T, 1} where var"#s13"}, LinearAlgebra.Hermitian{T, <:StaticArraysCore.StaticArray{Tuple{var"#s10", var"#s11"}, T, 2} where {var"#s10", var"#s11"}}, LinearAlgebra.LowerTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s18", var"#s19"}, T, 2} where {var"#s18", var"#s19"}}, LinearAlgebra.Symmetric{T, <:StaticArraysCore.StaticArray{Tuple{var"#s7", var"#s8"}, T, 2} where {var"#s7", var"#s8"}}, LinearAlgebra.Transpose{T, <:Union{StaticArraysCore.StaticArray{Tuple{var"#s2"}, T, 1} where var"#s2", StaticArraysCore.StaticArray{Tuple{var"#s3", var"#s4"}, T, 2} where {var"#s3", var"#s4"}}}, LinearAlgebra.UnitLowerTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s24", var"#s25"}, T, 2} where {var"#s24", var"#s25"}}, LinearAlgebra.UnitUpperTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s21", var"#s22"}, T, 2} where {var"#s21", var"#s22"}}, LinearAlgebra.UpperTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s15", var"#s16"}, T, 2} where {var"#s15", var"#s16"}}, StaticArraysCore.StaticArray{Tuple{var"#s25"}, T, 1} where var"#s25", StaticArraysCore.StaticArray{Tuple{var"#s1", var"#s3"}, T, 2} where {var"#s1", var"#s3"}, StaticArraysCore.StaticArray{<:Tuple, T}} where T) at ~/.julia/packages/StaticArrays/8Dz3j/src/abstractarray.jl:1
  ...
Stacktrace:
  [1] unzip(tuples::Tuple{DataType, ChainRules.var"#apply_type_pullback#42"{Tuple{Int64}}})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/rulesets.jl:92
  [2] bcast_rrule(::Yota.YotaRuleConfig, ::typeof(Base.Broadcast.broadcasted), ::typeof(Core.apply_type), ::Type, ::Vararg{Any}; kw::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/cr_api.jl:49
  [3] bcast_rrule(::Yota.YotaRuleConfig, ::typeof(Base.Broadcast.broadcasted), ::typeof(Core.apply_type), ::Type, ::Vararg{Any})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/cr_api.jl:48
  [4] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any}; val::Missing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:194
  [5] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:179
  [6] record_or_recurse!(::Umlaut.Tracer{Yota.BcastGradCtx}, ::Function, ::Vararg{Any})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:85
  [7] trace!(::Umlaut.Tracer{Yota.BcastGradCtx}, ::Core.CodeInfo, ::Umlaut.Variable, ::Vararg{Umlaut.Variable})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:220
  [8] trace(::Function, ::Vector{Float64}, ::Vararg{Any}; ctx::Yota.BcastGradCtx, fargtypes::Tuple{typeof(normal_pdf), Tuple{DataType, DataType, DataType}}, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:346
  [9] make_rrule(::typeof(Base.Broadcast.broadcasted), ::Function, ::Vector{Float64}, ::Vararg{Any})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/cr_api.jl:136
 [10] rrule_via_ad(::Yota.YotaRuleConfig, ::Function, ::Function, ::Vararg{Any})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/cr_api.jl:170
 [11] rrule(::Yota.YotaRuleConfig, ::typeof(Base.Broadcast.broadcasted), ::typeof(normal_pdf), ::Vector{Float64}, ::Float64, ::Float64)
    @ Yota ~/.julia/packages/Yota/VCIzN/src/rulesets.jl:98
 [12] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any}; val::Missing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:194
 [13] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:179
 [14] record_primitive!(::Umlaut.Tape{Yota.GradCtx}, ::Function, ::Vararg{Any})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:49
 [15] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:193
 [16] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Umlaut.Variable, ::Vararg{Umlaut.Variable})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:220
 [17] trace(f::Function, args::Float64; ctx::Yota.GradCtx, fargtypes::Nothing, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:346
 [18] #gradtape#90
    @ ~/.julia/packages/Yota/VCIzN/src/grad.jl:243 [inlined]
 [19] grad(f::var"#85#86", args::Float64; seed::Int64)
    @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:314
 [20] grad(f::var"#85#86", args::Float64)
    @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:306
 [21] top-level scope
    @ REPL[28]:1

As a user, I have no clue what ::Type{Val{2}} even is and where it came from. I don’t think I have it in my code.

Here’s an Enzyme-compatible version (specifically I removed the type instabilities that appeared to come from both sum and your closure over data).

Results on my machine (using the single thread code from the start of the discussion):
Enzyme forward (chunk=len(params0)): 101us
Enzyme forward (chunk=1): 596us
Enzyme reverse: 81us
ForwardDiff: 95us
ReverseDiff: 535us
Zygote reverse: 314,489us (314ms)

import Random
import Enzyme
import ForwardDiff, Zygote, ReverseDiff
const AV = AbstractVector{T} where T

# ===== Set up objective function =====
normal_pdf(x::Real, mean::Real, var::Real) =
    exp(-(x - mean)^2 / (2var)) / sqrt(2π * var)

function mixture_loglikelihoodE(params::AV{<:Real}, data::AV{<:Real})::Real
    K = length(params) ÷ 3
    weights, means, stds = @views params[1:K], params[(K+1):2K], params[(2K+1):end]

    result = 0.
    for x in data
        mid = 0.
        for (weight, mean, std) in zip(weights, means, stds)
            mid += weight * normal_pdf(x, mean, std)
        end
        result += log(mid)
    end
    return result
end

function mixture_loglikelihood(params::AV{<:Real}, data::AV{<:Real})::Real
    K = length(params) ÷ 3
    weights, means, stds = @views params[1:K], params[(K+1):2K], params[(2K+1):end]

    sum(
        sum(
            weight * normal_pdf(x, mean, std)
            for (weight, mean, std) in zip(weights, means, stds)
        ) |> log
        for x in data
    )
end
rnd = Random.MersenneTwister(42)
data = randn(rnd, 500)
params0 = [[0.2, 0.8]; [-1, 1]; [.3, .6]]

# using InteractiveUtils
# @show InteractiveUtils.@code_typed mixture_loglikelihood(params0, data)
using BenchmarkTools
grad_storage0 = similar(params0)
Enzyme.autodiff(Enzyme.Reverse, mixture_loglikelihoodE, Enzyme.Duplicated(params0, grad_storage0), data)
@show grad_storage0


SEED = 42
N_SAMPLES = 500
N_COMPONENTS = 4

rnd = Random.MersenneTwister(SEED)
data = randn(rnd, N_SAMPLES)
params0 = [rand(rnd, N_COMPONENTS); randn(rnd, N_COMPONENTS); 2rand(rnd, N_COMPONENTS)]
objective = params -> mixture_loglikelihood(params, data)

@info "Settings" SEED N_SAMPLES N_COMPONENTS length(params0)

@info "Computing gradient w/ Enzyme forward (chunk=len(params0))"
let
    batchdup=Enzyme.BatchDuplicated(params0, Enzyme.onehot(params0))
    res = Enzyme.autodiff(Enzyme.Forward, mixture_loglikelihoodE, Enzyme.BatchDuplicatedNoNeed, batchdup, data)
    @show res
	trial = @benchmark Enzyme.autodiff(Enzyme.Forward, mixture_loglikelihoodE, Enzyme.BatchDuplicatedNoNeed, $batchdup, $data)
    show(stdout, MIME("text/plain"), trial)
    println()
end

@info "Computing gradient w/ Enzyme forward (chunk=1)"
let
    shadow=Enzyme.onehot(params0)
	len = length(params0)
	trial = @benchmark ntuple($len) do i
        Enzyme.autodiff(Enzyme.Forward, mixture_loglikelihoodE, Enzyme.DuplicatedNoNeed, Enzyme.Duplicated($params0, $shadow[i]), $data)[1]
    end
    show(stdout, MIME("text/plain"), trial)
    println()
	@show ntuple(len) do i
        Enzyme.autodiff(Enzyme.Forward, mixture_loglikelihoodE, Enzyme.DuplicatedNoNeed, Enzyme.Duplicated(params0, shadow[i]), data)[1]
    end
end

@info "Computing gradient w/ Enzyme reverse"
let
    dup=Enzyme.Duplicated(params0, zero(params0))
	trial = @benchmark Enzyme.autodiff(Enzyme.Reverse, $mixture_loglikelihoodE, $dup, $data)
    show(stdout, MIME("text/plain"), trial)
    println()
    # Enzyme reverse +='s the derivative in, so to see the result, we just call it once, with storage zero init
	grad_storage = zero(params0)
	trial = Enzyme.autodiff(Enzyme.Reverse, mixture_loglikelihoodE, Enzyme.Duplicated(params0, grad_storage), data)
	@show grad_storage
end

@info "Computing gradient w/ ForwardDiff"
let
    grad_storage = similar(params0)
    cfg_grad = ForwardDiff.GradientConfig(objective, params0, ForwardDiff.Chunk{length(params0)}())

    # 1. Compile
    ForwardDiff.gradient!(grad_storage, objective, params0, cfg_grad)
    # 2. Benchmark
    trial = @benchmark ForwardDiff.gradient!($grad_storage, $objective, $params0, $cfg_grad)
    show(stdout, MIME("text/plain"), trial)
    println()
    @show grad_storage
end

@info "Computing gradient w/ ReverseDiff"
let
    grad_storage = similar(params0)
    objective_tape = ReverseDiff.GradientTape(objective, params0) |> ReverseDiff.compile

    # 1. Compile
    ReverseDiff.gradient!(grad_storage, objective_tape, params0)
    # 2. Benchmark
    trial = @benchmark ReverseDiff.gradient!($grad_storage, $objective_tape, $params0)
    show(stdout, MIME("text/plain"), trial)
    println()
    @show grad_storage
end

@info "Computing gradient w/ Zygote reverse"
let
    # 1. Compile
    grad_storage = Zygote.gradient(objective, params0)
    # 2. Benchmark
    trial = @benchmark Zygote.gradient($objective, $params0)
    show(stdout, MIME("text/plain"), trial)
    println()
    @show grad_storage
end

The results:


grad_storage0 = [885.130441033103, 403.71738974172405, -30.19822063378844, -259.9966440668968, 188.9477590261449, 81.34970783748037]
┌ Info: Settings
│   SEED = 42
│   N_SAMPLES = 500
│   N_COMPONENTS = 4
└   length(params0) = 12
[ Info: Computing gradient w/ Enzyme forward (chunk=len(params0))
res = ((275.6684047875707, 194.6371294031169, 259.4876528035026, 283.24055693669914, -8.818020642745301, 30.951111171059114, 1.2523661647362552, 33.091964748936775, -9.415544882191226, -3.2309482131545586, -13.904291598097785, -17.077741525198277),)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):   99.462 μs … 136.153 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     100.522 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   100.916 μs ±   1.333 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

      ▃██▅▂▅▇▆▄▁▁▄▂▁          ▁▁                                ▂
  ▂▄▇▇███████████████▇▇▇█▆▇▆█████▇▇▆▇▇▄▄▅▅▂▅▅▆▇█▇▆▇▆▇▅▆▅▄▃▄▅▂▂▃ █
  99.5 μs       Histogram: log(frequency) by time        108 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.
[ Info: Computing gradient w/ Enzyme forward (chunk=1)
BenchmarkTools.Trial: 8380 samples with 1 evaluation.
 Range (min … max):  584.811 μs … 635.361 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     595.691 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   595.065 μs ±   4.866 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

   ▁▁    ▁▅▇▅▂▁  ▁▂▂  ▂▂▃▆█▆▃▂▂▂▂▃▂▂▂▃▄▂▁                       ▂
  ▆██▆▅▃▂███████████████████████████████████▇█▆▇▇▇▆▇▅▅▅▃▄▃▂▂▄▂█ █
  585 μs        Histogram: log(frequency) by time        612 μs <

 Memory estimate: 464 bytes, allocs estimate: 14.
ntuple(len) do i
    #= /mnt/Data/git/Enzyme.jl/loglik.jl:80 =#
    (Enzyme.autodiff(Enzyme.Forward, mixture_loglikelihoodE, Enzyme.DuplicatedNoNeed, Enzyme.Duplicated(params0, shadow[i]), data))[1]
end = (275.6684047875707, 194.6371294031169, 259.4876528035026, 283.24055693669914, -8.8180206427453, 30.951111171059118, 1.2523661647362547, 33.091964748936775, -9.415544882191224, -3.2309482131545577, -13.904291598097782, -17.077741525198256)
[ Info: Computing gradient w/ Enzyme reverse
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  80.132 μs … 112.892 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     80.852 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   81.043 μs ±   1.048 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▁▄▄▄▃▆██▇▅▂▁                                                 ▂
  █████████████▆▄▄▅▇▇▆▄▄▅▃▅▄▅▇▆▇▇██▇▇▇█▇▇▅▅▆▆▅▃▄▅▄▁▅▃▆▄▇▆▇▇▇▇█ █
  80.1 μs       Histogram: log(frequency) by time      86.2 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.
grad_storage = [275.66840478757064, 194.6371294031169, 259.4876528035028, 283.2405569366992, -8.818020642745294, 30.951111171059107, 1.2523661647362523, 33.09196474893673, -9.415544882191213, -3.2309482131545604, -13.904291598097785, -17.07774152519828]
[ Info: Computing gradient w/ ForwardDiff
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):   92.012 μs …   3.378 ms  ┊ GC (min … max): 0.00% … 96.38%
 Time  (median):      94.861 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   105.155 μs ± 169.621 μs  ┊ GC (mean ± σ):  8.46% ±  5.08%

   ▅▇██▇▆▅▂▁▁▂▁▁                                                ▂
  ███████████████▇▆▆▆▇▆▅▇▄▇▆▆▆▆▅▄▄▄▅▄▅▄▃▃▄▃▃▁▁▃▄▅▆▇▇███▇▆▇▆█▇██ █
  92 μs         Histogram: log(frequency) by time        130 μs <

 Memory estimate: 125.69 KiB, allocs estimate: 1505.
grad_storage = [275.6684047875707, 194.6371294031169, 259.4876528035026, 283.24055693669914, -8.8180206427453, 30.9511111710591, 1.2523661647362565, 33.09196474893678, -9.415544882191226, -3.2309482131545586, -13.90429159809778, -17.07774152519827]
[ Info: Computing gradient w/ ReverseDiff
BenchmarkTools.Trial: 9190 samples with 1 evaluation.
 Range (min … max):  518.569 μs … 779.315 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     535.405 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   540.688 μs ±  13.340 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

            ▇▅▅█▇▅▆▇▄▆▅▁                                         
  ▁▁▁▂▂▂▄▅▇▇██████████████▆▅▅▄▃▄▃▂▂▂▂▃▂▄▄▅▆▇████▇██▇▆▅▅▄▄▃▃▂▂▂▂ ▄
  519 μs           Histogram: frequency by time          568 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.
grad_storage = [275.66840478757064, 194.6371294031169, 259.4876528035028, 283.2405569366992, -8.818020642745294, 30.951111171059107, 1.252366164736251, 33.09196474893673, -9.415544882191199, -3.230948213154556, -13.90429159809779, -17.077741525198284]
[ Info: Computing gradient w/ Zygote reverse
BenchmarkTools.Trial: 16 samples with 1 evaluation.
 Range (min … max):  296.436 ms … 349.718 ms  ┊ GC (min … max):  7.88% … 19.52%
 Time  (median):     314.489 ms               ┊ GC (median):    12.04%
 Time  (mean ± σ):   319.931 ms ±  15.012 ms  ┊ GC (mean ± σ):  13.60% ±  3.63%

  ▁        ▁▁▁ ▁    ▁█ ▁            ▁▁  ▁▁       ▁▁           ▁  
  █▁▁▁▁▁▁▁▁███▁█▁▁▁▁██▁█▁▁▁▁▁▁▁▁▁▁▁▁██▁▁██▁▁▁▁▁▁▁██▁▁▁▁▁▁▁▁▁▁▁█ ▁
  296 ms           Histogram: frequency by time          350 ms <

 Memory estimate: 74.39 MiB, allocs estimate: 1481813.
grad_storage = ([275.66840478757064, 194.6371294031169, 259.4876528035028, 283.2405569366992, -8.818020642745294, 30.951111171059107, 1.252366164736251, 33.09196474893673, -9.415544882191213, -3.2309482131545577, -13.904291598097789, -17.077741525198288],)
1 Like

Getting rid of the view (see below), gives a bit more of a speed boost as well (see below):

Enzyme forward (chunk=len(params0)): 76us
Enzyme forward (chunk=1): 548us
Enzyme reverse: 76us

function mixture_loglikelihoodE2(params::AV{<:Real}, data::AV{<:Real})::Real
    K = length(params) ÷ 3

    result = 0.
    for x in data
        mid = 0.
        for i in 1:K
            mid += @inbounds params[i] * normal_pdf(x, params[i+K], params[i+2K])
        end
        result += log(mid)
    end
    return result
end
[ Info: Computing gradient w/ Enzyme forward (chunk=len(params0))
res = ((275.6684047875707, 194.6371294031169, 259.4876528035026, 283.24055693669914, -8.8180206427453, 30.951111171059107, 1.2523661647362556, 33.09196474893678, -9.415544882191226, -3.2309482131545577, -13.904291598097783, -17.07774152519828),)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  75.421 μs … 97.202 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     75.991 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   76.247 μs ±  1.247 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▅██▅▆█▇▄▂▂▂▃▂▂▂▃▂         ▁▁                                ▂
  ██████████████████▇▇▆▅█▇▆▇███▆▇▇▇▆▆▇▆▆▆▄▅▃▅▅▅▆▆▇▇▆▇▆▆▆▅▄▅▄▅ █
  75.4 μs      Histogram: log(frequency) by time      82.3 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.
[ Info: Computing gradient w/ Enzyme forward (chunk=1)
BenchmarkTools.Trial: 9082 samples with 1 evaluation.
 Range (min … max):  543.590 μs … 606.032 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     548.140 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   549.017 μs ±   3.913 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

     ▁▄▃          ▄█▆▁                                           
  ▁▃▆███▇▃▂▂▁▁▁▂▅█████▆▄▄▄▃▄▄▃▃▃▂▂▂▂▂▃▃▃▄▄▄▄▃▃▃▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁ ▃
  544 μs           Histogram: frequency by time          559 μs <

 Memory estimate: 464 bytes, allocs estimate: 14.
ntuple(len) do i
    #= /mnt/Data/git/Enzyme.jl/loglik.jl:94 =#
    (Enzyme.autodiff(Enzyme.Forward, mixture_loglikelihoodE2, Enzyme.DuplicatedNoNeed, Enzyme.Duplicated(params0, shadow[i]), data))[1]
end = (275.6684047875707, 194.6371294031169, 259.4876528035026, 283.24055693669914, -8.8180206427453, 30.951111171059118, 1.2523661647362547, 33.091964748936775, -9.415544882191224, -3.2309482131545577, -13.904291598097782, -17.077741525198256)
[ Info: Computing gradient w/ Enzyme reverse
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  75.612 μs … 119.052 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     76.271 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   76.474 μs ±   1.424 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▂▇▇▁▆█ ▁              ▁▁▁                                    ▁
  █████████▅▅▄▆▅▅▄▅▆▅█▇▆███▇▆▇▆▆▅▆▅▄▅▅▅▅▄▅▇█▆▆▆▅▄▅▅▄▄▄▄▄▄▄▃▄▅▆ █
  75.6 μs       Histogram: log(frequency) by time      83.9 μs <
1 Like

ForwardDiff continues to be an absolute beast in these benchmarks. As I understand it, Enzyme (and the somewhat mythical Diffractor) is all the hype nowadays, but ForwardDiff almost beats it here (81us, 101us and 76us for Enzyme vs 95us for ForwardDiff).

I guess Enzyme is still a work-in-progress and is probably tailored towards functions with a lot of parameters (like neural networks), so it should become even faster.

Changing K to be a compile-time constant using Val also gets a slight further reduction:

[ Info: Computing gradient w/ Enzyme reverse
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  65.041 μs … 101.112 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     65.461 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   65.654 μs ±   1.163 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▇█▅▄█▇▄▄▄                                                    ▂
  █████████▇▅▇█▆▅▃▄▃▅▅▄▄▅▅▅▇▅▆███▇▇▇▆▆▆▅▆▅▅▄▄▄▄▅▄▄▅▅▅▅▆▇▇▆▆▆▅▃ █
  65 μs         Histogram: log(frequency) by time      71.2 μs <
function mixture_loglikelihoodE3(params::AV{<:Real}, data::AV{<:Real}, ::Val{K})::Real where {K}
    result = 0.
    for x in data
        mid = 0.
        for i in 1:K
            mid += @inbounds params[i] * normal_pdf(x, params[i+K], params[i+2K])
        end
        result += log(mid)
    end
    return result
end
...
K = Val(length(params0) ÷ 3)
trial = @benchmark Enzyme.autodiff(Enzyme.Reverse, $mixture_loglikelihoodE3, $dup, $data, $K)

tl;dr forward mode is proportional to number of inputs, but with low overhead for AD; reverse mode is constant wrt number of inputs, but w higher overhead. Each tool / algorithm combo has different pros/cons and overheads. This benchmark (low input count) naturally benefits from the forward mode AD algorithm’s low overhead since the small input parameter count doesn’t substantially impact the runtime.

So the three styles of Enzyme benchmarks I showed (Enzyme forward (chunk=len(params0)), Enzyme forward (chunk=1), Enzyme reverse) are all distinct algorithms.

Specifically forward mode AD (e.g. any labeled Enzyme forward, and also ForwardDiff.jl) give you the derivative of all outputs with respect to a single input. This is most useful if you either have a lot of outputs, or the number of inputs is small.

Reverse mode (Enzyme reverse, Zyogte, ReverseDiff.jl) give you the derivative of all inputs wrt a single output. Thus this is beneficial when the number of inputs is large. There is, however, often some cost to doing reverse mode. Hence if deciding between forward and reverse mode, you should consider the number of inputs vs outputs and if this is a big number probably use reverse mode, and if a small number, use forward mode.

Finally vector mode (aka batching, etc) does multiple forward or reverse modes in a single call. For example forward mode with respect to multiple inputs, or reverse mode wrt multiple outputs. This is done in the Enzyme forward chunk=length(params0) example, and also is the default in ForwardDiff.jl. In essence this cuts off a constant factor from either the forward/reverse mode times, respectively.

Because there are relatively few parameters, this problem lends itself nicely to forward mode AD, hence forwarddiff.jl doing well. Of course this is taking a gradient (eg of multiple inputs, though few), so ReverseMode may provide a benefit, but that’s dependant on how much of a reverse-mode overhead it has. In the case of Zygote above we see that’s huge and cannot overcome this, at least for this few inputs. For Enzyme, which has a much much smaller overhead, it is able to overcome the reverse-mode overhead, and even beat forward mode.

Depending on what you’re doing there’s a potential different tool applicable to each scenario, and Enzyme is great (and its new Forward/Vector modes are coming along – though this is notably in progress, among lots of other things like GC, etc), but it’s important to realize what algorithmically is appropriate at each time. Also regardless the size of these benchmarks is tiny and potentially means a lot of the time isn’t actually in the AD to see differences.

I will also add that we’re continuing to work on some really cool stuff that will add additional theoretically nontrivial improvements on top of the existing minimal overhead Enzyme adds, but the current Enzyme work is split between this style theory improvements and expanding the scope of code supported by Enzyme (in both Julia and all other LLVM-based languages like C/C++, Fortran, Rust, Swift, Python, JAX, etc that Enzyme differentiates through).

13 Likes

To demonstrate the algorithmic difference between forward-mode AD and reverse-mode AD in this particular example, if I change N_COMPONENTS to be 40 instead of 4, we find that the gap between the two expands.

[ Info: Computing gradient w/ Enzyme reverse
BenchmarkTools.Trial: 6680 samples with 1 evaluation.
 Range (min … max):  739.314 μs … 780.625 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     747.044 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   747.460 μs ±   4.172 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

         ▁█▇▅▄▂        ▄▆▇█▆▄▅▃▁▁                                
  ▁▁▁▁▂▄▅██████▇▆▅▆▅▆▆████████████▇▇▆▇▆▆▆▇██▇▆▆▅▆▅▅▄▃▃▃▂▂▂▂▂▂▂▂ ▄
  739 μs           Histogram: frequency by time          758 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.
[ Info: Computing gradient w/ ForwardDiff
BenchmarkTools.Trial: 812 samples with 1 evaluation.
 Range (min … max):  5.848 ms …   8.651 ms  ┊ GC (min … max): 0.00% … 28.20%
 Time  (median):     6.118 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   6.162 ms ± 274.420 μs  ┊ GC (mean ± σ):  0.48% ±  3.09%
 Memory estimate: 563.19 KiB, allocs estimate: 1505