I’m getting a strange error when using OMEinsum
, CuArrays
and Zygote
together. Here is the MRE:
using CuArrays
using Zygote
using OMEinsum
mutable struct S{T}
c::T
end
function f(α, s, H)
@ein s.c[d,b] := α[t,b] * H[d,t,b]
sum(s.c)
end
α = rand(3, 2)
H = rand(5, 3, 2)
s = S(rand(5, 2))
julia> gradient(x -> f(x, s, H), α) # works
([1.8856141333787124 3.177915606920301; 1.7898891125983922 3.2721594554280564; 1.959869250310523 2.247906876515321],)
Now if I call f
with cu
-ified inputs
julia> gradient(x -> f(x, S(cu(s.c)), cu(H)), cu(α)) # errors
ERROR: MethodError: no method matching +(::Base.RefValue{Any}, ::CuArray{Float32,2,Nothing})
Yet, the forward pass works
julia> f(cu(α), S(cu(s.c)), cu(H))
8.397266f0
Any explanations for what is happening? What would be the workaround here?