Force ZygoteVJP and look at the error.
I was looking this page: https://docs.juliahub.com/DiffEqSensitivity/02xYn/6.78.4/manual/differential_equation_sensitivities/#DiffEqSensitivity.ZygoteVJP
is this another package? because I can’t access anything regarding ZygoteVJP. Maybe this is the problem
ode_sol = solve(ode_nn, BS5(), p=Complex{Float64}.(p), abstol=tol, reltol=tol, sensealg = InterpolatingAdjoint(ZygoteVJP()))
Thanks, when I tried your advice, I am having “no method matching interpolatingadjoint” messgae:
MethodError: no method matching InterpolatingAdjoint(::ZygoteVJP)
Closest candidates are:
InterpolatingAdjoint(; chunk_size, autodiff, diff_type, autojacvec, checkpointing, noisemixing) at ~/.julia/packages/SciMLSensitivity/Wb65g/src/sensitivity_algorithms.jl:374Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0 [inlined]
[2] _pullback(ctx::Zygote.Context{false}, f::Type{InterpolatingAdjoint}, args::ZygoteVJP)
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:9
[3] _pullback
@ ./In[5]:39 [inlined]
[4] _pullback(::Zygote.Context{false}, ::typeof(cost_adjoint_nn), ::ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, ::Float64)
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
[5] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:816
[6] adjoint
@ ~/.julia/packages/Zygote/g2w9o/src/lib/lib.jl:203 [inlined]
[7] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[8] _pullback
@ ~/.julia/packages/SciMLBase/ys6dl/src/scimlfunctions.jl:3596 [inlined]
[9] _pullback(::Zygote.Context{false}, ::OptimizationFunction{true, Optimization.AutoZygote, typeof(cost_adjoint_nn), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ::ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}}, ::Float64)
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
[10] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:816
[11] adjoint
@ ~/.julia/packages/Zygote/g2w9o/src/lib/lib.jl:203 [inlined]
[12] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[13] _pullback
@ ~/.julia/packages/Optimization/aPPOg/src/function/zygote.jl:30 [inlined]
[14] _pullback(ctx::Zygote.Context{false}, f::Optimization.var"#156#165"{OptimizationFunction{true, Optimization.AutoZygote, typeof(cost_adjoint_nn), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Float64}, args::ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:32, ShapedAxis((32, 1), NamedTuple())), bias = ViewAxis(33:64, ShapedAxis((32, 1), NamedTuple())))), layer_2 = ViewAxis(65:1120, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_3 = ViewAxis(1121:2176, Axis(weight = ViewAxis(1:1024, ShapedAxis((32, 32), NamedTuple())), bias = ViewAxis(1025:1056, ShapedAxis((32, 1), NamedTuple())))), layer_4 = ViewAxis(2177:2242, Axis(weight = ViewAxis(1:64, ShapedAxis((2, 32), NamedTuple())), bias = ViewAxis(65:66, ShapedAxis((2, 1), NamedTuple())))))}}})
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
[15] _apply(::Function, ::Vararg{Any})
autojacvec = … just see the docs and interpret what I was saying.
I guess now it is done
- I changed this line:
nn_output = Ωp_nn([t/T],p,st)#there was no st here
tonn_output,_ = Ωp_nn([t/T],p,st)#there was no st here
- In any case, I added the following line too:
nn_output = Lux.ComponentArray(nn_output)
3)And then I listened your advise and I wrote that line:ode_sol = solve(ode_nn, BS5(), p=Complex{Float64}.(p), abstol=tol, reltol=tol, sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP()))
and this part of script worked!
I am closing this topic with the solution, if I have any further, I prefer to ask them under a new title
Awesome, good to hear. Might be good to fix the typo in the title so it’s easier for people to find who are doing this change. Most of this is Flux → Lux though.