How to differentiate functions involving arrays with one non-zero element in Zygote

function my_fidelity2(p)
  ψg = zeros(ComplexF64, 1)
  ψg = [1.0; ψg]
  ψe = zeros(ComplexF64, 1)
  ψe = [1.0; ψe]

  ψgg = kron(ψg, ψg)
  ψee = kron(ψe, ψe)
  ψ = normalize(p[1] * ψgg + p[2] * ψee)

  return abs(dot(ψ, ψgg))
end
1 Like