# Differentiating a function involving the adjoint of a LinearMap (LinearMaps.jl and Zygote.jl)

Hi! I’m trying to write differentiable code using `LinearMaps.jl`. However, I am running into an issue that I just haven’t been able to fix. Below is a minimum working example of this problem.

``````using Zygote
using LinearMaps

struct Xop <: LinearMap{Float64}
a::Float64
end

XopTranspose = LinearMaps.TransposeMap{<:Any, <:Xop}

Base.:(*)(X::Xop, u::AbstractVector) = X.a * u
Base.:(*)(Xt::XopTranspose, u::AbstractVector) = Xt.lmap.a .* u
Base.size(X::Xop) = (4, 4)

X = Xop(3.4)

f(a) = sum(X * a) # simple scalar-value function involving X *
g(a) = sum(X' * a) # simple scalar-value function involving X' *

``````

First, I subtype the `LinearMap` class to create my own operator (which just a uniform scaling). I also define an adjoint type for the operator. I make an instance of the operator, and define a function `f` involving a matrix-vector product, and a function `g` involving matrix-transpose-vector product. The function `f` differentiates just fine, but the function `g` creates an error, as shown in the log:

``````(Fill(3.4, 4),)
Closest candidates are:
...

Stacktrace:
[1] (::Zygote.var"#back#722")(::NamedTuple{(:lmap,),Tuple{NamedTuple{(:a,),Tuple{Float64}}}}) at /home/gaurav/.julia/packages/Zygote/ajuwN/src/lib/array.jl:392
[3] g at ./In[8]:17 [inlined]
[4] (::typeof(∂(g)))(::Float64) at /home/gaurav/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[5] (::Zygote.var"#50#51"{typeof(∂(g))})(::Float64) at /home/gaurav/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41
[7] top-level scope at In[8]:20
``````

Does anyone have any idea how I can fix this issue?

The following code works for me

``````using ForwardDiff
using LinearMaps

struct Xop <: LinearMap{Float64}
a::Float64
end

XopTranspose = LinearMaps.TransposeMap{<:Any, <:Xop}

Base.:(*)(X::Xop, u::AbstractVector) = X.a * u
Base.:(*)(Xt::XopTranspose, u::AbstractVector) = Xt.lmap.a .* u
Base.size(X::Xop) = (4, 4)

X = Xop(3.4)

f(a) = sum(X * a) # simple scalar-value function involving X *
g(a) = sum(X' * a) # simple scalar-value function involving X' *