Error in trying to use Optimization.jl for LSTM training based on Lux.jl

Hello, I want to try using Optimization.jl to perform model optimization based on Lux.jl. Here’s my code.

using Lux
using Zygote
using StableRNGs
using ComponentArrays
using Optimization
using OptimizationOptimisers

function LSTMCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T,2} where {T}
        x = reshape(x, size(x)..., 1)
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        output = [vec(classifier(y))]
        for x in x_rest
            y, carry = lstm_cell((x, carry))
            output = vcat(output, [vec(classifier(y))])
        end
        @return hcat(output...)
    end
end

model = LSTMCompact(3, 10, 1)
ps, st = Lux.setup(StableRNGs.LehmerRNG(1234), model)
ps_axes = getaxes(ComponentVector(ps))
model_func = (x, ps) -> Lux.apply(model, x, ps, st)
x = rand(3, 10)
y = rand(1, 10)

function object(u, p)
    ps = ComponentVector(u, ps_axes)
    sum((model_func(x, ps)[1] .- y) .^ 2)
end

opt_func = Optimization.OptimizationFunction(object, Optimization.AutoZygote())
opt_prob = Optimization.OptimizationProblem(opt_func, Vector(ComponentVector(ps)))
opt_sol = Optimization.solve(opt_prob, OptimizationOptimisers.Adam(0.1), maxiters=1000)

Translation: The code works when using AutoForwardDiff as the AD type, but when using AutoZygote it encounters the following error:

ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{…}})(::NTuple{9, Vector{…}})

Closest candidates are:
  (::ChainRulesCore.ProjectTo{T})(::ChainRulesCore.NotImplemented) where T
   @ ChainRulesCore D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\projection.jl:121
  (::ChainRulesCore.ProjectTo{T})(::ChainRulesCore.AbstractZero) where T
   @ ChainRulesCore D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\projection.jl:120
  (::ChainRulesCore.ProjectTo{AbstractArray})(::ChainRulesCore.Tangent)
   @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:200
  ...

Stacktrace:
  [1] (::ChainRules.var"#480#485"{ChainRulesCore.ProjectTo{…}, Tuple{…}, ChainRulesCore.Tangent{…}})()
    @ ChainRules D:\Julia\Julia-1.10.4\packages\packages\ChainRules\hShjJ\src\rulesets\Base\array.jl:314
  [2] unthunk
    @ D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\tangent_types\thunks.jl:205 [inlined]
  [3] unthunk(x::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#479#484"{…}})
    @ ChainRulesCore D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\tangent_types\thunks.jl:238
  [4] wrap_chainrules_output
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:110 [inlined]
  [5] map
    @ .\tuple.jl:293 [inlined]
  [6] wrap_chainrules_output
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:111 [inlined]
  [7] (::Zygote.ZBack{ChainRules.var"#vcat_pullback#481"{Tuple{…}, Tuple{…}, Val{…}}})(dy::NTuple{10, Vector{Float64}})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:211
  [8] #21
    @ D:\Julia\Julia-1.10.4\packages\packages\Lux\PsW4M\src\helpers\compact.jl:0 [inlined]
  [9] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Matrix{…}, Nothing})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [10] CompactLuxLayer
    @ D:\Julia\Julia-1.10.4\packages\packages\Lux\PsW4M\src\helpers\compact.jl:366 [inlined]
 [11] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Matrix{…}, Nothing})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [12] apply
    @ D:\Julia\Julia-1.10.4\packages\packages\LuxCore\kYVM5\src\LuxCore.jl:171 [inlined]
 [13] #23
    @ e:\JlCode\HydroModels\temp\train_lstm_in_opt.jl:27 [inlined]
 [14] object
    @ e:\JlCode\HydroModels\temp\train_lstm_in_opt.jl:33 [inlined]
 [15] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [16] #291
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
 [17] #2169#back
    @ D:\Julia\Julia-1.10.4\packages\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
 [18] OptimizationFunction
    @ D:\Julia\Julia-1.10.4\packages\packages\SciMLBase\nftrI\src\scimlfunctions.jl:3812 [inlined]
 [19] #291
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
 [20] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.Pullback{…}}})(Δ::Float64)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
 [21] #37
    @ D:\Julia\Julia-1.10.4\packages\packages\OptimizationBase\ni8lU\ext\OptimizationZygoteExt.jl:94 [inlined]
 [22] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [23] #291
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
 [24] #2169#back
    @ D:\Julia\Julia-1.10.4\packages\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
 [25] #39
    @ D:\Julia\Julia-1.10.4\packages\packages\OptimizationBase\ni8lU\ext\OptimizationZygoteExt.jl:97 [inlined]
 [26] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [27] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface.jl:91
 [28] gradient(f::Function, args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface.jl:148
 [29] (::OptimizationZygoteExt.var"#38#56"{…})(::ComponentVector{…}, ::ComponentVector{…})
    @ OptimizationZygoteExt D:\Julia\Julia-1.10.4\packages\packages\OptimizationBase\ni8lU\ext\OptimizationZygoteExt.jl:97
 [30] macro expansion
    @ D:\Julia\Julia-1.10.4\packages\packages\OptimizationOptimisers\AOkbT\src\OptimizationOptimisers.jl:68 [inlined]
 [31] macro expansion
    @ D:\Julia\Julia-1.10.4\packages\packages\Optimization\fPKIF\src\utils.jl:32 [inlined]
 [32] __solve(cache::OptimizationCache{…})
    @ OptimizationOptimisers D:\Julia\Julia-1.10.4\packages\packages\OptimizationOptimisers\AOkbT\src\OptimizationOptimisers.jl:66
 [33] solve!(cache::OptimizationCache{…})
    @ SciMLBase D:\Julia\Julia-1.10.4\packages\packages\SciMLBase\nftrI\src\solve.jl:188
 [34] solve(::OptimizationProblem{…}, ::Adam; kwargs::@Kwargs{…})
    @ SciMLBase D:\Julia\Julia-1.10.4\packages\packages\SciMLBase\nftrI\src\solve.jl:96
 [35] top-level scope
    @ REPL[2]:1
Some type information was truncated. Use `show(err)` to see complete types.

This issue seems to only occur with recurrent neural networks like LSTM, but not with regular fully connected neural networks. So I want to ask if there’s a way to optimize Lux.jl’s LSTMCell and other RNN models using Optimization.jl

This might be for @avikpal

Resolved earlier today Error in trying to use Optimization.jl for LSTM training based on Lux.jl · Issue #1114 · LuxDL/Lux.jl · GitHub