Differentiating a function involving the adjoint of a LinearMap (LinearMaps.jl and Zygote.jl)

Hi! I’m trying to write differentiable code using LinearMaps.jl. However, I am running into an issue that I just haven’t been able to fix. Below is a minimum working example of this problem.

using Zygote
using LinearMaps

struct Xop <: LinearMap{Float64}
    a::Float64
end

XopTranspose = LinearMaps.TransposeMap{<:Any, <:Xop}

Base.:(*)(X::Xop, u::AbstractVector) = X.a * u
Base.:(*)(Xt::XopTranspose, u::AbstractVector) = Xt.lmap.a .* u
Base.size(X::Xop) = (4, 4)

X = Xop(3.4)

f(a) = sum(X * a) # simple scalar-value function involving X *
g(a) = sum(X' * a) # simple scalar-value function involving X' *

println(gradient(f, collect(1:4))) # this works
println(gradient(g, collect(1:4))) # this doesn't

First, I subtype the LinearMap class to create my own operator (which just a uniform scaling). I also define an adjoint type for the operator. I make an instance of the operator, and define a function f involving a matrix-vector product, and a function g involving matrix-transpose-vector product. The function f differentiates just fine, but the function g creates an error, as shown in the log:

(Fill(3.4, 4),)
MethodError: no method matching adjoint(::NamedTuple{(:lmap,),Tuple{NamedTuple{(:a,),Tuple{Float64}}}})
Closest candidates are:
  adjoint(::Missing) at missing.jl:100
  adjoint(::ChainRulesCore.NotImplemented) at /home/gaurav/.julia/packages/ChainRulesCore/1LqRD/src/differentials/notimplemented.jl:54
  adjoint(::IRTools.Inner.CFG) at /home/gaurav/.julia/packages/IRTools/46viC/src/passes/passes.jl:29
  ...

Stacktrace:
 [1] (::Zygote.var"#back#722")(::NamedTuple{(:lmap,),Tuple{NamedTuple{(:a,),Tuple{Float64}}}}) at /home/gaurav/.julia/packages/Zygote/ajuwN/src/lib/array.jl:392
 [2] (::Zygote.var"#2930#back#723"{Zygote.var"#back#722"})(::NamedTuple{(:lmap,),Tuple{NamedTuple{(:a,),Tuple{Float64}}}}) at /home/gaurav/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [3] g at ./In[8]:17 [inlined]
 [4] (::typeof(∂(g)))(::Float64) at /home/gaurav/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [5] (::Zygote.var"#50#51"{typeof(∂(g))})(::Float64) at /home/gaurav/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41
 [6] gradient(::Function, ::Array{Int64,1}) at /home/gaurav/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:76
 [7] top-level scope at In[8]:20

Does anyone have any idea how I can fix this issue?

The following code works for me

using ForwardDiff
using LinearMaps

struct Xop <: LinearMap{Float64}
    a::Float64
end

XopTranspose = LinearMaps.TransposeMap{<:Any, <:Xop}

Base.:(*)(X::Xop, u::AbstractVector) = X.a * u
Base.:(*)(Xt::XopTranspose, u::AbstractVector) = Xt.lmap.a .* u
Base.size(X::Xop) = (4, 4)

X = Xop(3.4)

f(a) = sum(X * a) # simple scalar-value function involving X *
g(a) = sum(X' * a) # simple scalar-value function involving X' *

println(ForwardDiff.gradient(f, collect(1:4))) # this works
println(ForwardDiff.gradient(g, collect(1:4))) # this works, too

Looks like a problem in Zygote.jl to me?