My model:
using Random
using Flux, DiffEqFlux, DifferentialEquations
function model(dZ, Z, params, t)
W0 = reshape(params[1:12], (1, 12))
b0 = reshape(params[13:24], (1, 12))
W1 = reshape(params[25:36], (12, 1))
b1 = params[end]
N = Integer(length(Z) / 4);
X = Z[1:N];
Y = Z[N + 1:2 * N];
Vx = Z[2 * N + 1:3 * N];
Vy = Z[3 * N + 1:4 * N];
Vxdiff = broadcast(-, Vx, Vx');
Vydiff = broadcast(-, Vy, Vy');
Xdiff = broadcast(-, X, X');
Ydiff = broadcast(-, Y, Y');
R2 = Xdiff.^2 + Ydiff.^2;
r = reshape(R2, (N * N, :))
X0 = r * W0 .+ b0
Z0 = max.(X0, 0)
X1 = Z0 * W1 .+ b1
RR = reshape(X1, (N, N))
dVx = -sum(Vxdiff .* RR, dims = 2) ./ N
dVy = -sum(Vydiff .* RR, dims = 2) ./ N;
dZ[:] = [reshape(Vx, (N, 1)); reshape(Vy, (N, 1)); dVx; dVy] # ERROR LINE
end;
t = collect(0:0.1:10)
u0 = rand(20, 1)
tspan = (0.0, 1.0)
p_nominal = rand(37, 1);
prob = ODEProblem(model, u0, tspan, p_nominal)
data_sol = solve(prob, Tsit5(), saveat = 0.1)
p = rand(37, 1);
p = param(p)
function predict_rd()
diffeq_rd(p, prob, Tsit5(), saveat = 0.1)
end
loss_rd() = sum(abs2, predict_rd() - data_sol);
print(loss_rd());
My error is
ERROR: LoadError: Not implemented: convert tracked Tracker.TrackedReal{Float64} to tracked Float64
Stacktrace:
[1] error(::String) at .\error.jl:33
[2] convert(::Type{Tracker.TrackedReal{Float64}}, ::Tracker.TrackedReal{Tracker.TrackedReal{Float64}}) at C:\Users\.julia\packages\Tracker\6wcYJ\src\lib\real.jl:39
[3] setindex!(::Array{Tracker.TrackedReal{Float64},2}, ::Tracker.TrackedReal{Tracker.TrackedReal{Float64}}, ::Int64) at .\array.jl:767
[4] macro expansion at .\multidimensional.jl:701 [inlined]
[5] macro expansion at .\cartesian.jl:64 [inlined]
[6] macro expansion at .\multidimensional.jl:696 [inlined]
[7] _unsafe_setindex! at .\multidimensional.jl:689 [inlined]
[8] _setindex! at .\multidimensional.jl:684 [inlined]
[9] setindex! at .\abstractarray.jl:1020 [inlined]
[10] model(::Array{Tracker.TrackedReal{Float64},2}, ::Array{Tracker.TrackedReal{Float64},2}, ::TrackedArray{…,Array{Float64,2}}, ::Float64) at d:\question\run.jl:34
[11] ODEFunction at C:\Users\mzhen\.julia\packages\DiffEqBase\ZQVwI\src\diffeqfunction.jl:107 [inlined]
...
So the problem is here:
dZ[:] = [reshape(Vx, (N, 1)); reshape(Vy, (N, 1)); dVx; dVy]
How can I update dZ to avoid an error?
Thanks!