`Zygote.gradient` is 54000 TIMES slower than `jax.gradient`

I’m almost unable to compute ant gradients (EDIT 9 hrs later: meant to say “any gradients”, but got gradients for ants) in Zygote. I initially attempted to optimize my model using around a million parameters, but it got stuck computing the first gradient. I simplified the model so now it had only 401108 parameters. Still couldn’t get a single gradient. Finally, I used just 100 data points, waited some more and got my gradient after about 6 minutes.

Six minutes to compute one gradient of a logistic regression?? Also the gradient is with respect to only half of the parameters.

Julia code

using Statistics
import Random, Zygote

const AV{T} = AbstractVector{T}
const AM{T} = AbstractMatrix{T}

logistic_sigmoid(x::Real) = 1 / (1 + exp(-x))

function loss(
	Wmid::AM{<:Real}, Wctx::AM{<:Real},
	tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
)
	nll = -mean(
		let
			pij = @views logistic_sigmoid(Wmid[i, :]' * Wctx[j, :])
			nll = Xij * log(pij) + (1 - Xij) * log(1 - pij)
		end
		for (i, j, Xij) in zip(tok_mid, tok_ctx, x)
	)
	@info "loss: $nll"
	nll
end

function train(rng, nvocab::Integer, nsamples::Integer, ndim::Integer)
	_loss(Wmid::AV{<:Real}, Wctx::AV{<:Real}) = loss(
		reshape(Wmid, nvocab, ndim), reshape(Wctx, nvocab, ndim),
		tok_mid, tok_ctx, x
	)
	ntrues = nsamples ÷ 2

	tok_mid = rand(rng, 1:nvocab, nsamples)
	tok_ctx = rand(rng, 1:nvocab, nsamples)
	x = [trues(ntrues); falses(nsamples - ntrues)]

	weights_mid = 0.1randn(rng, Float32, nvocab * ndim)
	weights_ctx = 0.1randn(rng, Float32, nvocab * ndim)
	@info "Number of parameters:" size(weights_mid) size(weights_ctx) total=(length(weights_mid) + length(weights_ctx))
	_loss(weights_mid, weights_ctx)
	@timev _loss(weights_mid, weights_ctx)

	Zygote.gradient(
		Wmid -> _loss(Wmid, weights_ctx), weights_mid
	)[1]
end

grad = @timev train(Random.Xoshiro(5), 100277, 100, 2)
@show size(grad)
@info "Gradient sum:" sum(grad)

Julia output

> julia --project test_grad.jl
┌ Info: Number of parameters:
│   size(weights_mid) = (200554,)
│   size(weights_ctx) = (200554,)
└   total = 401108
[ Info: loss: 0.6925831785105732
[ Info: loss: 0.6925831785105732
  0.000092 seconds (213 allocations: 9.227 KiB)
elapsed time (ns):  91884.0
gc time (ns):       0
bytes allocated:    9448
pool allocs:        213
non-pool GC allocs: 0
minor collections:  0
full collections:   0
[ Info: loss: 0.6925831785105732
344.948917 seconds (42.56 M allocations: 2.558 GiB, 0.18% gc time, 99.93% compilation time)
elapsed time (ns):  3.44948917302e11
gc time (ns):       608519881
bytes allocated:    2746926240
pool allocs:        42536936
non-pool GC allocs: 967
malloc() calls:     24705
free() calls:       24872
minor collections:  99
full collections:   1
size(grad) = (200554,)
┌ Info: Gradient sum:
└   sum(grad) = -0.018946058381392624

Time: 344.948917 seconds

Python code

import time
import jax, jax.numpy as np, jax.random as rnd

def logistic_sigmoid(x):
	return 1 / (1 + np.exp(-x))

@jax.jit
def loss(Wmid, Wctx, tok_mid, tok_ctx, x):
	def nll(i, j, Xij):
		wi, wj = Wmid[i, :], Wctx[j, :]
		pij = logistic_sigmoid(wi.dot(wj))
		return Xij * np.log(pij) + (1 - Xij) * np.log(1 - pij)

	return -jax.vmap(nll)(tok_mid, tok_ctx, x).mean()

def train(rng, nvocab: int, nsamples: int, ndim: int):
	def _loss(Wmid, Wctx):
		return loss(
			np.reshape(Wmid, (nvocab, ndim)),
			np.reshape(Wctx, (nvocab, ndim)),
			tok_mid, tok_ctx, x
		)

	ntrues = nsamples // 2

	tok_mid = rnd.choice(rng, np.arange(nvocab), (nsamples, ))
	tok_ctx = rnd.choice(rng, np.arange(nvocab), (nsamples, ))
	x = np.concatenate([
		np.ones(ntrues), np.zeros(nsamples - ntrues)
	])

	weights_mid = 0.1 * rnd.normal(rng, (nvocab * ndim, ))
	weights_ctx = 0.1 * rnd.normal(rng, (nvocab * ndim, ))
	print(
		"Number of parameters:\n"
		f"\t{weights_mid.shape=}\n",
		f"\t{weights_ctx.shape=}\n",
		f"\ttotal={weights_mid.size + weights_ctx.size}"
	)

	nll = _loss(weights_mid,weights_ctx,)
	print("Loss:", nll)
	dloss_Wmid = jax.value_and_grad(lambda Wmid: _loss(Wmid, weights_ctx))
	dloss_Wctx = jax.value_and_grad(lambda Wctx: _loss(weights_mid, Wctx))

	return dloss_Wmid(weights_mid)

train_jitted = jax.jit(lambda key: train(key, 100277, 100, 2))
tbegin = time.time()
value, grad = jax.block_until_ready(train_jitted(rnd.key(5)))
print(f"{grad.shape=}")
print("Done in", time.time() - tbegin, "seconds (with compilation)")

tbegin = time.time()
value, grad = jax.block_until_ready(train_jitted(rnd.key(5)))
print(f"{grad.shape=}")
print("Done in", time.time() - tbegin, "seconds (2nd run)")

print("Gradient sum:", grad.sum())

Python output

> python test_grad.py
Number of parameters:
        weights_mid.shape=(200554,)
        weights_ctx.shape=(200554,)
        total=401108
Loss: Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
grad.shape=(200554,)
Done in 0.4434378147125244 seconds (with compilation)
grad.shape=(200554,)
Done in 0.006359100341796875 seconds (2nd run)
Gradient sum: 0.0016649812

Time after compilation: 0.0064 seconds

Question

So Julia seems to be 344.948917 / 0.0064 == 53898 times slower than JAX. What am I doing wrong in my Julia code?

I tried ForwardDiff as well, it’s just as slow.

Is your goal to get help making this faster?

If so, I am reasonably sure you can help Zygote along here, certainly it hates indexing. You could also try Enzyme instead. In both cases the first run has compilation time, which it appears you are including above.

2 Likes

Wait, there’s no obvious mistake in my code? I thought I’ve been staring at this code for too long, so I’m just missing something basic…

Yes, I’d like to make the Julia code faster. Not just faster, I’m simply trying to make it do the job. I couldn’t even begin estimating the original model with around a million parameters. I waited more than an hour (!) and still it didn’t finish computing even the first gradient.

Sure, I’m including compilation time in the Julia timings, but if I include JIT time in JAX timings, Julia is still 344.948917 / 0.45 == 766 times slower than JAX. That can’t be right, but this is exactly what I’m getting, I even ran it multiple times, getting these exact results every time.

What can I do to use less indexing? The only indexing is in logistic_sigmoid(Wmid[i, :]' * Wctx[j, :]). I think I have to use indexing here because otherwise Wmid' * Wctx will allocate a 100277-by-100277 matrix, but I don’t have nearly enough RAM (40 gigs assuming Float32) to store it.

I tried Enzyme with DifferentiationInterface.jl, but I got this error:

ERROR: LoadError: Function argument passed to autodiff cannot be proven readonly.
If the the function argument cannot contain derivative data, instead call autodiff(Mode, Const(f), ...)
See https://enzyme.mit.edu/index.fcgi/julia/stable/faq/#Activity-of-temporary-storage for more information.
The potentially writing call is   store {} addrspace(10)* %.fca.1.2.extract, {} addrspace(10)** %.fca.1.2.gep, align 8, !dbg !11, !noalias !22, using   %.fca.1.2.gep = getelementptr inbounds { {} addrspace(10)*, { i64, i64, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } }, { {} addrspace(10)*, { i64, i64, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } }* %.innerparm, i64 0, i32 1, i32 2, !dbg !11

Stacktrace:
  [1] augmented_julia__2_7525_inner_3wrap
    @ test_grad.jl:0
  [2] macro expansion
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5340 [inlined]
  [3] enzyme_call
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4878 [inlined]
  [4] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4814 [inlined]
  [5] autodiff
    @ ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:396 [inlined]
  [6] autodiff
    @ ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:524 [inlined]
  [7] gradient!(::var"#2#4"{Vector{Float64}, var"#_loss#3"{Int64, Int64}}, ::Vector{Float64}, ::DifferentiationInterfaceEnzymeExt.EnzymeGradientPrep{Vector{Float64}}, ::AutoEnzyme{Nothing, Nothing}, ::Vector{Float64})
    @ DifferentiationInterfaceEnzymeExt ~/.julia/packages/DifferentiationInterface/mXEZA/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl:272
  [8] gradient!
    @ ~/.julia/packages/DifferentiationInterface/mXEZA/src/fallbacks/no_prep.jl:55 [inlined]
  [9] train(rng::Random.Xoshiro, nvocab::Int64, nsamples::Int64, ndim::Int64, backend::AutoEnzyme{Nothing, Nothing})
    @ Main test_grad.jl:50
 [10] top-level scope
    @ ./timing.jl:581
in expression starting at test_grad.jl:57

Pretty sure the store {} addrspace(10)* %.fca.1.2.extract... stuff is LLVM IR, but I have no idea what it means or where exactly it comes from. Okay, store probably means that something somewhere is writing to memory, but it specifically says “Function argument passed to autodiff cannot be proven readonly”, yet I’m sure none of my code in loss modifies any of its arguments, especially not Wmid or Wctx that contain the model’s parameters.

I reduced the time a bit using

using Statistics,LinearAlgebra
import Random, Zygote,Enzyme

logistic_sigmoid(x::Real) = 1 / (1 + exp(-x))

function loss(
	Wmid, Wctx,
	tok_mid, tok_ctx, x
)
	nll = -mean(
		let
			pij = logistic_sigmoid(@view(Wmid[i, :]) ⋅ @view(Wctx[j, :]))
			nll = Xij * log(pij) + (1 - Xij) * log(1 - pij)
		end
		for (i, j, Xij) in zip(tok_mid, tok_ctx, x)
	)
	# @info "loss: $nll"
	nll
end

function train(rng, nvocab, nsamples, ndim)
	_loss(Wmid, Wctx) = loss(
		reshape(Wmid, nvocab, ndim), reshape(Wctx, nvocab, ndim),
		tok_mid, tok_ctx, x
	)
	ntrues = nsamples ÷ 2

	tok_mid = rand(rng, 1:nvocab, nsamples)
	tok_ctx = rand(rng, 1:nvocab, nsamples)
	x = [trues(ntrues); falses(nsamples - ntrues)]

	weights_mid = 0.1f0.*randn(rng, Float32, nvocab * ndim)
	weights_ctx = 0.1f0.*randn(rng, Float32, nvocab * ndim)
	# @info "Number of parameters:" size(weights_mid) size(weights_ctx) total=(length(weights_mid) + length(weights_ctx))
	# _loss(weights_mid, weights_ctx)
	# @timev _loss(weights_mid, weights_ctx)
    # @timev _loss(weights_mid, weights_ctx)
    dweights_mid = Enzyme.make_zero(weights_mid)
    # Enzyme.autodiff(Enzyme.Reverse,_loss,Enzyme.Const,Enzyme.Duplicated(weights_mid,dweights_mid),Enzyme.Const(weights_ctx))
	Zygote.gradient(
		Wmid -> _loss(Wmid, weights_ctx), weights_mid
	)[1]
    dweights_mid
end

grad = @timev train(Random.Xoshiro(5), 100277, 100, 2)
@profview train(Random.Xoshiro(5), 100277, 100, 2)
@show size(grad)
@info "Gradient sum:" sum(grad)

giving 0.18s second run (still bad aginst python but better), Enzyme migh help a lot however it doesn’t like closure ( the _loss function) you would have to make another function taking all needed arguments and then use the autodiff way I commented). If you want compilation through MLIR (which we will never beat by disign and is what jax uses) you need to be on linux or mac and use Reactant.jl to compile it

1 Like

Key point here :point_up: See Advanced tutorial · DifferentiationInterface.jl for how to avoid the closure when using Enzyme through DifferentiationInterface.

2 Likes

I get 43 ms with @yolhan_mannes 's code (Zygote), if I time it more carefully.

One idea to reduce indexing is to call eachrow, here that cuts memory by 7x but doesn’t in fact add speed:

julia> grad = @time train(Random.Xoshiro(5), 100277, 100, 2);  # second run
  0.203837 seconds (447.67 k allocations: 332.314 MiB, 22.56% gc time, 59.60% compilation time)

julia> grad = @btime train(Random.Xoshiro(5), 100277, 100, 2);  # better benchmark
  43.267 ms (68154 allocations: 314.30 MiB)

julia> function loss(
               Wmid, Wctx,
               tok_mid, tok_ctx, x
       )
               nll = -mean([
                       let
                               pij = logistic_sigmoid(dot(Wmidi, Wctxj))
                               nll = Xij * log(pij) + (1 - Xij) * log(1 - pij)
                       end
                       for (Wmidi, Wctxj, Xij) in zip(eachrow(Wmid)[tok_mid], eachrow(Wctx)[tok_ctx], x)
               ])
               # @info "loss: $nll"
               nll
       end
loss (generic function with 1 method)

julia> grad = @btime train(Random.Xoshiro(5), 100277, 100, 2);
  72.506 ms (807163 allocations: 44.95 MiB)

julia> function loss(
               Wmid, Wctx,
               tok_mid, tok_ctx, x
       )
         Wmidrows = eachrow(Wmid)[tok_mid]
         Wctxrows = eachrow(Wctx)[tok_ctx]
         acc = 0f0
         for k in eachindex(x)
           pij = logistic_sigmoid(dot(Wmidrows[k], Wctxrows[k]))
           Xij = x[k]
           acc += Xij * log(pij) + (1 - Xij) * log(1 - pij)
         end
         -acc / length(x)
       end
loss (generic function with 1 method)

julia> grad = @btime train(Random.Xoshiro(5), 100277, 100, 2);
  77.063 ms (815180 allocations: 48.09 MiB)
2 Likes

Slightly modified code with Enzyme and DifferentiationInterface

using Statistics; import Random
using DifferentiationInterface
import Zygote, Enzyme

const AV{T} = AbstractVector{T}
const AM{T} = AbstractMatrix{T}

logistic_sigmoid(x::Real) = 1 / (1 + exp(-x))

function loss(
	Wmid::AV{<:Real}, Wctx::AV{<:Real},
	nvocab::Integer, ndim::Integer,
	tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
)
	Wmid, Wctx = reshape(Wmid, nvocab, ndim), reshape(Wctx, nvocab, ndim)

	nll_one((i, j, Xij)) = begin
		pij = @views logistic_sigmoid(Wmid[i, :]' * Wctx[j, :])
		Xij * log(pij) + (1 - Xij) * log(1 - pij)
	end

	nll = -mean(nll_one.(zip(tok_mid, tok_ctx, x)))
	@info "loss: $nll"
	nll
end

function train(rng, nvocab::Integer, nsamples::Integer, ndim::Integer, backend)
	ntrues = nsamples ÷ 2

	tok_mid = rand(rng, 1:nvocab, nsamples)
	tok_ctx = rand(rng, 1:nvocab, nsamples)
	x = [trues(ntrues); falses(nsamples - ntrues)]

	weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
	weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
	@info "Number of parameters:" size(weights_mid) size(weights_ctx) total=(length(weights_mid) + length(weights_ctx))

	dweights_mid = similar(weights_mid)
	@info "Computing gradient..." backend
	@timev gradient!(
		loss,
		dweights_mid, backend,
		weights_mid, Constant(weights_ctx), Constant(nvocab), Constant(ndim),
		Constant(tok_mid), Constant(tok_ctx), Constant(x)
	)
	dweights_mid
end

grad = train(Random.Xoshiro(5), 100277, 100, 2, AutoZygote())
@show size(grad)
@info "Gradient sum:" sum(grad)

println("\n\n2nd run!!")
grad = train(Random.Xoshiro(5), 100277, 100, 2, AutoZygote())
@show size(grad)
@info "Gradient sum:" sum(grad)
  1. I changed this line weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim) to use a Float32 literal 0.1f0. As pointed out by @yolhan_mannes, my original code used 0.1 which switched all computations to Float64.
  2. I removed all closures except nll_one in loss and used Constant in the call to gradient!.

Results

> julia --project test_grad.jl
┌ Info: Number of parameters:
│   size(weights_mid) = (200554,)
│   size(weights_ctx) = (200554,)
└   total = 401108
┌ Info: Computing gradient...
└   backend = AutoZygote()
[ Info: loss: 0.69258314
231.698408 seconds (7.15 M allocations: 656.995 MiB, 0.15% gc time, 99.95% compilation time)
elapsed time (ns):  2.31698408198e11
gc time (ns):       352916979
bytes allocated:    688909032
pool allocs:        7142087
non-pool GC allocs: 346
malloc() calls:     4144
free() calls:       5145
minor collections:  42
full collections:   1
size(grad) = (200554,)
┌ Info: Gradient sum:
└   sum(grad) = -0.018946057f0


2nd run!!
┌ Info: Number of parameters:
│   size(weights_mid) = (200554,)
│   size(weights_ctx) = (200554,)
└   total = 401108
┌ Info: Computing gradient...
└   backend = AutoZygote()
[ Info: loss: 0.69258314
  0.090889 seconds (5.10 k allocations: 305.902 MiB, 49.59% gc time)
elapsed time (ns):  9.0888894e7
gc time (ns):       45070100
bytes allocated:    320761112
pool allocs:        4598
non-pool GC allocs: 104
malloc() calls:     402
free() calls:       717
minor collections:  43
full collections:   0
size(grad) = (200554,)
┌ Info: Gradient sum:
└   sum(grad) = -0.018946057f0
  1. The first run takes 231.698408 seconds (99.95% compilation time).
  2. The second run takes 0.090889 seconds. This is 2549 times faster than the first run, but still 0.090889 / 0.0064 == 14 times slower than JAX.
  3. Thus, the vast majority of the time is spent compiling something when evaluating the gradient for the first time.
3 Likes

I haven’t checked this carefully, but I think that we-writing this not to make slices at all can get us to 2ms:

julia> function loss(
               Wmid, Wctx,
               tok_mid, tok_ctx, x
       )
       tmp = sum(Wmid[tok_mid, :] .* Wctx[tok_ctx, :]; dims=2) |> vec
       -mean(@. x * log(logistic_sigmoid(tmp)) + (1 - x) * log(1 - logistic_sigmoid(tmp)))
       end
loss (generic function with 1 method)

julia> grad = @btime train(Random.Xoshiro(5), 100277, 100, 2);
  1.892 ms (274 allocations: 4.62 MiB)

julia> using Tullio  # first way I thought of

julia> function loss(
               Wmid, Wctx,
               tok_mid, tok_ctx, x
       )
       @tullio tmp[k] := Wmid[tok_mid[k], c] * Wctx[tok_ctx[k], c]  # sum over c
       -mean(@. x * log(logistic_sigmoid(tmp)) + (1 - x) * log(1 - logistic_sigmoid(tmp)))  # sum over k
       end
loss (generic function with 1 method)

julia> grad = @btime train(Random.Xoshiro(5), 100277, 100, 2);
  1.916 ms (348 allocations: 4.62 MiB)

All definitions of loss give me zero gradient, so there could be mistakes.

1 Like

Can confirm: I’m now getting 3 seconds for the first run and 0.29 seconds for the second run with Zygote using zip(eachrow(Wmid)[tok_mid], eachrow(Wctx)[tok_ctx], x).

The moment I delete the square brackets here:

nll = -mean([ # here
	let
		pij = logistic_sigmoid(dot(Wmidi, Wctxj))
		nll = Xij * log(pij) + (1 - Xij) * log(1 - pij)
	end
	for (Wmidi, Wctxj, Xij) in zip(eachrow(Wmid)[tok_mid], eachrow(Wctx)[tok_ctx], x)
]) # and here

…the comprehension creates a closure (AFAIK) and I get 244 seconds with 99.9% compilation time. Add brackets (thus constructing an unnecessary intermediate vector that’s immediately consumed by mean) - get 3 seconds (81 times speedup).

So looks like closures (including comprehensions) and indexing are evil.

This code actually takes 375 or 383 seconds for me (slower than my original code).

It’s just the first run, however, after restarting Julia, loading all packages and running loss for the first time. 4 subsequent runs are way faster at about 0.09 seconds each (still 14 times slower than JAX tho).


This code, on the other hand, takes only 13 seconds (29 times faster than the above version). Subsequent runs take around 0.09 seconds.

So something fishy’s going on during that first run, I’m getting 99.92% compilation time for all versions of the code.


Measurement code as in `Zygote.gradient` is 54000 TIMES slower than `jax.gradient` - #7 by ForceBru

@timev gradient!(
  loss,
  dweights_mid, backend,
  weights_mid, Constant(weights_ctx), Constant(nvocab), Constant(ndim),
  Constant(tok_mid), Constant(tok_ctx), Constant(x)
)

Please try timing these with BenchmarkTools.jl as above, @btime. Your time is 45x slower than my laptop on battery, so I suspect it’s not accurate. Or else time much longer runs. (Have not tried Jax locally.)

Edit – or maybe I’m lost as to what version is what. The version with sum(Wmid[tok_mid, :] .* Wctx[tok_ctx, :]; dims=2) is fastest for me, 1.2 ms now.

Both ways make a closure. As you say, the one with the square brackets is less lazy, but sometimes that’s better.

Writing sum(Wmid[tok_mid, :] .* Wctx[tok_ctx, :]; dims=2) allocates even bigger arrays, but they are contiguous. And doing so gives Zygote far fewer operations to think about. The total memory allocated ends up lower.

This makes me rethink adding a Reactant.jl wrapper to DifferentiationInterface.jl. Sure, this would only compile the gradient step and not, say, an entire training loop, but it could still be worth it in such cases.

1 Like

I’m down to 1.5ms, Enzyme has issues with reshape on 1.11 I think (maybe on 1.10 too didn’t try) so I used Mooncake

using Statistics; import Random
using DifferentiationInterface,Tullio,LinearAlgebra,BenchmarkTools
import Zygote, Enzyme, Mooncake

const AV{T} = AbstractVector{T}
const AM{T} = AbstractMatrix{T}

@fastmath logistic_sigmoid(x) = 1.0f0 / (1 + exp(-x))

function loss(
	Wmid, Wctx,
	nvocab, ndim,
	tok_mid, tok_ctx, x
)

    Wmid = reshape(Wmid, nvocab, ndim)
    Wctx = reshape(Wctx, nvocab, ndim)
    nll = 0.0f0
    @inbounds @fastmath for c in eachindex(x)
        Xij = x[c]
        dotprod = dot(@view(Wmid[tok_mid[c],:]),@view(Wctx[tok_ctx[c],:])) 
        nll += Xij * log(logistic_sigmoid(dotprod))
        nll += (1 - Xij) * log(1 - logistic_sigmoid(dotprod))
    end
    nll = nll / length(x)
	nll
end

function train(rng, nvocab::Integer, nsamples::Integer, ndim::Integer, backend)
	ntrues = nsamples ÷ 2

	tok_mid = rand(rng, 1:nvocab, nsamples)
	tok_ctx = rand(rng, 1:nvocab, nsamples)
	x = [trues(ntrues); falses(nsamples - ntrues)]

	weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
	weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
	# @info "Number of parameters:" size(weights_mid) size(weights_ctx) total=(length(weights_mid) + length(weights_ctx))

    # @info loss(weights_mid, weights_ctx, nvocab, ndim, tok_mid, tok_ctx, x)

	dweights_mid = Enzyme.make_zero(weights_mid)
	# @info "Computing gradient..." backend
    gradient!(loss,dweights_mid,backend,weights_mid,Constant(weights_ctx),Constant(nvocab),Constant(ndim),Constant(tok_mid),Constant(tok_ctx),Constant(x))
	dweights_mid
end

rng = Random.Xoshiro(5)
bck = AutoMooncake(config=nothing)
@btime train($rng, 100277, 100, 2,$bck )
2 Likes

Yes, but still the big equation, is Reactant somewhat like an interface or is it more like CuArray :cry:

I was thinking I could define a magical AutoReactant(AutoEnzyme()) which compiles every gradient call during the preparation step DI.prepare_gradient

Oh maybe it would have the same limitation as ForwardDiff then (need to be open minded on the input of the function) + the no-windows part obviously, I’m not sure Reactant will ever use something else than Enzyme so I don’t know if it’s important to have AutoEnzyme here beside user comprehension

LMAO, well, I calculated the mean time instead of the minimum and got 78.884 ms. Benchmark code:

function loss_mean(
		Wmid, Wctx,
		tok_mid, tok_ctx, x
)
	Wmidrows = @inbounds eachrow(Wmid)[tok_mid]
	Wctxrows = @inbounds eachrow(Wctx)[tok_ctx]

	-mean([
		let
			pij = logistic_sigmoid(dot(Wmidi, Wctxj))
			nll = Xij * log(pij) + (1 - Xij) * log(1 - pij)
		end
		for (Wmidi, Wctxj, Xij) in zip(Wmidrows, Wctxrows, x)
	])
end

using BenchmarkTools
@show mean(
  @benchmark let
    train(Random.Xoshiro(5), loss_mean, 100277, 100, 2, AutoZygote())
  end
)

I also removed all printing (@info and the like) from train.


  1. This code (with the list/vector comprehension) is the fastest at startup (13 seconds vs 300+ seconds) with Zygote.
  2. All codes are almost equally fast (within a second) during subsequent runs.
  3. Enzyme says “constant memory is stored or returned to a differentiable variable” when executing Wctxrows = @inbounds eachrow(Wctx)[tok_ctx]. The stacktrace shows there’s a setindex! call at this line. Yes, I’m using Constant(weights_ctx) in the call to gradient! because I’m not differentiating wrt weights_ctx. Enzyme doesn’t seem to think so, though
1 Like

Well, the fast ones are under 2ms, the slow ones are at 80ms. All less than a second but presumably you want longer runs. (Using the mean or min is fine.)

The fast ones (<2ms) are my Zygote here and yolhan’s Mooncake here. It would be interesting to have the times for these on your computer, to better compare to Jax times.

Agree startup times vary wildly but I haven’t recoreded them.

FInally, made it work with Enzyme (I needed context, so I used Enzyme.autodiff directly)
We get 984.600 μs !!! :slight_smile: :slight_smile: :slight_smile: 6X over jax without Reactant, imagine with it
981 with fastmath

using Statistics; import Random
using DifferentiationInterface,Tullio,LinearAlgebra,BenchmarkTools
import Zygote, Enzyme, Mooncake

const AV{T} = AbstractVector{T}
const AM{T} = AbstractMatrix{T}
logistic_sigmoid(x) = 1.0f0 / (1 + exp(-x))

function loss(
	Wmid, Wctx,
	nvocab, ndim,
	tok_mid, tok_ctx, x
)

    Wmid = reshape(Wmid, nvocab, ndim)
    Wctx = reshape(Wctx, nvocab, ndim)
    nll ::Float32 = 0.0f0
    @inbounds for c in eachindex(x)
        Xij = x[c]
        dotprod = dot(@view(Wmid[tok_mid[c],:]),@view(Wctx[tok_ctx[c],:])) 
        nll += Xij * log(logistic_sigmoid(dotprod))
        nll += (1 - Xij) * log(1 - logistic_sigmoid(dotprod))
    end
    nll = nll / length(x)
	nll
end

function train(rng, nvocab::Integer, nsamples::Integer, ndim::Integer, backend)
	ntrues = nsamples ÷ 2

	tok_mid = rand(rng, 1:nvocab, nsamples)
	tok_ctx = rand(rng, 1:nvocab, nsamples)
	x = [trues(ntrues); falses(nsamples - ntrues)]

	weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
	weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim)

	dweights_mid = Enzyme.make_zero(weights_mid)
    dweights_ctx = Enzyme.make_zero(weights_ctx)
	# @info "Computing gradient..." backend

    Enzyme.autodiff(Enzyme.Reverse, loss, Enzyme.Duplicated(weights_mid, dweights_mid), Enzyme.DuplicatedNoNeed(weights_ctx, dweights_ctx), Enzyme.Const(nvocab), Enzyme.Const(ndim), Enzyme.Const(tok_mid), Enzyme.Const(tok_ctx), Enzyme.Const(x))
	dweights_mid
end

rng = Random.Xoshiro(5)
# bck = AutoMooncake(config=nothing)
bckenz = AutoEnzyme()
@btime train($rng, 100277, 100, 2,$bckenz )
9 Likes

What do you mean by “I needed context”? It would be useful for me to improve the DI wrappers.

I also made it work with DI and AutoEnzyme(mode=Enzyme.set_runtime_activity(Enzyme.Reverse)), but it seems you managed without runtime activity.