This is my code
function p2vec(p)
w_b = p[1:nr] .+ b0;
w_out = reshape(p[nr + 1:end], ns, nr);
# w_out = clamp.(w_out, -2.5, 2.5);
w_in = clamp.(-w_out, 0, 2.5);
return w_in, w_b, w_out
end
function crnn!(du, u, p, t)
w_in, w_b, w_out = p2vec(p);
w_in_x = w_in' * @. log(clamp(u, lb, ub));
du .= w_out * @. exp(w_in_x + w_b);
end
p = randn(Float32, nr * (ns + 1)) .* 1.f-1;
prob = ODEProblem(crnn!, u0, tspan, saveat=tsteps,
atol=atol, rtol=rtol,sensealg=ReverseDiffAdjoint())
function predict_neuralode(u0, p)
pred = clamp.(Array(solve(prob, alg, u0=u0, p=p;
maxiters=maxiters)), -ub, ub)
return pred
end
I can get gradient from ForwardDiff.gradien like that:
function loss_neuralode(p, input, label)
pred = predict_neuralode(input, p)
loss = mae(label ./ y_std, pred ./ y_std)
return loss
end
i_exp = 10
loss_neuralode(p, u0_list[i_exp,:], ode_data_list[i_exp,:,:])
grad = ForwardDiff.gradient(x -> loss_neuralode(x, u0_list[i_exp,:], ode_data_list[i_exp,:,:]), p)
However, when I use ReverseDiff.gradient,
using ReverseDiff
grad = ReverseDiff.gradient(x -> loss_neuralode(x, u0_list[i_exp,:], ode_data_list[i_exp,:,:]), p)
It shows there is an error
Output exceeds the size limit. Open the full output data in a text editor
MethodError: *(::Adjoint{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}}, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}}, ::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}) is ambiguous. Candidates:
*(x::Adjoint{<:Number, <:ReverseDiff.TrackedArray{V, D, 2}}, y::AbstractVector) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:223
*(x::Adjoint{<:Number, <:ReverseDiff.TrackedArray{V, D}}, y::AbstractVector) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:223
*(x::Adjoint{<:Number, <:ReverseDiff.TrackedArray{V, D, 2}}, y::AbstractArray) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:223
*(x::Adjoint{<:Number, <:ReverseDiff.TrackedArray{V, D}}, y::AbstractArray) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:223
*(x::Adjoint{<:Number, <:AbstractMatrix}, y::ReverseDiff.TrackedArray{V, D, 1}) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:218
*(x::Adjoint{<:Number, <:AbstractMatrix}, y::ReverseDiff.TrackedArray{V, D}) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:218
*(adjA::Adjoint{<:Any, <:AbstractMatrix{T}}, x::AbstractVector{S}) where {T, S} in LinearAlgebra at /dssg/home/acct-esehazenet/hazenet-pg6/software/julia/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:103
*(x::Adjoint{<:Number, <:AbstractArray}, y::ReverseDiff.TrackedArray{V, D, 1}) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:218
*(x::AbstractMatrix, y::ReverseDiff.TrackedArray{V, D, 1}) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:214
*(x::AbstractArray, y::ReverseDiff.TrackedArray{V, D, 1}) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:214
*(x::Adjoint{<:Number, <:AbstractArray}, y::ReverseDiff.TrackedArray{V, D}) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:218
*(x::AbstractMatrix, y::ReverseDiff.TrackedArray{V, D}) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:214
*(A::AbstractMatrix{T}, x::AbstractVector{S}) where {T, S} in LinearAlgebra at /dssg/home/acct-esehazenet/hazenet-pg6/software/julia/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:54
*(x::AbstractArray, y::ReverseDiff.TrackedArray{V, D}) where {V, D} in ReverseDiff at /dssg/home/acct-esehazenet/hazenet-pg6/.julia/packages/ReverseDiff/E4Tzn/src/derivatives/linalg/arithmetic.jl:214
Possible fix, define
*(::Adjoint{<:Number, <:ReverseDiff.TrackedArray{V, D, 2, VA, DA}}, ::ReverseDiff.TrackedArray{V, D, 1, VA, DA}) where {V, D, V, D, VA, DA}
Stacktrace:
[1] crnn!(du::Vector{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}}, u::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, p::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, t::ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}})
@ Main ~/code/deepAdjoint/case2.ipynb:79
[2] ODEFunction
@ ~/.julia/packages/DiffEqBase/V7P18/src/diffeqfunction.jl:248 [inlined]
[3] (::DiffEqSensitivity.var"#77#86"{ODEFunction{true, typeof(crnn!), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}})(u::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, p::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, t::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}})
@ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/ZdaQE/src/local_sensitivity/adjoint_common.jl:127
...
@ ReverseDiff ~/.julia/packages/ReverseDiff/E4Tzn/src/api/gradients.jl:24
[20] gradient(f::Function, input::Vector{Float32})
@ ReverseDiff ~/.julia/packages/ReverseDiff/E4Tzn/src/api/gradients.jl:22
[21] top-level scope
@ ~/code/deepAdjoint/case2.ipynb:2