How to differentiate functions involving arrays with one non-zero element in Zygote

Hello,

I need to differentiate using Zygote.jl a function like

using LinearAlgebra

function my_fidelity2(p)
  ψg = zeros(ComplexF64, 2)
  ψg[2] = 1.0
  ψe = zeros(ComplexF64, 2)
  ψe[1] = 1.0

  ψgg = kron(ψg, ψg)
  ψee = kron(ψe, ψe)
  ψ = normalize(p[1] * ψgg + p[2] * ψee)

  return abs(dot(ψ, ψgg))
end

But here I have a mutating arrays, and so Zygote.jl complains. The differentiation works with Enzyme.jl, aside from the BLAS warning posted in this issue.

What should I do to make it working with Zygote.jl? I could for sure moving the definition of ψg and ψe outside the function. But what if I want to keep it inside it. Is there a way to generate an array with only one nonzero value that is supported by Zygote.jl?

I was thinking that I could create a function like

function my_state(N, i)
  ψ = zeros(ComplexF64, N)
  ψ[i] = 1.0
  return ψ
end

And the defining a custom rrule with ChainRulesCore.jl. I’m new to the autodiff world, so I have no idea of how to set it, or if this is the best solution.

function my_fidelity2(p)
  ψg = zeros(ComplexF64, 1)
  ψg = [1.0; ψg]
  ψe = zeros(ComplexF64, 1)
  ψe = [1.0; ψe]

  ψgg = kron(ψg, ψg)
  ψee = kron(ψe, ψe)
  ψ = normalize(p[1] * ψgg + p[2] * ψee)

  return abs(dot(ψ, ψgg))
end
1 Like

Thanks, very simple!

I think you could simplify this alot. Doesn’t it just calculate:

function my_fidelity3(p)
    return abs(p[1])/hypot(p[1], p[2])
end
1 Like

Yes, but I made this example just as a MWE. In principle the two vectors can be very general. Thanks anyway for this optimized solution.