I am trying to train a custom, quite complex model, am getting the following error during the reverse call to calculate gradients (using DiffEqFlux.sciml_train
for optimization) and I struggle to understand what it means/how to fix it, the full (lengthy) trace can be found here
Since I have played around with this model for quite a while, I am fairly certain which block of code causes the error, but I am not sure which line/function it is specifically. The following code defines a function which is passed to DiffEqFlux.ODEProblem
and produces derivatives (I wrap it appropriately so that only du, u, p, t
arguments are left, but omitted this for brevity):
hybrid_camkii_system = function (du, u, p, t, I, net, L, SIZE_R)
N_R = SIZE_R[1] * SIZE_R[2]
du[1:64] = partial_camkii_system(u[1:64], p[end-5:end], L)
kin_vec = softplus.(net(p[N_R+1:end-6])(u), 1)
Y = reshape(p[1:N_R], SIZE_R)
bal_mat = u[64:end] .* (Y .< 0) + (Y .> 0)
sto_mat = Y .* bal_mat
d_lat = sum(kin_vec' .* sto_mat, dims=2)
du[1] = du[1] + I(t)
du[64] = du[64] + d_lat[1]
du[65:end] = d_lat[2:end]
end
I suspect the issue should be specifically in the lines defining bal_mat, sto_mat, d_lat
(as other lines of code have been working fine in previous versions of the code) but don’t fully understand what it is or where it is located. Any help is appreciated, can provide extra information as necessary, I am working in Julia v1.7.1.