Hi,
I’m trying to get started using Zygote in a numerical simulation. Basically I need to differentiate through a Laplace solver and am getting an error which I don’t know how to handle.
Here is the code:
using Plots
using Statistics
using LinearAlgebra
using Zygote
function invert_laplace(ρ, Nz, Lz)
# Solve Laplace equation
# ρ: RHS function
# Nz: Number of grid points
# Lz: Length of the domain
Δz = Lz / Nz
zrg = range(0.0, step=Δz, length=Nz)
# Squeeze a minus in here.
invΔz² = -1.0 / Δz / Δz
# Periodic boundary conditions. First row of A is just [1, 0, ... 0]
A = diagm(0 => -2.0 * invΔz² * vcat([-0.5 / invΔz²], ones(Nz-1)),
1 => invΔz² * vcat([0.0], ones(Nz - 2)),
-1 => invΔz² * ones(Nz-1), -Nz+1 => [invΔz²])
# Fix phi(x=0) = 0.
ρvec = vcat([0.0], ρ(zrg[2:end]))
ϕ_num = A \ ρvec
return(ϕ_num)
end
function my_laplace_test(a)
# Use a RHS that scales linear with and
# solving ∂ₓ² ϕ = -a ρ is linear in a
# maximum(ϕ_num) should be a.
Nz = 256
Lz = 1.0
ρ = z -> a * 4. * π * π * sin.(2π * z)
ϕ_num = invert_laplace(ρ, Nz, Lz)
return maximum(ϕ_num)
end
# Now let's try to differentiate through the solver:
res = gradient(my_laplace_test, 1.0)
But this gives me an error:
ERROR: LoadError: type NamedTuple has no field first
Stacktrace:
[1] getproperty(::NamedTuple{(:second,),Tuple{Array{Float64,1}}}, ::Symbol) at ./Base.jl:33
[2] macro expansion at /Users/ralph/.julia/packages/Zygote/7Jrhj/src/lib/lib.jl:287 [inlined]
[3] (::Zygote.Jnew{Pair{Int64,Array{Float64,1}},Nothing,false})(::NamedTuple{(:second,),Tuple{Array{Float64,1}}}) at /Users/ralph/.julia/packages/Zygote/7Jrhj/src/lib/lib.jl:279
[4] (::Zygote.var"#1727#back#165"{Zygote.Jnew{Pair{Int64,Array{Float64,1}},Nothing,false}})(::NamedTuple{(:second,),Tuple{Array{Float64,1}}}) at /Users/ralph/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[5] Pair at ./pair.jl:12 [inlined]
[6] Pair at ./pair.jl:15 [inlined]
[7] (::typeof(∂(Pair)))(::NamedTuple{(:second,),Tuple{Array{Float64,1}}}) at /Users/ralph/.julia/packages/Zygote/7Jrhj/src/compiler/interface2.jl:0
[8] invert_laplace at /Users/ralph/source/julia/zygote_tests/test_laplace_direct.jl:25 [inlined]
[9] (::typeof(∂(invert_laplace)))(::Array{Float64,1}) at /Users/ralph/.julia/packages/Zygote/7Jrhj/src/compiler/interface2.jl:0
[10] my_laplace_test at /Users/ralph/source/julia/zygote_tests/test_laplace_direct.jl:67 [inlined]
[11] (::Zygote.var"#41#42"{typeof(∂(my_laplace_test))})(::Float64) at /Users/ralph/.julia/packages/Zygote/7Jrhj/src/compiler/interface.jl:40
[12] gradient(::Function, ::Float64) at /Users/ralph/.julia/packages/Zygote/7Jrhj/src/compiler/interface.jl:49
[13] top-level scope at /Users/ralph/source/julia/zygote_tests/test_laplace_direct.jl:72
[14] include(::Function, ::Module, ::String) at ./Base.jl:380
[15] include(::Module, ::String) at ./Base.jl:368
[16] exec_options(::Base.JLOptions) at ./client.jl:296
[17] _start() at ./client.jl:506
Here [8] refers to the line where the matrix A is constructed. Does anyone have an idea what to do here?