Enzyme throws `alwaysinline mustprogress` when using GraphNeuralNetworks to mutate Graphs

I am trying to use GraphNeuralNetworks with Flux in a Deep RL style problem. What I want to achieve by using GraphNeuralNetworks is to obtain new weights for a modifed graph that changes every epoch, and run a simulation with the obtained graph. Conceptually, what I’m trying to do is similar to obtaining a molecule, running a simulation with that molecule every epoch, and then train the model to minimize the loss from that simulation

I’ve debugged several issues with type instability in my simulation and loss calculations, but I seem to make no progress. I keep getting errors like

ERROR: AssertionError: ; Function Attrs: alwaysinline mustprogress

define internal "enzymejl_parmtype"="4795081968" "enzymejl_parmtype_ref"="0" void @diffejulia__21_59758_inner.1({ { {} addrspace(10)* }, { [3 x {} addrspace(10)*], (more IR ommited)

The last call in the stack is always

 [1] create_abi_wrapper(enzymefn::LLVM.Function, TT::Type, rettype::Type, actualRetType::Type, Mode::Enzyme.API.CDerivativeMode, augmented::Ptr{…}, width::Int64, returnPrimal::Bool, shadow_init::Bool, world::UInt64, interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{…})

I’m omitting the lowered IR included in the stacktrace. Here’s a MWE that reproduces this bug

using Flux, GraphNeuralNetworks, SparseArrays, Random, Enzyme

# Dummy graph construction
# Number of nodes and dimensions
n = 100
D = 2

A = sprand(Float32, n, n, 0.2) # Adjacency matrix
X = rand(Float32, D, n) # Node features. Two features per node
g = GNNGraph(A; ndata=(features = X))

# Model definition (same as your GraphPolicy)
struct GraphPolicy
    gnn::GNNChain
    edge_predictor::DotDecoder
end

Flux.@layer GraphPolicy

function GraphPolicy(nin::Int, d::Int)
    return GraphPolicy(GNNChain(GCNConv(nin => d, tanh),
                                GCNConv(d => d, relu)),
                       DotDecoder())
end

function (model::GraphPolicy)(g::GNNGraph, x)
    x = model.gnn(g, x)
    new_weights = model.edge_predictor(g, x)
    proposed_A = Flux.ignore_derivatives() do
        set_edge_weight(g, vec(new_weights))
        return adjacency_matrix(g)
    end
    return SparseMatrixCSC{eltype(new_weights)}(proposed_A)
end

# Dummy loss function: just modifies X and Z, returns a Float32
function simulate_loss(X, Z, A, action; Nt=10, dt=0.01)
    Reward = 0.0f0
    
    # This is where the simulation loop would go
    for t in 1:Nt
        # Directly modify X and Z as step! would
        X .+= dt * randn(Float32, size(X))
        Z .+= dt * randn(Float32, size(Z))
        property_of_A = sum(A)  # Dummy operation on A
        # Reward depends on the old A, the new A, and node features
        Reward -= dt * property_of_A + sum(abs2, X) + sum(abs2, Z) + sum(abs2, A - action)
    end
    return Reward
end

# Setup model and optimizer
model = GraphPolicy(2, 50)
opt = Flux.setup(Adam(1.0f-3), model)
model_emb = WithGraph(model, g) |> Duplicated

# Dummy Z. Z is not part of the node features, but it's used in the simulation
Z = rand(Float32, n, 4)

# Training loop (should trigger the bug)
for epoch in 1:10
    loss, grads = Flux.withgradient(model_emb) do m
        prop_A = m(g.ndata.x) # Returns a SparseMatrix as simulate_loss expects
        simulate_loss(X, Z, A, prop_A)
    end
    Flux.update!(opt, model, grads[1])
end

In case the IR in the stacktrace is interesting:

entry:
  %"'de" = alloca double, align 8
  %1 = getelementptr double, double* %"'de", i64 0
  store double 0.000000e+00, double* %1, align 8
  %2 = call {}*** @julia.get_pgcstack() #12, !dbg !122
  %3 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 13, !dbg !123
  br i1 false, label %err.i, label %ok.i, !dbg !123

err.i:                                            ; preds = %entry
  unreachable

ok.i:                                             ; preds = %entry
  %current_task1.i1 = getelementptr inbounds {}**, {}*** %2, i64 -14
  %4 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 8, !dbg !123
  %"'ip_phi" = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 9, !dbg !123
  %5 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 10, !dbg !123
  %6 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 5, !dbg !123
  %"'ip_phi1" = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 6, !dbg !123
  %7 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 7, !dbg !123
  %8 = bitcast {}*** %current_task1.i1 to {}*, !dbg !123
  %"'mi" = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 3, !dbg !123
  %9 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 4, !dbg !123
  %10 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 0, !dbg !123
  %"'ip_phi3" = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 1, !dbg !123
  %11 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 2, !dbg !123
  %12 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 14, !dbg !125
  br i1 false, label %err3.i, label %ok4.i, !dbg !125

err3.i:                                           ; preds = %ok.i
  unreachable

ok4.i:                                            ; preds = %ok.i
  %13 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 15, !dbg !125
  br i1 false, label %err5.i, label %ok6.i, !dbg !125

err5.i:                                           ; preds = %ok4.i
  unreachable

ok6.i:                                            ; preds = %ok4.i
  %14 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 16, !dbg !125
  br i1 false, label %err7.i, label %julia__23_59836_inner.exit, !dbg !125

err7.i:                                           ; preds = %ok6.i
  unreachable

julia__23_59836_inner.exit:                       ; preds = %ok6.i
  %15 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 11, !dbg !125
  %"'ip_phi7" = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 12, !dbg !125
  br i1 true, label %ret, label %fail, !dbg !122

ret:                                              ; preds = %julia__23_59836_inner.exit
  %"'ipc" = bitcast {} addrspace(10)* %"'ip_phi7" to double addrspace(10)*, !dbg !122
  %"'ipc13" = addrspacecast double addrspace(10)* %"'ipc" to double addrspace(11)*, !dbg !122
  br label %invertret, !dbg !122

fail:                                             ; preds = %julia__23_59836_inner.exit
  unreachable

invertentry:                                      ; preds = %invertok.i
  ret void

invertok.i:                                       ; preds = %invertok4.i
  %16 = call {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)*, {} addrspace(10)*, ...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* @ijl_apply_generic, {} addrspace(10)* @ejl_enz_runtime_generic_rev, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 4532595088 to {}*) to {} addrspace(10)*), {} addrspace(10)* @ejl_enz_val_false, {} addrspace(10)* @ejl_enz_val_false, {} addrspace(10)* @ejl_enz_val_1, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 14204484080 to {}*) to {} addrspace(10)*), {} addrspace(10)* %10, {} addrspace(10)* %9, {} addrspace(10)* %"'mi", {} addrspace(10)* %7, {} addrspace(10)* %"'ip_phi1"), !dbg !123
  %17 = call {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)*, {} addrspace(10)*, ...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* @ijl_apply_generic, {} addrspace(10)* @ejl_enz_runtime_generic_rev, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 14148829904 to {}*) to {} addrspace(10)*), {} addrspace(10)* @ejl_enz_val_false, {} addrspace(10)* @ejl_enz_val_false, {} addrspace(10)* @ejl_enz_val_1, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 14204088016 to {}*) to {} addrspace(10)*), {} addrspace(10)* %6, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 4758987648 to {}*) to {} addrspace(10)*), {} addrspace(10)* @ejl_jl_nothing, {} addrspace(10)* %5, {} addrspace(10)* %"'ip_phi", {} addrspace(10)* addrspacecast ({}* inttoptr (i64 4381468264 to {}*) to {} addrspace(10)*), {} addrspace(10)* @ejl_jl_nothing), !dbg !123
  %18 = call {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)*, {} addrspace(10)*, ...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* @ijl_apply_generic, {} addrspace(10)* @ejl_enz_runtime_generic_rev, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 14148830288 to {}*) to {} addrspace(10)*), {} addrspace(10)* @ejl_enz_val_false, {} addrspace(10)* @ejl_enz_val_false, {} addrspace(10)* @ejl_enz_val_1, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 14204088016 to {}*) to {} addrspace(10)*), {} addrspace(10)* %4, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 4758987648 to {}*) to {} addrspace(10)*), {} addrspace(10)* @ejl_jl_nothing, {} addrspace(10)* %3, {} addrspace(10)* @ejl_jl_nothing, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 6284084368 to {}*) to {} addrspace(10)*), {} addrspace(10)* @ejl_jl_nothing), !dbg !123
  br label %invertentry

invertok4.i:                                      ; preds = %invertok6.i
  br label %invertok.i

invertok6.i:                                      ; preds = %invertjulia__23_59836_inner.exit
  br label %invertok4.i

invertjulia__23_59836_inner.exit:                 ; preds = %invertret
  %19 = call {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)*, {} addrspace(10)*, ...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* @ijl_apply_generic, {} addrspace(10)* @ejl_enz_runtime_generic_rev, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 14148830736 to {}*) to {} addrspace(10)*), {} addrspace(10)* @ejl_enz_val_false, {} addrspace(10)* @ejl_enz_val_false, {} addrspace(10)* @ejl_enz_val_1, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 14204089144 to {}*) to {} addrspace(10)*), {} addrspace(10)* %15, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 6284086232 to {}*) to {} addrspace(10)*), {} addrspace(10)* @ejl_jl_nothing, {} addrspace(10)* %12, {} addrspace(10)* @ejl_jl_nothing, {} addrspace(10)* %13, {} addrspace(10)* @ejl_jl_nothing, {} addrspace(10)* %14, {} addrspace(10)* @ejl_jl_nothing, {} addrspace(10)* %11, {} addrspace(10)* %"'ip_phi3"), !dbg !125
  br label %invertok6.i

invertret:                                        ; preds = %ret
  store double %differeturn, double* %"'de", align 8
  %20 = load double, double* %"'de", align 8, !dbg !122
  store double 0.000000e+00, double* %"'de", align 8, !dbg !122
  %21 = atomicrmw fadd double addrspace(11)* %"'ipc13", double %20 monotonic, align 8, !dbg !122
  br label %invertjulia__23_59836_inner.exit
}

Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = Float64, literal_rt = Any, rettype = Active{Any}, sret_union=false, pactualRetType=Float64

Stacktrace:
  [1] create_abi_wrapper(enzymefn::LLVM.Function, TT::Type, rettype::Type, actualRetType::Type, Mode::Enzyme.API.CDerivativeMode, augmented::Ptr{…}, width::Int64, returnPrimal::Bool, shadow_init::Bool, world::UInt64, interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{…})

Just to confirm, what version of Enzyme are you using?

This error looks like one we fixed a while ago

Sorry, should have mentioned it from the get-go

  [7da242da] Enzyme v0.13.56

I think it might be pinned by other package, as I work on several computers and whenever I switch and instantiate they complain of unsatisfiable versions of Enzyme

Sorry for double posting, but I found out that the bug can even result in a segfault. I simplified the MWE even further.

using Flux, GraphNeuralNetworks, SparseArrays, Random, Enzyme

# Dummy graph construction
# Number of nodes and dimensions
n = 100
D = 2

A = sprand(Float32, n, n, 0.2) # Adjacency matrix
X = rand(Float32, D, n) # Node features. Two features per node
g = GNNGraph(A; ndata=(features = X))

# Model definition
struct GraphPolicy
    gnn::GNNChain
    edge_predictor::DotDecoder
end

Flux.@layer GraphPolicy

function GraphPolicy(nin::Int, d::Int)
    return GraphPolicy(GNNChain(GCNConv(nin => d, tanh),
                                GCNConv(d => d, relu)),
                       DotDecoder())
end

function (model::GraphPolicy)(g::GNNGraph, x)
    x = model.gnn(g, x)
    new_weights = model.edge_predictor(g, x)
    return new_weights
end

# Extremely simplified loss
function simulate_loss(model, X, A, g)
    Reward = 0.0f0
    weights = model(g, g.ndata.x)

    for t in 1:100
        # Reward -= sum(abs2, X) # With this line it throws alwaysinline mustprogress
        Reward -= rand() # With this line it segfaults
    end
    return Reward
end

# Setup model and optimizer
model = GraphPolicy(2, 50) |> Duplicated
opt = Flux.setup(Adam(1.0f-3), model)

# Training loop triggers the bug
for epoch in 1:10
    @show epoch
    loss, grads = Flux.withgradient(model -> simulate_loss(model, X, A, g), model)
    Flux.update!(opt, model, grads[1])
end

The simulate_loss function does nothing out of the ordinary. It calculates the forward pass of the model (just a simple GNNChain, no embedded graph), and returns an accumulated Float.

As the MWE says, this results in a segmentation fault when the line

Reward -= rand()

is used. While the other option for accumulating the reward

Reward -= sum(abs2, X)

Doesn’t segfault but produces the original bug

ERROR: AssertionError: ; Function Attrs: alwaysinline mustprogress
define internal "enzymejl_parmtype"="140675577665680" "enzymejl_parmtype_ref"="0" void @diffejulia__5_35361_inner.1({ {} addrspace(10)* } "enzyme_type"="{[-1]:Pointer}" "enzymejl_parmtype"="140673312915600" "enzymejl_parmtype_ref"="0" %0, { {} addrspace(10)* } "enzyme_type"="{[-1]:Pointer}" "enzymejl_parmtype"="140673312915600" "enzymejl_parmtype_ref"="0" %"'", double %differeturn, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg) local_unnamed_addr #10 !dbg !109 {
entry:
  %"'de" = alloca double, align 8
  %1 = getelementptr double, double* %"'de", i64 0
  store double 0.000000e+00, double* %1, align 8
  %2 = call {}*** @julia.get_pgcstack() #12, !dbg !110
  %current_task1.i1 = getelementptr inbounds {}**, {}*** %2, i64 -14
  %3 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 4, !dbg !111
  br i1 false, label %err.i, label %ok.i, !dbg !111

err.i:                                            ; preds = %entry
  unreachable

ok.i:                                             ; preds = %entry
  %4 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 5, !dbg !111
  br i1 false, label %err2.i, label %ok3.i, !dbg !111

err2.i:                                           ; preds = %ok.i
  unreachable

ok3.i:                                            ; preds = %ok.i
  %5 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 6, !dbg !111
  br i1 false, label %err4.i, label %julia__5_35361_inner.exit, !dbg !111

err4.i:                                           ; preds = %ok3.i
  unreachable

julia__5_35361_inner.exit:                        ; preds = %ok3.i
  %6 = bitcast {}*** %current_task1.i1 to {}*, !dbg !111
  %"'mi" = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 2, !dbg !111
  %7 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 3, !dbg !111
  %8 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 0, !dbg !111
  %"'ip_phi" = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* } %tapeArg, 1, !dbg !111
  br i1 true, label %ret, label %fail, !dbg !110

ret:                                              ; preds = %julia__5_35361_inner.exit
  %"'ipc" = bitcast {} addrspace(10)* %"'ip_phi" to double addrspace(10)*, !dbg !110
  %"'ipc7" = addrspacecast double addrspace(10)* %"'ipc" to double addrspace(11)*, !dbg !110
  br label %invertret, !dbg !110

fail:                                             ; preds = %julia__5_35361_inner.exit
  unreachable

invertentry:                                      ; preds = %invertok.i
  ret void

invertok.i:                                       ; preds = %invertok3.i
  br label %invertentry

invertok3.i:                                      ; preds = %invertjulia__5_35361_inner.exit
  br label %invertok.i

invertjulia__5_35361_inner.exit:                  ; preds = %invertret
  %9 = call {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)*, {} addrspace(10)*, ...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* @ijl_apply_generic, {} addrspace(10)* @ejl_enz_runtime_generic_rev, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140673208742096 to {}*) to {} addrspace(10)*), {} addrspace(10)* @ejl_enz_val_false, {} addrspace(10)* @ejl_enz_val_false, {} addrspace(10)* @ejl_enz_val_1, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140673540590544 to {}*) to {} addrspace(10)*), {} addrspace(10)* %8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140674423445368 to {}*) to {} addrspace(10)*), {} addrspace(10)* @ejl_jl_nothing, {} addrspace(10)* %7, {} addrspace(10)* %"'mi", {} addrspace(10)* %3, {} addrspace(10)* @ejl_jl_nothing, {} addrspace(10)* %4, {} addrspace(10)* @ejl_jl_nothing, {} addrspace(10)* %5, {} addrspace(10)* @ejl_jl_nothing), !dbg !111
  br label %invertok3.i

invertret:                                        ; preds = %ret
  store double %differeturn, double* %"'de", align 8
  %10 = load double, double* %"'de", align 8, !dbg !110
  store double 0.000000e+00, double* %"'de", align 8, !dbg !110
  %11 = atomicrmw fadd double addrspace(11)* %"'ipc7", double %10 monotonic, align 8, !dbg !110
  br label %invertjulia__5_35361_inner.exit
}

Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = Float64, literal_rt = Any, rettype = Active{Any}, sret_union=false, pactualRetType=Float64

Stacktrace:
  [1] create_abi_wrapper(enzymefn::LLVM.Function, TT::Type, rettype::Type, actualRetType::Type, Mode::Enzyme.API.CDerivativeMode, augmented::Ptr{…}, width::Int64, returnPrimal::Bool, shadow_init::Bool, world::UInt64, interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{…})
    @ Enzyme.Compiler /scratch/htc/amartine/julia/packages/Enzyme/rwbr4/src/compiler.jl:2056
  [2] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler /scratch/htc/amartine/julia/packages/Enzyme/rwbr4/src/compiler.jl:1799
  [3] compile_unhooked(output::Symbol, job::GPUCompiler.CompilerJob{…})
    @ Enzyme.Compiler /scratch/htc/amartine/julia/packages/Enzyme/rwbr4/src/compiler.jl:4819
  [4] compile(target::Symbol, job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
    @ GPUCompiler /scratch/htc/amartine/julia/packages/GPUCompiler/Ecaql/src/driver.jl:67
  [5] compile
    @ /scratch/htc/amartine/julia/packages/GPUCompiler/Ecaql/src/driver.jl:55 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{…}, postopt::Bool)
    @ Enzyme.Compiler /scratch/htc/amartine/julia/packages/Enzyme/rwbr4/src/compiler.jl:5682
  [7] _thunk
    @ /scratch/htc/amartine/julia/packages/Enzyme/rwbr4/src/compiler.jl:5680 [inlined]
  [8] cached_compilation
    @ /scratch/htc/amartine/julia/packages/Enzyme/rwbr4/src/compiler.jl:5734 [inlined]
  [9] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, edges::Vector{…})
    @ Enzyme.Compiler /scratch/htc/amartine/julia/packages/Enzyme/rwbr4/src/compiler.jl:5848
 [10] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type, strongzero::Type)
    @ Enzyme.Compiler /scratch/htc/amartine/julia/packages/Enzyme/rwbr4/src/compiler.jl:6041
 [11] autodiff_thunk
    @ /scratch/htc/amartine/julia/packages/Enzyme/rwbr4/src/Enzyme.jl:997 [inlined]
 [12] _enzyme_withgradient(f::Function, args::Duplicated{GraphPolicy}; zero::Bool)
    @ FluxEnzymeExt /scratch/htc/amartine/julia/packages/Flux/9PibT/ext/FluxEnzymeExt/FluxEnzymeExt.jl:74
 [13] _enzyme_withgradient
    @ /scratch/htc/amartine/julia/packages/Flux/9PibT/ext/FluxEnzymeExt/FluxEnzymeExt.jl:64 [inlined]
 [14] withgradient(f::Function, args::Duplicated{GraphPolicy})
    @ Flux /scratch/htc/amartine/julia/packages/Flux/9PibT/src/gradient.jl:226
 [15] top-level scope
    @ REPL[35]:3

I looked at the PR referenced on an earlier bug report, but I suspect this is a different bug, as I’m not using fill! or doing any type inference AFAIK. Besides, I’m using version 0.13.59 and the PR fixed in on v0.12.2

Should I open a bug report on Enzyme, or does this look like a bug from GraphNeuralNetworks @wsmoses ?

yeah go ahead and file an issue on enzyme.jl

1 Like