Flux.gradient throws MethodError even though loss is evaluated just fine

Hello I am fairly new to Julia and I am having trouble in the implementation of a custom neural network. I am trying to take the derivative of the loss function wrt the parameters of the model, but Flux.gradient throws an error even though I can call the loss function by itself without any problems. I will appreciate any help I can get.

using Flux: gradient
using Flux.Optimise: update!, Descent
using Plots
using DifferentialEquations: Tsit5, ODEProblem, solve
using DiffEqSensitivity
include("funs.jl");
function loss(x, ps)
    loss = 0.
    for i in 1:length(x)
        y = ∂Φ∂τᵢ_gov(x[i])
        ŷ = ∂Φ∂τᵢ(x[i], ps)
        loss+= sum((ŷ .- y).^2)
    end
    return loss/length(x)
end;

#Initialize the parameters of the model
ps = []
for i in 1:5 #5 NNs are needed to evaluate ∂Φ∂τᵢ for the loss function. Each NN has an architecture of 1x3x3x1
    push!(ps,initialize([1,3,3,1]))
end

τ = [[-10000.,-10000.,-10000.]]

loss(τ,ps)
#Output: 8.747999941984282e27

gs = gradient(ps->loss(τ,ps), ps)
MethodError: no method matching similar(::Float64, ::Int64)
Closest candidates are:
  similar(::Test.GenericArray, ::Integer...) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/Test/src/Test.jl:1831
  similar(::ReverseDiff.TrackedArray, ::Union{Integer, AbstractUnitRange}...) at ~/.julia/packages/ReverseDiff/Y5qec/src/tracked.jl:387
  similar(::BitArray, ::Int64...) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/bitarray.jl:369
  ...

Stacktrace:
  [1] ODEAdjointProblem(sol::SciMLBase.ODESolution{Float64, 1, Vector{Float64}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Float64}}, ODEProblem{Float64, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, SciMLBase.ODEFunction{false, var"#f#1"{Tuple{Vector{Any}, Vector{Any}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, var"#f#1"{Tuple{Vector{Any}, Vector{Any}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, Vector{Float64}, Vector{Vector{Float64}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, sensealg::QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, g::Function, t::Vector{Float64}, dg::Nothing, callback::Nothing)
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/T7LDZ/src/quadrature_adjoint.jl:66
  [2] _adjoint_sensitivities(sol::SciMLBase.ODESolution{Float64, 1, Vector{Float64}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Float64}}, ODEProblem{Float64, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, SciMLBase.ODEFunction{false, var"#f#1"{Tuple{Vector{Any}, Vector{Any}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, var"#f#1"{Tuple{Vector{Any}, Vector{Any}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, Vector{Float64}, Vector{Vector{Float64}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, sensealg::QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, alg::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, g::DiffEqSensitivity.var"#df#251"{Vector{Any}, Colon}, t::Vector{Float64}, dg::Nothing; abstol::Float64, reltol::Float64, callback::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/T7LDZ/src/quadrature_adjoint.jl:252
  [3] adjoint_sensitivities(::SciMLBase.ODESolution{Float64, 1, Vector{Float64}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Float64}}, ODEProblem{Float64, Tuple{Float64, Float64}, false, SciMLBase.NullParameters, SciMLBase.ODEFunction{false, var"#f#1"{Tuple{Vector{Any}, Vector{Any}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, var"#f#1"{Tuple{Vector{Any}, Vector{Any}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, Vector{Float64}, Vector{Vector{Float64}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::Vararg{Any}; sensealg::QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, kwargs::Base.Pairs{Symbol, Union{Nothing, Float64}, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:callback, :reltol, :abstol), Tuple{Nothing, Float64, Float64}}})
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/T7LDZ/src/sensitivity_interface.jl:6
  [4] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#250"{Base.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, Float64, SciMLBase.NullParameters, Tuple{}, Colon, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}})(Δ::Vector{Any})
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/T7LDZ/src/concrete_solve.jl:249
  [5] ZBack
    @ ~/.julia/packages/Zygote/H6vD3/src/compiler/chainrules.jl:204 [inlined]
  [6] (::Zygote.var"#kw_zpullback#41"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#250"{Base.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, Float64, SciMLBase.NullParameters, Tuple{}, Colon, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}}})(dy::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/chainrules.jl:230
  [7] #212
    @ ~/.julia/packages/Zygote/H6vD3/src/lib/lib.jl:203 [inlined]
  [8] (::Zygote.var"#1750#back#214"{Zygote.var"#212#213"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, Zygote.var"#kw_zpullback#41"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#250"{Base.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, QuadratureAdjoint{0, true, Val{:central}, ZygoteVJP}, Float64, SciMLBase.NullParameters, Tuple{}, Colon, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}}}}})(Δ::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [9] Pullback
    @ ~/.julia/packages/DiffEqBase/ziNGu/src/solve.jl:165 [inlined]
 [10] (::typeof(∂(#solve#40)))(Δ::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
 [11] (::Zygote.var"#212#213"{Tuple{NTuple{7, Nothing}, Tuple{Nothing}}, typeof(∂(#solve#40))})(Δ::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/lib.jl:203
 [12] (::Zygote.var"#1750#back#214"{Zygote.var"#212#213"{Tuple{NTuple{7, Nothing}, Tuple{Nothing}}, typeof(∂(#solve#40))}})(Δ::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [13] Pullback
    @ ~/.julia/packages/DiffEqBase/ziNGu/src/solve.jl:159 [inlined]
 [14] (::typeof(∂(solve##kw)))(Δ::RecursiveArrayTools.VectorOfArray{Any, 1, Vector{Any}})
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
 [15] Pullback
    @ ~/Library/CloudStorage/OneDrive-purdue.edu/Research/Others/Tinkering/funs.jl:27 [inlined]
 [16] (::typeof(∂(node)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
 [17] Pullback
    @ ~/Library/CloudStorage/OneDrive-purdue.edu/Research/Others/Tinkering/funs.jl:44 [inlined]
 [18] (::typeof(∂(∂Φ∂τᵢ)))(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
 [19] Pullback
    @ ./In[2]:5 [inlined]
 [20] (::typeof(∂(loss)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
 [21] Pullback
    @ ./In[17]:1 [inlined]
 [22] (::typeof(∂(#8)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
 [23] (::Zygote.var"#56#57"{typeof(∂(#8))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:41
 [24] gradient(f::Function, args::Vector{Any})
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:76
 [25] top-level scope
    @ In[17]:1
 [26] eval
    @ ./boot.jl:373 [inlined]
 [27] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
    @ Base ./loading.jl:1196

I did not include the contents of funs.jl to reduce clutter, but please let me know if that is needed for identifying the issue. It contains the functions ∂Φ∂τᵢ_gov, ∂Φ∂τᵢ and initialize.

When trying to reproduce I get

ERROR: UndefVarError: initialize not defined

This is just the same question and answer as Using Flux: gradient on DifferentialEquations: solve results in an error and https://www.reddit.com/r/Julia/comments/u4qgxc/flux_gradient_cant_differentiate_a_function/.