ReverseDiff over Zygote to differentiate neural network using GalacticOptim - Dimension mismatch

Hi All, I am struggling a bit with using GalacticOptim.AutoReverseDiff to find the gradients over a loss function which internally calculates a gradient using Zygote. Running ForwardDiff over the function works fine but is quite slow. I have created a mwe as seen below along with the error message. Any help would be appreciated to get around this or find an alternative approach.

using Zygote
using GalacticOptim
using Flux
using LinearAlgebra
using Statistics
using StatsBase

l1 = Dense(2, 10, tanh)
l2 = Dense(10, 4, identity)
network = Chain(l1, l2)
pars = Flux.params(network)
vec_pars, re = Flux.destructure(network)

function loss(x, p)
    network = re(x)
    u = Zygote.gradient(d -> sum(network(d)), p)[1]
    sum(u)
end

data = rand(2, 1000)
loss(vec_pars, data)

adtype = GalacticOptim.AutoReverseDiff()
optf = GalacticOptim.OptimizationFunction(loss, adtype)
optprob = GalacticOptim.OptimizationProblem(optf, vec_pars, data)

result_ad = GalacticOptim.solve(optprob, ADAM(0.001), maxiters=1)
DimensionMismatch("array could not be broadcast to match destination")
check_broadcast_shape at broadcast.jl:520 [inlined]
check_broadcast_axes at broadcast.jl:523 [inlined]
check_broadcast_axes at broadcast.jl:527 [inlined]
instantiate at broadcast.jl:269 [inlined]
materialize! at broadcast.jl:894 [inlined]
materialize! at broadcast.jl:891 [inlined]
apply!(o::ADAM, x::Vector{Float32}, Δ::Vector{Float64}) at optimisers.jl:179
update!(opt::ADAM, x::Vector{Float32}, x̄::Vector{Float64}) at train.jl:23
update!(opt::ADAM, xs::Params, gs::Zygote.Grads) at train.jl:29
__solve(prob::OptimizationProblem{true, OptimizationFunction{true, GalacticOptim.AutoReverseDiff, typeof(loss), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Vector{Float32}, Matrix{Float64}, Nothing, Nothing, Nothing, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, opt::ADAM, data::Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) at solve.jl:106
__solve at solve.jl:66 [inlined]
__solve at solve.jl:66 [inlined]
#solve#468 at solve.jl:3 [inlined]
(::CommonSolve.var"#solve##kw")(::NamedTuple{(:maxiters,), Tuple{Int64}}, ::typeof(solve), ::OptimizationProblem{true, OptimizationFunction{true, GalacticOptim.AutoReverseDiff, typeof(loss), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Vector{Float32}, Matrix{Float64}, Nothing, Nothing, Nothing, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::ADAM) at solve.jl:3
top-level scope at MWE.jl:27
eval at boot.jl:360 [inlined]

Open an issue on GalacticOptim.jl. I’ll want to dig into this and having it in my email will make it easier to track. @dhairyagandhi96 and @mohamed82008 this should be a fun one.

Will do @ChrisRackauckas, thanks. :+1: