Julia 1.9.2
Zygote v0.6.63
Flux v0.14.2
I think I solved it with:
loss, back = Zygote.pullback((θ, batch, time_batch) -> loss_adjoint(θ, batch, time_batch), pp, batch1, time1)
This returns the loss
as a tuple from sum(abs2, batch .- pred), pred
When I compute the gradients with:
back(loss)
What am I getting? The gradients w.r.t to the loss [network] and the gradients w.r.t to the batch?
Where does the latter come into play?
Edit:
Also,
If I check the gradients on batch and without batch, I end up with slightly different values.
using Random
Random.seed!(123);
batch1 = train_loader.data[1]
time1 = train_loader.data[2]
prob2 = ODEProblem{false}(dudt_, u0, tspan, pp, saveat=time1)
function predict_adjoint2(fullp)
Array(solve(prob2, Tsit5(), p = fullp))
end
function loss_adjoint2(fullp)
pred = predict_adjoint2(fullp)
sum(abs2, batch1 .- pred)
end
loss, back = Zygote.pullback((θ, batch, time_batch) -> loss_adjoint(θ, batch,
time_batch), pp, batch1, time1)
loss2, back2 = Zygote.pullback((θ) -> loss_adjoint2(θ), pp)
a = back(loss)
b = back2(loss2)
julia> a[1]
25-element Vector{Float32}:
0.0
-2.933251f8
0.0
-183.55194
0.0
0.0
0.0
0.0
0.0
-3.9543002f6
0.0
-8.834232
0.0
0.0
0.0
0.0
1.7576794f10
-1.752914f10
1.7576794f10
-1.7576794f10
1.7576794f10
-1.7576794f10
-1.7576794f10
1.7576794f10
-1.7576784f10
julia> b[1]
25-element Vector{Float32}:
0.0
-2.9332538f8
0.0
-183.55225
0.0
0.0
0.0
0.0
0.0
-3.9543042f6
0.0
-8.834239
0.0
0.0
0.0
0.0
1.7576827f10
-1.7529172f10
1.7576827f10
-1.7576827f10
1.7576827f10
-1.7576827f10
-1.7576827f10
1.7576827f10
-1.7576817f10
Shouldn’t they be exactly the same?
Best Regards