Hello I am fairly new to Julia and I am having trouble in the implementation of a custom neural network. I am trying to take the derivative of the loss function wrt the parameters of the model, but Flux.gradient throws an error even though I can call the loss function by itself without any problems. I will appreciate any help I can get.
using Flux: gradient
using Flux.Optimise: update!, Descent
using Plots
using DifferentialEquations: Tsit5, ODEProblem, solve
using DiffEqSensitivity
include("funs.jl");
function loss(x, ps)
loss = 0.
for i in 1:length(x)
y = ∂Φ∂τᵢ_gov(x[i])
ŷ = ∂Φ∂τᵢ(x[i], ps)
loss+= sum((ŷ .- y).^2)
end
return loss/length(x)
end;
#Initialize the parameters of the model
ps = []
for i in 1:5 #5 NNs are needed to evaluate ∂Φ∂τᵢ for the loss function. Each NN has an architecture of 1x3x3x1
push!(ps,initialize([1,3,3,1]))
end
τ = [[-10000.,-10000.,-10000.]]
loss(τ,ps)
#Output: 8.747999941984282e27
gs = gradient(ps->loss(τ,ps), ps)
MethodError: no method matching similar(::Float64, ::Int64)
Closest candidates are:
similar(::Test.GenericArray, ::Integer...) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/Test/src/Test.jl:1831
similar(::ReverseDiff.TrackedArray, ::Union{Integer, AbstractUnitRange}...) at ~/.julia/packages/ReverseDiff/Y5qec/src/tracked.jl:387
similar(::BitArray, ::Int64...) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/bitarray.jl:369
...
Stacktrace:
[1] ODEAdjointProblem(sol::SciMLBase.ODESolution{Float64, 1, Vector{Float64}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Float64}}, ODEProblem{Float64, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, SciMLBase.ODEFunction{false, var"#f#1"{Tuple{Vector{Any}, Vector{Any}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, var"#f#1"{Tuple{Vector{Any}, Vector{Any}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, Vector{Float64}, Vector{Vector{Float64}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, sensealg::QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, g::Function, t::Vector{Float64}, dg::Nothing, callback::Nothing)
@ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/T7LDZ/src/quadrature_adjoint.jl:66
[2] _adjoint_sensitivities(sol::SciMLBase.ODESolution{Float64, 1, Vector{Float64}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Float64}}, ODEProblem{Float64, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, SciMLBase.ODEFunction{false, var"#f#1"{Tuple{Vector{Any}, Vector{Any}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, var"#f#1"{Tuple{Vector{Any}, Vector{Any}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, Vector{Float64}, Vector{Vector{Float64}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, sensealg::QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, alg::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, g::DiffEqSensitivity.var"#df#251"{Vector{Any}, Colon}, t::Vector{Float64}, dg::Nothing; abstol::Float64, reltol::Float64, callback::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/T7LDZ/src/quadrature_adjoint.jl:252
[3] adjoint_sensitivities(::SciMLBase.ODESolution{Float64, 1, Vector{Float64}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Float64}}, ODEProblem{Float64, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, SciMLBase.ODEFunction{false, var"#f#1"{Tuple{Vector{Any}, Vector{Any}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, var"#f#1"{Tuple{Vector{Any}, Vector{Any}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, Vector{Float64}, Vector{Vector{Float64}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::Vararg{Any}; sensealg::QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, kwargs::Base.Pairs{Symbol, Union{Nothing, Float64}, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:callback, :reltol, :abstol), Tuple{Nothing, Float64, Float64}}})
@ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/T7LDZ/src/sensitivity_interface.jl:6
[4] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#250"{Base.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, Float64, SciMLBase.NullParameters, Tuple{}, Colon, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}})(Δ::Vector{Any})
@ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/T7LDZ/src/concrete_solve.jl:249
[5] ZBack
@ ~/.julia/packages/Zygote/H6vD3/src/compiler/chainrules.jl:204 [inlined]
[6] (::Zygote.var"#kw_zpullback#41"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#250"{Base.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, Float64, SciMLBase.NullParameters, Tuple{}, Colon, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}}})(dy::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/chainrules.jl:230
[7] #212
@ ~/.julia/packages/Zygote/H6vD3/src/lib/lib.jl:203 [inlined]
[8] (::Zygote.var"#1750#back#214"{Zygote.var"#212#213"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, Zygote.var"#kw_zpullback#41"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#250"{Base.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, Float64, SciMLBase.NullParameters, Tuple{}, Colon, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}}}}})(Δ::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[9] Pullback
@ ~/.julia/packages/DiffEqBase/ziNGu/src/solve.jl:165 [inlined]
[10] (::typeof(∂(#solve#40)))(Δ::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[11] (::Zygote.var"#212#213"{Tuple{NTuple{7, Nothing}, Tuple{Nothing}}, typeof(∂(#solve#40))})(Δ::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/lib.jl:203
[12] (::Zygote.var"#1750#back#214"{Zygote.var"#212#213"{Tuple{NTuple{7, Nothing}, Tuple{Nothing}}, typeof(∂(#solve#40))}})(Δ::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[13] Pullback
@ ~/.julia/packages/DiffEqBase/ziNGu/src/solve.jl:159 [inlined]
[14] (::typeof(∂(solve##kw)))(Δ::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[15] Pullback
@ ~/Library/CloudStorage/OneDrive-purdue.edu/Research/Others/Tinkering/funs.jl:27 [inlined]
[16] (::typeof(∂(node)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[17] Pullback
@ ~/Library/CloudStorage/OneDrive-purdue.edu/Research/Others/Tinkering/funs.jl:44 [inlined]
[18] (::typeof(∂(∂Φ∂τᵢ)))(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[19] Pullback
@ ./In[2]:5 [inlined]
[20] (::typeof(∂(loss)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[21] Pullback
@ ./In[17]:1 [inlined]
[22] (::typeof(∂(#8)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[23] (::Zygote.var"#56#57"{typeof(∂(#8))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:41
[24] gradient(f::Function, args::Vector{Any})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:76
[25] top-level scope
@ In[17]:1
[26] eval
@ ./boot.jl:373 [inlined]
[27] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
@ Base ./loading.jl:1196