I’m having trouble debugging why ForwardDiff is not working for my function. Below is a mwe. Its very close to working, but I can’t figure it out.
using ForwardDiff
using LinearAlgebra
struct Stack
M::Matrix{Int}
Us::Vector{Matrix{Complex{Float64}}}
end
stack = Stack(rand(0:1,4,10), [rand(4,2) for i in 1:10])
function update!(stack::Stack, x::Vector, p::Int)
# A = exp.(Diagonal(stack.M[:,p] )) # works
A = exp.(Diagonal(stack.M[:,p] * x[p])) # doesn't work
B = qr!(A[:, 1:2], Val(true)).Q |> Matrix
stack.Us[p] .= B
end
update!(stack, rand(10), 1)
function f(stack, x, p)
update!(stack, x, p)
return diag(sum(stack.Us)) |> sum |> real
end
g = x -> ForwardDiff.gradient(x -> f(stack, x, 1), x)
f(stack, rand(10),1)
g( rand(10))
⇡5% [I] ➜ jul --project mwe_forwarddiff.jl
ERROR: LoadError: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##6#8")),Float64},Float64,10})
Closest candidates are:
Float64(::Real, !Matched::RoundingMode) where T<:AbstractFloat at rounding.jl:200
Float64(::T<:Number) where T<:Number at boot.jl:718
Float64(!Matched::Int8) at float.jl:60
...
Stacktrace:
[1] convert(::Type{Float64}, ::ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##6#8")),Float64},Float64,10}) at ./number.jl:7
[2] Complex{Float64}(::ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##6#8")),Float64},Float64,10}, ::Int64) at ./complex.jl:12
[3] Complex{Float64}(::ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##6#8")),Float64},Float64,10}) at ./complex.jl:35
[4] convert(::Type{Complex{Float64}}, ::ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##6#8")),Float64},Float64,10}) at ./number.jl:7
[5] setindex!(::Array{Complex{Float64},2}, ::ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##6#8")),Float64},Float64,10}, ::Int64) at ./array.jl:780
[6] copyto! at ./abstractarray.jl:807 [inlined]
[7] copyto! at ./abstractarray.jl:799 [inlined]
[8] copyto! at ./broadcast.jl:883 [inlined]
[9] copyto! at ./broadcast.jl:842 [inlined]
[10] materialize! at ./broadcast.jl:801 [inlined]
[11] update!(::Stack, ::Array{ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##6#8")),Float64},Float64,10},1}, ::Int64) at /home/meach/.julia/dev/PDQMC/src/mwe_forwarddiff.jl:15
[12] f(::Stack, ::Array{ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##6#8")),Float64},Float64,10},1}, ::Int64) at /home/meach/.julia/dev/PDQMC/src/mwe_forwarddiff.jl:21
[13] (::getfield(Main, Symbol("##6#8")))(::Array{ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##6#8")),Float64},Float64,10},1}) at /home/meach/.julia/dev/PDQMC/src/mwe_forwarddiff.jl:28
[14] vector_mode_gradient(::getfield(Main, Symbol("##6#8")), ::Array{Float64,1}, ::ForwardDiff.GradientConfig{ForwardDiff.Tag{getfield(Main, Symbol("##6#8")),Float64},Float64,10,Array{ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##6#8")),Float64},Float64,10},1}}) at /home/meach/.julia/packages/ForwardDiff/N0wMF/src/apiutils.jl:37
[15] gradient(::Function, ::Array{Float64,1}, ::ForwardDiff.GradientConfig{ForwardDiff.Tag{getfield(Main, Symbol("##6#8")),Float64},Float64,10,Array{ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##6#8")),Float64},Float64,10},1}}, ::Val{true}) at /home/meach/.julia/packages/ForwardDiff/N0wMF/src/gradient.jl:17
[16] gradient(::Function, ::Array{Float64,1}, ::ForwardDiff.GradientConfig{ForwardDiff.Tag{getfield(Main, Symbol("##6#8")),Float64},Float64,10,Array{ForwardDiff.Dual{ForwardDiff.Tag{getfield(Main, Symbol("##6#8")),Float64},Float64,10},1}}) at /home/meach/.julia/packages/ForwardDiff/N0wMF/src/gradient.jl:15 (repeats 2 times)
[17] (::getfield(Main, Symbol("##5#7")))(::Array{Float64,1}) at /home/meach/.julia/dev/PDQMC/src/mwe_forwarddiff.jl:28
[18] top-level scope at /home/meach/.julia/dev/PDQMC/src/mwe_forwarddiff.jl:32
[19] include at ./boot.jl:328 [inlined]
[20] include_relative(::Module, ::String) at ./loading.jl:1094
[21] include(::Module, ::String) at ./Base.jl:31
[22] exec_options(::Base.JLOptions) at ./client.jl:295
[23] _start() at ./client.jl:468
in expression starting at /home/meach/.julia/dev/PDQMC/src/mwe_forwarddiff.jl:32
Any ideas?