Sensitivity of an ODEProblem defined by another package (QuantumToolbox)

Hi @drandran12,

Zygote doesn’t automatically apply differentiation on complex functions, as there are many definitions as expressed in the Zygote Documentation.

One way would be to define the Wirtinger derivatives with for example

function wirtinger(f, x)
  y, back = Zygote.pullback(f, x)

  du = [back([i == j ? 1 : 0 for i in 1:length(y)])[1] for j in 1:length(y)]
  dv = [back([i == j ? im : 0 for i in 1:length(y)])[1] for j in 1:length(y)]
  (conj.(du) + im*conj.(dv))/2, (du + im*dv)/2
end

And then applying it for your case

const N = 20
const a = destroy(N)

const G, K, γ = 0.002, 0.001, 0.01

coef_Δ(p, t) = p[1]

H = QobjEvo(a' * a, coef_Δ) + K * a' * a' * a * a - G * (a' * a' + a * a)
c_ops = [sqrt(γ) * a]
const L = liouvillian(H, c_ops)

const ψ0 = fock(N, 0)

function my_f_mesolve(p)

  tlist = [0, 40.0]

    sol = mesolve(
        L,
        ψ0,
        tlist,
        progress_bar = Val(false),
        params = p,
        sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
    )

    return vec(sol.states[end].data)
end

Δ = 1.0 + 0im
params = [Δ]

my_f_mesolve(params)

# Using FiniteDiff
ρ, dρ = DifferentiationInterface.value_and_jacobian(my_f_mesolve, AutoFiniteDiff(), params)

# Using Zygote
dρ_z, dρ_z_bar = wirtinger(my_f_mesolve, params)

However, be aware that Reverse autodiff is efficient when your function is R^n \to R^m with n > m. Here you want to return the entire density matrix, differentiation only over one parameter. So, I would use Forward autodiff instead.

I’m not an expert of autodiff, I just made Zygote work on simple mesolve problems. This definitely needs more tests and use cases support.

1 Like