# Cannot compute gradient/apply back method to solve function

Hi,

I’m facing an issue which avoids me to train my model using sciml_train.
I have the following training data:


model = Normal(5., 0.1)
raw_data = rand(model, (1, 100))
x = rand(model, (1; 100))
u0 = [x; 0]


I have an ODE defined as:

function initial_value_problem(u::AbstractArray, p, t) #dynamics of z and logpz given in the paper
z = u[1:end-1]
f = re(p)
[f(z')'; -t_J(f, z)]
end

prob = ODEProblem(initial_value_problem, u0, tspan)



With f a simple neural network with weights p:

nn = Chain(Dense(1, 1, swish), Dense(1, 1))
p, re = Flux.destructure(nn)


I also define a function predict_adjoint as follows:



_prob = remake(prob, u0 = u0, p=p)

return solve(_prob ,Tsit5(), #[u0,0f0]
checkpointing=true)).u[end]
end



Then when I’m trying to compute gradient of a loss:

function loss_adjoint(p)
pz = Normal(0.0, 1.0)
z0 = preds[1:end-1]
delta_logp = preds[end]

logpx = logpz .- delta_logp
loss = -mean(logpx)
loss
end


I get the following error:

ERROR: MethodError: no method matching flatten(::Array{Float64,2})


Investigating into this error I found that it might come from step predict_adjoint in loss_adjoint.

I don’t understand why because loss_adjoint returns a scalar and every operation invoked inside should be differentiable according to Zygote so i should be able to compute gradient.

Do you know where I am wrong?

Thank you

t_J isn’t defined here. Could you share the whole script so I can copy/paste run it and see the issue?

Is this a FFJORD implementation where it’s using Zygote inside?

Hey Chris thank you for your help.

Sorry I should have copied it too. Here is full script:


model = Normal(5., 0.1)

raw_data = rand(model, (1, 100))

# we want to learn a mapping from base distribution to model ie mapping generated point to data

#defining base distribution

pz = Normal(0.0, 1.0)

nn = Chain(Dense(1, 1, swish), Dense(1, 1))

p_init, re = Flux.destructure(nn)

x = rand(model, (1; 100))

u0 = [x; 0]

# Simulation interval and intermediary points
tspan = (0.0, 10.0)
tsteps = 0.0:0.1:10.0

function t_J(f, z::AbstractArray) #takes function and array (requires size = (N,)) as argument returns scalar
N = size(z)
diag_Jacobian_nn = zeros(N)
for t in 1:N
g(x) = f(x')[t]
y, back = Zygote.pullback(g, z) #on cherche le gradient de cette fonction
end
sum(diag_Jacobian_nn)
end

function initial_value_problem(du::AbstractArray, u::AbstractArray, p, t) #dynamics of z and logpz given in the paper
z = u[1:end-1]
f = re(p)      #why?
du[1:end-1] = f(z')
du[end] =  -t_J(f, z)
return du
end

# essayer avec la forme ivp(u,p,t)

function initial_value_problem(u::AbstractArray, p, t) #dynamics of z and logpz given in the paper
z = u[1:end-1]
f = re(q)
[f(z')'; -t_J(f, z)]
end

prob = ODEProblem(initial_value_problem, u0, tspan)

_prob = remake(prob, u0 = u0, p=p)

return solve(_prob ,Tsit5(), #[u0,0f0]
checkpointing=true)).u[end]
end

pz = Normal(0.0, 1.0)
z0 = preds[1:end-1]
delta_logp = preds[end]

logpx = logpz .- delta_logp
loss = -mean(logpx)
loss
end

function cb2(l)
@show l
false
end

res = DiffEqFlux.sciml_train(loss_adjoint, p_init, opt, maxiters = 100, cb = cb2)



I’m concerned about Zygote because I understood that sciml_train calls Zygote.pullback to compute gradient. The fact is that running pullback method on loss_adjoint works but then back method does not work

y, back = Zygote.pullback(loss_adjoint, p)
back(y)  #error


So the issue seems to be ReverseDiff getting in the way of the backwards pass of broadcasting.

It is possible to get around it by defining a more specific method to be picked up by Zygote, specifically with swish. Something like

Zygote.@adjoint function Base.broadcasted(::typeof(swish), x)
y = swish.(x)
y, Δ -> begin
(nothing, Δ .* (y .+ (σ.(x) .* (one.(x) .- y))), )
end
end


There is a DimensionMismatch error too though, but seems unrelated to this.

This adjoint is a bit orthogonal to Zygote really, but good that we can work around it.