Time normalisation results in `nothing` gradients of ODE solutions

The following defines the simple one-dimensional ode dx_dt = -ωx, and a function x(ω, t) to solve it. Fixing t=T, I can readily differentiate x(ω, t) with respect to ω:

using DifferentialEquations
using SciMLSensitivity
using Zygote

const T = 19.7

dx_dt(x, ω, t) = -only(ω)*x

function x(ω, t)
    problem = ODEProblem(
        dx_dt,
        [1.0,],    # initial condition
        (0, t),    # time span
        ω,         # parameter (dimensions of 1/time)
    )
    solution = solve(
        problem,
        Tsit5(),
        saveat=[t,],  # only need final value
        sensealg=InterpolatingAdjoint(; autojacvec = ZygoteVJP()),
    )
    return only(solution.u)
end

f(ω) = only(x(ω, T))
gradient(f, [1.0])
# ([-0.13590672273297327],)

However, certain invariance of my problem under rescaling means an equivalent solution is given as follows:

function x2(ω, t)
    problem = ODEProblem(
        dx_dt,
        [1.0,],         # initial condition
        (0, t*only(ω)), # non-dimensionalized time span
        [1.0,],         # non-dimensionalized parameter
    )
    solution = solve(
        problem,
        Tsit5(),
        saveat=[t*only(ω),],  # only need final value
        sensealg=InterpolatingAdjoint(; autojacvec = ZygoteVJP()),
    )
    return only(solution.u)
end

f2(ω) = only(x2(ω, T))

Indeed, we can see there is no difference between f and f2`:

ω = [rand(),]
@assert f(ω) ≈ f2(ω)

And yet…

gradient(f2, [1.0])
# nothing
  • Output of using Pkg; Pkg.status()
(jl_Sd68uJ) pkg> status
Status `/private/var/folders/4n/gvbmlhdc8xj973001s6vdyw00000gq/T/jl_Sd68uJ/Project.toml`
  [f6b3d34e] Bertalanffy v0.1.0 `~/Dropbox/Tumor/Bertalanffy`
  [336ed68f] CSV v0.10.12
  [b0b7db55] ComponentArrays v0.15.7
  [a93c6f00] DataFrames v1.6.1
  [0c46a032] DifferentialEquations v7.12.0
  [d9f16b24] Functors v0.4.5
  [b3c1a2ee] IterationControl v0.5.4
  [3bd65402] Optimisers v0.3.1
  [91a5bcdd] Plots v1.40.0
  [1ed8b502] SciMLSensitivity v7.53.0
  [860ef19b] StableRNGs v1.0.1
  [2913bbd2] StatsBase v0.34.2
  [bd369af6] Tables v1.11.1
  [3a884ed6] UnPack v1.0.2
  [e88e6eb3] Zygote v0.6.69
  [9a3f8284] Random
  [8dfed614] Test
julia> versioninfo()
Julia Version 1.10.0
Commit 3120989f39b (2023-12-25 18:01 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (x86_64-apple-darwin22.4.0)
  CPU: 12 × Intel(R) Core(TM) i7-8850H CPU @ 2.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, skylake)
  Threads: 17 on 12 virtual cores
Environment:
  JULIA_LTS_PATH = /Applications/Julia-1.6.app/Contents/Resources/julia/bin/julia
  JULIA_PATH = /Applications/Julia-1.10.app/Contents/Resources/julia/bin/julia
  JULIA_EGLOT_PATH = /Applications/Julia-1.7.app/Contents/Resources/julia/bin/julia
  JULIA_NUM_THREADS = 12
  JULIA_NIGHTLY_PATH = /Applications/Julia-1.10.app/Contents/Resources/julia/bin/julia
  DYLD_FALLBACK_LIBRARY_PATH = /Users/anthony/.julia/artifacts/0233bb40b298b03aa3743cc339b4a5c6816ce583/lib:/Users/anthony/.julia/artifacts/1e901863cf8fbb1ee50a5d0976114a2371899331/lib:/Users/anthony/.julia/artifacts/dcc1b7719d5a106fba77bbc272d231e163d15fe5/lib:/Applications/Julia-1.10.app/Contents/Res

(truncated)

Zygote currently cannot differentiate w.r.t. the end time. This is something that isn’t too hard to fix but it needs an extended dispatch.

2 Likes

Ah, good to know.

Thanks @ChrisRackauckas for taking the time to diagnose my problem and for the lightning response.

It would be nice to have this, as I’m struggling a bit to find a workaround. Should I open an issue at SciMLSensitivity.jl?

My problem is this: I’m using neural odes in a clinical application but I prefer the trained NN (ie, it’s parameters) to be independent of the time units adopted for calibration. (I reason that, where possible, NN parameters ought to be dimensionless.) Based on existing models, I postulate the existence of natural time scale for the problem (which is to be learned but is not part of the NN) and I non-dimensionalize times in the training set by this time scale (think of 1/ω in the example above). But since I don’t know the time scale a priori, I get the parameter-dependent end-time in my call to the solver, as well as in the saveat times.

Yup it could use an issue.

Done.