Is it possible, perhaps by defining an appropriate pullback, to differentiate expressions that involve a LinearMap
object? Here’s a small example:
using Zygote, LinearMaps
m, n = 2, 3; r = ones(m); x = ones(n)
A = LinearMap(ones(m, n))
f(x) = norm(A*x)^2/2
g(r) = norm(A'*r)^2/2
Differentiating throws an error. Here’s a partial stacktrace:
julia> f'(x) # should evaluate to [6.0, 6.0, 6.0]
ERROR: Can't differentiate foreigncall expression
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] gemv! at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/blas.jl:626 [inlined]
[3] (::typeof(∂(gemv!)))(::Array{Float64,1}) at /Users/mpf/.julia/packages/Zygote/pM10l/src/compiler/interface2.jl:0
[4] gemv! at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/matmul.jl:470 [inlined]
[5] (::typeof(∂(gemv!)))(::Array{Float64,1}) at /Users/mpf/.julia/packages/Zygote/pM10l/src/compiler/interface2.jl:0
[6] mul! at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/matmul.jl:66 [inlined]
[7] mul! at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/matmul.jl:208 [inlined]
[8] _unsafe_mul! at /Users/mpf/.julia/packages/LinearMaps/rw0V6/src/LinearMaps.jl:234 [inlined]
A similar error occurs when evaluating g'(r)
.