`Zygote.gradient` is 54000 TIMES slower than `jax.gradient`

In this function the parameter Wctx is modified and even if we asign a new value to the reshape, Enzyme says no, the only way is to give a DuplicatedNoNeed (which is some kind of a context)

function loss(
	Wmid, Wctx,
	nvocab, ndim,
	tok_mid, tok_ctx, x
)

    Wmid = reshape(Wmid, nvocab, ndim)
    Wctx = reshape(Wctx, nvocab, ndim)
    nll ::Float32 = 0.0f0
    @inbounds @fastmath for c in eachindex(x)
        Xij = x[c]
        dotprod = dot(@view(Wmid[tok_mid[c],:]),@view(Wctx[tok_ctx[c],:])) 
        nll += Xij * log(logistic_sigmoid(dotprod))
        nll += (1 - Xij) * log(1 - logistic_sigmoid(dotprod))
    end
    nll = nll / length(x)
	nll
end

meaning we need to have a way to do

	dweights_mid = Enzyme.make_zero(weights_mid)
    dweights_ctx = Enzyme.make_zero(weights_ctx)
    Enzyme.autodiff(Enzyme.Reverse, loss, Enzyme.Duplicated(weights_mid, dweights_mid), Enzyme.DuplicatedNoNeed(weights_ctx, dweights_ctx), Enzyme.Const(nvocab), Enzyme.Const(ndim), Enzyme.Const(tok_mid), Enzyme.Const(tok_ctx), Enzyme.Const(x))

or even better without dweights_ctx = Enzyme.make_zero(weights_ctx)

This is not exactly equivalent but it works with DI:

    gradient!(
        loss,
        dweights_mid,
        backend,
        weights_mid,
        Cache(weights_ctx),
        Constant(nvocab),
        Constant(ndim),
        Constant(tok_mid),
        Constant(tok_ctx),
        Constant(x),
    )

Usual caveats apply: the native Enzyme interface will probably be faster and/or less buggy.

2 Likes

Oh ok, did not follow the improvement, I thought this was duable only with symbolics backends and ForwardDiff

Until yesterday, yes ^^
The Cache context is more tricky than the Constant one, so expect rough edges, but I’m starting to make it work with Enzyme and other backends, as shown in this table. There are still a few bugs in reverse mode and performance needs to be optimized, but it should provide a much needed feature rather soon.

2 Likes

Thank you for that, I just saw the paper on sparse tracing it’s really great ! DI will be everywhere soon I’m sure !

1 Like

I tried Cache

import Random
using DifferentiationInterface,LinearAlgebra,BenchmarkTools
import Enzyme

@fastmath logistic_sigmoid(x) = 1.0f0 / (1 + exp(-x))
function loss(
	Wmid, Wctx,
	nvocab, ndim,
	tok_mid, tok_ctx, x
)

    Wmid = reshape(Wmid, nvocab, ndim)
    Wctx = reshape(Wctx, nvocab, ndim)
    nll ::Float32 = 0.0f0
    @inbounds @fastmath for c in eachindex(x)
        Xij = x[c]
        dotprod = dot(@view(Wmid[tok_mid[c],:]),@view(Wctx[tok_ctx[c],:])) 
        nll += Xij * log(logistic_sigmoid(dotprod))
        nll += (1 - Xij) * log(1 - logistic_sigmoid(dotprod))
    end
    nll = nll / length(x)
	nll
end

function train(rng, nvocab::Integer, nsamples::Integer, ndim::Integer, backend)
	ntrues = nsamples ÷ 2

	tok_mid = rand(rng, 1:nvocab, nsamples)
	tok_ctx = rand(rng, 1:nvocab, nsamples)
	x = [trues(ntrues); falses(nsamples - ntrues)]

	weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
	weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim)


	# @info "Computing gradient..." backend
	dweights_mid = Enzyme.make_zero(weights_mid)
    # dweights_ctx = Enzyme.make_zero(weights_ctx)
    gradient!(loss,dweights_mid,backend,weights_mid,Cache(weights_ctx),Constant(nvocab),Constant(ndim),Constant(tok_mid),Constant(tok_ctx),Constant(x))
    # Enzyme.autodiff(Enzyme.Reverse, loss, Enzyme.Duplicated(weights_mid, dweights_mid), Enzyme.DuplicatedNoNeed(weights_ctx, dweights_ctx), Enzyme.Const(nvocab), Enzyme.Const(ndim), Enzyme.Const(tok_mid), Enzyme.Const(tok_ctx), Enzyme.Const(x))
	dweights_mid
end

I get this error

ERROR: MethodError: no method matching _translate(::AutoEnzyme{…}, ::EnzymeCore.ReverseMode{…}, ::Val{…}, ::Cache{…})
The function `_translate` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  _translate(::AutoEnzyme, ::EnzymeCore.Mode, ::Val{B}, ::DifferentiationInterface.FunctionContext) where B
   @ DifferentiationInterfaceEnzymeExt C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\utils.jl:56
  _translate(::AutoEnzyme, ::EnzymeCore.Mode, ::Val{B}, ::Union{DifferentiationInterface.BackendContext, Constant}) where B
   @ DifferentiationInterfaceEnzymeExt C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\utils.jl:50

Stacktrace:
  [1] (::DifferentiationInterfaceEnzymeExt.var"#3#4"{1, AutoEnzyme{…}, EnzymeCore.ReverseMode{…}})(c::Cache{Vector{…}})
    @ DifferentiationInterfaceEnzymeExt C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\utils.jl:66
  [2] map
    @ .\tuple.jl:358 [inlined]
  [3] translate
    @ C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\utils.jl:65 [inlined]
  [4] gradient!(::typeof(loss), ::Vector{…}, ::DifferentiationInterfaceEnzymeExt.EnzymeGradientPrep{…}, ::AutoEnzyme{…}, ::Vector{…}, ::Cache{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…})
    @ DifferentiationInterfaceEnzymeExt C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\reverse_onearg.jl:266
  [5] gradient!(::typeof(loss), ::Vector{…}, ::AutoEnzyme{…}, ::Vector{…}, ::Cache{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…})
    @ DifferentiationInterface C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\src\fallbacks\no_prep.jl:55
  [6] train(rng::Random.Xoshiro, nvocab::Int64, nsamples::Int64, ndim::Int64, backend::AutoEnzyme{Nothing, Nothing})
    @ Main c:\Users\yolha\Desktop\bench_py_jl\benchmark.jl:39
  [7] var"##core#237"(rng#235::Random.Xoshiro, bck#236::AutoEnzyme{Nothing, Nothing})
    @ Main C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:598
  [8] var"##sample#238"(::Tuple{Random.Xoshiro, AutoEnzyme{Nothing, Nothing}}, __params::BenchmarkTools.Parameters)
    @ Main C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:607
  [9] _lineartrial(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; maxevals::Int64, kwargs::@Kwargs{})
    @ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:186
 [10] _lineartrial(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters)
    @ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:181
 [11] #invokelatest#2
    @ .\essentials.jl:1055 [inlined]
 [12] invokelatest
    @ .\essentials.jl:1052 [inlined]
 [13] #lineartrial#46
    @ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:51 [inlined]
 [14] lineartrial
    @ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:50 [inlined]
 [15] tune!(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; progressid::Nothing, nleaves::Float64, ndone::Float64, verbose::Bool, pad::String, kwargs::@Kwargs{})  
    @ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:299
 [16] tune!
    @ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:288 [inlined]
 [17] tune!(b::BenchmarkTools.Benchmark)
    @ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:288
 [18] top-level scope
    @ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:728
Some type information was truncated. Use `show(err)` to see complete types.

train (generic function with 1 method)

Random.Xoshiro(0xc3b1bf2ea9b69425, 0xfb15c0017deb2d31, 0xb5c67bbde9fbe995, 0x10b9c73e8107af27, 0xb6e7dc9e3c9975b2)

AutoEnzyme()

ERROR: MethodError: no method matching gradient!(::typeof(loss), ::AutoEnzyme{…}, ::Vector{…}, ::Vector{…}, ::Cache{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…})
The function `gradient!` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  gradient!(::F, ::Any, ::ADTypes.AbstractADType, ::Any, ::Context...) where {F, C}
   @ DifferentiationInterface C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\src\fallbacks\no_prep.jl:51
  gradient!(::F, ::Any, ::DifferentiationInterfaceEnzymeExt.EnzymeGradientPrep, ::AutoEnzyme{<:Union{Nothing, EnzymeCore.ReverseMode}}, ::Any, ::Context...) where {F, C}     
   @ DifferentiationInterfaceEnzymeExt C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\reverse_onearg.jl:254
  gradient!(::Any, ::Any, ::DifferentiationInterface.NoGradientPrep, ::AutoZygote, ::Any, ::Union{DifferentiationInterface.BackendContext, Constant, DifferentiationInterface.FunctionContext}...) where C
   @ DifferentiationInterfaceZygoteExt C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceZygoteExt\DifferentiationInterfaceZygoteExt.jl:147
  ...

Stacktrace:
  [1] train(rng::Random.Xoshiro, nvocab::Int64, nsamples::Int64, ndim::Int64, backend::AutoEnzyme{Nothing, Nothing})
    @ Main c:\Users\yolha\Desktop\bench_py_jl\benchmark.jl:39
  [2] var"##core#237"(rng#235::Random.Xoshiro, bck#236::AutoEnzyme{Nothing, Nothing})
    @ Main C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:598
  [3] var"##sample#238"(::Tuple{Random.Xoshiro, AutoEnzyme{Nothing, Nothing}}, __params::BenchmarkTools.Parameters)
    @ Main C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:607
  [4] _lineartrial(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; maxevals::Int64, kwargs::@Kwargs{})
    @ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:186
  [5] _lineartrial(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters)
    @ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:181
  [6] #invokelatest#2
    @ .\essentials.jl:1055 [inlined]
  [7] invokelatest
    @ .\essentials.jl:1052 [inlined]
  [8] #lineartrial#46
    @ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:51 [inlined]
  [9] lineartrial
    @ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:50 [inlined]
 [10] tune!(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; progressid::Nothing, nleaves::Float64, ndone::Float64, verbose::Bool, pad::String, kwargs::@Kwargs{})  
    @ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:299
 [11] tune!
    @ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:288 [inlined]
 [12] tune!(b::BenchmarkTools.Benchmark)
    @ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:288
 [13] top-level scope
    @ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:728
Some type information was truncated. Use `show(err)` to see complete types.

train (generic function with 1 method)

Random.Xoshiro(0xc3b1bf2ea9b69425, 0xfb15c0017deb2d31, 0xb5c67bbde9fbe995, 0x10b9c73e8107af27, 0xb6e7dc9e3c9975b2)

AutoEnzyme()

ERROR: MethodError: no method matching _translate(::AutoEnzyme{…}, ::EnzymeCore.ReverseMode{…}, ::Val{…}, ::Cache{…})
The function `_translate` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  _translate(::AutoEnzyme, ::EnzymeCore.Mode, ::Val{B}, ::DifferentiationInterface.FunctionContext) where B
   @ DifferentiationInterfaceEnzymeExt C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\utils.jl:56
  _translate(::AutoEnzyme, ::EnzymeCore.Mode, ::Val{B}, ::Union{DifferentiationInterface.BackendContext, Constant}) where B
   @ DifferentiationInterfaceEnzymeExt C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\utils.jl:50

Stacktrace:
  [1] (::DifferentiationInterfaceEnzymeExt.var"#3#4"{1, AutoEnzyme{…}, EnzymeCore.ReverseMode{…}})(c::Cache{Vector{…}})
    @ DifferentiationInterfaceEnzymeExt C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\utils.jl:66
  [2] map
    @ .\tuple.jl:358 [inlined]
  [3] translate
    @ C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\utils.jl:65 [inlined]
  [4] gradient!(::typeof(loss), ::Vector{…}, ::DifferentiationInterfaceEnzymeExt.EnzymeGradientPrep{…}, ::AutoEnzyme{…}, ::Vector{…}, ::Cache{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…})
    @ DifferentiationInterfaceEnzymeExt C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\reverse_onearg.jl:266
  [5] gradient!(::typeof(loss), ::Vector{…}, ::AutoEnzyme{…}, ::Vector{…}, ::Cache{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…})
    @ DifferentiationInterface C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\src\fallbacks\no_prep.jl:55
  [6] train(rng::Random.Xoshiro, nvocab::Int64, nsamples::Int64, ndim::Int64, backend::AutoEnzyme{Nothing, Nothing})
    @ Main c:\Users\yolha\Desktop\bench_py_jl\benchmark.jl:39
  [7] var"##core#237"(rng#235::Random.Xoshiro, bck#236::AutoEnzyme{Nothing, Nothing})
    @ Main C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:598
  [8] var"##sample#238"(::Tuple{Random.Xoshiro, AutoEnzyme{Nothing, Nothing}}, __params::BenchmarkTools.Parameters)
    @ Main C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:607
  [9] _lineartrial(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; maxevals::Int64, kwargs::@Kwargs{})
    @ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:186
 [10] _lineartrial(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters)
    @ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:181
 [11] #invokelatest#2
    @ .\essentials.jl:1055 [inlined]
 [12] invokelatest
    @ .\essentials.jl:1052 [inlined]
 [13] #lineartrial#46
    @ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:51 [inlined]
 [14] lineartrial
    @ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:50 [inlined]
 [15] tune!(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; progressid::Nothing, nleaves::Float64, ndone::Float64, verbose::Bool, pad::String, kwargs::@Kwargs{})  
    @ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:299
 [16] tune!
    @ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:288 [inlined]
 [17] tune!(b::BenchmarkTools.Benchmark)
    @ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:288
 [18] top-level scope
    @ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:728
Some type information was truncated. Use `show(err)` to see complete types.

I open an issue ? I can try to lower all this a bit (same on DI#main)

1 Like

I had a look at the actual IR the JAX example generates (see Whether the hlo generated by jax through jax.fit.lower() is optimized by xla · jax-ml/jax · Discussion #15899 · GitHub if anyone wants to do this yourself), and it looks extremely close to the operations used here. A couple of dot products, a reduction for the mean, and fused elementwise broadcasts. So it’s not surprising that a Julia translation has similar performance.

2 Likes

MLIR may add a little of multi threading here and there in jax since it will make kernels no ? Nevermind you would have seen that in the IR sorry, so the 6x are probebly from c++ → python and reverse and timing / or just that we run the benchmark on different computers lol

Benchmark time!

Loss function Autograd 1st grad Subsequent grads (mean)
loss_mcabbott Zygote 4.4 s 2.6 ms
loss_mcabbott Mooncake 82.3 s 4.2 ms
loss_yolhan_mannes Zygote 345.3 s 78.5 ms
loss_yolhan_mannes Mooncake 63.3 s 3.7 ms
loss_mcabbott_mean Zygote 4.2 s 79.97 ms
loss_mcabbott_mean Mooncake 79.4 s 4.2 ms

I ran the entire code from scratch (julia code.jl) for each implementation of the loss function. Below, “startup time” means “time-to-first-gradient”.

Observations:

  • Mooncake’s startup times aren’t great.
  • Mooncake’s other timings are impressive.
  • Accumulation like nll += stuff in loss_yolhan_mannes seems to hurt startup times a lot.

Loss functions

function loss_mcabbott_mean(
	Wmid::AM{T}, Wctx::AM{T},
	tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
) where T<:Real
	Wmidrows = @inbounds eachrow(Wmid)[tok_mid]
	Wctxrows = @inbounds eachrow(Wctx)[tok_ctx]

	-mean([
		let
			pij = logistic_sigmoid(dot(wi, wj))
			Xij * log(pij) + (1 - Xij) * log(1 - pij)
		end
		for (wi, wj, Xij) in zip(Wmidrows, Wctxrows, x)
	])
end

function loss_mcabbott(
	Wmid::AM{T}, Wctx::AM{T},
	tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
) where T<:Real
	tmp = sum(
		Wmid[tok_mid, :] .* Wctx[tok_ctx, :], dims=2
	) |> vec
	-mean(
		@. x * log(logistic_sigmoid(tmp)) + (1 - x) * log(1 - logistic_sigmoid(tmp))
	)
end

function loss_yolhan_mannes(
	Wmid::AM{T}, Wctx::AM{T},
	tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
) where T<:Real
	nll = zero(T)
	@inbounds @fastmath for c in eachindex(x)
		Xij = x[c]
		dotprod = dot(@view(Wmid[tok_mid[c], :]), @view Wctx[tok_ctx[c], :])
		sigm = logistic_sigmoid(dotprod)
		nll += Xij * log(sigm) + (1 - Xij) * log(1 - sigm)
	end
	nll / length(x)
end
Full code
using Statistics, LinearAlgebra; import Random
using DifferentiationInterface
import Zygote, Mooncake

const AV{T} = AbstractVector{T}
const AM{T} = AbstractMatrix{T}

logistic_sigmoid(x::Real) = 1 / (1 + exp(-x))

function loss_mcabbott_mean(
	Wmid::AM{T}, Wctx::AM{T},
	tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
) where T<:Real
	Wmidrows = @inbounds eachrow(Wmid)[tok_mid]
	Wctxrows = @inbounds eachrow(Wctx)[tok_ctx]

	-mean([
		let
			pij = logistic_sigmoid(dot(wi, wj))
			Xij * log(pij) + (1 - Xij) * log(1 - pij)
		end
		for (wi, wj, Xij) in zip(Wmidrows, Wctxrows, x)
	])
end

function loss_mcabbott(
	Wmid::AM{T}, Wctx::AM{T},
	tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
) where T<:Real
	tmp = sum(
		Wmid[tok_mid, :] .* Wctx[tok_ctx, :], dims=2
	) |> vec
	-mean(
		@. x * log(logistic_sigmoid(tmp)) + (1 - x) * log(1 - logistic_sigmoid(tmp))
	)
end

function loss_yolhan_mannes(
	Wmid::AM{T}, Wctx::AM{T},
	tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
) where T<:Real
	nll = zero(T)
	@inbounds @fastmath for c in eachindex(x)
		Xij = x[c]
		dotprod = dot(@view(Wmid[tok_mid[c], :]), @view Wctx[tok_ctx[c], :])
		sigm = logistic_sigmoid(dotprod)
		nll += Xij * log(sigm) + (1 - Xij) * log(1 - sigm)
	end
	nll / length(x)
end

function train(rng, loss, nvocab::Integer, nsamples::Integer, ndim::Integer, backend, quiet::Bool)
	ntrues = nsamples ÷ 2

	tok_mid = rand(rng, 1:nvocab, nsamples)
	tok_ctx = rand(rng, 1:nvocab, nsamples)
	x = [trues(ntrues); falses(nsamples - ntrues)]

	weights_mid = 0.1f0 .* randn(rng, Float32, nvocab, ndim)
	weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab, ndim)
	quiet || @info "Number of parameters:" size(weights_mid) size(weights_ctx) total=(length(weights_mid) + length(weights_ctx))

	dweights_mid = similar(weights_mid)
	quiet || @info "Computing gradient..." backend
	if !quiet
		@timev gradient!(
			loss,
			dweights_mid, backend,
			weights_mid, Constant(weights_ctx),
			Constant(tok_mid), Constant(tok_ctx), Constant(x)
		)
	else
		gradient!(
			loss,
			dweights_mid, backend,
			weights_mid, Constant(weights_ctx),
			Constant(tok_mid), Constant(tok_ctx), Constant(x)
		)
	end
	dweights_mid
end

using BenchmarkTools

function main(loss)
	@info "CODE: $loss; backend: Zygote"
	backend = AutoZygote()
	grad = train(Random.Xoshiro(5), loss, 100277, 100, 2, backend, false)
	@info "Gradient info:" size(grad) sum(grad)

	@show mean(
		@benchmark let
			train(Random.Xoshiro(5), $loss, 100277, 100, 2, $backend, true)
		end
	)

	println("\n\n")
	@info "CODE: $loss; backend: Mooncake"
	backend = AutoMooncake(config=nothing)
	grad = train(Random.Xoshiro(5), loss, 100277, 100, 2, backend, false)
	@info "Gradient info:" size(grad) sum(grad)

	@show mean(
		@benchmark let
			train(Random.Xoshiro(5), $loss, 100277, 100, 2, $backend, true)
		end
	)
end

main(loss_mcabbott_mean)
Full output
[ Info: CODE: loss_mcabbott; backend: Zygote
┌ Info: Number of parameters:
│   size(weights_mid) = (100277, 2)
│   size(weights_ctx) = (100277, 2)
└   total = 401108
┌ Info: Computing gradient...
└   backend = AutoZygote()
  4.433219 seconds (8.67 M allocations: 440.679 MiB, 1.83% gc time, 99.98% compilation time)
elapsed time (ns):  4.433218935e9
gc time (ns):       81026940
bytes allocated:    462085232
pool allocs:        8661685
non-pool GC allocs: 332
malloc() calls:     9061
free() calls:       7113
minor collections:  2
full collections:   0
┌ Info: Gradient info:
│   size(grad) = (100277, 2)
└   sum(grad) = -0.018946059f0
mean(@benchmark(let
            train(Random.Xoshiro(5), $(Expr(:$, :loss)), 100277, 100, 2, $(Expr(:$, :backend)), true)
        end)) = TrialEstimate(2.613 ms)


[ Info: CODE: loss_mcabbott; backend: Mooncake
┌ Info: Number of parameters:
│   size(weights_mid) = (100277, 2)
│   size(weights_ctx) = (100277, 2)
└   total = 401108
┌ Info: Computing gradient...
└   backend = AutoMooncake{Nothing}(nothing)
 82.335135 seconds (164.97 M allocations: 8.256 GiB, 2.64% gc time, 84.53% compilation time: <1% of which was recompilation)
elapsed time (ns):  8.2335135303e10
gc time (ns):       2172342058
bytes allocated:    8865191920
pool allocs:        164814503
non-pool GC allocs: 5073
malloc() calls:     155178
free() calls:       154675
minor collections:  37
full collections:   2
┌ Info: Gradient info:
│   size(grad) = (100277, 2)
└   sum(grad) = -0.018946059f0
mean(@benchmark(let
            train(Random.Xoshiro(5), $(Expr(:$, :loss)), 100277, 100, 2, $(Expr(:$, :backend)), true)
        end)) = TrialEstimate(4.221 ms)


[ Info: CODE: loss_yolhan_mannes; backend: Zygote
┌ Info: Number of parameters:
│   size(weights_mid) = (100277, 2)
│   size(weights_ctx) = (100277, 2)
└   total = 401108
┌ Info: Computing gradient...
└   backend = AutoZygote()
345.317342 seconds (8.39 M allocations: 719.310 MiB, 0.04% gc time, 99.96% compilation time)
elapsed time (ns):  3.45317341779e11
gc time (ns):       152304804
bytes allocated:    754251240
pool allocs:        8381193
non-pool GC allocs: 381
malloc() calls:     6019
free() calls:       8507
minor collections:  27
full collections:   0
┌ Info: Gradient info:
│   size(grad) = (100277, 2)
└   sum(grad) = 0.018946057f0
mean(@benchmark(let
            train(Random.Xoshiro(5), $(Expr(:$, :loss)), 100277, 100, 2, $(Expr(:$, :backend)), true)
        end)) = TrialEstimate(78.547 ms)



[ Info: CODE: loss_yolhan_mannes; backend: Mooncake
┌ Info: Number of parameters:
│   size(weights_mid) = (100277, 2)
│   size(weights_ctx) = (100277, 2)
└   total = 401108
┌ Info: Computing gradient...
└   backend = AutoMooncake{Nothing}(nothing)
 63.335114 seconds (127.02 M allocations: 6.374 GiB, 2.67% gc time, 82.60% compilation time: <1% of which was recompilation)
elapsed time (ns):  6.3335113863e10
gc time (ns):       1690957639
bytes allocated:    6844269736
pool allocs:        126886207
non-pool GC allocs: 3808
malloc() calls:     126373
free() calls:       123244
minor collections:  28
full collections:   2
┌ Info: Gradient info:
│   size(grad) = (100277, 2)
└   sum(grad) = 0.018946059f0
mean(@benchmark(let
            train(Random.Xoshiro(5), $(Expr(:$, :loss)), 100277, 100, 2, $(Expr(:$, :backend)), true)
        end)) = TrialEstimate(3.659 ms)

[ Info: CODE: loss_mcabbott_mean; backend: Zygote
┌ Info: Number of parameters:
│   size(weights_mid) = (100277, 2)
│   size(weights_ctx) = (100277, 2)
└   total = 401108
┌ Info: Computing gradient...
└   backend = AutoZygote()
  4.200216 seconds (9.54 M allocations: 483.351 MiB, 2.87% gc time, 95.91% compilation time)
elapsed time (ns):  4.2002158319999995e9
gc time (ns):       120481781
bytes allocated:    506830280
pool allocs:        9530497
non-pool GC allocs: 284
malloc() calls:     6166
free() calls:       8660
minor collections:  3
full collections:   0
┌ Info: Gradient info:
│   size(grad) = (100277, 2)
└   sum(grad) = -0.018946057f0
mean(@benchmark(let
            train(Random.Xoshiro(5), $(Expr(:$, :loss)), 100277, 100, 2, $(Expr(:$, :backend)), true)
        end)) = TrialEstimate(79.968 ms)



[ Info: CODE: loss_mcabbott_mean; backend: Mooncake
┌ Info: Number of parameters:
│   size(weights_mid) = (100277, 2)
│   size(weights_ctx) = (100277, 2)
└   total = 401108
┌ Info: Computing gradient...
└   backend = AutoMooncake{Nothing}(nothing)
 79.360692 seconds (154.90 M allocations: 7.758 GiB, 2.89% gc time, 84.79% compilation time: <1% of which was recompilation)
elapsed time (ns):  7.9360692289e10
gc time (ns):       2289888095
bytes allocated:    8329889720
pool allocs:        154749525
non-pool GC allocs: 4811
malloc() calls:     146434
free() calls:       140564
minor collections:  34
full collections:   3
┌ Info: Gradient info:
│   size(grad) = (100277, 2)
└   sum(grad) = -0.018946059f0
mean(@benchmark(let
            train(Random.Xoshiro(5), $(Expr(:$, :loss)), 100277, 100, 2, $(Expr(:$, :backend)), true)
        end)) = TrialEstimate(4.223 ms)
4 Likes

Oops, the cache translation isn’t yet released. But it should be on main though :thinking:

maybe I needed to gc before I will try

I tried again and Cache should work on DI#main. This version is being registered as we speak

working ! 1.061 ms with that :slight_smile: yeah since it’s an extension I had to gc --all and restart julia typical things

Reusing the same variable name for different objects of potentially different types is an antipattern in Julia, and seems to be the reason you had to add an extra Duplicated to make Enzyme work. If I use distinct names I can use Enzyme.Const(weights_ctx) no problem:

using Statistics
import Random
using LinearAlgebra, BenchmarkTools
import Enzyme

const AV{T} = AbstractVector{T}
const AM{T} = AbstractMatrix{T}
logistic_sigmoid(x) = 1.0f0 / (1 + exp(-x))

function loss(
    Wmid, Wctx,
    nvocab, ndim,
    tok_mid, tok_ctx, x
)

    Wmid_reshaped = reshape(Wmid, nvocab, ndim)
    Wctx_reshaped = reshape(Wctx, nvocab, ndim)
    nll::Float32 = 0.0f0
    @inbounds for c in eachindex(x)
        Xij = x[c]
        dotprod = dot(@view(Wmid_reshaped[tok_mid[c], :]), @view(Wctx_reshaped[tok_ctx[c], :]))
        nll += Xij * log(logistic_sigmoid(dotprod))
        nll += (1 - Xij) * log(1 - logistic_sigmoid(dotprod))
    end
    nll = nll / length(x)
    nll
end

function train(rng, nvocab::Integer, nsamples::Integer, ndim::Integer)
    ntrues = nsamples ÷ 2

    tok_mid = rand(rng, 1:nvocab, nsamples)
    tok_ctx = rand(rng, 1:nvocab, nsamples)
    x = [trues(ntrues); falses(nsamples - ntrues)]

    weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
    weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim)

    dweights_mid = Enzyme.make_zero(weights_mid)

    Enzyme.autodiff(
        Enzyme.Reverse,
        loss,
        Enzyme.Duplicated(weights_mid, dweights_mid),
        Enzyme.Const(weights_ctx),
        Enzyme.Const(nvocab),
        Enzyme.Const(ndim),
        Enzyme.Const(tok_mid),
        Enzyme.Const(tok_ctx),
        Enzyme.Const(x),
    )
    dweights_mid
end

rng = Random.Xoshiro(5)
@btime train($rng, 100277, 100, 2)

Output:

julia> include("ad.jl");
  1.226 ms (27 allocations: 3.84 MiB)
4 Likes

Oh ok my bad then no need for that. Was interesting though

actually, on my computer, Enzyme crash if I just change the name

function loss(
	Wmid, Wctx,
	nvocab, ndim,
	tok_mid, tok_ctx, x
)

    Wmid2 = reshape(Wmid, nvocab, ndim)
    Wctx2 = reshape(Wctx, nvocab, ndim)
    nll ::Float32 = 0.0f0
    @inbounds @fastmath for c in eachindex(x)
        Xij = x[c]
        dotprod = dot(@view(Wmid2[tok_mid[c],:]),@view(Wctx2[tok_ctx[c],:])) 
        nll += Xij * log(logistic_sigmoid(dotprod))
        nll += (1 - Xij) * log(1 - logistic_sigmoid(dotprod))
    end
    nll = nll / length(x)
	nll
end
	dweights_mid = Enzyme.make_zero(weights_mid)
    @time Enzyme.autodiff(Enzyme.Reverse, loss, Enzyme.Duplicated(weights_mid, dweights_mid), Enzyme.Const(weights_ctx), Enzyme.Const(nvocab), Enzyme.Const(ndim), Enzyme.Const(tok_mid), Enzyme.Const(tok_ctx), Enzyme.Const(x))

What version of Enzyme are you using? The current is 0.13.28

Enzyme v0.13.28. Works on 1.10 not on 1.11

1 Like

I’ve also encountered a few more bugs with Enzyme on 1.11 than on 1.10. And its behavior can be platform-dependent since the LLVM codegen is different

2 Likes

We’re on the nighmare of wsmoses

2 Likes