Hi! I’m trying to write differentiable code using LinearMaps.jl
. However, I am running into an issue that I just haven’t been able to fix. Below is a minimum working example of this problem.
using Zygote
using LinearMaps
struct Xop <: LinearMap{Float64}
a::Float64
end
XopTranspose = LinearMaps.TransposeMap{<:Any, <:Xop}
Base.:(*)(X::Xop, u::AbstractVector) = X.a * u
Base.:(*)(Xt::XopTranspose, u::AbstractVector) = Xt.lmap.a .* u
Base.size(X::Xop) = (4, 4)
X = Xop(3.4)
f(a) = sum(X * a) # simple scalar-value function involving X *
g(a) = sum(X' * a) # simple scalar-value function involving X' *
println(gradient(f, collect(1:4))) # this works
println(gradient(g, collect(1:4))) # this doesn't
First, I subtype the LinearMap
class to create my own operator (which just a uniform scaling). I also define an adjoint type for the operator. I make an instance of the operator, and define a function f
involving a matrix-vector product, and a function g
involving matrix-transpose-vector product. The function f
differentiates just fine, but the function g
creates an error, as shown in the log:
(Fill(3.4, 4),)
MethodError: no method matching adjoint(::NamedTuple{(:lmap,),Tuple{NamedTuple{(:a,),Tuple{Float64}}}})
Closest candidates are:
adjoint(::Missing) at missing.jl:100
adjoint(::ChainRulesCore.NotImplemented) at /home/gaurav/.julia/packages/ChainRulesCore/1LqRD/src/differentials/notimplemented.jl:54
adjoint(::IRTools.Inner.CFG) at /home/gaurav/.julia/packages/IRTools/46viC/src/passes/passes.jl:29
...
Stacktrace:
[1] (::Zygote.var"#back#722")(::NamedTuple{(:lmap,),Tuple{NamedTuple{(:a,),Tuple{Float64}}}}) at /home/gaurav/.julia/packages/Zygote/ajuwN/src/lib/array.jl:392
[2] (::Zygote.var"#2930#back#723"{Zygote.var"#back#722"})(::NamedTuple{(:lmap,),Tuple{NamedTuple{(:a,),Tuple{Float64}}}}) at /home/gaurav/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[3] g at ./In[8]:17 [inlined]
[4] (::typeof(∂(g)))(::Float64) at /home/gaurav/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[5] (::Zygote.var"#50#51"{typeof(∂(g))})(::Float64) at /home/gaurav/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41
[6] gradient(::Function, ::Array{Int64,1}) at /home/gaurav/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:76
[7] top-level scope at In[8]:20
Does anyone have any idea how I can fix this issue?