I’m new to SciMLSensitivity.jl and am trying to learn the ropes. I wanted to try something simple to find the sensitivities of the solution to a LinearProblem
with respect to parameters. I couldn’t find an example of how to do so in the documentation (perhaps I missed it, or maybe it was too simple to include in the docs). I was wondering if someone had one and/or could tell me what I’m doing wrong here:
using Zygote
using SciMLSensitivity
using ForwardDiff
using LinearSolve
import Random
Random.seed!(1234)
N = 2
function test_func(x::AbstractVector{T}) where {T<:Real}
A = reshape(x[1:N*N], (N,N))
b = x[N*N+1:end]
# This works:
# sol = A\b
# But this seems to not work:
prob = LinearProblem(A, b)
sol = solve(prob)
return sum(sol)
end
# Random Point
x0 = rand(N*N+N)
# Try with Zygote
grad_zygote = Zygote.gradient(test_func, x0)
display(grad_zygote[1])
# Compare with ForwardDiff
grad_forwarddiff = ForwardDiff.gradient(test_func, x0)
display(grad_forwarddiff)
The following error occurs:
ERROR: type Fill has no field u
Stacktrace:
[1] getproperty
@ .\Base.jl:37 [inlined]
[2] (::LinearSolve.var"#∇linear_solve#103"{…})(∂sol::FillArrays.Fill{…})
@ LinearSolve C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\adjoint.jl:58
[3] ZBack
@ C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:211 [inlined]
[4] (::Zygote.var"#291#292"{Tuple{…}, Zygote.ZBack{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
@ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206
[5] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
@ Zygote C:\Users\Mark\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
[6] #solve#5
@ C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\common.jl:188 [inlined]
[7] (::Zygote.Pullback{Tuple{…}, Any})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
@ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[8] #291
@ C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
[9] #2169#back
@ C:\Users\Mark\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
[10] solve
@ C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\common.jl:186 [inlined]
[11] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
@ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[12] #291
@ C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
[13] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
@ Zygote C:\Users\Mark\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
[14] #solve#4
@ C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\common.jl:183 [inlined]
[15] (::Zygote.Pullback{Tuple{…}, Any})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
@ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[16] #291
@ C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
[17] #2169#back
@ C:\Users\Mark\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
[18] solve
@ C:\Users\Mark\.julia\packages\LinearSolve\88qI9\src\common.jl:182 [inlined]
[19] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
@ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[20] test_func
@ c:\GitCode\marklau\sandbox\sensitivity\test_zygote.jl:17 [inlined]
[21] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[22] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:91
[23] gradient(f::Function, args::Vector{Float64})
@ Zygote C:\Users\Mark\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:148
[24] top-level scope
@ c:\GitCode\marklau\sandbox\sensitivity\test_zygote.jl:25
Some type information was truncated. Use `show(err)` to see complete types.
Any help/advice would be greatly appreciated! Thanks!