Make mutating function more AD-friendly

Dear all,
since this is my first post, I want to thank you for the help I got in many years I used Julia (I always found the answer to my problems in previous posts!).

Diving into the problem:
what I need to do is to compute the quantity

\mathscr{L}(\theta; \mathbf{x}, \mathbf{y}) = \sum_i \left[\log(1 + e^{\eta_i(\mathbf{x}, \mathbf{y}, \theta)}) + \eta_i(\mathbf{x}, \mathbf{y}, \theta)y_i\right]

where

\mathbf{\eta}(\mathbf{x}, \mathbf{y}, \theta) = \mathbf{h}(\mathbf{x}, \theta) + \mathbf{W}(\mathbf{x}, \theta)\mathbf{y}

with \mathbf{y}\in\{0,1\}^N, x\in\mathbb{R}^M, \mathbf{h}\in\mathbb{R}^N, \eta\in\mathbb{R}^N, \mathbf{W}\in\mathbb{R}^{N\times N} (but heavily sparse) and \theta\in\mathbb{R}^P.

since N can be of the order of 10^5 and I need to compute \mathcal{L} many times, I decided to store in a cache struct \mathbf{h}, \mathbf{W} and the vector of the nonzero elements of the matrix, \mathbf{w}. My implementation of \mathcal{L} looks like this:

function npll!(
        mem::AbstractIsingCache,
        rsp::AbstractCovariatesResponseModel,
        y::AbstractVector{Bool},
        p::ComponentArray,
        x::AbstractVector{<:Real}
    )
    # Get references to cache elements
    h = deposit_vector(mem)
    w = crosstalk_vector(mem)
    W = crosstalk_matrix(mem)

    # Update cache with the response for this specific data point
    deposit!(h, rsp, p, x)
    crosstalk!(w, rsp, p, x)
    update_crosstalk_matrix!(mem)

    #################################################
    # Compute local conditional fields `Ξ·` in-place #
    #################################################
    # The `h` vector is overwritten by `deposit!`, then we add the crosstalk
    # influence from neighbors `W*y`. `h` now holds the local fields `Ξ·`.
    Tv = eltype(h)
    mul!(h, W, y, one(Tv), one(Tv))

    ########################################################
    # Compute the final negative pseudo-log-likelihood sum #
    ########################################################
    # `h` here is equivalent to `Ξ·` from the formula.
    @tullio n_p_ll := log1pexp(h[i]) + h[i] * y[i]
    #n_p_ll = sum(log1pexp.(h) .+ h .* y)
    return -n_p_ll
end

deposit! and crosstalk! are just linear algebra operations on parameters and x vector where I use @tullio to contract indices.
Every update function is tested and works as expected. Now, the issue arises when I try to use DifferentiationInterface to compute the gradient with respect to the parameters. Playing with backends, I found that only Mooncake is able to compute the gradients:


julia> using PIXIE #my package

julia> using DifferentiationInterface

julia> using SparseArrays

julia> s = RectangularPixel(1000,10)
RectangularPixel(1000, 10, :rook)

julia> r = SimpleResponse(s, num_covariates=4)
SimpleResponse(4, 10000, 18990)

julia> c = ConditionalIsing(s, r) # constructor for my model
ConditionalIsing{PIXIE.IsingCachePool{PIXIE.IsingCacheCPU{Float64, Int64}}}(RectangularPixel(1000, 10, :rook), SimpleResponse(4, 10000, 18990), PIXIE.IsingCachePool{PIXIE.IsingCacheCPU{Float64, Int64}}(Channel{PIXIE.IsingCacheCPU{Float64, Int64}}(16), 16))

julia> p = PIXIE.init_parameters(r)
ComponentVector{Float64}(deposit = (bias = [...], linear = [...], quadratic = [...]), crosstalk = (bias = [...], linear = [...]))


julia> cc = take!(PIXIE.pool(c)) 
PIXIE.IsingCacheCPU{Float64, Int64}([...], [...], [...], PIXIE.CrosstalkEdgesHSCDS{Int64}([1, 10], [1, 10000, 19990], [...]))


julia> y = sprand(Bool, 10_000, 0.001)
10000-element SparseVector{Bool, Int64} with 11 stored entries:
  [175 ]  =  1
  [2224]  =  1
  [2527]  =  1
  [3617]  =  1
  [3684]  =  1
  [4060]  =  1
  [4908]  =  1
  [6762]  =  1
  [6890]  =  1
  [7340]  =  1
  [8740]  =  1

julia> x = rand(4)
4-element Vector{Float64}:
 0.3524710528059386
 0.3988983273927492
 0.6920326598241286
 0.7996804030096054

julia> f(p, mem, rsp, y, x) = PIXIE.npll!(mem, rsp, y, p, x)
f (generic function with 1 method)

julia> import Mooncake

julia> backend = AutoMooncake()
AutoMooncake()

julia> gradient(f, backend, p, Cache(cc), Constant(r), Constant(y), Constant(x))
Mooncake.Tangent{@NamedTuple{data::Vector{Float64}, axes::Mooncake.NoTangent}}((data = [-0.7969185840561417, -0.11077871561171315, -0.7656696626602588, -0.2705008115642961, -0.40970563708520497, -0.10346557289376614, -0.4827844568393658, -0.5948295114134357, -0.44043476113139235, -0.546995154664162  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], axes = Mooncake.NoTangent()))

But if I try with Enzyme the computation fails

julia> import Enzyme

julia> backend = AutoEnzyme()
AutoEnzyme()

julia> gradient(f, backend, p, Cache(cc), Constant(r), Constant(y), Constant(x))
ERROR: Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
 a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
 b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.
 Failure within method: MethodInstance for PIXIE.deposit!(::Vector{…}, ::SimpleResponse, ::ComponentArrays.ComponentVector{…}, ::Vector{…})
Hint: catch this exception as `err` and call `code_typed(err)` to inspect the errornous code.
If you have Cthulu.jl loaded you can also use `code_typed(err; interactive = true)` to interactively introspect the code.
Mismatched activity for:   store {} addrspace(10)* %3, {} addrspace(10)* addrspace(11)* %.fca.2.gep59, align 8, !dbg !242, !noalias !194 const val: {} addrspace(10)* %3
 value=Unknown object of type Vector{Float64}
 llvalue={} addrspace(10)* %3

Stacktrace:
 [1] tile_halves
   @ ~/.julia/packages/Tullio/2zyFP/src/threads.jl:136
 [2] threader
   @ ~/.julia/packages/Tullio/2zyFP/src/threads.jl:65
 [3] macro expansion
   @ ~/.julia/packages/Tullio/2zyFP/src/macro.jl:1004
 [4] deposit!
   @ ~/Documents/PhD/Alignment/PIXIE/src/covariates_response_models.jl:127

Stacktrace:
  [1] tile_halves
    @ ~/.julia/packages/Tullio/2zyFP/src/threads.jl:136 [inlined]
  [2] threader
    @ ~/.julia/packages/Tullio/2zyFP/src/threads.jl:65 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/Tullio/2zyFP/src/macro.jl:1004 [inlined]
  [4] deposit!
    @ ~/Documents/PhD/Alignment/PIXIE/src/covariates_response_models.jl:127
  [5] npll!
    @ ~/Documents/PhD/Alignment/PIXIE/src/pseudologlikelihood.jl:40
  [6] f
    @ ./REPL[12]:1 [inlined]
  [7] f
    @ ./REPL[12]:0 [inlined]
  [8] diffejulia_f_61625_inner_71wrap
    @ ./REPL[12]:0
  [9] macro expansion
    @ ~/.julia/packages/Enzyme/eJcor/src/compiler.jl:5875 [inlined]
 [10] enzyme_call
    @ ~/.julia/packages/Enzyme/eJcor/src/compiler.jl:5409 [inlined]
 [11] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/eJcor/src/compiler.jl:5295 [inlined]
 [12] autodiff
    @ ~/.julia/packages/Enzyme/eJcor/src/Enzyme.jl:521 [inlined]
 [13] autodiff
    @ ~/.julia/packages/Enzyme/eJcor/src/Enzyme.jl:542 [inlined]
 [14] gradient(::typeof(f), ::DifferentiationInterfaceEnzymeExt.EnzymeGradientPrep{…}, ::AutoEnzyme{…}, ::ComponentArrays.ComponentVector{…}, ::Cache{…}, ::Constant{…}, ::Constant{…}, ::Constant{…})
    @ DifferentiationInterfaceEnzymeExt ~/.julia/packages/DifferentiationInterface/zJHX8/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl:303
 [15] gradient(::typeof(f), ::AutoEnzyme{…}, ::ComponentArrays.ComponentVector{…}, ::Cache{…}, ::Constant{…}, ::Constant{…}, ::Constant{…})
    @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/zJHX8/src/first_order/gradient.jl:63
 [16] top-level scope
    @ REPL[18]:1
Some type information was truncated. Use `show(err)` to see complete types.

julia> try 
           gradient(f, backend, p, Cache(cc), Constant(r), Constant(y), Constant(x))
       catch err
           code_typed(err)
       end
1-element Vector{Any}:
 CodeInfo(
1 ─── %1   = Base.getfield(h, :size)::Tuple{Int64}
β”‚     %2   = Base.getfield(R, :num_nodes)::Int64
β”‚     %3   = $(Expr(:boundscheck, true))::Bool
β”‚     %4   = Base.getfield(%1, 1, %3)::Int64
β”‚     %5   = (%4 === %2)::Bool
β”‚     %6   = (%5 === false)::Bool
└────        goto #3 if not %6
2 ───        goto #4
3 ───        goto #4
4 ┄── %10  = Ο† (#2 => false, #3 => true)::Bool
└────        goto #5
5 ─── %12  = Base.not_int(%10)::Bool
└────        goto #6
6 ───        goto #8 if not %12
7 ─── %15  = invoke PIXIE.DimensionMismatch("h vector size does not match response model"::String)::DimensionMismatch
β”‚            PIXIE.throw(%15)::Union{}
└────        unreachable
8 ─── %18  = ComponentArrays.getfield(p, :data)::Vector{Float64}
β”‚     %19  = $(Expr(:boundscheck, true))::Bool
└────        goto #13 if not %19
9 ─── %21  = Base.getfield(%18, :size)::Tuple{Int64}
β”‚     %22  = $(Expr(:boundscheck, true))::Bool
β”‚     %23  = Base.getfield(%21, 1, %22)::Int64
β”‚     %24  = Base.bitcast(UInt64, %23)::UInt64
β”‚     %25  = Base.ult_int(0x0000000000000000, %24)::Bool
β”‚     %26  = Base.bitcast(UInt64, %23)::UInt64
β”‚     %27  = Base.ult_int(0x000000000003344f, %26)::Bool
β”‚     %28  = Base.and_int(%25, %27)::Bool
β”‚     %29  = Base.or_int(false, %28)::Bool
└────        goto #11 if not %29
10 ──        goto #12
11 ──        invoke Base.throw_boundserror(%18::Vector{Float64}, (1:210000,)::Tuple{UnitRange{Int64}})::Union{}
└────        unreachable
12 ──        nothing::Nothing
13 ┄─        goto #14
14 ──        goto #15
15 ──        goto #16
16 ──        goto #17
17 ── %39  = $(Expr(:boundscheck, true))::Bool
└────        goto #19 if not %39
18 ──        nothing::Nothing
19 ┄─ %42  = %new(SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, %18, (1:10000,), 0, 1)::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}
└────        goto #20
20 ──        goto #21
21 ──        goto #22
22 ──        goto #23
23 ── %47  = ComponentArrays.getfield(p, :data)::Vector{Float64}
β”‚     %48  = $(Expr(:boundscheck, true))::Bool
└────        goto #28 if not %48
24 ── %50  = Base.getfield(%47, :size)::Tuple{Int64}
β”‚     %51  = $(Expr(:boundscheck, true))::Bool
β”‚     %52  = Base.getfield(%50, 1, %51)::Int64
β”‚     %53  = Base.bitcast(UInt64, %52)::UInt64
β”‚     %54  = Base.ult_int(0x0000000000000000, %53)::Bool
β”‚     %55  = Base.bitcast(UInt64, %52)::UInt64
β”‚     %56  = Base.ult_int(0x000000000003344f, %55)::Bool
β”‚     %57  = Base.and_int(%54, %56)::Bool
β”‚     %58  = Base.or_int(false, %57)::Bool
└────        goto #26 if not %58
25 ──        goto #27
26 ──        invoke Base.throw_boundserror(%47::Vector{Float64}, (1:210000,)::Tuple{UnitRange{Int64}})::Union{}
└────        unreachable
27 ──        nothing::Nothing
28 ┄─        goto #29
29 ──        goto #30
30 ──        goto #31
31 ──        goto #32
32 ── %68  = $(Expr(:boundscheck, true))::Bool
└────        goto #34 if not %68
33 ──        nothing::Nothing
34 ┄─ %71  = %new(SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, %47, (10001:50000,), 10000, 1)::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}
└────        goto #35
35 ──        goto #36
36 ── %74  = %new(Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, %71, (10000, 4), ())::Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}
└────        goto #37
37 ──        goto #38
38 ── %77  = ComponentArrays.getfield(p, :data)::Vector{Float64}
β”‚     %78  = $(Expr(:boundscheck, true))::Bool
└────        goto #43 if not %78
39 ── %80  = Base.getfield(%77, :size)::Tuple{Int64}
β”‚     %81  = $(Expr(:boundscheck, true))::Bool
β”‚     %82  = Base.getfield(%80, 1, %81)::Int64
β”‚     %83  = Base.bitcast(UInt64, %82)::UInt64
β”‚     %84  = Base.ult_int(0x0000000000000000, %83)::Bool
β”‚     %85  = Base.bitcast(UInt64, %82)::UInt64
β”‚     %86  = Base.ult_int(0x000000000003344f, %85)::Bool
β”‚     %87  = Base.and_int(%84, %86)::Bool
β”‚     %88  = Base.or_int(false, %87)::Bool
└────        goto #41 if not %88
40 ──        goto #42
41 ──        invoke Base.throw_boundserror(%77::Vector{Float64}, (1:210000,)::Tuple{UnitRange{Int64}})::Union{}
└────        unreachable
42 ──        nothing::Nothing
43 ┄─        goto #44
44 ──        goto #45
45 ──        goto #46
46 ──        goto #47
47 ── %98  = $(Expr(:boundscheck, true))::Bool
└────        goto #49 if not %98
48 ──        nothing::Nothing
49 ┄─ %101 = %new(SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, %77, (50001:210000,), 50000, 1)::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}
└────        goto #50
50 ──        goto #51
51 ── %104 = %new(Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, %101, (10000, 4, 4), ())::Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}
└────        goto #52
52 ──        goto #53
53 ── %107 = Base.getfield(h, :size)::Tuple{Int64}
β”‚     %108 = $(Expr(:boundscheck, true))::Bool
β”‚     %109 = Base.getfield(%107, 1, %108)::Int64
β”‚     %110 = %new(Base.OneTo{Int64}, %109)::Base.OneTo{Int64}
β”‚     %111 = Core.tuple(%110)::Tuple{Base.OneTo{Int64}}
β”‚     %112 = (%109 === 10000)::Bool
└────        goto #55 if not %112
54 ──        goto #56
55 ──        goto #56
56 ┄─ %116 = Ο† (#54 => %112, #55 => false)::Bool
└────        goto #58 if not %116
57 ──        goto #59
58 ── %119 = invoke Base.Broadcast.DimensionMismatch("array could not be broadcast to match destination"::String)::DimensionMismatch
β”‚            Base.Broadcast.throw(%119)::Union{}
└────        unreachable
59 ──        goto #60
60 ──        goto #61
61 ── %124 = Base.getfield(h, :size)::Tuple{Int64}
β”‚     %125 = $(Expr(:boundscheck, true))::Bool
β”‚     %126 = Base.getfield(%124, 1, %125)::Int64
β”‚     %127 = %new(Base.OneTo{Int64}, %126)::Base.OneTo{Int64}
β”‚     %128 = Core.tuple(%127)::Tuple{Base.OneTo{Int64}}
β”‚     %129 = (%126 === %109)::Bool
β”‚     %130 = (%129 === false)::Bool
└────        goto #63 if not %130
62 ──        goto #64
63 ──        goto #64
64 ┄─ %134 = Ο† (#62 => false, #63 => true)::Bool
└────        goto #65
65 ──        goto #153 if not %134
66 ── %137 = (%126 === 10000)::Bool
β”‚     %138 = (%137 === false)::Bool
└────        goto #68 if not %138
67 ──        goto #69
68 ──        goto #69
69 ┄─ %142 = Ο† (#67 => false, #68 => true)::Bool
└────        goto #70
70 ──        goto #72 if not %142
71 ──        invoke Base.copyto!(h::Vector{Float64}, %42::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true})::Any
└────        goto #154
72 ── %147 = Base.getfield(h, :size)::Tuple{Int64}
β”‚     %148 = $(Expr(:boundscheck, true))::Bool
β”‚     %149 = Base.getfield(%147, 1, %148)::Int64
β”‚     %150 = (%149 === 0)::Bool
β”‚     %151 = Base.not_int(%150)::Bool
└────        goto #85 if not %151
73 ── %153 = Base.getfield(h, :ref)::MemoryRef{Float64}
β”‚     %154 = Base.getfield(%153, :mem)::Memory{Float64}
β”‚     %155 = $(Expr(:foreigncall, :(:jl_genericmemory_owner), Any, svec(Any), 0, :(:ccall), :(%154)))::Any
β”‚     %156 = (%155 isa Memory{Float64})::Bool
└────        goto #75 if not %156
74 ── %158 = Ο€ (%155, Memory{Float64})
└────        goto #76
75 ──        nothing::Nothing
76 ┄─ %161 = Ο† (#74 => %158, #75 => %154)::Memory{Float64}
β”‚     %162 = Base.getfield(%161, :ptr)::Ptr{Nothing}
β”‚     %163 = Base.bitcast(Ptr{Float64}, %162)::Ptr{Float64}
β”‚     %164 = Core.bitcast(Core.UInt, %163)::UInt64
└────        goto #77
77 ──        goto #78
78 ── %167 = Base.getfield(%18, :ref)::MemoryRef{Float64}
β”‚     %168 = Base.getfield(%167, :mem)::Memory{Float64}
β”‚     %169 = $(Expr(:foreigncall, :(:jl_genericmemory_owner), Any, svec(Any), 0, :(:ccall), :(%168)))::Any
β”‚     %170 = (%169 isa Memory{Float64})::Bool
└────        goto #80 if not %170
79 ── %172 = Ο€ (%169, Memory{Float64})
└────        goto #81
80 ──        nothing::Nothing
81 ┄─ %175 = Ο† (#79 => %172, #80 => %168)::Memory{Float64}
β”‚     %176 = Base.getfield(%175, :ptr)::Ptr{Nothing}
β”‚     %177 = Base.bitcast(Ptr{Float64}, %176)::Ptr{Float64}
β”‚     %178 = Core.bitcast(Core.UInt, %177)::UInt64
└────        goto #82
82 ──        goto #83
83 ──        goto #84
84 ── %182 = (%164 === %178)::Bool
β”‚     %183 = Base.not_int(%182)::Bool
β”‚     %184 = Base.not_int(%183)::Bool
└────        goto #86
85 ──        goto #86
86 ┄─ %187 = Ο† (#84 => %184, #85 => false)::Bool
└────        goto #107 if not %187
87 ── %189 = $(Expr(:foreigncall, :(:jl_alloc_genericmemory), Ref{Memory{Float64}}, svec(Any, Int64), 0, :(:ccall), Memory{Float64}, 10000, 10000))::Memory{Float64}
β”‚     %190 = Core.memoryrefnew(%189)::MemoryRef{Float64}
β”‚     %191 = %new(Vector{Float64}, %190, (10000,))::Vector{Float64}
β”‚     %192 = $(Expr(:boundscheck, true))::Bool
└────        goto #89 if not %192
88 ──        nothing::Nothing
89 ┄─ %195 = $(Expr(:boundscheck, false))::Bool
└────        goto #91 if not %195
90 ──        nothing::Nothing
91 ┄─        goto #92
92 ──        goto #93
93 ──        goto #94
94 ──        goto #95
95 ──        goto #96
96 ──        goto #97
97 ──        goto #98
98 ──        goto #99
99 ──        goto #100
100 ─ %207 = %new(SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, %18, (1:10000,), 0, 1)::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}
└────        goto #101
101 ─        goto #102
102 ─        goto #103
103 ─        goto #104
104 ─        goto #105
105 ─        invoke Base.copyto!(%191::Vector{Float64}, %207::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true})::Any
β”‚     %214 = %new(SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, %191, (1:10000,), 0, 1)::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}
└────        goto #106
106 ─        goto #108
107 ─        goto #108
108 β”„ %218 = Ο† (#106 => %191, #107 => %18)::Vector{Float64}
β”‚     %219 = Ο† (#106 => %214, #107 => %42)::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}
└────        goto #109
109 ─ %221 = %new(Base.Broadcast.Extruded{SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{Bool}, Tuple{Int64}}, %219, (true,), (1,))::Base.Broadcast.Extruded{SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{Bool}, Tuple{Int64}}
└────        goto #110
110 ─ %223 = Core.tuple(%221)::Tuple{Base.Broadcast.Extruded{SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{Bool}, Tuple{Int64}}}
└────        goto #111
111 ─ %225 = %new(Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}}, typeof(identity), Tuple{Base.Broadcast.Extruded{SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{Bool}, Tuple{Int64}}}}, nothing, identity, %223, %111)::Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}}, typeof(identity), Tuple{Base.Broadcast.Extruded{SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{Bool}, Tuple{Int64}}}}
└────        goto #112
112 ─ %227 = Base.slt_int(0, %109)::Bool
└────        goto #151 if not %227
113 ─        nothing::Nothing
114 β”„ %230 = Ο† (#113 => 0, #150 => %317)::Int64
β”‚     %231 = Base.slt_int(%230, %109)::Bool
└────        goto #151 if not %231
115 ─ %233 = Base.add_int(%230, 1)::Int64
β”‚     %234 = $(Expr(:boundscheck, false))::Bool
└────        goto #120 if not %234
116 ─ %236 = Core.tuple(%233)::Tuple{Int64}
β”‚     %237 = Base.sub_int(%233, 1)::Int64
β”‚     %238 = Base.bitcast(UInt64, %237)::UInt64
β”‚     %239 = Base.bitcast(UInt64, %109)::UInt64
β”‚     %240 = Base.ult_int(%238, %239)::Bool
└────        goto #118 if not %240
117 ─        goto #119
118 ─        invoke Base.throw_boundserror(%110::Base.OneTo{Int64}, %236::Tuple{Int64})::Union{}
└────        unreachable
119 ─        nothing::Nothing
120 β”„        goto #121
121 ─        goto #122
122 ─        goto #123
123 ─ %249 = $(Expr(:boundscheck, false))::Bool
└────        goto #130 if not %249
124 ─        goto #125
125 ─ %252 = Base.sub_int(%233, 1)::Int64
β”‚     %253 = Base.bitcast(UInt64, %252)::UInt64
β”‚     %254 = Base.bitcast(UInt64, %109)::UInt64
β”‚     %255 = Base.ult_int(%253, %254)::Bool
β”‚     %256 = Base.and_int(%255, true)::Bool
└────        goto #126
126 ─        goto #128 if not %256
127 ─        goto #129
128 ─ %260 = Core.tuple(%233)::Tuple{Int64}
β”‚            invoke Base.throw_boundserror(%225::Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}}, typeof(identity), Tuple{Base.Broadcast.Extruded{SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{Bool}, Tuple{Int64}}}}, %260::Tuple{Int64})::Union{}
└────        unreachable
129 ─        nothing::Nothing
130 β”„ %264 = $(Expr(:boundscheck, false))::Bool
└────        goto #135 if not %264
131 ─ %266 = Core.tuple(%233)::Tuple{Int64}
β”‚     %267 = Base.sub_int(%233, 1)::Int64
β”‚     %268 = Base.bitcast(UInt64, %267)::UInt64
β”‚     %269 = Base.ult_int(%268, 0x0000000000002710)::Bool
└────        goto #133 if not %269
132 ─        goto #134
133 ─        invoke Base.throw_boundserror(%219::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, %266::Tuple{Int64})::Union{}
└────        unreachable
134 ─        nothing::Nothing
135 β”„ %275 = Base.add_int(0, %233)::Int64
β”‚     %276 = $(Expr(:boundscheck, false))::Bool
└────        goto #139 if not %276
136 ─ %278 = Base.sub_int(%275, 1)::Int64
β”‚     %279 = Base.bitcast(Base.UInt, %278)::UInt64
β”‚     %280 = Base.getfield(%218, :size)::Tuple{Int64}
β”‚     %281 = $(Expr(:boundscheck, true))::Bool
β”‚     %282 = Base.getfield(%280, 1, %281)::Int64
β”‚     %283 = Base.bitcast(Base.UInt, %282)::UInt64
β”‚     %284 = Base.ult_int(%279, %283)::Bool
└────        goto #138 if not %284
137 ─        goto #139
138 ─ %287 = Core.tuple(%275)::Tuple{Int64}
β”‚            invoke Base.throw_boundserror(%218::Vector{Float64}, %287::Tuple{Int64})::Union{}
└────        unreachable
139 β”„ %290 = Base.getfield(%218, :ref)::MemoryRef{Float64}
β”‚     %291 = Base.memoryrefnew(%290, %275, false)::MemoryRef{Float64}
β”‚     %292 = Base.memoryrefget(%291, :not_atomic, false)::Float64
└────        goto #140
140 ─        goto #141
141 ─        goto #142
142 ─        goto #143
143 ─        goto #144
144 ─        goto #145
145 ─ %299 = $(Expr(:boundscheck, false))::Bool
└────        goto #149 if not %299
146 ─ %301 = Base.sub_int(%233, 1)::Int64
β”‚     %302 = Base.bitcast(UInt64, %301)::UInt64
β”‚     %303 = Base.getfield(h, :size)::Tuple{Int64}
β”‚     %304 = $(Expr(:boundscheck, true))::Bool
β”‚     %305 = Base.getfield(%303, 1, %304)::Int64
β”‚     %306 = Base.bitcast(UInt64, %305)::UInt64
β”‚     %307 = Base.ult_int(%302, %306)::Bool
└────        goto #148 if not %307
147 ─        goto #149
148 ─ %310 = Core.tuple(%233)::Tuple{Int64}
β”‚            invoke Base.throw_boundserror(h::Vector{Float64}, %310::Tuple{Int64})::Union{}
└────        unreachable
149 β”„ %313 = Base.getfield(h, :ref)::MemoryRef{Float64}
β”‚     %314 = Base.memoryrefnew(%313, %233, false)::MemoryRef{Float64}
β”‚            Base.memoryrefset!(%314, %292, :not_atomic, false)::Float64
└────        goto #150
150 ─ %317 = Base.add_int(%230, 1)::Int64
β”‚            $(Expr(:loopinfo, Symbol("julia.simdloop"), nothing))::Nothing
└────        goto #114
151 β”„        goto #152
152 ─        goto #154
153 ─        invoke Base.Broadcast.throwdm(%128::Tuple{Base.OneTo{Int64}}, %111::Tuple{Base.OneTo{Int64}})::Union{}
└────        unreachable
154 β”„        goto #155
155 ─        goto #156
156 ─        goto #157
157 ─ %327 = Base.getfield(x, :size)::Tuple{Int64}
β”‚     %328 = $(Expr(:boundscheck, true))::Bool
β”‚     %329 = Base.getfield(%327, 1, %328)::Int64
β”‚     %330 = (4 === %329)::Bool
β”‚     %331 = Base.not_int(%330)::Bool
└────        goto #159 if not %331
158 ─ %333 = Base.getfield(x, :size)::Tuple{Int64}
β”‚     %334 = $(Expr(:boundscheck, true))::Bool
β”‚     %335 = Base.getfield(%333, 1, %334)::Int64
β”‚     %336 = Core.tuple("second dimension of A, ", 4, ", does not match length of x, ", %335)::Tuple{String, Int64, String, Int64}
β”‚     %337 = %new(Base.LazyString, %336, Base.nothing)::LazyString
β”‚     %338 = %new(Base.DimensionMismatch, %337)::DimensionMismatch
β”‚            LinearAlgebra.throw(%338)::Union{}
└────        unreachable
159 ─ %341 = Base.getfield(h, :size)::Tuple{Int64}
β”‚     %342 = $(Expr(:boundscheck, true))::Bool
β”‚     %343 = Base.getfield(%341, 1, %342)::Int64
β”‚     %344 = (10000 === %343)::Bool
β”‚     %345 = Base.not_int(%344)::Bool
└────        goto #161 if not %345
160 ─ %347 = Base.getfield(h, :size)::Tuple{Int64}
β”‚     %348 = $(Expr(:boundscheck, true))::Bool
β”‚     %349 = Base.getfield(%347, 1, %348)::Int64
β”‚     %350 = Core.tuple("first dimension of A, ", 10000, ", does not match length of y, ", %349)::Tuple{String, Int64, String, Int64}
β”‚     %351 = %new(Base.LazyString, %350, Base.nothing)::LazyString
β”‚     %352 = %new(Base.DimensionMismatch, %351)::DimensionMismatch
β”‚            LinearAlgebra.throw(%352)::Union{}
└────        unreachable
161 ─        invoke LinearAlgebra.BLAS.gemv!('N'::Char, 1.0::Float64, %74::Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, x::Vector{Float64}, 1.0::Float64, h::Vector{Float64})::Vector{Float64}
└────        goto #162
162 ─        goto #163
163 ─        goto #164
164 ─        goto #165
165 ─        nothing::Nothing
β”‚            nothing::Nothing
β”‚     %362 = Base.getfield(x, :size)::Tuple{Int64}
β”‚     %363 = $(Expr(:boundscheck, true))::Bool
β”‚     %364 = Base.getfield(%362, 1, %363)::Int64
β”‚     %365 = (%364 === 4)::Bool
└────        goto #182 if not %365
166 ─        nothing::Nothing
β”‚     %368 = Base.getfield(x, :size)::Tuple{Int64}
β”‚     %369 = $(Expr(:boundscheck, true))::Bool
β”‚     %370 = Base.getfield(%368, 1, %369)::Int64
β”‚     %371 = (%370 === 4)::Bool
└────        goto #181 if not %371
167 ─        nothing::Nothing
β”‚     %374 = Base.getfield(h, :size)::Tuple{Int64}
β”‚     %375 = $(Expr(:boundscheck, true))::Bool
β”‚     %376 = Base.getfield(%374, 1, %375)::Int64
β”‚     %377 = Base.getfield(h, :size)::Tuple{Int64}
β”‚     %378 = $(Expr(:boundscheck, true))::Bool
β”‚     %379 = Base.getfield(%377, 1, %378)::Int64
β”‚     %380 = (10000 === %379)::Bool
└────        goto #180 if not %380
168 ─        nothing::Nothing
β”‚     %383 = Base.sle_int(1, %376)::Bool
└────        goto #170 if not %383
169 ─        goto #171
170 ─        goto #171
171 β”„ %387 = Ο† (#169 => %376, #170 => 0)::Int64
β”‚     %388 = %new(UnitRange{Int64}, 1, %387)::UnitRange{Int64}
└────        goto #172
172 ─        goto #173
173 ─        goto #174
174 ─ %392 = Core.tuple(%388)::Tuple{UnitRange{Int64}}
└────        goto #175
175 ─ %394 = Base.sub_int(%387, 1)::Int64
β”‚     %395 = Base.add_int(1, %394)::Int64
β”‚     %396 = invoke Base.Threads.nthreads()::Int64
β”‚     %397 = Base.mul_int(%395, 16)::Int64
β”‚     %398 = Base.checked_sdiv_int(%397, 262144)::Int64
β”‚     %399 = Base.slt_int(0, %397)::Bool
β”‚     %400 = (%399 === true)::Bool
β”‚     %401 = Base.mul_int(%398, 262144)::Int64
β”‚     %402 = (%401 === %397)::Bool
β”‚     %403 = Base.not_int(%402)::Bool
β”‚     %404 = Base.and_int(%400, %403)::Bool
β”‚     %405 = Core.zext_int(Core.Int64, %404)::Int64
β”‚     %406 = Core.and_int(%405, 1)::Int64
β”‚     %407 = Base.add_int(%398, %406)::Int64
β”‚     %408 = Base.slt_int(%407, %396)::Bool
β”‚     %409 = Core.ifelse(%408, %407, %396)::Int64
β”‚     %410 = Base.slt_int(%395, %409)::Bool
β”‚     %411 = Core.ifelse(%410, %395, %409)::Int64
β”‚     %412 = Base.slt_int(1, %411)::Bool
└────        goto #177 if not %412
176 ─ %414 = Core.tuple(h, %104, x)::Tuple{Vector{Float64}, Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{Float64}}
β”‚            invoke Tullio.thread_halves(PIXIE.var"#π’œπ’Έπ“‰!#19"()::PIXIE.var"#π’œπ’Έπ“‰!#19", Vector{Float64}::Type{Vector{Float64}}, %414::Tuple{Vector{Float64}, Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{Float64}}, %392::Tuple{UnitRange{Int64}}, (1:4, 1:4)::Tuple{UnitRange{Int64}, UnitRange{Int64}}, %411::Int64, true::Bool)::Any
└────        goto #178
177 ─ %417 = Core.tuple(h, %104, x)::Tuple{Vector{Float64}, Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{Float64}}
β”‚     %418 = Tullio.tile_halves::typeof(Tullio.tile_halves)
└────        invoke %418(PIXIE.var"#π’œπ’Έπ“‰!#19"()::PIXIE.var"#π’œπ’Έπ“‰!#19", Vector{Float64}::Type{Vector{Float64}}, %417::Tuple{Vector{Float64}, Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{Float64}}, %392::Tuple{UnitRange{Int64}}, (1:4, 1:4)::Tuple{UnitRange{Int64}, UnitRange{Int64}}, true::Bool, true::Bool)::Nothing
178 β”„        goto #179
179 ─        return PIXIE.nothing
180 ─        (throw)("range of index i must agree")::Union{}
└────        unreachable
181 ─        (throw)("range of index k must agree")::Union{}
└────        unreachable
182 ─        (throw)("range of index j must agree")::Union{}
└────        unreachable
) => Nothing

julia> 

Following the suggestions in this thread I tried to define f as a clojure but the issue persists.

And all the other backends that support Cache() fail as well.

Do you have any suggestion on how to improve the design to be more AD friendly while avoiding creating big temporary vectors during the computation? Benchmarks confirm that I have considerable performance gains by using the cache.

Thank you all for the help!

In theory, the reverse-mode AD backends that support mutation in Julia are Enzyme, Mooncake and ReverseDiff. In practice, it can be hard to get ReverseDiff working due to type issues. However, Enzyme should be able to handle a lot of code out of the box.

  • The right way to fix this would be to figure out if Enzyme itself has a problem with your code, or if the problem comes from the DI wrapper. Trying out Enzyme’s native API should give you the answer:

EDIT: this was incorrect, see below

Enzyme.gradient(Reverse, f, p, Duplicated(cc, make_zero(cc)), Const(r), Const(y), Const(x))
  • In the meantime, if you need to use DI, you can change the backend object to comply with the suggestion in the error message[1]. However, keep in mind that runtime activity will make things slightly slower.
backend = AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse))

  1. DI used to provide a nicer error hint explaining this, but it was broken by an otherwise harmless Enzyme PR because the exception type changed. I had to modify Julia internals to reactivate it, so we’ll have it back in Julia 1.13. β†©οΈŽ

Thank you @gdalle for the quick reply. Indeed, setting the runtime activity made Enzyme through the DI interface work, even if slower than Mooncake:

  • Mooncake:
    julia> @benchmark gradient($f, $prep, $backend, $p, Cache($cc), Constant($r), Constant($y), Constant($x))
    BenchmarkTools.Trial: 422 samples with 1 evaluation per sample.
     Range (min … max):  10.691 ms … 28.670 ms  β”Š GC (min … max): 0.00% … 60.04%
     Time  (median):     11.438 ms              β”Š GC (median):    0.00%
     Time  (mean Β± Οƒ):   11.832 ms Β±  1.349 ms  β”Š GC (mean Β± Οƒ):  1.62% Β±  5.34%
    
        β–…β–ˆβ–‡β–†β–„β–‚β–„β–…β–„β–„β–ƒβ–‚β–‚β–‚β–‚β–                                           
      β–‡β–‡β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–†β–‡β–ˆβ–‡β–‡β–„β–β–β–„β–β–„β–„β–‡β–†β–†β–‡β–„β–β–‡β–„β–β–„β–‡β–β–„β–β–„β–β–β–β–β–β–„β–β–„β–β–β–β–β–„ β–‡
      10.7 ms      Histogram: log(frequency) by time      16.9 ms <
    
     Memory estimate: 2.55 MiB, allocs estimate: 104.
    
  • Enzyme (mode=Enzyme.set_runtime_activity(Enzyme.Reverse):
    julia> @benchmark gradient($f, $prep, $backend, $p, Cache($cc), Constant($r), Constant($y), Constant($x))
    β”Œ Warning: active variables passed by value to jl_new_task are not yet supported
    β”” @ Enzyme.Compiler ~/.julia/packages/Enzyme/eJcor/src/rules/parallelrules.jl:726
    β”Œ Warning: active variables passed by value to jl_new_task are not yet supported
    β”” @ Enzyme.Compiler ~/.julia/packages/Enzyme/eJcor/src/rules/parallelrules.jl:726
    BenchmarkTools.Trial: 160 samples with 1 evaluation per sample.
     Range (min … max):  25.923 ms … 65.722 ms  β”Š GC (min … max): 0.00% … 50.01%
     Time  (median):     29.565 ms              β”Š GC (median):    0.00%
     Time  (mean Β± Οƒ):   31.312 ms Β±  4.983 ms  β”Š GC (mean Β± Οƒ):  4.61% Β±  9.90%
    
              β–…β–ˆβ–†β–…       ▁▁                                        
    β–…β–β–…β–…β–β–β–β–…β–ˆβ–ˆβ–ˆβ–ˆβ–†β–†β–‡β–…β–…β–‡β–β–ˆβ–ˆβ–…β–…β–†β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–…β–…β–…β–‡β–‡β–†β–β–β–β–…β–β–β–β–… β–…
      25.9 ms      Histogram: log(frequency) by time      46.7 ms <
    
     Memory estimate: 8.92 MiB, allocs estimate: 119330.
    

Meanwhile, trying out Enzyme native API gave this error:

julia> Enzyme.gradient(Reverse, f, Duplicated(cc, make_zero(cc)), Const(r), Const(y), Const(x))
ERROR: MethodError: no method matching primal_return_type(::Expr)
The function `primal_return_type` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  primal_return_type(::Mode, ::Type, ::Type)
   @ Enzyme ~/.julia/packages/Enzyme/eJcor/src/typeutils/inference.jl:136

Stacktrace:
 [1] (::Core.GeneratedFunctionStub)(world::UInt64, source::LineNumberNode, args::Any)
   @ Core ./boot.jl:707
 [2] primal_return_type_generator(world::UInt64, source::Any, self::Any, mode::Type, ft::Type, tt::Type)
   @ Enzyme.Compiler ~/.julia/packages/Enzyme/eJcor/src/typeutils/inference.jl:118
 [3] autodiff
   @ ~/.julia/packages/Enzyme/eJcor/src/Enzyme.jl:387 [inlined]
 [4] autodiff
   @ ~/.julia/packages/Enzyme/eJcor/src/Enzyme.jl:542 [inlined]
 [5] macro expansion
   @ ~/.julia/packages/Enzyme/eJcor/src/sugar.jl:286 [inlined]
 [6] gradient(::ReverseMode{…}, ::typeof(f), ::Duplicated{…}, ::Const{…}, ::Const{…}, ::Const{…})
   @ Enzyme ~/.julia/packages/Enzyme/eJcor/src/sugar.jl:273
 [7] top-level scope
   @ REPL[14]:1
Some type information was truncated. Use `show(err)` to see complete types.

julia> show(err)
1-element ExceptionStack:
MethodError: no method matching primal_return_type(::Expr)
The function `primal_return_type` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  primal_return_type(::Mode, ::Type, ::Type)
   @ Enzyme ~/.julia/packages/Enzyme/eJcor/src/typeutils/inference.jl:136

Stacktrace:
 [1] (::Core.GeneratedFunctionStub)(world::UInt64, source::LineNumberNode, args::Any)
   @ Core ./boot.jl:707
 [2] primal_return_type_generator(world::UInt64, source::Any, self::Any, mode::Type, ft::Type, tt::Type)
   @ Enzyme.Compiler ~/.julia/packages/Enzyme/eJcor/src/typeutils/inference.jl:118
 [3] autodiff
   @ ~/.julia/packages/Enzyme/eJcor/src/Enzyme.jl:387 [inlined]
 [4] autodiff
   @ ~/.julia/packages/Enzyme/eJcor/src/Enzyme.jl:542 [inlined]
 [5] macro expansion
   @ ~/.julia/packages/Enzyme/eJcor/src/sugar.jl:286 [inlined]
 [6] gradient(::ReverseMode{false, false, false, FFIABI, false, false}, ::typeof(f), ::Duplicated{PIXIE.IsingCacheCPU{Float64, Int64}}, ::Const{SimpleResponse}, ::Const{SparseVector{Bool, Int64}}, ::Const{Vector{Float64}})
   @ Enzyme ~/.julia/packages/Enzyme/eJcor/src/sugar.jl:273
 [7] top-level scope
   @ REPL[14]:1
julia> 

My bad, I gave you wrong instructions for using Enzyme’s API (the perk of often using DI is that I don’t have to memorize it anymore :sunglasses:). You cannot pass Duplicated to gradient, but you can pass Const. So here it is:

Enzyme.gradient(Reverse, f, p, cc, Const(r), Const(y), Const(x))

This version doesn’t allow you to manage storage for the shadows though. To do that, I think you have to go one level deeper:

dp, dcc = make_zero(p), make_zero(cc)
# always ensure that `dp` is zeroed out before the next call
Enzyme.autodiff(Reverse, f, Duplicated(p, dp), Duplicated(cc, dcc), Const(r), Const(y), Const(x))

I haven’t run it locally but it should be better. My earlier suggestion was an unholy mix of the two approaches.

Does this work? In terms of speed, how does it compare to DI+Enzyme and DI+Mooncake?

I cannot blame you, DI is super handy!

Both of your suggestion threw the very same error:

ERROR: Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
 a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
 b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.
 Failure within method: MethodInstance for PIXIE.deposit!(::Vector{…}, ::SimpleResponse, ::ComponentArrays.ComponentVector{…}, ::Vector{…})
Hint: catch this exception as `err` and call `code_typed(err)` to inspect the errornous code.
If you have Cthulu.jl loaded you can also use `code_typed(err; interactive = true)` to interactively introspect the code.
Mismatched activity for:   store {} addrspace(10)* %3, {} addrspace(10)* addrspace(11)* %.fca.2.gep59, align 8, !dbg !242, !noalias !194 const val: {} addrspace(10)* %3
 value=Unknown object of type Vector{Float64}
 llvalue={} addrspace(10)* %3

Stacktrace:
 [1] tile_halves
   @ ~/.julia/packages/Tullio/2zyFP/src/threads.jl:136
 [2] threader
   @ ~/.julia/packages/Tullio/2zyFP/src/threads.jl:65
 [3] macro expansion
   @ ~/.julia/packages/Tullio/2zyFP/src/macro.jl:1004
 [4] deposit!
   @ ~/Documents/PhD/Alignment/PIXIE/src/covariates_response_models.jl:127

Stacktrace:
  [1] tile_halves
    @ ~/.julia/packages/Tullio/2zyFP/src/threads.jl:136 [inlined]
  [2] threader
    @ ~/.julia/packages/Tullio/2zyFP/src/threads.jl:65 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/Tullio/2zyFP/src/macro.jl:1004 [inlined]
  [4] deposit!
    @ ~/Documents/PhD/Alignment/PIXIE/src/covariates_response_models.jl:127
  [5] npll!
    @ ~/Documents/PhD/Alignment/PIXIE/src/pseudologlikelihood.jl:40
  [6] f
    @ ./REPL[12]:1 [inlined]
  [7] f
    @ ./REPL[12]:0 [inlined]
  [8] diffejulia_f_12219_inner_71wrap
    @ ./REPL[12]:0
  [9] macro expansion
    @ ~/.julia/packages/Enzyme/eJcor/src/compiler.jl:5875 [inlined]
 [10] enzyme_call
    @ ~/.julia/packages/Enzyme/eJcor/src/compiler.jl:5409 [inlined]
 [11] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/eJcor/src/compiler.jl:5295 [inlined]
 [12] autodiff
    @ ~/.julia/packages/Enzyme/eJcor/src/Enzyme.jl:521 [inlined]
 [13] autodiff
    @ ~/.julia/packages/Enzyme/eJcor/src/Enzyme.jl:542 [inlined]
 [14] macro expansion
    @ ~/.julia/packages/Enzyme/eJcor/src/sugar.jl:286 [inlined]
 [15] gradient(::ReverseMode{…}, ::typeof(f), ::ComponentArrays.ComponentVector{…}, ::PIXIE.IsingCacheCPU{…}, ::Const{…}, ::Const{…}, ::Const{…})
    @ Enzyme ~/.julia/packages/Enzyme/eJcor/src/sugar.jl:273
 [16] top-level scope
    @ REPL[18]:1
Some type information was truncated. Use `show(err)` to see complete types.

Looking at my code this is where things fail:

function deposit!(h, R::SimpleResponse, p, x)
    if size(h) β‰  (num_nodes(R),)
        throw(DimensionMismatch("h vector size does not match response model"))
    end
    Tv = eltype(h)
    # Energy deposit response
    hb = p.deposit.bias
    hl = p.deposit.linear
    hq = p.deposit.quadratic
    # Bias part
    h .= hb
    # Adding linear part
    mul!(h, hl, x, one(Tv), one(Tv))
    # Adding quadratic part
    # Element wise quadratic form. Einstein notation
    @tullio h[i] += 0.5 * hq[i, j, k] * x[j] * x[k] # <-- THIS IS THE LINE THAT GIVES ERROR!
    return nothing
end

Is it Tullio known for having issues in integrating with Enzyme?

P.S. Sorry for not providing a MWE, I will try later this week to creating it.

Okay, then if you want to use Enzyme’s native API, you also need to activate runtime activity for the time being:

mode = Enzyme.set_runtime_activity(Reverse)
Enzyme.autodiff(mode f, Duplicated(p, dp), Duplicated(cc, dcc), Const(r), Const(y), Const(x))

My hope is that it shouldn’t be too different from the performance of DI+Enzyme. I’m curious that you get a 2.5x slowdown wrt Mooncake though. If you can share a profiling flame graph with StatProfilerHTML.jl, we might be able to figure out why.

For some reason I wasn’t able to generate an html flame graph. I ended up using ProfileView files.

You can find jlprof files here

Hope it can still help in some way!

Can you explain how I open those? A quick look at the ProfileView.jl README did not reveal the answer

For sharing profiles PProf.jl is quite nice since it

  1. Also shows the C runtime and C libraries
  2. Works with sites like pprof.me and flamegraph.com

(I have been meaning to add a β€œupload profile function” to PProf.jl)

One thing to possibly test is:

Enzyme.autodiff(Reverse, f, Duplicated(p, dp), Duplicated(cc, dcc), Duplicated(r, make_zero(r)), Duplicated(y, make_zero(y)), Duplicated(x, make_zero(x)))

And the variants where some of the arguments are constant and some are duplicated.
E.g. Enzyme might need to propagate a gradient through the shadow of an input argument, but by it being marked Const will fail to do so.

1 Like

@gdalle in the same drive you can find also the .pb.gz profile databases. I hope the content is reasonable and can help!

Regarding

it does not throw any error, but it returns a list of nothing:

julia> Enzyme.autodiff(Reverse, f, Duplicated(p, dp), Duplicated(cc, dcc), Duplicated(r, make_zero(r)), Duplicated(y, make_zero(y)), Duplicated(x, make_zero(x)))
β”Œ Warning: active variables passed by value to jl_new_task are not yet supported
β”” @ Enzyme.Compiler ~/.julia/packages/Enzyme/eJcor/src/rules/parallelrules.jl:726
β”Œ Warning: active variables passed by value to jl_new_task are not yet supported
β”” @ Enzyme.Compiler ~/.julia/packages/Enzyme/eJcor/src/rules/parallelrules.jl:726
((nothing, nothing, nothing, nothing, nothing),)

which I do not think is the expected behaviour. At the moment in my work I will stick to the Mooncake/slowerEnzyme backend, but feel free to ask more questions if I can help you!

That is expected behavior, calling autodiff with duplicated will add the gradient in place in the derivative argument of the duplicated.