How to take the gradient of an ODE system with respect to many data points?

plot_35

I’ll investigate why some of the backends fail

Benchmarking code
begin
    using Statistics
    # using ProfileCanvas
    # using BenchmarkTools
    using DataFrames
    using Plots

    using DifferentiationInterface
    using DifferentiationInterfaceTest

    # using ChainRulesCore
    # using Diffractor
    using Enzyme
    using FastDifferentiation
    using FiniteDiff
    # using FiniteDifferences
    using ForwardDiff
    using PolyesterForwardDiff
    using ReverseDiff
    using SparseDiffTools
    using Symbolics
    using Tapir
    using Tracker
    using Zygote
end

function loss(x, ξ, k, dx)
    k1, k2, k3 = view(k, 1, :), view(k, 2, :), view(k, 3, :)
    x1, x2, x3, x4 = view(x, 1, :), view(x, 2, :), view(x, 3, :), view(x, 4, :)
    dx1, dx2, dx3, dx4 = view(dx, 1, :), view(dx, 2, :), view(dx, 3, :), view(dx, 4, :)
    y = @. begin
        abs(
            -dx3 +
            k1 * ξ[31] +
            k2 * ξ[32] +
            k3 * ξ[33] +
            (k1 * ξ[10] + k2 * ξ[11] + k3 * ξ[12]) * x1 +
            (k1 * ξ[25] + k2 * ξ[26] + k3 * ξ[27]) * x2 +
            (-k1 * ξ[34] - k2 * ξ[35] - k3 * ξ[36]) * x3 +
            (-k1 * ξ[37] - k2 * ξ[38] - k3 * ξ[39]) * x3 +
            (-k1 * ξ[40] - k2 * ξ[41] - k3 * ξ[42]) * x3 +
            (-k1 * ξ[43] - k2 * ξ[44] - k3 * ξ[45]) * x3 +
            (k1 * ξ[58] + k2 * ξ[59] + k3 * ξ[60]) * x4,
        ) +
        abs(
            -dx1 +
            k1 * ξ[1] +
            k2 * ξ[2] +
            k3 * ξ[3] +
            (-k1 * ξ[10] - k2 * ξ[11] - k3 * ξ[12]) * x1 +
            (-k1 * ξ[13] - k2 * ξ[14] - k3 * ξ[15]) * x1 +
            (k1 * ξ[22] + k2 * ξ[23] + k3 * ξ[24]) * x2 +
            (k1 * ξ[37] + k2 * ξ[38] + k3 * ξ[39]) * x3 +
            (-k1 * ξ[4] - k2 * ξ[5] - k3 * ξ[6]) * x1 +
            (k1 * ξ[52] + k2 * ξ[53] + k3 * ξ[54]) * x4 +
            (-k1 * ξ[7] - k2 * ξ[8] - k3 * ξ[9]) * x1,
        ) +
        abs(
            -dx4 +
            k1 * ξ[46] +
            k2 * ξ[47] +
            k3 * ξ[48] +
            (k1 * ξ[13] + k2 * ξ[14] + k3 * ξ[15]) * x1 +
            (k1 * ξ[28] + k2 * ξ[29] + k3 * ξ[30]) * x2 +
            (k1 * ξ[43] + k2 * ξ[44] + k3 * ξ[45]) * x3 +
            (-k1 * ξ[49] - k2 * ξ[50] - k3 * ξ[51]) * x4 +
            (-k1 * ξ[52] - k2 * ξ[53] - k3 * ξ[54]) * x4 +
            (-k1 * ξ[55] - k2 * ξ[56] - k3 * ξ[57]) * x4 +
            (-k1 * ξ[58] - k2 * ξ[59] - k3 * ξ[60]) * x4,
        ) +
        abs(
            -dx2 +
            k1 * ξ[16] +
            k2 * ξ[17] +
            k3 * ξ[18] +
            (-k1 * ξ[19] - k2 * ξ[20] - k3 * ξ[21]) * x2 +
            (-k1 * ξ[22] - k2 * ξ[23] - k3 * ξ[24]) * x2 +
            (-k1 * ξ[25] - k2 * ξ[26] - k3 * ξ[27]) * x2 +
            (-k1 * ξ[28] - k2 * ξ[29] - k3 * ξ[30]) * x2 +
            (k1 * ξ[40] + k2 * ξ[41] + k3 * ξ[42]) * x3 +
            (k1 * ξ[55] + k2 * ξ[56] + k3 * ξ[57]) * x4 +
            (k1 * ξ[7] + k2 * ξ[8] + k3 * ξ[9]) * x1,
        )
    end
    return mean(y)
end

begin
    v = 4
    d = 10
    nk = 3
    nξ = 60
    x = rand(v, d)
    dx = rand(v, d)
    ξ = rand(nξ)
    k = rand(nk, d)
end

f(ξ) = loss(x, ξ, k, dx)
f(ξ)

scenarios = [GradientScenario(f; x=ξ, operator=:inplace)]

backends = [
    # AutoDiffractor(),
    # AutoEnzyme(; mode=Enzyme.Forward),
    # AutoEnzyme(; mode=Enzyme.Reverse),
    AutoFastDifferentiation(),
    AutoFiniteDiff(),
    # AutoFiniteDifferences(),
    AutoForwardDiff(),
    # AutoPolyesterForwardDiff(; chunksize=8),
    # AutoTracker(),
    # AutoReverseDiff(),
    # AutoSymbolics(),
    # AutoTapir(),
    AutoZygote(),
]

result = benchmark_differentiation(backends, scenarios; logging=true)

df = DataFrame(result)

df_filtered = df[df[!, :operator] .== :gradient!, :]

plt = bar(
    df_filtered[!, :backend],
    df_filtered[!, :time],
    label=nothing,
    xlabel="backend",
    ylabel="runtime [log]",
    xrotation=10,
    yscale=:log10,
    margin=15Plots.mm
)

savefig(plt, "benchmark.png")
2 Likes