Parameter sensitivity of ODE with parameter-dependent event

Hey,

I’m trying to get parameter sensitivities of an ODE problem with a parameter-dependent time-event.
Just throwing AD at the problem doesn’t seem to work:

using OrdinaryDiffEq
using SciMLSensitivity
using ForwardDiff

# At t = p[2], we assign p[1] <- p[3]
function rhs!(du, u, p, t)
    du[1] = -u[1] + p[1]
end

u0 = [1.0]
p_start = [1.2, 2.0, 0.1]

prob = ODEProblem(rhs!, u0, (0.0, 10.0), p_start)

function loss(p)
    _prob = remake(prob, p=p)

    function condition_disc(u, t, integrator)
        return t == integrator.p[2]
    end

    function condition_cont(u, t, integrator)
        return t - integrator.p[2]
    end

    function affect!(integrator)
        # Triggered at t = p[2], use p[3] instead of p[1] for the remaining time
        integrator.p[1] = p[3]
    end

#    sol = solve(_prob, Tsit5(), saveat = 0.0:0.1:10.0, tstops=[ForwardDiff.value(p[2])], callback = DiscreteCallback(condition_disc, affect!), sensealg=ForwardDiffSensitivity())
    sol = solve(_prob, Tsit5(), saveat = 0.0:0.1:10.0, tstops=[ForwardDiff.value(p[2])], callback = ContinuousCallback(condition_cont, affect!), sensealg=ForwardDiffSensitivity())
    loss = sum(abs2, sol .- 1)
    return loss
end

ForwardDiff.gradient(loss, p_start)

This yields

3-element Vector{Float64}:
   -2.9178227175008735
    0.0
 -116.18633113248218

But simple finite differences for p[2] give

julia> (loss([1.20, 2.01, 0.1]) - loss([1.2, 1.99, 0.1])) / 0.02
-7.956884892551486

Neither DiscreteCallback nor ContinuousCallback appear to work.
How can I get the derivative wrt. p[2] using AD?