Here is a simple example.
Flux’s parameter seems to be not so friendly with matrix multiplication.
When doing p .* A
where p is a parameter and A is a matrix, things go bad…
using DifferentialEquations
using Flux, DiffEqFlux
function myode(du,u,p,t)
p1, p2 = p
A1 = [0.0 1.0; 0.0 0.0]
A2 = [0.0 0.0; 1.0 0.0]
du .= p1 .* (A1*u) + p2 .* (A2*u)
end
u0 = [1.0, -1.0]
tspan = (0.0,1.0)
p = [1.0, 2.0]
prob = ODEProblem(myode,u0,tspan,p)
p = param([1.0, 2.0])
params = Flux.Params([p])
diffeq_rd(p, prob, Tsit5())
I got the following error
Not implemented: convert tracked Flux.Tracker.TrackedReal{Float64} to tracked Float64
I would be appreciate any tips on getting around this issue!