I have a function which is built out of a large number of elementary functions, and as such reverse mode AD is much slower than forward mode. I would like to use Zygote for this problem as it has support for complex inputs. Below is the function whose gradient I want to construct:
function wavefunction(Z) ### N not necessary.
slater_det = Matrix{ComplexF64}(undef, N, N)
for i in range(1, n)
for j in range(1, Int(N/n))
map(column_index -> slater_det[j+Int(N/n)*(i-1), column_index] = (Z[column_index]^(j-i))*polynomial_list[i](Z[column_index], Z[1:end .!= column_index]), 1:N)
end
end
# println(exp(-dot(Z, Z)/4))
return logdet(slater_det) - dot(Z, Z)/4
end
# Zygote.refresh()
begin
Z = randn(Float64, 2*N)
logpdf = x->2*real(wavefunction(x[begin:2:end] + 1.0im*x[begin+1:2:end]))
# @benchmark wirtinger(logpdf, Z)
Zygote.forwarddiff(logpdf, Z)
end
The above code returns a scalar and using Forwarddiff.jl directly results in the following error:
MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{var"#23#24", Float64}, Float64, 11})
Closest candidates are:
(::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:200
(::Type{T})(::T) where T<:Number at boot.jl:772
(::Type{T})(::SymbolicUtils.Symbolic) where T<:Union{AbstractFloat, Integer, Complex{<:AbstractFloat}, Complex{<:Integer}} at ~/.julia/packages/Symbolics/UrqtQ/src/Symbolics.jl:150
...
Stacktrace:
[1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#23#24", Float64}, Float64, 11})
@ Base ./number.jl:7
[2] ComplexF64(re::ForwardDiff.Dual{ForwardDiff.Tag{var"#23#24", Float64}, Float64, 11}, im::ForwardDiff.Dual{ForwardDiff.Tag{var"#23#24", Float64}, Float64, 11})
@ Base ./complex.jl:14
[3] ComplexF64(z::Complex{ForwardDiff.Dual{ForwardDiff.Tag{var"#23#24", Float64}, Float64, 11}})
@ Base ./complex.jl:43
[4] convert(#unused#::Type{ComplexF64}, x::Complex{ForwardDiff.Dual{ForwardDiff.Tag{var"#23#24", Float64}, Float64, 11}})
@ Base ./number.jl:7
[5] setindex!(::Matrix{ComplexF64}, ::Complex{ForwardDiff.Dual{ForwardDiff.Tag{var"#23#24", Float64}, Float64, 11}}, ::Int64, ::Int64)
@ Base ./array.jl:968
[6] (::var"#7#8"{Vector{Complex{ForwardDiff.Dual{ForwardDiff.Tag{var"#23#24", Float64}, Float64, 11}}}, Matrix{ComplexF64}, Int64, Int64})(column_index::Int64)
@ Main ./In[6]:5
[7] iterate
@ ./generator.jl:47 [inlined]
[8] _collect(c::UnitRange{Int64}, itr::Base.Generator{UnitRange{Int64}, var"#7#8"{Vector{Complex{ForwardDiff.Dual{ForwardDiff.Tag{var"#23#24", Float64}, Float64, 11}}}, Matrix{ComplexF64}, Int64, Int64}}, #unused#::Base.EltypeUnknown, isz::Base.HasShape{1})
@ Base ./array.jl:807
[9] collect_similar(cont::UnitRange{Int64}, itr::Base.Generator{UnitRange{Int64}, var"#7#8"{Vector{Complex{ForwardDiff.Dual{ForwardDiff.Tag{var"#23#24", Float64}, Float64, 11}}}, Matrix{ComplexF64}, Int64, Int64}})
@ Base ./array.jl:716
[10] map(f::Function, A::UnitRange{Int64})
@ Base ./abstractarray.jl:2933
[11] wavefunction(Z::Vector{Complex{ForwardDiff.Dual{ForwardDiff.Tag{var"#23#24", Float64}, Float64, 11}}})
@ Main ./In[6]:5
[12] (::var"#23#24")(x::Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#23#24", Float64}, Float64, 11}})
@ Main ./In[12]:3
[13] chunk_mode_gradient(f::var"#23#24", x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#23#24", Float64}, Float64, 11, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#23#24", Float64}, Float64, 11}}})
@ ForwardDiff ~/.julia/packages/ForwardDiff/QdStj/src/gradient.jl:150
[14] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#23#24", Float64}, Float64, 11, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#23#24", Float64}, Float64, 11}}}, ::Val{true})
@ ForwardDiff ~/.julia/packages/ForwardDiff/QdStj/src/gradient.jl:21
[15] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{var"#23#24", Float64}, Float64, 11, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#23#24", Float64}, Float64, 11}}}) (repeats 2 times)
@ ForwardDiff ~/.julia/packages/ForwardDiff/QdStj/src/gradient.jl:17
[16] top-level scope
@ In[12]:10
For reference,
Polynomial List consists of compiled functions built from Symbolics.jl.
#### We will create an array of functions recursively.
@variables sym_z, sym_Z[1:N-1]
D = Differential(sym_z)
sym_polynomials = []
for i in range(1, n)
if i==1
eq = prod(i->(sym_z-sym_Z[i])^p, 1:N-1)
append!(sym_polynomials, expand_derivatives(D(eq)))
else
append!(sym_polynomials, expand_derivatives(D(sym_polynomials[end])))
end
end
polynomial_list = map(x->build_function(x,sym_z,sym_Z;expression=Val{false}), sym_polynomials);
How do I resolve this issue? I do not think this is a limitation of Zygote.forwarddiff itself but an issue with my code.