Unable to compute gradient of loss of my model with sciml_train

Hello,

I’m facing an error for which I can’t find any solution or reporting.

I’m trying to train a model defined by an ODE for which i solve the adjoint problem using concrete_solve.

function initial_value_problem(du::AbstractArray, u::AbstractArray, p, t)
z = u[1:end-1]'
    f = re(q) 
    du[1:end-1] = f(z)
    du[end] = -t_J(f, z')
end

x = Float32.(rand(model, (1, 100)))
u0 = [Flux.flatten(x) 0]
prob = ODEProblem(initial_value_problem, u0, tspan)

function predict_adjoint(x, p) #we want to solve ivp with adjoint method 
    concrete_solve(prob,Tsit5(),u0 = x, p,  #[u0,0f0]
                   saveat=0f0:0.1f0:10f0,sensealg=DiffEqFlux.InterpolatingAdjoint(
                   checkpointing=true))
end


function loss_adjoint(xs::AbstractArray, p)
    xs =[xs] 
    pz = Normal(0.0, 1.0)
    @showgrad preds = predict_adjoint(xs, p)[:,:,end]
    z = preds
    delta_logp = predict_adjoint(x, p)[:,:,end][end] 

    logpz = DistributionsAD.logpdf(pz, z)
    logpx = logpz .- delta_logp
    loss = -mean(logpx)
end

My loss function is a log-likelihood.
I’m trying to use sciml_train to train on some generated data but I always get the same error.

ERROR: MethodError: no method matching similar(::DiffEqBase.NullParameters)

Indeed when I try to manually use Zygote.gradient on my loss function it doesn’t work.
Do you have any explanation? Is solve function compatible with Zygote.gradient?
By the way my function t_J is a function that I defined before computing trace of jacobian.

I’m surprised that didn’t just error earlier. The issue is that you weren’t really passing p as it needed to be a keyword argument. That’s a deprecated function too, so the suggested updated syntax is:

function predict_adjoint(x, p) #we want to solve ivp with adjoint method
    _prob = remake(prob,u0=x,p=p)
          solve(_prob ,Tsit5(),  #[u0,0f0]
                   saveat=0f0:0.1f0:10f0,sensealg=DiffEqFlux.InterpolatingAdjoint(
                   checkpointing=true))
end
1 Like

Continuing the discussion from Unable to compute gradient of loss of my model with sciml_train:

Hi, I have met similar issue here. I defined a ODE problem with external input signal as the parameter. Here is my code. The wierd part is that, in my loss_batch function, if the loss l is returned, I will have the error. If I test it by returning a constant value l = 0.1, the training will success.
In this case, my thoughts is the dudt and prediction_batch function should be fine, but something wrong with loss so that the gradient is unsupported. Any ideas about this? Thanks in advance.

nn_model = FastChain(FastDense(89,64, tanh), FastDense(64, 1))
pa = initial_params(nn_model)
u01 = Float32.([0.0])

function dudt(u,p,t)
    feature = p[1][Int(round(t*25+1)),:];
    nn_model(vcat(u, feature), p[2])
end

function predict_batch(fullp, x_input)
    prob_gpu2 = ODEProblem(dudt, u01, tspan,(x_input',fullp))
    Array(concrete_solve(prob_gpu2,Tsit5(),
    saveat = tsteps))
end

function loss_batch(fullp,x_input,y_output)
    pred =predict_batch(fullp,x_input)
    N = length(pred)
    l = sum(abs2, y_output[1:N] .- pred')
    #l = 0.1
    return l
end

res1 = DiffEqFlux.sciml_train(loss_batch,pa, ADAM(0.05), train_loader, maxiters =10)

train_loader isn’t defined in your example.

Hi thanks for replying.

train_loader is defined via:

train_loader = DataLoader(xtrain, ytrain, batchsize=10, shuffle=true)

where x_train is training features with feature dimensions of 88, and y_train contains labels. In this toy example, I am using x_train to be a 88*300 matrix, and y_label to be a 1 * 300 array.

During debugging, I have checked the ODE solver works fine, I can print the solution of concreted_solve every iteration, which means the forward path should be alright.

However, the error message seems to show there is a type mismatch in function where I defined ODEProblem with external input signal: prob_gpu2 = ODEProblem(dudt, u01, tspan,(x_input',fullp)), (x_input',fullp) is the only part I have used a Tuple.

Oh adjoints cannot have tuple parameters.