Ambiguous method bug in ReverseDiff

This is my code

function p2vec(p)
    w_b = p[1:nr] .+ b0;
    w_out = reshape(p[nr + 1:end], ns, nr);
    # w_out = clamp.(w_out, -2.5, 2.5);
    w_in = clamp.(-w_out, 0, 2.5);
    return w_in, w_b, w_out
end

function crnn!(du, u, p, t)
    w_in, w_b, w_out = p2vec(p);
    w_in_x = w_in' * @. log(clamp(u, lb, ub));
    du .= w_out * @. exp(w_in_x + w_b);
end

p = randn(Float32, nr * (ns + 1)) .* 1.f-1;

prob = ODEProblem(crnn!, u0, tspan, saveat=tsteps,
                  atol=atol, rtol=rtol,sensealg=ReverseDiffAdjoint())

function predict_neuralode(u0, p)
    pred = clamp.(Array(solve(prob, alg, u0=u0, p=p; 
                  maxiters=maxiters)), -ub, ub)
    return pred
end

I can get gradient from ForwardDiff.gradien like that:

function loss_neuralode(p, input, label)
    pred = predict_neuralode(input, p)
    loss = mae(label ./ y_std, pred ./ y_std)
    return loss
end

i_exp = 10
loss_neuralode(p, u0_list[i_exp,:], ode_data_list[i_exp,:,:])
grad = ForwardDiff.gradient(x -> loss_neuralode(x, u0_list[i_exp,:], ode_data_list[i_exp,:,:]), p)

However, when I use ReverseDiff.gradient,

using ReverseDiff
grad = ReverseDiff.gradient(x -> loss_neuralode(x, u0_list[i_exp,:], ode_data_list[i_exp,:,:]), p)

It shows there is an error

Output exceeds the size limit. Open the full output data in a text editor
MethodError: *(::Adjoint{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}}, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}}, ::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}) is ambiguous. Candidates:
  *(x::Adjoint{<:Number, <:ReverseDiff.TrackedArray{V, D, 2}}, y::AbstractVector) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:223
  *(x::Adjoint{<:Number, <:ReverseDiff.TrackedArray{V, D}}, y::AbstractVector) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:223
  *(x::Adjoint{<:Number, <:ReverseDiff.TrackedArray{V, D, 2}}, y::AbstractArray) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:223
  *(x::Adjoint{<:Number, <:ReverseDiff.TrackedArray{V, D}}, y::AbstractArray) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:223
  *(x::Adjoint{<:Number, <:AbstractMatrix}, y::ReverseDiff.TrackedArray{V, D, 1}) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:218
  *(x::Adjoint{<:Number, <:AbstractMatrix}, y::ReverseDiff.TrackedArray{V, D}) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:218
  *(adjA::Adjoint{<:Any, <:AbstractMatrix{T}}, x::AbstractVector{S}) where {T, S} in LinearAlgebra at /dssg/home/acct-esehazenet/hazenet-pg6/software/julia/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:103
  *(x::Adjoint{<:Number, <:AbstractArray}, y::ReverseDiff.TrackedArray{V, D, 1}) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:218
  *(x::AbstractMatrix, y::ReverseDiff.TrackedArray{V, D, 1}) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:214
  *(x::AbstractArray, y::ReverseDiff.TrackedArray{V, D, 1}) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:214
  *(x::Adjoint{<:Number, <:AbstractArray}, y::ReverseDiff.TrackedArray{V, D}) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:218
  *(x::AbstractMatrix, y::ReverseDiff.TrackedArray{V, D}) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:214
  *(A::AbstractMatrix{T}, x::AbstractVector{S}) where {T, S} in LinearAlgebra at /dssg/home/acct-esehazenet/hazenet-pg6/software/julia/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:54
  *(x::AbstractArray, y::ReverseDiff.TrackedArray{V, D}) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:214
Possible fix, define
  *(::Adjoint{<:Number, <:ReverseDiff.TrackedArray{V, D, 2, VA, DA}}, ::ReverseDiff.TrackedArray{V, D, 1, VA, DA}) where {V, D, V, D, VA, DA}

Stacktrace:
  [1] crnn!(du::Vector{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}}, u::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, p::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, t::ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}})
    @ Main ~/code/deepAdjoint/case2.ipynb:79
  [2] ODEFunction
    @ ~/.julia/packages/DiffEqBase/V7P18/src/diffeqfunction.jl:248 [inlined]
  [3] (::DiffEqSensitivity.var"#77#86"{ODEFunction{true, typeof(crnn!), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}})(u::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, p::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, t::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}})
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/ZdaQE/src/local_sensitivity/adjoint_common.jl:127
...
    @ ReverseDiff ~/.julia/packages/ReverseDiff/E4Tzn/src/api/gradients.jl:24
 [20] gradient(f::Function, input::Vector{Float32})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/E4Tzn/src/api/gradients.jl:22
 [21] top-level scope
    @ ~/code/deepAdjoint/case2.ipynb:2

Here is my test code

using OrdinaryDiffEq, Flux, Optim, Random, Plots
using Zygote
using ForwardDiff
using LinearAlgebra, Statistics
using ProgressBars, Printf
using Flux.Optimise: update!, ExpDecay

Random.seed!(1234);

###################################
# Argments
p_cutoff = 0.0;
n_epoch = 100;
n_plot = 100;
opt = ADAMW(0.001, (0.9, 0.999), 1.f-8);
datasize = 100;
tstep = 0.4;
n_exp_train = 20;
n_exp_test = 10;
n_exp = n_exp_train + n_exp_test;
noise = 5.f-2;
ns = 5;
nr = 4;
k = Float32[0.1, 0.2, 0.13, 0.3];
alg = Tsit5();
atol = 1e-5;
rtol = 1e-2;

maxiters = 10000;

lb = 1.f-5;
ub = 1.f1;
####################################

function trueODEfunc(dydt, y, k, t)
    dydt[1] = -2 * k[1] * y[1]^2 - k[2] * y[1];
    dydt[2] = k[1] * y[1]^2 - k[4] * y[2] * y[4];
    dydt[3] = k[2] * y[1] - k[3] * y[3];
    dydt[4] = k[3] * y[3] - k[4] * y[2] * y[4];
    dydt[5] = k[4] * y[2] * y[4];
end

# Generate data sets
u0_list = rand(Float32, (n_exp, ns));
u0_list[:, 1:2] .+= 2.f-1;
u0_list[:, 3:end] .= 0.f0;
tspan = Float32[0.0, datasize * tstep];
tsteps = range(tspan[1], tspan[2], length=datasize);
ode_data_list = zeros(Float32, (n_exp, ns, datasize));
std_list = [];

function max_min(ode_data)
    return maximum(ode_data, dims=2) .- minimum(ode_data, dims=2) .+ lb
end


for i in 1:n_exp
    u0 = u0_list[i, :];
    prob_trueode = ODEProblem(trueODEfunc, u0, tspan, k);
    ode_data = Array(solve(prob_trueode, alg, saveat=tsteps));
    ode_data += randn(size(ode_data)) .* ode_data .* noise
    ode_data_list[i, :, :] = ode_data
    push!(std_list, max_min(ode_data));
end
y_std = maximum(hcat(std_list...), dims=2);

b0 = -10.0

function p2vec(p)
    w_b = p[1:nr] .+ b0;
    w_out = reshape(p[nr + 1:end], ns, nr);
    # w_out = clamp.(w_out, -2.5, 2.5);
    w_in = clamp.(-w_out, 0, 2.5);
    return w_in, w_b, w_out
end

function crnn!(du, u, p, t)
    w_in, w_b, w_out = p2vec(p);
    w_in_x = w_in' * @. log(clamp(u, lb, ub));
    du .= w_out * @. exp(w_in_x + w_b);
end

u0 = u0_list[1, :]
p = randn(Float32, nr * (ns + 1)) .* 1.f-1;
# p[1:nr] .+= b0;

prob = ODEProblem(crnn!, u0, tspan, saveat=tsteps,
                  atol=atol, rtol=rtol,sensealg=ReverseDiffAdjoint())

function predict_neuralode(u0, p)
    pred = clamp.(Array(solve(prob, alg, u0=u0, p=p; 
                  maxiters=maxiters)), -ub, ub)
    return pred
end

function loss_neuralode(p, input, label)
    pred = predict_neuralode(input, p)
    loss = mae(label ./ y_std, pred ./ y_std)
    return loss
end

i_exp = 10
grad = ForwardDiff.gradient(x -> loss_neuralode(x, u0_list[i_exp,:], ode_data_list[i_exp,:,:]), p)

using ReverseDiff
grad = ReverseDiff.gradient(x -> loss_neuralode(x, u0_list[i_exp,:], ode_data_list[i_exp,:,:]), p)

I’d recommend using Zygote for something like this and then only doing ReverseDiff on the scalar part implicitly through the adjoint.

That’s gotta be throwing tons of warnings about not being correct keyword arguments. Fix the warnings first.

1 Like

Sorry, I ought to reply in this thread.

I’v changed the code to:

grad = Zygote.gradient(x -> loss_neuralode(x, input, label), p)

However, it will report an error

Output exceeds the size limit. Open the full output data in a text editor
type DataType has no field mutable



Stacktrace:

  [1] getproperty

    @ ./Base.jl:33 [inlined]

  [2] adjoint

    @ ~/.julia/packages/Zygote/1GXzF/src/lib/lib.jl:279 [inlined]

  [3] _pullback

    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]

  [4] _pullback

    @ ./boot.jl:607 [inlined]

  [5] _pullback(ctx::Zygote.Context, f::Type{NamedTuple{(:u0, :p, :maxiters), Tuple{Vector{Float32}, Vector{Float32}, Int64}}}, args::Tuple{Vector{Float32}, Vector{Float32}, Int64})

    @ Zygote ~/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
...
    @ Zygote ~/.julia/packages/Zygote/1GXzF/src/compiler/interface.jl:53

 [17] top-level scope

    @ ~/code/deepAdjoint/ReverseDiff.ipynb:2

I use another simple case. The ForwardDiff.gradient works, but the ReverseDiff.gradient and Zygote.gradient will report an error

Here’s the code

using DifferentialEquations, ReverseDiff
using Zygote, ForwardDiff
function simpleODE!(du, u, p, t)
    du[1] = -p[1]*u[1]
    du[2] = p[1]*u[1] - p[2]*u[2]
    du[3] = p[2]*u[2]
end

u0 = [1.0, 0.0, 0.0]
p = [1.5, 1.0]
tspan = (0.0, 5.0)

prob = ODEProblem(simpleODE!, u0, tspan, p)

sol = solve(prob)

function objective(p)
    prob = ODEProblem(simpleODE!, u0, tspan, p)
    sol = solve(prob)
    return sol[3, end]
end

p0 = [1.5, 1.0] 
# Both report an error
g = ReverseDiff.gradient(objective, p0)
g = Zygote.gradient(objective, p0)
# ForwardDiff.gradient works
g = ForwardDiff..gradient(objective, p0)

So I wonder whether ReverseDiff can be used to get the gradient of ODE. It seems that Zygote.gradient will use ReverseDiff in default. At the same time, when I use Zygote.forwarddiff, I can get the gradient

As the error message says:


ERROR: Compatibility with reverse-mode automatic differentiation requires SciMLSensitivity.jl.
Please install SciMLSensitivity.jl and do `using SciMLSensitivity`/`import SciMLSensitivity`
for this functionality. For more details, see https://sensitivity.sciml.ai/dev/.