Compute gradients in neuralODE with Zygote

Hello everyone,

I would like to know how to check the gradients for a single batch in this example.

After the computation of l1 I did:

batch1 = train_loader.data[1]

time1 = train_loader.data[2]

loss, back = Zygote.pullback((θ, p, batch, time_batch) -> loss_adjoint(θ, batch, time_batch), pp, batch1, time1)

But this is giving me an error saying:
ERROR: MethodError: no method matching (::var"#10#11")(::Vector{Float32}, ::Matrix{Float32}, ::StepRangeLen{Float32, Float64, Float64, Int64})

How can I correctly get the gradients in this case?
Also, how can I get the gradients with respect to the inputs to the network u?

Best Regards

What verisons?

Julia 1.9.2
Zygote v0.6.63
Flux v0.14.2

I think I solved it with:

loss, back = Zygote.pullback((θ, batch, time_batch) -> loss_adjoint(θ, batch, time_batch), pp, batch1, time1)

This returns the loss as a tuple from sum(abs2, batch .- pred), pred

When I compute the gradients with:
back(loss)

What am I getting? The gradients w.r.t to the loss [network] and the gradients w.r.t to the batch?
Where does the latter come into play?

Edit:

Also,

If I check the gradients on batch and without batch, I end up with slightly different values.

using Random
Random.seed!(123);

batch1 = train_loader.data[1]
time1 = train_loader.data[2]
prob2 = ODEProblem{false}(dudt_, u0, tspan, pp, saveat=time1)

function predict_adjoint2(fullp)
    Array(solve(prob2, Tsit5(), p = fullp))
end

function loss_adjoint2(fullp)
    pred = predict_adjoint2(fullp)
    sum(abs2, batch1 .- pred)
end

loss, back = Zygote.pullback((θ, batch, time_batch) -> loss_adjoint(θ, batch,
time_batch), pp, batch1, time1)

loss2, back2 = Zygote.pullback((θ) -> loss_adjoint2(θ), pp)

a = back(loss)
b = back2(loss2)
julia> a[1]
25-element Vector{Float32}:
    0.0
   -2.933251f8
    0.0
 -183.55194
    0.0
    0.0
    0.0
    0.0
    0.0
   -3.9543002f6
    0.0
   -8.834232
    0.0
    0.0
    0.0
    0.0
    1.7576794f10
   -1.752914f10
    1.7576794f10
   -1.7576794f10
    1.7576794f10
   -1.7576794f10
   -1.7576794f10
    1.7576794f10
   -1.7576784f10
julia> b[1]
25-element Vector{Float32}:
    0.0
   -2.9332538f8
    0.0
 -183.55225
    0.0
    0.0
    0.0
    0.0
    0.0
   -3.9543042f6
    0.0
   -8.834239
    0.0
    0.0
    0.0
    0.0
    1.7576827f10
   -1.7529172f10
    1.7576827f10
   -1.7576827f10
    1.7576827f10
   -1.7576827f10
   -1.7576827f10
    1.7576827f10
   -1.7576817f10

Shouldn’t they be exactly the same?

Best Regards

The gradients with respect to each of the input arguments.

loss_adjoint is not defined in your code so it should error. But if it is just a different call to loss_adjoint2 which passes in the same arguments, then likely the issue is that the version with the global batch1, because it’s global, removes a compiler optimization and thus slightly changes the floating point results. Remember, floating point is not associative and all of that jazz, and 32-bit floats only have 8 digits of accuracy, and so the difference that you’re seeing is the expected difference you’d have from (x + y) + z vs x + (y + z) for floating point numbers.