AD with Enzyme.jl: "task switch not allowed from inside staged nor pure functions"

Hi all,

I am trying to use Enzyme.jl to compute a gradient. I want to use gradient descent to find the optimal parameters to reproduce the spectrum of a family of Hamiltonians (i.e. just 2x2 hermitian matrices). Here is my code, (slimmed down to be standalone, hence the magic numbers everywhere):

using Enzyme

function H(kfrac, Δ)
    γ = 1 + exp(-2im*π * kfrac[1]) + exp(-2im*π * kfrac[2])
    [Δ γ;
     γ' -Δ]
end

function eigvals2d(H)
    # only work for 2x2 matrices
    # @assert size(H) == (2, 2)
    # assume H is hermitian; if not UB
    a = H[1, 1]
    b = H[1, 2]
    d = H[2, 2]

    Δ = (d-a)^2 + 4abs2(b)

    (a+d)/2 .+ √Δ/2 * [1, -1]
end

function loss(Es, Δ)
    # compute eigenvalues with Δ
    kpath = [0.0  -0.00277778  -0.00555556  -0.00833333;
             0.0   0.00277778   0.00555556   0.00833333;
             0.0   0.0          0.0          0.0]
    
    # Espred = real.(hcat([eigvals2d(H(kpath[:, k], Δ)) for k in 1:4]...))
    Espred = zeros(2, 4)
    for k in 1:4
        Espred[:, k] = eigvals2d(H(kpath[:, k], Δ))
    end

    # return mean square error
    err = reshape(Es, :) - reshape(Espred, :)
    sum(err .* err) / length(err)
end

Δ = 0.5
Es = [3.0413812651491097 3.0410808004168115 3.0401795026932263 3.0386776613775344; 
    -3.0413812651491097 -3.0410808004168115 -3.0401795026932263 -3.0386776613775344]

@show loss(Es, Δ)
autodiff(ReverseWithPrimal, loss, Active, Const(Es), Active(Δ))

Here is the error I get:

task switch not allowed from inside staged nor pure functions

Stacktrace:
  [1] try_yieldto(undo::typeof(Base.ensure_rescheduled))
    @ Base ./task.jl:767
  [2] wait()
    @ Base ./task.jl:837
  [3] uv_write(s::Base.PipeEndpoint, p::Ptr{UInt8}, n::UInt64)
    @ Base ./stream.jl:992
  [4] unsafe_write(s::Base.PipeEndpoint, p::Ptr{UInt8}, n::UInt64)
    @ Base ./stream.jl:1064
  [5] unsafe_write
    @ ./io.jl:362 [inlined]
  [6] write
    @ ./strings/io.jl:244 [inlined]
  [7] print
    @ ./strings/io.jl:246 [inlined]
  [8] print(::IJulia.IJuliaStdio{Base.PipeEndpoint}, ::String, ::String, ::Vararg{String})
    @ Base ./strings/io.jl:46
  [9] println(::IJulia.IJuliaStdio{Base.PipeEndpoint}, ::String, ::Vararg{String})
    @ Base ./strings/io.jl:75
 [10] println(xs::String)
    @ Base ./coreio.jl:4
 [11] nodecayed_phis!(mod::LLVM.Module)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler/optimize.jl:205
 [12] optimize!
    @ ~/.julia/packages/Enzyme/rbuCz/src/compiler/optimize.jl:1137 [inlined]
 [13] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9262
 [14] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool) (repeats 2 times)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9830
 [15] cached_compilation
    @ ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9864 [inlined]
 [16] (::Enzyme.Compiler.var"#474#475"{DataType, UnionAll, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9921
 [17] JuliaContext(f::Enzyme.Compiler.var"#474#475"{DataType, UnionAll, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/kwkKA/src/driver.jl:58
 [18] #s289#473
    @ ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9882 [inlined]
 [19] var"#s289#473"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ABI::Any, ::Any, #unused#::Type, #unused#::Type, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
    @ Enzyme.Compiler ./none:0
 [20] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:580
 [21] autodiff(::ReverseMode{true, FFIABI}, ::Const{typeof(loss)}, ::Type{Active}, ::Const{Matrix{Float64}}, ::Vararg{Any})
    @ Enzyme ~/.julia/packages/Enzyme/rbuCz/src/Enzyme.jl:207
 [22] autodiff(::ReverseMode{true, FFIABI}, ::typeof(loss), ::Type, ::Const{Matrix{Float64}}, ::Vararg{Any})
    @ Enzyme ~/.julia/packages/Enzyme/rbuCz/src/Enzyme.jl:222
 [23] top-level scope
    @ In[1]:44

There is one more strange thing, it’s that I don’t get the same error message if I run in a Jupyter lab notebook, or directly in the REPL. The previous error message was from the notebook, and here is the REPL one:

ERROR: AssertionError: value_type(v) == nty
Stacktrace:
  [1] nodecayed_phis!(mod::LLVM.Module)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler/optimize.jl:208
  [2] optimize!
    @ ~/.julia/packages/Enzyme/rbuCz/src/compiler/optimize.jl:1137 [inlined]
  [3] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9262
  [4] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool) (repeats 2 times)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9830
  [5] cached_compilation
    @ ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9864 [inlined]
  [6] (::Enzyme.Compiler.var"#474#475"{DataType, UnionAll, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9921
  [7] JuliaContext(f::Enzyme.Compiler.var"#474#475"{DataType, UnionAll, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/kwkKA/src/driver.jl:58
  [8] #s289#473
    @ ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9882 [inlined]
  [9] var"#s289#473"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ABI::Any, ::Any, #unused#::Type, #unused#::Type, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
    @ Enzyme.Compiler ./none:0
 [10] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:580
 [11] autodiff(::ReverseMode{true, FFIABI}, ::Const{typeof(loss)}, ::Type{Active}, ::Const{Matrix{Float64}}, ::Vararg{Any})
    @ Enzyme ~/.julia/packages/Enzyme/rbuCz/src/Enzyme.jl:207
 [12] autodiff(::ReverseMode{true, FFIABI}, ::typeof(loss), ::Type, ::Const{Matrix{Float64}}, ::Vararg{Any})
    @ Enzyme ~/.julia/packages/Enzyme/rbuCz/src/Enzyme.jl:222
 [13] top-level scope
    @ REPL[9]:1

Do you know how to solve this ?
I could not find much information on how Enzyme deals with more complex AD, like for mutable data, and with array variables. As far as I understood, extra care (using Duplicate instead of Active) was required only for non constant arrays, the rest could be left as is.

I didn’t post the full standard output I get, since a lot of assembly (or LLVM IR?) is dumped, and it’s too large for a post.

Thank you for making the MWE standalone!

Could you open an issue on GitHub - EnzymeAD/Enzyme.jl: Julia bindings for the Enzyme automatic differentiator? This is a bug inside Enzyme itself.

Here is the issue for reference: "task switch not allowed from inside staged nor pure functions" · Issue #1172 · EnzymeAD/Enzyme.jl · GitHub

FastDifferentiation.jl might work for you. I believe this code computes the derivative you want, if I’m interpreting your problem correctly. I assumed Es was not constant but was a variable input, and that you were only computing the derivative with respect to Δ.

module Hamiltonian
using FastDifferentiation

function H(kfrac, Δ)
    γ = 1 + exp(-2im * π * kfrac[1]) + exp(-2im * π * kfrac[2])
    [Δ γ;
        γ' -Δ]
end

function eigvals2d(H)
    # only work for 2x2 matrices
    # @assert size(H) == (2, 2)
    # assume H is hermitian; if not UB
    a = H[1, 1]
    b = H[1, 2]
    d = H[2, 2]

    Δ = (d - a)^2 + 4abs2(b)

    (a + d) / 2 .+ √Δ / 2 * [1, -1]
end

function loss(Es, Δ)
    # compute eigenvalues with Δ
    kpath = [0.0 -0.00277778 -0.00555556 -0.00833333;
        0.0 0.00277778 0.00555556 0.00833333;
        0.0 0.0 0.0 0.0]

    # Espred = real.(hcat([eigvals2d(H(kpath[:, k], Δ)) for k in 1:4]...))
    Espred = zeros(typeof(Δ), 2, 4)
    for k in 1:4
        Espred[:, k] = eigvals2d(H(kpath[:, k], Δ))
    end

    # return mean square error
    err = reshape(Es, :) - reshape(Espred, :)
    sum(err .* err) / length(err)
end

function test_deriv()
    Es_vals = [3.0413812651491097 3.0410808004168115 3.0401795026932263 3.0386776613775344;
        -3.0413812651491097 -3.0410808004168115 -3.0401795026932263 -3.0386776613775344]

    Es = make_variables(:Es, 2, 4)
    @variables Δ

    jac = jacobian([loss(Es, Δ)], [Δ])
    fjac = make_function(jac, [vec(Es)..., Δ])
    # fjac = make_function(jac, [vec(Es)..., Δ], in_place=true) #twice as fast
    fjac([vec(Es_vals)..., 0.5])
end
export test_deriv

end # module Hamiltonian
1 Like