Hello, I want to try using Optimization.jl to perform model optimization based on Lux.jl. Here’s my code.
using Lux
using Zygote
using StableRNGs
using ComponentArrays
using Optimization
using OptimizationOptimisers
function LSTMCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T,2} where {T}
x = reshape(x, size(x)..., 1)
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
output = [vec(classifier(y))]
for x in x_rest
y, carry = lstm_cell((x, carry))
output = vcat(output, [vec(classifier(y))])
end
@return hcat(output...)
end
end
model = LSTMCompact(3, 10, 1)
ps, st = Lux.setup(StableRNGs.LehmerRNG(1234), model)
ps_axes = getaxes(ComponentVector(ps))
model_func = (x, ps) -> Lux.apply(model, x, ps, st)
x = rand(3, 10)
y = rand(1, 10)
function object(u, p)
ps = ComponentVector(u, ps_axes)
sum((model_func(x, ps)[1] .- y) .^ 2)
end
opt_func = Optimization.OptimizationFunction(object, Optimization.AutoZygote())
opt_prob = Optimization.OptimizationProblem(opt_func, Vector(ComponentVector(ps)))
opt_sol = Optimization.solve(opt_prob, OptimizationOptimisers.Adam(0.1), maxiters=1000)
Translation: The code works when using AutoForwardDiff
as the AD type, but when using AutoZygote
it encounters the following error:
ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{…}})(::NTuple{9, Vector{…}})
Closest candidates are:
(::ChainRulesCore.ProjectTo{T})(::ChainRulesCore.NotImplemented) where T
@ ChainRulesCore D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\projection.jl:121
(::ChainRulesCore.ProjectTo{T})(::ChainRulesCore.AbstractZero) where T
@ ChainRulesCore D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\projection.jl:120
(::ChainRulesCore.ProjectTo{AbstractArray})(::ChainRulesCore.Tangent)
@ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:200
...
Stacktrace:
[1] (::ChainRules.var"#480#485"{ChainRulesCore.ProjectTo{…}, Tuple{…}, ChainRulesCore.Tangent{…}})()
@ ChainRules D:\Julia\Julia-1.10.4\packages\packages\ChainRules\hShjJ\src\rulesets\Base\array.jl:314
[2] unthunk
@ D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\tangent_types\thunks.jl:205 [inlined]
[3] unthunk(x::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#479#484"{…}})
@ ChainRulesCore D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\tangent_types\thunks.jl:238
[4] wrap_chainrules_output
@ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:110 [inlined]
[5] map
@ .\tuple.jl:293 [inlined]
[6] wrap_chainrules_output
@ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:111 [inlined]
[7] (::Zygote.ZBack{ChainRules.var"#vcat_pullback#481"{Tuple{…}, Tuple{…}, Val{…}}})(dy::NTuple{10, Vector{Float64}})
@ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:211
[8] #21
@ D:\Julia\Julia-1.10.4\packages\packages\Lux\PsW4M\src\helpers\compact.jl:0 [inlined]
[9] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Matrix{…}, Nothing})
@ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
[10] CompactLuxLayer
@ D:\Julia\Julia-1.10.4\packages\packages\Lux\PsW4M\src\helpers\compact.jl:366 [inlined]
[11] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Matrix{…}, Nothing})
@ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
[12] apply
@ D:\Julia\Julia-1.10.4\packages\packages\LuxCore\kYVM5\src\LuxCore.jl:171 [inlined]
[13] #23
@ e:\JlCode\HydroModels\temp\train_lstm_in_opt.jl:27 [inlined]
[14] object
@ e:\JlCode\HydroModels\temp\train_lstm_in_opt.jl:33 [inlined]
[15] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
[16] #291
@ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
[17] #2169#back
@ D:\Julia\Julia-1.10.4\packages\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
[18] OptimizationFunction
@ D:\Julia\Julia-1.10.4\packages\packages\SciMLBase\nftrI\src\scimlfunctions.jl:3812 [inlined]
[19] #291
@ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
[20] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.Pullback{…}}})(Δ::Float64)
@ Zygote D:\Julia\Julia-1.10.4\packages\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
[21] #37
@ D:\Julia\Julia-1.10.4\packages\packages\OptimizationBase\ni8lU\ext\OptimizationZygoteExt.jl:94 [inlined]
[22] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
[23] #291
@ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
[24] #2169#back
@ D:\Julia\Julia-1.10.4\packages\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
[25] #39
@ D:\Julia\Julia-1.10.4\packages\packages\OptimizationBase\ni8lU\ext\OptimizationZygoteExt.jl:97 [inlined]
[26] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
[27] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface.jl:91
[28] gradient(f::Function, args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}})
@ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface.jl:148
[29] (::OptimizationZygoteExt.var"#38#56"{…})(::ComponentVector{…}, ::ComponentVector{…})
@ OptimizationZygoteExt D:\Julia\Julia-1.10.4\packages\packages\OptimizationBase\ni8lU\ext\OptimizationZygoteExt.jl:97
[30] macro expansion
@ D:\Julia\Julia-1.10.4\packages\packages\OptimizationOptimisers\AOkbT\src\OptimizationOptimisers.jl:68 [inlined]
[31] macro expansion
@ D:\Julia\Julia-1.10.4\packages\packages\Optimization\fPKIF\src\utils.jl:32 [inlined]
[32] __solve(cache::OptimizationCache{…})
@ OptimizationOptimisers D:\Julia\Julia-1.10.4\packages\packages\OptimizationOptimisers\AOkbT\src\OptimizationOptimisers.jl:66
[33] solve!(cache::OptimizationCache{…})
@ SciMLBase D:\Julia\Julia-1.10.4\packages\packages\SciMLBase\nftrI\src\solve.jl:188
[34] solve(::OptimizationProblem{…}, ::Adam; kwargs::@Kwargs{…})
@ SciMLBase D:\Julia\Julia-1.10.4\packages\packages\SciMLBase\nftrI\src\solve.jl:96
[35] top-level scope
@ REPL[2]:1
Some type information was truncated. Use `show(err)` to see complete types.
This issue seems to only occur with recurrent neural networks like LSTM, but not with regular fully connected neural networks. So I want to ask if there’s a way to optimize Lux.jl’s LSTMCell and other RNN models using Optimization.jl