From Scmil_train to Optimization.jl

Hello,

I am trying to write the following line:

optimized_sol_nn = DiffEqFlux.sciml_train(p -> cost_adjoint_nn(p, 0.08), θ_nn,RADAM(0.003), maxiters = 1000)
with Optimization.jl package since scmil_train is deprecated.
I used the following documentation: documnetation and I could write the following lines:


adtype = Optimization.AutoZygote()

optf = Optimization.OptimizationFunction(p -> cost_adjoint_nn(p, 0.08), adtype)
#optprob = Optimization.OptimizationProblem(optf, pinit)

optimized_sol_nn = Optimization.solve(optf,
                                      θ_nn,
                                      RADAM(0.003)
                                      callback = callback,
                                      maxiters = 300
)

However, I am having the following error:

1 method for anonymous function “#23”:

[1] (::var"#23#24")(p) in Main at In[19]:6

Stacktrace:
[1] isinplace(f::Function, inplace_param_number::Int64, fname::String, iip_preferred::Bool; has_two_dispatches::Bool, isoptimization::Bool)
@ SciMLBase ~/.julia/packages/SciMLBase/ys6dl/src/utils.jl:279
[2] (OptimizationFunction{true})(f::var"#23#24", adtype::Optimization.AutoZygote; grad::Nothing, hess::Nothing, hv::Nothing, cons::Nothing, cons_j::Nothing, cons_h::Nothing, lag_h::Nothing, hess_prototype::Nothing, cons_jac_prototype::Nothing, cons_hess_prototype::Nothing, lag_hess_prototype::Nothing, syms::Nothing, paramsyms::Nothing, observed::typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), hess_colorvec::Nothing, cons_jac_colorvec::Nothing, cons_hess_colorvec::Nothing, lag_hess_colorvec::Nothing, expr::Nothing, cons_expr::Nothing, sys::Nothing)
@ SciMLBase ~/.julia/packages/SciMLBase/ys6dl/src/scimlfunctions.jl:3620
[3] (OptimizationFunction{true})(f::Function, adtype::Optimization.AutoZygote)
@ SciMLBase ~/.julia/packages/SciMLBase/ys6dl/src/scimlfunctions.jl:3599
[4] OptimizationFunction(::Function, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ SciMLBase ~/.julia/packages/SciMLBase/ys6dl/src/scimlfunctions.jl:3597
[5] OptimizationFunction(::Function, ::Vararg{Any})
@ SciMLBase ~/.julia/packages/SciMLBase/ys6dl/src/scimlfunctions.jl:3597
[6] top-level scope
@ In[19]:6

I quite did not understand why it is happening, can someone help me? I can ,if you need, send my cost_adjoint_nn function too.
Thanks in advance

p,_ → cost_adjoint_nn(p, 0.08)

1 Like

Hey thanks!

After doing it, now, I am having the following trouble:

MethodError: no method matching (OptimizationFunction{true})(::Vector{Bool}, ::var"#1#2", ::Optimization.AutoZygote)
Closest candidates are:
(OptimizationFunction{iip})(::Any) where iip at ~/.julia/packages/SciMLBase/ys6dl/src/scimlfunctions.jl:3599
(OptimizationFunction{iip})(::Any, ::SciMLBase.AbstractADType; grad, hess, hv, cons, cons_j, cons_h, lag_h, hess_prototype, cons_jac_prototype, cons_hess_prototype, lag_hess_prototype, syms, paramsyms, observed, hess_colorvec, cons_jac_colorvec, cons_hess_colorvec, lag_hess_colorvec, expr, cons_expr, sys) where iip at ~/.julia/packages/SciMLBase/ys6dl/src/scimlfunctions.jl:3599

Stacktrace:
[1] OptimizationFunction(::Vector{Bool}, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ SciMLBase ~/.julia/packages/SciMLBase/ys6dl/src/scimlfunctions.jl:3597
[2] OptimizationFunction(::Vector{Bool}, ::Vararg{Any})
@ SciMLBase ~/.julia/packages/SciMLBase/ys6dl/src/scimlfunctions.jl:3597
[3] top-level scope
@ In[6]:7

Hello,
I added the final version of my code. The kernel version is Julia 1.8.5

using DifferentialEquations,Lux, Optimization, Flux, DiffEqFlux, Optim,SciMLSensitivity , Plots, OrdinaryDiffEq, Zygote, StaticArrays, LinearAlgebra, BenchmarkTools, PaddedViews, LaTeXStrings, PGFPlotsX, PlotThemes, ApproxFun
pgfplotsx();
Plots.PGFPlotsXBackend();
const σ0 = Hermitian(Complex{Float64}[1 0; 0 1]);
const σx = Hermitian(Complex{Float64}[0 1; 1 0]);
const σy = Hermitian(Complex{Float64}[0 -im; im 0]);
const σz = Hermitian(Complex{Float64}[1 0; 0 -1]);
Ωp_nn = Lux.Chain(Lux.Dense(1,32), Lux.Dense(32,32,tanh), Lux.Dense(32,32,tanh), Lux.Dense(32,2))
θ_nn = initial_params(Ωp_nn);
const β = 2π*0.2;

const tol = 1e-7;

const Hϵ = Hermitian(Complex{Float64}[1 0; 0 -1;]);


const T = 6.0;
tspan = (0.0, T);

const θ = π/2;
const Utarget = cos(θ/2)*σ0 + sin(θ/2)*im*σz;
    
const steepness = 20*T;
smooth_square_envelope(t) = coth(steepness/4)*( tanh(steepness*t/(4*T)) - tanh(steepness*(t-T)/(4*T)) ) - 1;

function schrodinger_nn(u, p, t)
    @views @inbounds U = u[1:2,1:2];
    @views @inbounds ℰ = u[3:4,1:2];
    envelope = smooth_square_envelope(t);
    nn_output = Ωp_nn([t/T],p)
    @inbounds Ω = envelope*( nn_output[1]*sin(nn_output[2]) )
    @inbounds H = Hermitian([β Ω; Ω -β]);
    local dℰ = Hermitian(U'*Hϵ*U);
    return [-im*H*U; dℰ; dℰ*ℰ - ℰ*dℰ]; # 1/2 of (dℰ*ℰ - ℰ*dℰ)/2 is in cost_adjoint_nn
end

const u0 = Complex{Float64}[1 0; 0 1; 0 0; 0 0; 0 0; 0 0];

ode_nn = ODEProblem(schrodinger_nn, u0, tspan, θ);

function callback(p, cost)
    return cost < 1e-7
end


function cost_adjoint_nn(p, w=1.0)
    ode_sol = solve(ode_nn, BS5(), p=Complex{Float64}.(p), abstol=tol, reltol=tol)
    usol = last(ode_sol)
    @views @inbounds Ugate = usol[1:2,1:2];
    @views @inbounds ℰ = usol[3:4, 1:2];
    @views @inbounds ℰ2 = usol[5:6, 1:2];

    loss = abs(1.0 - abs(tr(Ugate*Utarget')/2)^2) + w^2*(norm(ℰ)/2)^2 + 4*w^4*(norm(ℰ2)/4)^2

    return loss

end
#######################THE NEW METHOD####################################################################
#In the new method, when I say p -> cost_adjoint_nn(p, 0.08), I am having trouble with undefined p value.
#However in thr sciml_train version, I could write p -> cost_adjoint_nn(p, 0.08) 
#and the code did not ask me the value of p
pinit = Lux.ComponentArray(θ_nn)#θ_nn#Ωp_nn
#callback(pinit, cost_adjoint_nn(pinit)...; doplot=true)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction(pinit,_ -> cost_adjoint_nn(pinit, 0.08), adtype) #pinit was theta_nn
                                                                                          #theta_nn was p
optprob = Optimization.OptimizationProblem(optf, pinit)

optimized_sol_nn = Optimization.solve(optprob,
                                      #θ_nn,
                                      RADAM(0.003),#RADAM
                                      callback = callback,
                                      maxiters = 1000)
######################THE OLD METHOD###########################################################################
################### This was the old version of the code and it was working without asking of p value###########

#optimized_sol_nn = DiffEqFlux.sciml_train(p -> cost_adjoint_nn(p, 0.08), θ_nn,RADAM(0.003), maxiters = 1000)
optimized_sol_nn2 = Optimization.solve(pinit -> cost_adjoint_nn(pinit, 0.08), optimized_sol_nn.minimizer, BFGS(initial_stepnorm=0.001), maxiters = 1000, allow_f_increases = true)
#optimized_sol_nn2 = DiffEqFlux.sciml_train(p -> cost_adjoint_nn(p, 0.08), optimized_sol_nn.minimizer, BFGS(initial_stepnorm=0.001), maxiters = 1000, allow_f_increases = true)

And after this line `optf = Optimization.OptimizationFunction(pinit,_ -> cost_adjoint_nn(pinit, 0.08), adtype) #pinit was theta_nn`, I am getting the following error message
:

> MethodError: no method matching (OptimizationFunction{true})(::Vector{Bool}, ::var"#1#2", ::Optimization.AutoZygote)
> Closest candidates are:
>   (OptimizationFunction{iip})(::Any) where iip at ~/.julia/packages/SciMLBase/ys6dl/src/scimlfunctions.jl:3599
>   (OptimizationFunction{iip})(::Any, ::SciMLBase.AbstractADType; grad, hess, hv, cons, cons_j, cons_h, lag_h, hess_prototype, cons_jac_prototype, cons_hess_prototype, lag_hess_prototype, syms, paramsyms, observed, hess_colorvec, cons_jac_colorvec, cons_hess_colorvec, lag_hess_colorvec, expr, cons_expr, sys) where iip at ~/.julia/packages/SciMLBase/ys6dl/src/scimlfunctions.jl:3599
> 
> Stacktrace:
>  [1] OptimizationFunction(::Vector{Bool}, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
>    @ SciMLBase ~/.julia/packages/SciMLBase/ys6dl/src/scimlfunctions.jl:3597
>  [2] OptimizationFunction(::Vector{Bool}, ::Vararg{Any})
>    @ SciMLBase ~/.julia/packages/SciMLBase/ys6dl/src/scimlfunctions.jl:3597
>  [3] top-level scope
>    @ In[6]:8
> 
> optimized_sol_nn2 = Optimization.solve(pinit -> cost_adjoint_nn(pinit, 0.08), optimized_sol_nn.minimizer, BFGS(initial_stepnorm=0.001), maxiters = 1000, allow_f_increases = true)
> #optimized_sol_nn2 = DiffEqFlux.sciml_train(p -> cost_adjoint_nn(p, 0.08), optimized_sol_nn.minimizer, BFGS(initial_stepnorm=0.001), maxiters = 1000, allow_f_increases = true)

I really did not understand why. I tried different version of this line but I am still getting the error.

Can someone help me?
Thanks in advance

are you on the release versions?

I am using Optimization v3.11.2

I guess it is working like that:

optf = Optimization.OptimizationFunction((pinit,_) -> cost_adjoint_nn(pinit, 0.08), adtype) #pinit was theta_nn

We need parenthesis in this line but then I got another error message for other line:

Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Wb65g/src/concrete_solve.jl:92
MethodError: no method matching (::Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), Tuple{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}})(::Vector{Float64}, ::Vector{ComplexF64})
Closest candidates are:
(::Lux.Chain)(::Any, ::Any, ::NamedTuple) at ~/.julia/packages/Lux/6vByk/src/layers/containers.jl:456

Stacktrace:
[1] schrodinger_nn(u::Matrix{ComplexF64}, p::Vector{ComplexF64}, t::Float64)
@ Main ./In[5]:21
[2] ODEFunction
@ ~/.julia/packages/SciMLBase/ys6dl/src/scimlfunctions.jl:2096 [inlined]
[3] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{BS5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, false, Matrix{ComplexF64}, Nothing, Float64, Vector{ComplexF64}, Float64, Float64, Float64, Float64, Vector{Matrix{ComplexF64}}, ODESolution{ComplexF64, 3, Vector{Matrix{ComplexF64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Matrix{ComplexF64}}}, ODEProblem{Matrix{ComplexF64}, Tuple{Float64, Float64}, false, Vector{ComplexF64}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, BS5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Matrix{ComplexF64}}, Vector{Float64}, Vector{Vector{Matrix{ComplexF64}}}, OrdinaryDiffEq.BS5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, OrdinaryDiffEq.BS5ConstantCache{Float64, Float64}, OrdinaryDiffEq.DEOptions{Float64, Float64, Float64, Float64, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(opnorm), Bool, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, Tuple{}, Tuple{}}, Matrix{ComplexF64}, ComplexF64, Nothing, OrdinaryDiffEq.DefaultInit}, cache::OrdinaryDiffEq.BS5ConstantCache{Float64, Float64})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/thWwa/src/perform_step/low_order_rk_perform_step.jl:440
[4] __init(prob::ODEProblem{Matrix{ComplexF64}, Tuple{Float64, Float64}, false, Vector{ComplexF64}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, alg::BS5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{Val{true}}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Bool, callback::Nothing, dense::Bool, calck::Bool, dt::Float64, dtmin::Nothing, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{Int64}, abstol::Float64, reltol::Float64, qmin::Rational{Int64}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{Int64}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::Base.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:save_noise,), Tuple{Bool}}})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/thWwa/src/solve.jl:499
[5] #__solve#623
@ ~/.julia/packages/OrdinaryDiffEq/thWwa/src/solve.jl:5 [inlined]
[6] #solve_call#22
@ ~/.julia/packages/DiffEqBase/QR8gq/src/solve.jl:494 [inlined]
[7] solve_up(prob::ODEProblem{Matrix{ComplexF64}, Tuple{Float64, Float64}, false, Vector{ComplexF64}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, sensealg::Nothing, u0::Matrix{ComplexF64}, p::Vector{ComplexF64}, args::BS5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}; kwargs::Base.Pairs{Symbol, Real, NTuple{6, Symbol}, NamedTuple{(:save_noise, :save_start, :save_end, :verbose, :abstol, :reltol), Tuple{Bool, Bool, Bool, Bool, Float64, Float64}}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/QR8gq/src/solve.jl:915
[8] #solve#27
@ ~/.julia/packages/DiffEqBase/QR8gq/src/solve.jl:825 [inlined]
[9] _concrete_solve_adjoint(::ODEProblem{Matrix{ComplexF64}, Tuple{Float64, Float64}, false, Vector{ComplexF64}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::BS5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::InterpolatingAdjoint{0, false, Val{:central}, Bool}, ::Matrix{ComplexF64}, ::Vector{ComplexF64}, ::SciMLBase.ChainRulesOriginator; save_start::Bool, save_end::Bool, saveat::Vector{Float64}, save_idxs::Nothing, kwargs::Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:verbose, :abstol, :reltol), Tuple{Bool, Float64, Float64}}})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Wb65g/src/concrete_solve.jl:279
[10] _concrete_solve_adjoint(::ODEProblem{Matrix{ComplexF64}, Tuple{Float64, Float64}, false, Vector{ComplexF64}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::BS5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::Nothing, ::Matrix{ComplexF64}, ::Vector{ComplexF64}, ::SciMLBase.ChainRulesOriginator; verbose::Bool, kwargs::Base.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:abstol, :reltol), Tuple{Float64, Float64}}})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Wb65g/src/concrete_solve.jl:163
[11] _solve_adjoint(prob::ODEProblem{Matrix{ComplexF64}, Tuple{Float64, Float64}, false, Float64, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, sensealg::Nothing, u0::Matrix{ComplexF64}, p::Vector{ComplexF64}, originator::SciMLBase.ChainRulesOriginator, args::BS5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}; merge_callbacks::Bool, kwargs::Base.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:abstol, :reltol), Tuple{Float64, Float64}}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/QR8gq/src/solve.jl:1323
[12] #rrule#50
@ ~/.julia/packages/DiffEqBase/QR8gq/src/solve.jl:1276 [inlined]
[13] chain_rrule_kw
@ ~/.julia/packages/Zygote/g2w9o/src/compiler/chainrules.jl:235 [inlined]
[14] macro expansion
@ ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0 [inlined]
[15] _pullback
@ ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:9 [inlined]
[16] _apply
@ ./boot.jl:816 [inlined]
[17] adjoint
@ ~/.julia/packages/Zygote/g2w9o/src/lib/lib.jl:203 [inlined]
[18] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[19] _pullback
@ ~/.julia/packages/DiffEqBase/QR8gq/src/solve.jl:825 [inlined]

That looks to be a typo in the Lux call where you’re missing the st. It has 3 arguments in every tutorial, right? I presume you only put 2?

1 Like

I changed the lines as follows:

pinit = Lux.ComponentArray(θ_nn)#θ_nn#Ωp_nn
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction(cost_adjoint_nn, adtype)
optprob = Optimization.OptimizationProblem(optf, pinit, 0.08)                                                                                          
optimized_sol_nn = Optimization.solve(optprob,
                                      RADAM(0.003),#RADAM
                                      callback = callback,
                                      maxiters = 1000)

However, I am still getting catastrophic errors…

like?

Hello,

based on the answer from github page,
I added these lines to my code:

rng = Random.default_rng()
p, st = Lux.setup(rng, Ωp_nn)

And I changed the line from nn_output = Ωp_nn([t/T],p) to nn_output = Ωp_nn([t/T],p,st)#there was no st here
I got the following error message:

type Array has no field layer_1

Stacktrace:
  [1] getproperty
    @ ./Base.jl:38 [inlined]
  [2] macro expansion
    @ ~/.julia/packages/Lux/6vByk/src/layers/containers.jl:0 [inlined]
  [3] applychain(layers::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), Tuple{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}, x::Vector{Float64}, ps::Vector{ComplexF64}, st::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), NTuple{4, NamedTuple{(), Tuple{}}}})
@ Lux ~/.julia/packages/Lux/6vByk/src/layers/containers.jl:460
  [4] (::Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), Tuple{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}})(x::Vector{Float64}, ps::Vector{ComplexF64}, st::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), NTuple{4, NamedTuple{(), Tuple{}}}})
    @ Lux ~/.julia/packages/Lux/6vByk/src/layers/containers.jl:457
  [5] schrodinger_nn(u::Matrix{ComplexF64}, p::Vector{ComplexF64}, t::Float64)
    @ Main ./In[15]:21
  [6] ODEFunction
.....

You missed the line pinit = Lux.ComponentArray(p)

1 Like

i tried that too and I got this error:

MethodError: no method matching sin(::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), NTuple{4, NamedTuple{(), Tuple{}}}})
Closest candidates are:
  sin(::T) where T<:Union{Float32, Float64} at special/trig.jl:29
  sin(::ProductFun{<:Chebyshev, <:Chebyshev}) at ~/.julia/packages/ApproxFunOrthogonalPolynomials/Z4COs/src/Spaces/Chebyshev/Chebyshev.jl:247
  sin(::ProductFun) at ~/.julia/packages/ApproxFunBase/yOX4F/src/Multivariate/ProductFun.jl:394
  ...

Stacktrace:

You didn’t pass the componentarray, you passed the namedtuple

1 Like

I did not understand: pinit is my componentarray and it comes from this line: pinit = Lux.ComponentArray(p)#θ_nn#Ωp_nn#
And then, I gave the pinit to my optimization problem as follows

pinit = Lux.ComponentArray(p)#θ_nn#Ωp_nn#p
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction(cost_adjoint_nn, adtype)
optprob = Optimization.OptimizationProblem(optf, pinit, 0.08)                                                                                          
optimized_sol_nn = Optimization.solve(optprob,
                                      RADAM(0.003),#RADAM
                                      callback = callback,
                                      maxiters = 1000)

You wrote that I did not pass the componentarray, pinit is my componentarray and I passed it. Am I wrong?

then what’s the …?

since the error is too long, I exceeded the max character limit so I did not paste all of them but I pasted until “…”. Here is a little long version of error message(until the character limit)

┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Wb65g/src/concrete_solve.jl:92
MethodError: no method matching sin(::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), NTuple{4, NamedTuple{(), Tuple{}}}})
Closest candidates are:
  sin(::T) where T<:Union{Float32, Float64} at special/trig.jl:29
  sin(::ProductFun{<:Chebyshev, <:Chebyshev}) at ~/.julia/packages/ApproxFunOrthogonalPolynomials/Z4COs/src/Spaces/Chebyshev/Chebyshev.jl:247
  sin(::ProductFun) at ~/.julia/packages/ApproxFunBase/yOX4F/src/Multivariate/ProductFun.jl:394
  ...

Stacktrace:
  [1] schrodinger_nn(u::Matrix{ComplexF64}, p::ComponentArrays.ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, t::Float64)
    @ Main ./In[39]:22
  [2] ODEFunction
    @ ~/.julia/packages/SciMLBase/ys6dl/src/scimlfunctions.jl:2096 [inlined]
  [3] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{BS5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, false, Matrix{ComplexF64}, Nothing, Float64, ComponentArrays.ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, Float64, Float64, Float64, Float64, Vector{Matrix{ComplexF64}}, ODESolution{ComplexF64, 3, Vector{Matrix{ComplexF64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Matrix{ComplexF64}}}, ODEProblem{Matrix{ComplexF64}, Tuple{Float64, Float64}, false, ComponentArrays.ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, BS5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Matrix{ComplexF64}}, Vector{Float64}, Vector{Vector{Matrix{ComplexF64}}}, OrdinaryDiffEq.BS5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, OrdinaryDiffEq.BS5ConstantCache{Float64, Float64}, OrdinaryDiffEq.DEOptions{Float64, Float64, Float64, Float64, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(opnorm), Bool, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, Tuple{}, Tuple{}}, Matrix{ComplexF64}, ComplexF64, Nothing, OrdinaryDiffEq.DefaultInit}, cache::OrdinaryDiffEq.BS5ConstantCache{Float64, Float64})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/thWwa/src/perform_step/low_order_rk_perform_step.jl:440
  [4] __init(prob::ODEProblem{Matrix{ComplexF64}, Tuple{Float64, Float64}, false, ComponentArrays.ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, alg::BS5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{Val{true}}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Bool, callback::Nothing, dense::Bool, calck::Bool, dt::Float64, dtmin::Nothing, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{Int64}, abstol::Float64, reltol::Float64, qmin::Rational{Int64}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{Int64}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::Base.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:save_noise,), Tuple{Bool}}})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/thWwa/src/solve.jl:499
  [5] #__solve#623
    @ ~/.julia/packages/OrdinaryDiffEq/thWwa/src/solve.jl:5 [inlined]
  [6] #solve_call#22
    @ ~/.julia/packages/DiffEqBase/QR8gq/src/solve.jl:494 [inlined]
  [7] solve_up(prob::ODEProblem{Matrix{ComplexF64}, Tuple{Float64, Float64}, false, ComponentArrays.ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, sensealg::Nothing, u0::Matrix{ComplexF64}, p::ComponentArrays.ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, args::BS5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}; kwargs::Base.Pairs{Symbol, Real, NTuple{6, Symbol}, NamedTuple{(:save_noise, :save_start, :save_end, :verbose, :abstol, :reltol), Tuple{Bool, Bool, Bool, Bool, Float64, Float64}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/QR8gq/src/solve.jl:915
  [8] #solve#27
    @ ~/.julia/packages/DiffEqBase/QR8gq/src/solve.jl:825 [inlined]
  [9] _concrete_solve_adjoint(::ODEProblem{Matrix{ComplexF64}, Tuple{Float64, Float64}, false, ComponentArrays.ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::BS5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::InterpolatingAdjoint{0, false, Val{:central}, Bool}, ::Matrix{ComplexF64}, ::ComponentArrays.ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, ::SciMLBase.ChainRulesOriginator; save_start::Bool, save_end::Bool, saveat::Vector{Float64}, save_idxs::Nothing, kwargs::Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:verbose, :abstol, :reltol), Tuple{Bool, Float64, Float64}}})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Wb65g/src/concrete_solve.jl:279
 [10] _concrete_solve_adjoint(::ODEProblem{Matrix{ComplexF64}, Tuple{Float64, Float64}, false, ComponentArrays.ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::BS5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::Nothing, ::Matrix{ComplexF64}, ::ComponentArrays.ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, ::SciMLBase.ChainRulesOriginator; verbose::Bool, kwargs::Base.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:abstol, :reltol), Tuple{Float64, Float64}}})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Wb65g/src/concrete_solve.jl:163
 [11] _solve_adjoint(prob::ODEProblem{Matrix{ComplexF64}, Tuple{Float64, Float64}, false, Float64, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, sensealg::Nothing, u0::Matrix{ComplexF64}, p::ComponentArrays.ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, originator::SciMLBase.ChainRulesOriginator, args::BS5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}; merge_callbacks::Bool, kwargs::Base.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:abstol, :reltol), Tuple{Float64, Float64}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/QR8gq/src/solve.jl:1323
 [12] #rrule#50
    @ ~/.julia/packages/DiffEqBase/QR8gq/src/solve.jl:1276 [inlined]
 [13] chain_rrule_kw
    @ ~/.julia/packages/Zygote/g2w9o/src/compiler/chainrules.jl:235 [inlined]
 [14] macro expansion
    @ ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0 [inlined]
 [15] _pullback
    @ ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:9 [inlined]
 [16] _apply
    @ ./boot.jl:816 [inlined]
 [17] adjoint
    @ ~/.julia/packages/Zygote/g2w9o/src/lib/lib.jl:203 [inlined]
 [18] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [19] _pullback
    @ ~/.julia/packages/DiffEqBase/QR8gq/src/solve.jl:825 [inlined]
 [20] _pullback(::Zygote.Context{false}, ::DiffEqBase.var"##solve#27", ::Nothing, ::Nothing, ::ComponentArrays.ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, ::Val{true}, ::Base.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:abstol, :reltol), Tuple{Float64, Float64}}}, ::typeof(solve), ::ODEProblem{Matrix{ComplexF64}, Tuple{Float64, Float64}, false, Float64, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::BS5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
 [21] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:816
 [22] adjoint
    @ ~/.julia/packages/Zygote/g2w9o/src/lib/lib.jl:203 [inlined]
 [23] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [24] _pullback
    @ ~/.julia/packages/DiffEqBase/QR8gq/src/solve.jl:815 [inlined]
 [25] _pullback(::Zygote.Context{false}, ::CommonSolve.var"#solve##kw", ::NamedTuple{(:p, :abstol, :reltol), Tuple{ComponentArrays.ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, Float64, Float64}}, ::typeof(solve), ::ODEProblem{Matrix{ComplexF64}, Tuple{Float64, Float64}, false, Float64, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::BS5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
 [26] _pullback
    @ ./In[39]:38 [inlined]
 [27] _pullback(::Zygote.Context{false}, ::typeof(cost_adjoint_nn), ::ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, ::Float64)
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
 [28] _apply
    @ ./boot.jl:816 [inlined]
 [29] adjoint
    @ ~/.julia/packages/Zygote/g2w9o/src/lib/lib.jl:203 [inlined]
 [30] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [31] _pullback
    @ ~/.julia/packages/SciMLBase/ys6dl/src/scimlfunctions.jl:3596 [inlined]
 [32] _pullback(::Zygote.Context{false}, ::OptimizationFunction{true, Optimization.AutoZygote, typeof(cost_adjoint_nn), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ::ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, ::Float64)
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
 [33] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:816
 [34] adjoint
    @ ~/.julia/packages/Zygote/g2w9o/src/lib/lib.jl:203 [inlined]
 [35] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [36] _pullback
    @ ~/.julia/packages/Optimization/aPPOg/src/function/zygote.jl:30 [inlined]
 [37] _pullback(ctx::Zygote.Context{false}, f::Optimization.var"#156#165"{OptimizationFunction{true, Optimization.AutoZygote, typeof(cost_adjoint_nn), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Float64}, args::ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
 [38] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:816
 [39] adjoint
    @ ~/.julia/packages/Zygote/g2w9o/src/lib/lib.jl:203 [inlined]
 [40] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [41] _pullback
    @ ~/.julia/packages/Optimization/aPPOg/src/function/zygote.jl:34 [inlined]
 [42] _pullback(ctx::Zygote.Context{false}, f::Optimization.var"#158#167"{Tuple{}, Optimization.var"#156#165"{OptimizationFunction{true, Optimization.AutoZygote, typeof(cost_adjoint_nn), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Float64}}, args::ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
 [43] pullback(f::Function, cx::Zygote.Context{false}, args::ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:44
 [44] pullback
    @ ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:42 [inlined]
 [45] gradient(f::Function, args::ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}})
   

That looks fine. How are you calling the NN? How is schrodinger_nn now defined?

1 Like

here is the schrodinger_nn function:

function schrodinger_nn(u, p, t)
    @views @inbounds U = u[1:2,1:2];
    @views @inbounds ℰ = u[3:4,1:2];
    envelope = smooth_square_envelope(t);
    nn_output = Ωp_nn([t/T],p,st)#there was no st here
    #nn_output = Lux.ComponentArray(nn_output)
    @inbounds Ω = envelope*( nn_output[1]*sin(nn_output[2])) #Lux.ComponentArray(nn_output[1]*Lux.ComponentArray(nn_output[2]))
    @inbounds H = Hermitian([β Ω; Ω -β]);
    local dℰ = Hermitian(U'*Hϵ*U);
    return [-im*H*U; dℰ; dℰ*ℰ - ℰ*dℰ]; # 1/2 of (dℰ*ℰ - ℰ*dℰ)/2 is in cost_adjoint_nn
end
1 Like

you mean nn_output,st = Ωp_nn([t/T],p,st)? Lux always outputs (value,state)

If I say: nn_output,_= Ωp_nn([t/T],p,st), I am getting the following error:

┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Wb65g/src/concrete_solve.jl:92
MethodError: no method matching default_relstep(::Nothing, ::Type{ComplexF64})
Closest candidates are:
  default_relstep(::Type, ::Any) at ~/.julia/packages/FiniteDiff/40JnL/src/epsilons.jl:25
  default_relstep(::Val{fdtype}, ::Type{T}) where {fdtype, T<:Number} at ~/.julia/packages/FiniteDiff/40JnL/src/epsilons.jl:26

Stacktrace:
  [1] finite_difference_jacobian!(J::Matrix{ComplexF64}, f::Function, x::Matrix{ComplexF64}, fdtype::Nothing, returntype::Type, f_in::Nothing) (repeats 2 times)
    @ FiniteDiff ~/.julia/packages/FiniteDiff/40JnL/src/jacobians.jl:298
  [2] jacobian!(J::Matrix{ComplexF64}, f::Function, x::Matrix{ComplexF64}, fx::Nothing, alg::InterpolatingAdjoint{0, false, Val{:central}, Bool}, jac_config::Nothing)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Wb65g/src/derivative_wrappers.jl:150
  [3] _vecjacobian!(dλ::SubArray{ComplexF64, 1, Vector{ComplexF64}, Tuple{UnitRange{Int64}}, true}, y::Matrix{ComplexF64}, λ::SubArray{ComplexF64, 1, Vector{ComplexF64}, Tuple{UnitRange{Int64}}, true}, p::ComponentArrays.ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, t::Float64, S::SciMLSensitivity.ODEInterpolatingAdjointSensitivityFunction{SciMLSensitivity.AdjointDiffCache{SciMLBase.UDerivativeWrapper{ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Float64, ComponentArrays.ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}}, SciMLSensitivity.ParamGradientWrapper{ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Float64, Matrix{ComplexF64}}, Nothing, Matrix{ComplexF64}, Matrix{ComplexF64}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, UniformScaling{Bool}}, InterpolatingAdjoint{0, false, Val{:central}, Bool}, Matrix{ComplexF64}, ODESolution{ComplexF64, 3, Vector{Matrix{ComplexF64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Matrix{ComplexF64}}}, ODEProblem{Matrix{ComplexF64}, Tuple{Float64, Float64}, false, ComponentArrays.ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, BS5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Matrix{ComplexF64}}, Vector{Float64}, Vector{Vector{Matrix{ComplexF64}}}, OrdinaryDiffEq.BS5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats, Nothing}, Nothing, ODEProblem{Matrix{ComplexF64}, Tuple{Float64, Float64}, false, ComponentArrays.ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}}, isautojacvec::Bool, dgrad::SubArray{ComplexF64, 1, Vector{ComplexF64}, Tuple{UnitRange{Int64}}, true}, dy::Nothing, W::Nothing)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Wb65g/src/derivative_wrappers.jl:262
  [4] vecjacobian!(dλ::SubArray{ComplexF64, 1, Vector{ComplexF64}, Tuple{UnitRange{Int64}}, true}, y::Matrix{ComplexF64}, λ::SubArray{ComplexF64, 1, Vector{ComplexF64}, Tuple{UnitRange{Int64}}, true}, p::ComponentArrays.ComponentVector{ComplexF64, Vector{ComplexF64}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, t::Float64, S::SciMLSensitivity.ODEInterpolatingAdjointSensitivityFunction{SciMLSensitivity.AdjointDiffCache{SciMLBase.UDerivativeWrapper{ODEFunction{false, SciMLBase.AutoSpecialize, typeof(schrodinger_nn), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing