Cannot compute gradient/apply back method to solve function

Hi,

I’m facing an issue which avoids me to train my model using sciml_train.
I have the following training data:


model = Normal(5., 0.1)
raw_data = rand(model, (1, 100))
x = rand(model, (1; 100))
u0 = [x; 0] 

I have an ODE defined as:

function initial_value_problem(u::AbstractArray, p, t) #dynamics of z and logpz given in the paper 
    z = u[1:end-1]
    f = re(p)
    [f(z')'; -t_J(f, z)] 
end


prob = ODEProblem(initial_value_problem, u0, tspan)

With f a simple neural network with weights p:

nn = Chain(Dense(1, 1, swish), Dense(1, 1))
p, re = Flux.destructure(nn)

I also define a function predict_adjoint as follows:


function predict_adjoint(p) #we want to solve ivp with adjoint method

    _prob = remake(prob, u0 = u0, p=p)
    
    return solve(_prob ,Tsit5(), #[u0,0f0]
                   saveat=0f0:0.1f0:10f0,sensealg=DiffEqFlux.InterpolatingAdjoint(
                   checkpointing=true)).u[end][1] 
end

Then when I’m trying to compute gradient of a loss:

function loss_adjoint(p)
    pz = Normal(0.0, 1.0)
    preds = predict_adjoint(p)[end]
    z0 = preds[1:end-1]
    delta_logp = preds[end]

    logpz = DistributionsAD.logpdf.(pz, z0)
    logpx = logpz .- delta_logp
    loss = -mean(logpx)
    loss
end

I get the following error:

ERROR: MethodError: no method matching flatten(::Array{Float64,2})

Investigating into this error I found that it might come from step predict_adjoint in loss_adjoint.

I don’t understand why because loss_adjoint returns a scalar and every operation invoked inside should be differentiable according to Zygote so i should be able to compute gradient.

Do you know where I am wrong?

Thank you

t_J isn’t defined here. Could you share the whole script so I can copy/paste run it and see the issue?

Is this a FFJORD implementation where it’s using Zygote inside?

Hey Chris thank you for your help.

Sorry I should have copied it too. Here is full script:


model = Normal(5., 0.1)

raw_data = rand(model, (1, 100))

# we want to learn a mapping from base distribution to model ie mapping generated point to data 

#defining base distribution

pz = Normal(0.0, 1.0)

nn = Chain(Dense(1, 1, swish), Dense(1, 1))

p_init, re = Flux.destructure(nn)

x = rand(model, (1; 100))

u0 = [x; 0] 

# Simulation interval and intermediary points
tspan = (0.0, 10.0)
tsteps = 0.0:0.1:10.0

function t_J(f, z::AbstractArray) #takes function and array (requires size = (N,)) as argument returns scalar 
    N = size(z)[1]
    diag_Jacobian_nn = zeros(N)
    for t in 1:N
        g(x) = f(x')[t] 
        y, back = Zygote.pullback(g, z) #on cherche le gradient de cette fonction 
        gradient = back(1)
        diag_Jacobian_nn[t] = gradient[1][t] 
    end
    sum(diag_Jacobian_nn)
end

function initial_value_problem(du::AbstractArray, u::AbstractArray, p, t) #dynamics of z and logpz given in the paper 
    z = u[1:end-1]
    f = re(p)      #why? 
    du[1:end-1] = f(z')
    du[end] =  -t_J(f, z)
    return du
end

# essayer avec la forme ivp(u,p,t)

function initial_value_problem(u::AbstractArray, p, t) #dynamics of z and logpz given in the paper 
    z = u[1:end-1]
    f = re(q)
    [f(z')'; -t_J(f, z)] 
end

prob = ODEProblem(initial_value_problem, u0, tspan)

function predict_adjoint(p) #we want to solve ivp with adjoint method

    _prob = remake(prob, u0 = u0, p=p)
    
    return solve(_prob ,Tsit5(), #[u0,0f0]
                   saveat=0f0:0.1f0:10f0,sensealg=DiffEqFlux.InterpolatingAdjoint(
                   checkpointing=true)).u[end]
end

function loss_adjoint(p)
    pz = Normal(0.0, 1.0)
    preds = predict_adjoint(p)
    z0 = preds[1:end-1]
    delta_logp = preds[end]

    logpz = DistributionsAD.logpdf.(pz, z0)
    logpx = logpz .- delta_logp
    loss = -mean(logpx)
    loss
end

function cb2(l)
    @show l
    false
end

opt = ADAM(0.1)

res = DiffEqFlux.sciml_train(loss_adjoint, p_init, opt, maxiters = 100, cb = cb2)

I’m concerned about Zygote because I understood that sciml_train calls Zygote.pullback to compute gradient. The fact is that running pullback method on loss_adjoint works but then back method does not work

y, back = Zygote.pullback(loss_adjoint, p)
back(y)  #error

So the issue seems to be ReverseDiff getting in the way of the backwards pass of broadcasting.

It is possible to get around it by defining a more specific method to be picked up by Zygote, specifically with swish. Something like

Zygote.@adjoint function Base.broadcasted(::typeof(swish), x)
  y = swish.(x)
  y, Δ -> begin
    (nothing, Δ .* (y .+ (σ.(x) .* (one.(x) .- y))), )
  end
end

There is a DimensionMismatch error too though, but seems unrelated to this.

This adjoint is a bit orthogonal to Zygote really, but good that we can work around it.