Dear all,
since this is my first post, I want to thank you for the help I got in many years I used Julia (I always found the answer to my problems in previous posts!).
Diving into the problem:
what I need to do is to compute the quantity
\mathscr{L}(\theta; \mathbf{x}, \mathbf{y}) = \sum_i \left[\log(1 + e^{\eta_i(\mathbf{x}, \mathbf{y}, \theta)}) + \eta_i(\mathbf{x}, \mathbf{y}, \theta)y_i\right]
where
\mathbf{\eta}(\mathbf{x}, \mathbf{y}, \theta) = \mathbf{h}(\mathbf{x}, \theta) + \mathbf{W}(\mathbf{x}, \theta)\mathbf{y}
with \mathbf{y}\in\{0,1\}^N, x\in\mathbb{R}^M, \mathbf{h}\in\mathbb{R}^N, \eta\in\mathbb{R}^N, \mathbf{W}\in\mathbb{R}^{N\times N} (but heavily sparse) and \theta\in\mathbb{R}^P.
since N can be of the order of 10^5 and I need to compute \mathcal{L} many times, I decided to store in a cache struct \mathbf{h}, \mathbf{W} and the vector of the nonzero elements of the matrix, \mathbf{w}. My implementation of \mathcal{L} looks like this:
function npll!(
mem::AbstractIsingCache,
rsp::AbstractCovariatesResponseModel,
y::AbstractVector{Bool},
p::ComponentArray,
x::AbstractVector{<:Real}
)
# Get references to cache elements
h = deposit_vector(mem)
w = crosstalk_vector(mem)
W = crosstalk_matrix(mem)
# Update cache with the response for this specific data point
deposit!(h, rsp, p, x)
crosstalk!(w, rsp, p, x)
update_crosstalk_matrix!(mem)
#################################################
# Compute local conditional fields `Ξ·` in-place #
#################################################
# The `h` vector is overwritten by `deposit!`, then we add the crosstalk
# influence from neighbors `W*y`. `h` now holds the local fields `Ξ·`.
Tv = eltype(h)
mul!(h, W, y, one(Tv), one(Tv))
########################################################
# Compute the final negative pseudo-log-likelihood sum #
########################################################
# `h` here is equivalent to `Ξ·` from the formula.
@tullio n_p_ll := log1pexp(h[i]) + h[i] * y[i]
#n_p_ll = sum(log1pexp.(h) .+ h .* y)
return -n_p_ll
end
deposit! and crosstalk! are just linear algebra operations on parameters and x vector where I use @tullio to contract indices.
Every update function is tested and works as expected. Now, the issue arises when I try to use DifferentiationInterface to compute the gradient with respect to the parameters. Playing with backends, I found that only Mooncake is able to compute the gradients:
julia> using PIXIE #my package
julia> using DifferentiationInterface
julia> using SparseArrays
julia> s = RectangularPixel(1000,10)
RectangularPixel(1000, 10, :rook)
julia> r = SimpleResponse(s, num_covariates=4)
SimpleResponse(4, 10000, 18990)
julia> c = ConditionalIsing(s, r) # constructor for my model
ConditionalIsing{PIXIE.IsingCachePool{PIXIE.IsingCacheCPU{Float64, Int64}}}(RectangularPixel(1000, 10, :rook), SimpleResponse(4, 10000, 18990), PIXIE.IsingCachePool{PIXIE.IsingCacheCPU{Float64, Int64}}(Channel{PIXIE.IsingCacheCPU{Float64, Int64}}(16), 16))
julia> p = PIXIE.init_parameters(r)
ComponentVector{Float64}(deposit = (bias = [...], linear = [...], quadratic = [...]), crosstalk = (bias = [...], linear = [...]))
julia> cc = take!(PIXIE.pool(c))
PIXIE.IsingCacheCPU{Float64, Int64}([...], [...], [...], PIXIE.CrosstalkEdgesHSCDS{Int64}([1, 10], [1, 10000, 19990], [...]))
julia> y = sprand(Bool, 10_000, 0.001)
10000-element SparseVector{Bool, Int64} with 11 stored entries:
[175 ] = 1
[2224] = 1
[2527] = 1
[3617] = 1
[3684] = 1
[4060] = 1
[4908] = 1
[6762] = 1
[6890] = 1
[7340] = 1
[8740] = 1
julia> x = rand(4)
4-element Vector{Float64}:
0.3524710528059386
0.3988983273927492
0.6920326598241286
0.7996804030096054
julia> f(p, mem, rsp, y, x) = PIXIE.npll!(mem, rsp, y, p, x)
f (generic function with 1 method)
julia> import Mooncake
julia> backend = AutoMooncake()
AutoMooncake()
julia> gradient(f, backend, p, Cache(cc), Constant(r), Constant(y), Constant(x))
Mooncake.Tangent{@NamedTuple{data::Vector{Float64}, axes::Mooncake.NoTangent}}((data = [-0.7969185840561417, -0.11077871561171315, -0.7656696626602588, -0.2705008115642961, -0.40970563708520497, -0.10346557289376614, -0.4827844568393658, -0.5948295114134357, -0.44043476113139235, -0.546995154664162 β¦ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], axes = Mooncake.NoTangent()))
But if I try with Enzyme the computation fails
julia> import Enzyme
julia> backend = AutoEnzyme()
AutoEnzyme()
julia> gradient(f, backend, p, Cache(cc), Constant(r), Constant(y), Constant(x))
ERROR: Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.
Failure within method: MethodInstance for PIXIE.deposit!(::Vector{β¦}, ::SimpleResponse, ::ComponentArrays.ComponentVector{β¦}, ::Vector{β¦})
Hint: catch this exception as `err` and call `code_typed(err)` to inspect the errornous code.
If you have Cthulu.jl loaded you can also use `code_typed(err; interactive = true)` to interactively introspect the code.
Mismatched activity for: store {} addrspace(10)* %3, {} addrspace(10)* addrspace(11)* %.fca.2.gep59, align 8, !dbg !242, !noalias !194 const val: {} addrspace(10)* %3
value=Unknown object of type Vector{Float64}
llvalue={} addrspace(10)* %3
Stacktrace:
[1] tile_halves
@ ~/.julia/packages/Tullio/2zyFP/src/threads.jl:136
[2] threader
@ ~/.julia/packages/Tullio/2zyFP/src/threads.jl:65
[3] macro expansion
@ ~/.julia/packages/Tullio/2zyFP/src/macro.jl:1004
[4] deposit!
@ ~/Documents/PhD/Alignment/PIXIE/src/covariates_response_models.jl:127
Stacktrace:
[1] tile_halves
@ ~/.julia/packages/Tullio/2zyFP/src/threads.jl:136 [inlined]
[2] threader
@ ~/.julia/packages/Tullio/2zyFP/src/threads.jl:65 [inlined]
[3] macro expansion
@ ~/.julia/packages/Tullio/2zyFP/src/macro.jl:1004 [inlined]
[4] deposit!
@ ~/Documents/PhD/Alignment/PIXIE/src/covariates_response_models.jl:127
[5] npll!
@ ~/Documents/PhD/Alignment/PIXIE/src/pseudologlikelihood.jl:40
[6] f
@ ./REPL[12]:1 [inlined]
[7] f
@ ./REPL[12]:0 [inlined]
[8] diffejulia_f_61625_inner_71wrap
@ ./REPL[12]:0
[9] macro expansion
@ ~/.julia/packages/Enzyme/eJcor/src/compiler.jl:5875 [inlined]
[10] enzyme_call
@ ~/.julia/packages/Enzyme/eJcor/src/compiler.jl:5409 [inlined]
[11] CombinedAdjointThunk
@ ~/.julia/packages/Enzyme/eJcor/src/compiler.jl:5295 [inlined]
[12] autodiff
@ ~/.julia/packages/Enzyme/eJcor/src/Enzyme.jl:521 [inlined]
[13] autodiff
@ ~/.julia/packages/Enzyme/eJcor/src/Enzyme.jl:542 [inlined]
[14] gradient(::typeof(f), ::DifferentiationInterfaceEnzymeExt.EnzymeGradientPrep{β¦}, ::AutoEnzyme{β¦}, ::ComponentArrays.ComponentVector{β¦}, ::Cache{β¦}, ::Constant{β¦}, ::Constant{β¦}, ::Constant{β¦})
@ DifferentiationInterfaceEnzymeExt ~/.julia/packages/DifferentiationInterface/zJHX8/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl:303
[15] gradient(::typeof(f), ::AutoEnzyme{β¦}, ::ComponentArrays.ComponentVector{β¦}, ::Cache{β¦}, ::Constant{β¦}, ::Constant{β¦}, ::Constant{β¦})
@ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/zJHX8/src/first_order/gradient.jl:63
[16] top-level scope
@ REPL[18]:1
Some type information was truncated. Use `show(err)` to see complete types.
julia> try
gradient(f, backend, p, Cache(cc), Constant(r), Constant(y), Constant(x))
catch err
code_typed(err)
end
1-element Vector{Any}:
CodeInfo(
1 βββ %1 = Base.getfield(h, :size)::Tuple{Int64}
β %2 = Base.getfield(R, :num_nodes)::Int64
β %3 = $(Expr(:boundscheck, true))::Bool
β %4 = Base.getfield(%1, 1, %3)::Int64
β %5 = (%4 === %2)::Bool
β %6 = (%5 === false)::Bool
βββββ goto #3 if not %6
2 βββ goto #4
3 βββ goto #4
4 βββ %10 = Ο (#2 => false, #3 => true)::Bool
βββββ goto #5
5 βββ %12 = Base.not_int(%10)::Bool
βββββ goto #6
6 βββ goto #8 if not %12
7 βββ %15 = invoke PIXIE.DimensionMismatch("h vector size does not match response model"::String)::DimensionMismatch
β PIXIE.throw(%15)::Union{}
βββββ unreachable
8 βββ %18 = ComponentArrays.getfield(p, :data)::Vector{Float64}
β %19 = $(Expr(:boundscheck, true))::Bool
βββββ goto #13 if not %19
9 βββ %21 = Base.getfield(%18, :size)::Tuple{Int64}
β %22 = $(Expr(:boundscheck, true))::Bool
β %23 = Base.getfield(%21, 1, %22)::Int64
β %24 = Base.bitcast(UInt64, %23)::UInt64
β %25 = Base.ult_int(0x0000000000000000, %24)::Bool
β %26 = Base.bitcast(UInt64, %23)::UInt64
β %27 = Base.ult_int(0x000000000003344f, %26)::Bool
β %28 = Base.and_int(%25, %27)::Bool
β %29 = Base.or_int(false, %28)::Bool
βββββ goto #11 if not %29
10 ββ goto #12
11 ββ invoke Base.throw_boundserror(%18::Vector{Float64}, (1:210000,)::Tuple{UnitRange{Int64}})::Union{}
βββββ unreachable
12 ββ nothing::Nothing
13 ββ goto #14
14 ββ goto #15
15 ββ goto #16
16 ββ goto #17
17 ββ %39 = $(Expr(:boundscheck, true))::Bool
βββββ goto #19 if not %39
18 ββ nothing::Nothing
19 ββ %42 = %new(SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, %18, (1:10000,), 0, 1)::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}
βββββ goto #20
20 ββ goto #21
21 ββ goto #22
22 ββ goto #23
23 ββ %47 = ComponentArrays.getfield(p, :data)::Vector{Float64}
β %48 = $(Expr(:boundscheck, true))::Bool
βββββ goto #28 if not %48
24 ββ %50 = Base.getfield(%47, :size)::Tuple{Int64}
β %51 = $(Expr(:boundscheck, true))::Bool
β %52 = Base.getfield(%50, 1, %51)::Int64
β %53 = Base.bitcast(UInt64, %52)::UInt64
β %54 = Base.ult_int(0x0000000000000000, %53)::Bool
β %55 = Base.bitcast(UInt64, %52)::UInt64
β %56 = Base.ult_int(0x000000000003344f, %55)::Bool
β %57 = Base.and_int(%54, %56)::Bool
β %58 = Base.or_int(false, %57)::Bool
βββββ goto #26 if not %58
25 ββ goto #27
26 ββ invoke Base.throw_boundserror(%47::Vector{Float64}, (1:210000,)::Tuple{UnitRange{Int64}})::Union{}
βββββ unreachable
27 ββ nothing::Nothing
28 ββ goto #29
29 ββ goto #30
30 ββ goto #31
31 ββ goto #32
32 ββ %68 = $(Expr(:boundscheck, true))::Bool
βββββ goto #34 if not %68
33 ββ nothing::Nothing
34 ββ %71 = %new(SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, %47, (10001:50000,), 10000, 1)::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}
βββββ goto #35
35 ββ goto #36
36 ββ %74 = %new(Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, %71, (10000, 4), ())::Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}
βββββ goto #37
37 ββ goto #38
38 ββ %77 = ComponentArrays.getfield(p, :data)::Vector{Float64}
β %78 = $(Expr(:boundscheck, true))::Bool
βββββ goto #43 if not %78
39 ββ %80 = Base.getfield(%77, :size)::Tuple{Int64}
β %81 = $(Expr(:boundscheck, true))::Bool
β %82 = Base.getfield(%80, 1, %81)::Int64
β %83 = Base.bitcast(UInt64, %82)::UInt64
β %84 = Base.ult_int(0x0000000000000000, %83)::Bool
β %85 = Base.bitcast(UInt64, %82)::UInt64
β %86 = Base.ult_int(0x000000000003344f, %85)::Bool
β %87 = Base.and_int(%84, %86)::Bool
β %88 = Base.or_int(false, %87)::Bool
βββββ goto #41 if not %88
40 ββ goto #42
41 ββ invoke Base.throw_boundserror(%77::Vector{Float64}, (1:210000,)::Tuple{UnitRange{Int64}})::Union{}
βββββ unreachable
42 ββ nothing::Nothing
43 ββ goto #44
44 ββ goto #45
45 ββ goto #46
46 ββ goto #47
47 ββ %98 = $(Expr(:boundscheck, true))::Bool
βββββ goto #49 if not %98
48 ββ nothing::Nothing
49 ββ %101 = %new(SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, %77, (50001:210000,), 50000, 1)::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}
βββββ goto #50
50 ββ goto #51
51 ββ %104 = %new(Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, %101, (10000, 4, 4), ())::Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}
βββββ goto #52
52 ββ goto #53
53 ββ %107 = Base.getfield(h, :size)::Tuple{Int64}
β %108 = $(Expr(:boundscheck, true))::Bool
β %109 = Base.getfield(%107, 1, %108)::Int64
β %110 = %new(Base.OneTo{Int64}, %109)::Base.OneTo{Int64}
β %111 = Core.tuple(%110)::Tuple{Base.OneTo{Int64}}
β %112 = (%109 === 10000)::Bool
βββββ goto #55 if not %112
54 ββ goto #56
55 ββ goto #56
56 ββ %116 = Ο (#54 => %112, #55 => false)::Bool
βββββ goto #58 if not %116
57 ββ goto #59
58 ββ %119 = invoke Base.Broadcast.DimensionMismatch("array could not be broadcast to match destination"::String)::DimensionMismatch
β Base.Broadcast.throw(%119)::Union{}
βββββ unreachable
59 ββ goto #60
60 ββ goto #61
61 ββ %124 = Base.getfield(h, :size)::Tuple{Int64}
β %125 = $(Expr(:boundscheck, true))::Bool
β %126 = Base.getfield(%124, 1, %125)::Int64
β %127 = %new(Base.OneTo{Int64}, %126)::Base.OneTo{Int64}
β %128 = Core.tuple(%127)::Tuple{Base.OneTo{Int64}}
β %129 = (%126 === %109)::Bool
β %130 = (%129 === false)::Bool
βββββ goto #63 if not %130
62 ββ goto #64
63 ββ goto #64
64 ββ %134 = Ο (#62 => false, #63 => true)::Bool
βββββ goto #65
65 ββ goto #153 if not %134
66 ββ %137 = (%126 === 10000)::Bool
β %138 = (%137 === false)::Bool
βββββ goto #68 if not %138
67 ββ goto #69
68 ββ goto #69
69 ββ %142 = Ο (#67 => false, #68 => true)::Bool
βββββ goto #70
70 ββ goto #72 if not %142
71 ββ invoke Base.copyto!(h::Vector{Float64}, %42::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true})::Any
βββββ goto #154
72 ββ %147 = Base.getfield(h, :size)::Tuple{Int64}
β %148 = $(Expr(:boundscheck, true))::Bool
β %149 = Base.getfield(%147, 1, %148)::Int64
β %150 = (%149 === 0)::Bool
β %151 = Base.not_int(%150)::Bool
βββββ goto #85 if not %151
73 ββ %153 = Base.getfield(h, :ref)::MemoryRef{Float64}
β %154 = Base.getfield(%153, :mem)::Memory{Float64}
β %155 = $(Expr(:foreigncall, :(:jl_genericmemory_owner), Any, svec(Any), 0, :(:ccall), :(%154)))::Any
β %156 = (%155 isa Memory{Float64})::Bool
βββββ goto #75 if not %156
74 ββ %158 = Ο (%155, Memory{Float64})
βββββ goto #76
75 ββ nothing::Nothing
76 ββ %161 = Ο (#74 => %158, #75 => %154)::Memory{Float64}
β %162 = Base.getfield(%161, :ptr)::Ptr{Nothing}
β %163 = Base.bitcast(Ptr{Float64}, %162)::Ptr{Float64}
β %164 = Core.bitcast(Core.UInt, %163)::UInt64
βββββ goto #77
77 ββ goto #78
78 ββ %167 = Base.getfield(%18, :ref)::MemoryRef{Float64}
β %168 = Base.getfield(%167, :mem)::Memory{Float64}
β %169 = $(Expr(:foreigncall, :(:jl_genericmemory_owner), Any, svec(Any), 0, :(:ccall), :(%168)))::Any
β %170 = (%169 isa Memory{Float64})::Bool
βββββ goto #80 if not %170
79 ββ %172 = Ο (%169, Memory{Float64})
βββββ goto #81
80 ββ nothing::Nothing
81 ββ %175 = Ο (#79 => %172, #80 => %168)::Memory{Float64}
β %176 = Base.getfield(%175, :ptr)::Ptr{Nothing}
β %177 = Base.bitcast(Ptr{Float64}, %176)::Ptr{Float64}
β %178 = Core.bitcast(Core.UInt, %177)::UInt64
βββββ goto #82
82 ββ goto #83
83 ββ goto #84
84 ββ %182 = (%164 === %178)::Bool
β %183 = Base.not_int(%182)::Bool
β %184 = Base.not_int(%183)::Bool
βββββ goto #86
85 ββ goto #86
86 ββ %187 = Ο (#84 => %184, #85 => false)::Bool
βββββ goto #107 if not %187
87 ββ %189 = $(Expr(:foreigncall, :(:jl_alloc_genericmemory), Ref{Memory{Float64}}, svec(Any, Int64), 0, :(:ccall), Memory{Float64}, 10000, 10000))::Memory{Float64}
β %190 = Core.memoryrefnew(%189)::MemoryRef{Float64}
β %191 = %new(Vector{Float64}, %190, (10000,))::Vector{Float64}
β %192 = $(Expr(:boundscheck, true))::Bool
βββββ goto #89 if not %192
88 ββ nothing::Nothing
89 ββ %195 = $(Expr(:boundscheck, false))::Bool
βββββ goto #91 if not %195
90 ββ nothing::Nothing
91 ββ goto #92
92 ββ goto #93
93 ββ goto #94
94 ββ goto #95
95 ββ goto #96
96 ββ goto #97
97 ββ goto #98
98 ββ goto #99
99 ββ goto #100
100 β %207 = %new(SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, %18, (1:10000,), 0, 1)::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}
βββββ goto #101
101 β goto #102
102 β goto #103
103 β goto #104
104 β goto #105
105 β invoke Base.copyto!(%191::Vector{Float64}, %207::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true})::Any
β %214 = %new(SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, %191, (1:10000,), 0, 1)::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}
βββββ goto #106
106 β goto #108
107 β goto #108
108 β %218 = Ο (#106 => %191, #107 => %18)::Vector{Float64}
β %219 = Ο (#106 => %214, #107 => %42)::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}
βββββ goto #109
109 β %221 = %new(Base.Broadcast.Extruded{SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{Bool}, Tuple{Int64}}, %219, (true,), (1,))::Base.Broadcast.Extruded{SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{Bool}, Tuple{Int64}}
βββββ goto #110
110 β %223 = Core.tuple(%221)::Tuple{Base.Broadcast.Extruded{SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{Bool}, Tuple{Int64}}}
βββββ goto #111
111 β %225 = %new(Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}}, typeof(identity), Tuple{Base.Broadcast.Extruded{SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{Bool}, Tuple{Int64}}}}, nothing, identity, %223, %111)::Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}}, typeof(identity), Tuple{Base.Broadcast.Extruded{SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{Bool}, Tuple{Int64}}}}
βββββ goto #112
112 β %227 = Base.slt_int(0, %109)::Bool
βββββ goto #151 if not %227
113 β nothing::Nothing
114 β %230 = Ο (#113 => 0, #150 => %317)::Int64
β %231 = Base.slt_int(%230, %109)::Bool
βββββ goto #151 if not %231
115 β %233 = Base.add_int(%230, 1)::Int64
β %234 = $(Expr(:boundscheck, false))::Bool
βββββ goto #120 if not %234
116 β %236 = Core.tuple(%233)::Tuple{Int64}
β %237 = Base.sub_int(%233, 1)::Int64
β %238 = Base.bitcast(UInt64, %237)::UInt64
β %239 = Base.bitcast(UInt64, %109)::UInt64
β %240 = Base.ult_int(%238, %239)::Bool
βββββ goto #118 if not %240
117 β goto #119
118 β invoke Base.throw_boundserror(%110::Base.OneTo{Int64}, %236::Tuple{Int64})::Union{}
βββββ unreachable
119 β nothing::Nothing
120 β goto #121
121 β goto #122
122 β goto #123
123 β %249 = $(Expr(:boundscheck, false))::Bool
βββββ goto #130 if not %249
124 β goto #125
125 β %252 = Base.sub_int(%233, 1)::Int64
β %253 = Base.bitcast(UInt64, %252)::UInt64
β %254 = Base.bitcast(UInt64, %109)::UInt64
β %255 = Base.ult_int(%253, %254)::Bool
β %256 = Base.and_int(%255, true)::Bool
βββββ goto #126
126 β goto #128 if not %256
127 β goto #129
128 β %260 = Core.tuple(%233)::Tuple{Int64}
β invoke Base.throw_boundserror(%225::Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}}, typeof(identity), Tuple{Base.Broadcast.Extruded{SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{Bool}, Tuple{Int64}}}}, %260::Tuple{Int64})::Union{}
βββββ unreachable
129 β nothing::Nothing
130 β %264 = $(Expr(:boundscheck, false))::Bool
βββββ goto #135 if not %264
131 β %266 = Core.tuple(%233)::Tuple{Int64}
β %267 = Base.sub_int(%233, 1)::Int64
β %268 = Base.bitcast(UInt64, %267)::UInt64
β %269 = Base.ult_int(%268, 0x0000000000002710)::Bool
βββββ goto #133 if not %269
132 β goto #134
133 β invoke Base.throw_boundserror(%219::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, %266::Tuple{Int64})::Union{}
βββββ unreachable
134 β nothing::Nothing
135 β %275 = Base.add_int(0, %233)::Int64
β %276 = $(Expr(:boundscheck, false))::Bool
βββββ goto #139 if not %276
136 β %278 = Base.sub_int(%275, 1)::Int64
β %279 = Base.bitcast(Base.UInt, %278)::UInt64
β %280 = Base.getfield(%218, :size)::Tuple{Int64}
β %281 = $(Expr(:boundscheck, true))::Bool
β %282 = Base.getfield(%280, 1, %281)::Int64
β %283 = Base.bitcast(Base.UInt, %282)::UInt64
β %284 = Base.ult_int(%279, %283)::Bool
βββββ goto #138 if not %284
137 β goto #139
138 β %287 = Core.tuple(%275)::Tuple{Int64}
β invoke Base.throw_boundserror(%218::Vector{Float64}, %287::Tuple{Int64})::Union{}
βββββ unreachable
139 β %290 = Base.getfield(%218, :ref)::MemoryRef{Float64}
β %291 = Base.memoryrefnew(%290, %275, false)::MemoryRef{Float64}
β %292 = Base.memoryrefget(%291, :not_atomic, false)::Float64
βββββ goto #140
140 β goto #141
141 β goto #142
142 β goto #143
143 β goto #144
144 β goto #145
145 β %299 = $(Expr(:boundscheck, false))::Bool
βββββ goto #149 if not %299
146 β %301 = Base.sub_int(%233, 1)::Int64
β %302 = Base.bitcast(UInt64, %301)::UInt64
β %303 = Base.getfield(h, :size)::Tuple{Int64}
β %304 = $(Expr(:boundscheck, true))::Bool
β %305 = Base.getfield(%303, 1, %304)::Int64
β %306 = Base.bitcast(UInt64, %305)::UInt64
β %307 = Base.ult_int(%302, %306)::Bool
βββββ goto #148 if not %307
147 β goto #149
148 β %310 = Core.tuple(%233)::Tuple{Int64}
β invoke Base.throw_boundserror(h::Vector{Float64}, %310::Tuple{Int64})::Union{}
βββββ unreachable
149 β %313 = Base.getfield(h, :ref)::MemoryRef{Float64}
β %314 = Base.memoryrefnew(%313, %233, false)::MemoryRef{Float64}
β Base.memoryrefset!(%314, %292, :not_atomic, false)::Float64
βββββ goto #150
150 β %317 = Base.add_int(%230, 1)::Int64
β $(Expr(:loopinfo, Symbol("julia.simdloop"), nothing))::Nothing
βββββ goto #114
151 β goto #152
152 β goto #154
153 β invoke Base.Broadcast.throwdm(%128::Tuple{Base.OneTo{Int64}}, %111::Tuple{Base.OneTo{Int64}})::Union{}
βββββ unreachable
154 β goto #155
155 β goto #156
156 β goto #157
157 β %327 = Base.getfield(x, :size)::Tuple{Int64}
β %328 = $(Expr(:boundscheck, true))::Bool
β %329 = Base.getfield(%327, 1, %328)::Int64
β %330 = (4 === %329)::Bool
β %331 = Base.not_int(%330)::Bool
βββββ goto #159 if not %331
158 β %333 = Base.getfield(x, :size)::Tuple{Int64}
β %334 = $(Expr(:boundscheck, true))::Bool
β %335 = Base.getfield(%333, 1, %334)::Int64
β %336 = Core.tuple("second dimension of A, ", 4, ", does not match length of x, ", %335)::Tuple{String, Int64, String, Int64}
β %337 = %new(Base.LazyString, %336, Base.nothing)::LazyString
β %338 = %new(Base.DimensionMismatch, %337)::DimensionMismatch
β LinearAlgebra.throw(%338)::Union{}
βββββ unreachable
159 β %341 = Base.getfield(h, :size)::Tuple{Int64}
β %342 = $(Expr(:boundscheck, true))::Bool
β %343 = Base.getfield(%341, 1, %342)::Int64
β %344 = (10000 === %343)::Bool
β %345 = Base.not_int(%344)::Bool
βββββ goto #161 if not %345
160 β %347 = Base.getfield(h, :size)::Tuple{Int64}
β %348 = $(Expr(:boundscheck, true))::Bool
β %349 = Base.getfield(%347, 1, %348)::Int64
β %350 = Core.tuple("first dimension of A, ", 10000, ", does not match length of y, ", %349)::Tuple{String, Int64, String, Int64}
β %351 = %new(Base.LazyString, %350, Base.nothing)::LazyString
β %352 = %new(Base.DimensionMismatch, %351)::DimensionMismatch
β LinearAlgebra.throw(%352)::Union{}
βββββ unreachable
161 β invoke LinearAlgebra.BLAS.gemv!('N'::Char, 1.0::Float64, %74::Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, x::Vector{Float64}, 1.0::Float64, h::Vector{Float64})::Vector{Float64}
βββββ goto #162
162 β goto #163
163 β goto #164
164 β goto #165
165 β nothing::Nothing
β nothing::Nothing
β %362 = Base.getfield(x, :size)::Tuple{Int64}
β %363 = $(Expr(:boundscheck, true))::Bool
β %364 = Base.getfield(%362, 1, %363)::Int64
β %365 = (%364 === 4)::Bool
βββββ goto #182 if not %365
166 β nothing::Nothing
β %368 = Base.getfield(x, :size)::Tuple{Int64}
β %369 = $(Expr(:boundscheck, true))::Bool
β %370 = Base.getfield(%368, 1, %369)::Int64
β %371 = (%370 === 4)::Bool
βββββ goto #181 if not %371
167 β nothing::Nothing
β %374 = Base.getfield(h, :size)::Tuple{Int64}
β %375 = $(Expr(:boundscheck, true))::Bool
β %376 = Base.getfield(%374, 1, %375)::Int64
β %377 = Base.getfield(h, :size)::Tuple{Int64}
β %378 = $(Expr(:boundscheck, true))::Bool
β %379 = Base.getfield(%377, 1, %378)::Int64
β %380 = (10000 === %379)::Bool
βββββ goto #180 if not %380
168 β nothing::Nothing
β %383 = Base.sle_int(1, %376)::Bool
βββββ goto #170 if not %383
169 β goto #171
170 β goto #171
171 β %387 = Ο (#169 => %376, #170 => 0)::Int64
β %388 = %new(UnitRange{Int64}, 1, %387)::UnitRange{Int64}
βββββ goto #172
172 β goto #173
173 β goto #174
174 β %392 = Core.tuple(%388)::Tuple{UnitRange{Int64}}
βββββ goto #175
175 β %394 = Base.sub_int(%387, 1)::Int64
β %395 = Base.add_int(1, %394)::Int64
β %396 = invoke Base.Threads.nthreads()::Int64
β %397 = Base.mul_int(%395, 16)::Int64
β %398 = Base.checked_sdiv_int(%397, 262144)::Int64
β %399 = Base.slt_int(0, %397)::Bool
β %400 = (%399 === true)::Bool
β %401 = Base.mul_int(%398, 262144)::Int64
β %402 = (%401 === %397)::Bool
β %403 = Base.not_int(%402)::Bool
β %404 = Base.and_int(%400, %403)::Bool
β %405 = Core.zext_int(Core.Int64, %404)::Int64
β %406 = Core.and_int(%405, 1)::Int64
β %407 = Base.add_int(%398, %406)::Int64
β %408 = Base.slt_int(%407, %396)::Bool
β %409 = Core.ifelse(%408, %407, %396)::Int64
β %410 = Base.slt_int(%395, %409)::Bool
β %411 = Core.ifelse(%410, %395, %409)::Int64
β %412 = Base.slt_int(1, %411)::Bool
βββββ goto #177 if not %412
176 β %414 = Core.tuple(h, %104, x)::Tuple{Vector{Float64}, Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{Float64}}
β invoke Tullio.thread_halves(PIXIE.var"#ππΈπ!#19"()::PIXIE.var"#ππΈπ!#19", Vector{Float64}::Type{Vector{Float64}}, %414::Tuple{Vector{Float64}, Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{Float64}}, %392::Tuple{UnitRange{Int64}}, (1:4, 1:4)::Tuple{UnitRange{Int64}, UnitRange{Int64}}, %411::Int64, true::Bool)::Any
βββββ goto #178
177 β %417 = Core.tuple(h, %104, x)::Tuple{Vector{Float64}, Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{Float64}}
β %418 = Tullio.tile_halves::typeof(Tullio.tile_halves)
βββββ invoke %418(PIXIE.var"#ππΈπ!#19"()::PIXIE.var"#ππΈπ!#19", Vector{Float64}::Type{Vector{Float64}}, %417::Tuple{Vector{Float64}, Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Vector{Float64}}, %392::Tuple{UnitRange{Int64}}, (1:4, 1:4)::Tuple{UnitRange{Int64}, UnitRange{Int64}}, true::Bool, true::Bool)::Nothing
178 β goto #179
179 β return PIXIE.nothing
180 β (throw)("range of index i must agree")::Union{}
βββββ unreachable
181 β (throw)("range of index k must agree")::Union{}
βββββ unreachable
182 β (throw)("range of index j must agree")::Union{}
βββββ unreachable
) => Nothing
julia>
Following the suggestions in this thread I tried to define f as a clojure but the issue persists.
And all the other backends that support Cache() fail as well.
Do you have any suggestion on how to improve the design to be more AD friendly while avoiding creating big temporary vectors during the computation? Benchmarks confirm that I have considerable performance gains by using the cache.
Thank you all for the help!