Hello,
I need to differentiate using Zygote.jl a function like
using LinearAlgebra
function my_fidelity2(p)
ψg = zeros(ComplexF64, 2)
ψg[2] = 1.0
ψe = zeros(ComplexF64, 2)
ψe[1] = 1.0
ψgg = kron(ψg, ψg)
ψee = kron(ψe, ψe)
ψ = normalize(p[1] * ψgg + p[2] * ψee)
return abs(dot(ψ, ψgg))
end
But here I have a mutating arrays, and so Zygote.jl complains. The differentiation works with Enzyme.jl, aside from the BLAS warning posted in this issue.
What should I do to make it working with Zygote.jl? I could for sure moving the definition of ψg
and ψe
outside the function. But what if I want to keep it inside it. Is there a way to generate an array with only one nonzero value that is supported by Zygote.jl?
I was thinking that I could create a function like
function my_state(N, i)
ψ = zeros(ComplexF64, N)
ψ[i] = 1.0
return ψ
end
And the defining a custom rrule
with ChainRulesCore.jl. I’m new to the autodiff world, so I have no idea of how to set it, or if this is the best solution.