Looking for SciMLSensitivity Example with LinearSolve

I’m new to SciMLSensitivity.jl and am trying to learn the ropes. I wanted to try something simple to find the sensitivities of the solution to a LinearProblem with respect to parameters. I couldn’t find an example of how to do so in the documentation (perhaps I missed it, or maybe it was too simple to include in the docs). I was wondering if someone had one and/or could tell me what I’m doing wrong here:

using Zygote
using SciMLSensitivity
using ForwardDiff
using LinearSolve
import Random
Random.seed!(1234)

N = 2

function test_func(x::AbstractVector{T}) where {T<:Real}
    A = reshape(x[1:N*N], (N,N))
    b = x[N*N+1:end]
    # This works:
    # sol = A\b
    # But this seems to not work:
    prob = LinearProblem(A, b)
    sol = solve(prob)
    return sum(sol)
end

# Random Point
x0 = rand(N*N+N)

# Try with Zygote
grad_zygote = Zygote.gradient(test_func, x0)
display(grad_zygote[1])

# Compare with ForwardDiff
grad_forwarddiff = ForwardDiff.gradient(test_func, x0)
display(grad_forwarddiff)

The following error occurs:

ERROR: type Fill has no field u
Stacktrace:
  [1] getproperty
    @ .\Base.jl:37 [inlined]
  [2] (::LinearSolve.var"#∇linear_solve#103"{…})(∂sol::FillArrays.Fill{…})
    @ LinearSolve C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\adjoint.jl:58
  [3] ZBack
    @ C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:211 [inlined]
  [4] (::Zygote.var"#291#292"{Tuple{…}, Zygote.ZBack{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206
  [5] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
  [6] #solve#5
    @ C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\common.jl:188 [inlined]
  [7] (::Zygote.Pullback{Tuple{…}, Any})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
  [8] #291
    @ C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
  [9] #2169#back
    @ C:\Users\Mark\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
 [10] solve
    @ C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\common.jl:186 [inlined]
 [11] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [12] #291
    @ C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
 [13] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
 [14] #solve#4
    @ C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\common.jl:183 [inlined]
 [15] (::Zygote.Pullback{Tuple{…}, Any})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [16] #291
    @ C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
 [17] #2169#back
    @ C:\Users\Mark\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
 [18] solve
    @ C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\common.jl:182 [inlined]
 [19] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [20] test_func
    @ c:\GitCode\marklau\sandbox\sensitivity\test_zygote.jl:17 [inlined]
 [21] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [22] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:91
 [23] gradient(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:148
 [24] top-level scope
    @ c:\GitCode\marklau\sandbox\sensitivity\test_zygote.jl:25
Some type information was truncated. Use `show(err)` to see complete types.

Any help/advice would be greatly appreciated! Thanks!

Maybe I should have read the GitHub issues more thoroughly… Is this somehow related to this: Adjoints for LinearSolve · Issue #832 · SciML/SciMLSensitivity.jl (github.com) in the sense that this functionality is work in progress?

Can you open an issue in the LinearSolve.jl repo?

As a temporary solution return sum(sol.u) should work.

This is likely a bug in not handling getindex(sol, sym) rrule correctly.

Thanks! I’ve opened the issue LinearSolve with SciMLSensitivity Solution Handling Requires sol.u · Issue #483 · SciML/LinearSolve.jl (github.com).