# AD with Enzyme.jl: "task switch not allowed from inside staged nor pure functions"

Hi all,

I am trying to use Enzyme.jl to compute a gradient. I want to use gradient descent to find the optimal parameters to reproduce the spectrum of a family of Hamiltonians (i.e. just 2x2 hermitian matrices). Here is my code, (slimmed down to be standalone, hence the magic numbers everywhere):

using Enzyme

function H(kfrac, Δ)
γ = 1 + exp(-2im*π * kfrac[1]) + exp(-2im*π * kfrac[2])
[Δ γ;
γ' -Δ]
end

function eigvals2d(H)
# only work for 2x2 matrices
# @assert size(H) == (2, 2)
# assume H is hermitian; if not UB
a = H[1, 1]
b = H[1, 2]
d = H[2, 2]

Δ = (d-a)^2 + 4abs2(b)

(a+d)/2 .+ √Δ/2 * [1, -1]
end

function loss(Es, Δ)
# compute eigenvalues with Δ
kpath = [0.0  -0.00277778  -0.00555556  -0.00833333;
0.0   0.00277778   0.00555556   0.00833333;
0.0   0.0          0.0          0.0]

# Espred = real.(hcat([eigvals2d(H(kpath[:, k], Δ)) for k in 1:4]...))
Espred = zeros(2, 4)
for k in 1:4
Espred[:, k] = eigvals2d(H(kpath[:, k], Δ))
end

# return mean square error
err = reshape(Es, :) - reshape(Espred, :)
sum(err .* err) / length(err)
end

Δ = 0.5
Es = [3.0413812651491097 3.0410808004168115 3.0401795026932263 3.0386776613775344;
-3.0413812651491097 -3.0410808004168115 -3.0401795026932263 -3.0386776613775344]

@show loss(Es, Δ)
autodiff(ReverseWithPrimal, loss, Active, Const(Es), Active(Δ))


Here is the error I get:

task switch not allowed from inside staged nor pure functions

Stacktrace:
[1] try_yieldto(undo::typeof(Base.ensure_rescheduled))
[2] wait()
[3] uv_write(s::Base.PipeEndpoint, p::Ptr{UInt8}, n::UInt64)
@ Base ./stream.jl:992
[4] unsafe_write(s::Base.PipeEndpoint, p::Ptr{UInt8}, n::UInt64)
@ Base ./stream.jl:1064
[5] unsafe_write
@ ./io.jl:362 [inlined]
[6] write
@ ./strings/io.jl:244 [inlined]
[7] print
@ ./strings/io.jl:246 [inlined]
[8] print(::IJulia.IJuliaStdio{Base.PipeEndpoint}, ::String, ::String, ::Vararg{String})
@ Base ./strings/io.jl:46
[9] println(::IJulia.IJuliaStdio{Base.PipeEndpoint}, ::String, ::Vararg{String})
@ Base ./strings/io.jl:75
[10] println(xs::String)
@ Base ./coreio.jl:4
[11] nodecayed_phis!(mod::LLVM.Module)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler/optimize.jl:205
[12] optimize!
@ ~/.julia/packages/Enzyme/rbuCz/src/compiler/optimize.jl:1137 [inlined]
[13] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9262
[14] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool) (repeats 2 times)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9830
[15] cached_compilation
@ ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9864 [inlined]
[16] (::Enzyme.Compiler.var"#474#475"{DataType, UnionAll, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9921
[17] JuliaContext(f::Enzyme.Compiler.var"#474#475"{DataType, UnionAll, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})
@ GPUCompiler ~/.julia/packages/GPUCompiler/kwkKA/src/driver.jl:58
[18] #s289#473
@ ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9882 [inlined]
[19] var"#s289#473"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ABI::Any, ::Any, #unused#::Type, #unused#::Type, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
@ Enzyme.Compiler ./none:0
[20] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
@ Core ./boot.jl:580
[21] autodiff(::ReverseMode{true, FFIABI}, ::Const{typeof(loss)}, ::Type{Active}, ::Const{Matrix{Float64}}, ::Vararg{Any})
@ Enzyme ~/.julia/packages/Enzyme/rbuCz/src/Enzyme.jl:207
[22] autodiff(::ReverseMode{true, FFIABI}, ::typeof(loss), ::Type, ::Const{Matrix{Float64}}, ::Vararg{Any})
@ Enzyme ~/.julia/packages/Enzyme/rbuCz/src/Enzyme.jl:222
[23] top-level scope
@ In[1]:44


There is one more strange thing, it’s that I don’t get the same error message if I run in a Jupyter lab notebook, or directly in the REPL. The previous error message was from the notebook, and here is the REPL one:

ERROR: AssertionError: value_type(v) == nty
Stacktrace:
[1] nodecayed_phis!(mod::LLVM.Module)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler/optimize.jl:208
[2] optimize!
@ ~/.julia/packages/Enzyme/rbuCz/src/compiler/optimize.jl:1137 [inlined]
[3] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9262
[4] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool) (repeats 2 times)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9830
[5] cached_compilation
@ ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9864 [inlined]
[6] (::Enzyme.Compiler.var"#474#475"{DataType, UnionAll, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9921
[7] JuliaContext(f::Enzyme.Compiler.var"#474#475"{DataType, UnionAll, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})
@ GPUCompiler ~/.julia/packages/GPUCompiler/kwkKA/src/driver.jl:58
[8] #s289#473
@ ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9882 [inlined]
[9] var"#s289#473"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ABI::Any, ::Any, #unused#::Type, #unused#::Type, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
@ Enzyme.Compiler ./none:0
[10] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
@ Core ./boot.jl:580
[11] autodiff(::ReverseMode{true, FFIABI}, ::Const{typeof(loss)}, ::Type{Active}, ::Const{Matrix{Float64}}, ::Vararg{Any})
@ Enzyme ~/.julia/packages/Enzyme/rbuCz/src/Enzyme.jl:207
[12] autodiff(::ReverseMode{true, FFIABI}, ::typeof(loss), ::Type, ::Const{Matrix{Float64}}, ::Vararg{Any})
@ Enzyme ~/.julia/packages/Enzyme/rbuCz/src/Enzyme.jl:222
[13] top-level scope
@ REPL[9]:1


Do you know how to solve this ?
I could not find much information on how Enzyme deals with more complex AD, like for mutable data, and with array variables. As far as I understood, extra care (using Duplicate instead of Active) was required only for non constant arrays, the rest could be left as is.

I didn’t post the full standard output I get, since a lot of assembly (or LLVM IR?) is dumped, and it’s too large for a post.

Thank you for making the MWE standalone!

Could you open an issue on GitHub - EnzymeAD/Enzyme.jl: Julia bindings for the Enzyme automatic differentiator? This is a bug inside Enzyme itself.

Here is the issue for reference: "task switch not allowed from inside staged nor pure functions" · Issue #1172 · EnzymeAD/Enzyme.jl · GitHub

FastDifferentiation.jl might work for you. I believe this code computes the derivative you want, if I’m interpreting your problem correctly. I assumed Es was not constant but was a variable input, and that you were only computing the derivative with respect to Δ.

module Hamiltonian
using FastDifferentiation

function H(kfrac, Δ)
γ = 1 + exp(-2im * π * kfrac[1]) + exp(-2im * π * kfrac[2])
[Δ γ;
γ' -Δ]
end

function eigvals2d(H)
# only work for 2x2 matrices
# @assert size(H) == (2, 2)
# assume H is hermitian; if not UB
a = H[1, 1]
b = H[1, 2]
d = H[2, 2]

Δ = (d - a)^2 + 4abs2(b)

(a + d) / 2 .+ √Δ / 2 * [1, -1]
end

function loss(Es, Δ)
# compute eigenvalues with Δ
kpath = [0.0 -0.00277778 -0.00555556 -0.00833333;
0.0 0.00277778 0.00555556 0.00833333;
0.0 0.0 0.0 0.0]

# Espred = real.(hcat([eigvals2d(H(kpath[:, k], Δ)) for k in 1:4]...))
Espred = zeros(typeof(Δ), 2, 4)
for k in 1:4
Espred[:, k] = eigvals2d(H(kpath[:, k], Δ))
end

# return mean square error
err = reshape(Es, :) - reshape(Espred, :)
sum(err .* err) / length(err)
end

function test_deriv()
Es_vals = [3.0413812651491097 3.0410808004168115 3.0401795026932263 3.0386776613775344;
-3.0413812651491097 -3.0410808004168115 -3.0401795026932263 -3.0386776613775344]

Es = make_variables(:Es, 2, 4)
@variables Δ

jac = jacobian([loss(Es, Δ)], [Δ])
fjac = make_function(jac, [vec(Es)..., Δ])
# fjac = make_function(jac, [vec(Es)..., Δ], in_place=true) #twice as fast
fjac([vec(Es_vals)..., 0.5])
end
export test_deriv

end # module Hamiltonian

1 Like