The tracked parameters are not updated during Flux.training

Hey,
Sorry it took a bit to get to this. The issue was that you were mixing an old deprecated syntax with a new one. What you wanted was:

    sol = solve(
        ODEProblem(VO_system!, u0, tspan, p),
        Tsit5();
        u0 = u0,
        p = p,
        saveat = range(tspan[1], tspan[2], length = 500),
        sensealg = QuadratureAdjoint(),
        abstol = 1e-10,
        reltol = 1e-10,
    )

i.e. use solve and keyword arguments for u0 and p. A full example that double checks the values of the gradients with ForwardDiff is:

using Flux, OrdinaryDiffEq, DiffEqSensitivity


# encapsule all parameters in VO struct
struct VO
    Cth::Any
    Rth::Any
end

Flux.@functor VO
Flux.trainable(device::VO) = (device.Rth, device.Cth)

# VO constructor
VO() = VO([1.63e-13], [4.0e5])


# 2nd order function of VO
function (device::VO)(Input_time_series,p = [device.Cth[1], device.Rth[1]])
    tspan = (0.0, 1e-5)
    u0 = vcat(repeat([1.0]; outer = [10]), repeat([300.0]; outer = [10]))
    # single device paramters

    function VO_system!(du, u, p, t)
        Cth, Rth = p
        V = u[1:10]
        T = u[11:end]
        du[1:10] = -V
        @. du[11:end] = 1 / Cth * (V - T / Rth)
    end


    sol = solve(
        ODEProblem(VO_system!, u0, tspan, p),
        Tsit5();
        u0 = u0,
        p = p,
        saveat = range(tspan[1], tspan[2], length = 500),
        sensealg = QuadratureAdjoint(),
        abstol = 1e-10,
        reltol = 1e-10,
    )

    return sol
end


###############################################Training#########################

device = VO()
model = device
ps = Flux.params(model)
loss(x, y) = Flux.mse(model(x),y)

println("Tracked Parameters: ", ps)

gs = Flux.gradient(() -> loss(0, 0), ps)
@show gs[ps[1]]
@show gs[ps[2]]

using ForwardDiff
grad = ForwardDiff.gradient(p->Flux.mse(model(0,p),0),[1.63e-13,4.0e5])
@show grad[2]
@show grad[1]

which outputs:

Tracked Parameters: Params([[400000.0], [1.63e-13]])
gs[ps[1]] = [3.937418595344859e9]
gs[ps[2]] = [-4.787204526955042e25]
grad[2] = 3.9374185953439655e9
grad[1] = -4.787204527049037e25
1 Like