Hi All, I am struggling a bit with using GalacticOptim.AutoReverseDiff to find the gradients over a loss function which internally calculates a gradient using Zygote. Running ForwardDiff over the function works fine but is quite slow. I have created a mwe as seen below along with the error message. Any help would be appreciated to get around this or find an alternative approach.
using Zygote
using GalacticOptim
using Flux
using LinearAlgebra
using Statistics
using StatsBase
l1 = Dense(2, 10, tanh)
l2 = Dense(10, 4, identity)
network = Chain(l1, l2)
pars = Flux.params(network)
vec_pars, re = Flux.destructure(network)
function loss(x, p)
network = re(x)
u = Zygote.gradient(d -> sum(network(d)), p)[1]
sum(u)
end
data = rand(2, 1000)
loss(vec_pars, data)
adtype = GalacticOptim.AutoReverseDiff()
optf = GalacticOptim.OptimizationFunction(loss, adtype)
optprob = GalacticOptim.OptimizationProblem(optf, vec_pars, data)
result_ad = GalacticOptim.solve(optprob, ADAM(0.001), maxiters=1)
DimensionMismatch("array could not be broadcast to match destination")
check_broadcast_shape at broadcast.jl:520 [inlined]
check_broadcast_axes at broadcast.jl:523 [inlined]
check_broadcast_axes at broadcast.jl:527 [inlined]
instantiate at broadcast.jl:269 [inlined]
materialize! at broadcast.jl:894 [inlined]
materialize! at broadcast.jl:891 [inlined]
apply!(o::ADAM, x::Vector{Float32}, Δ::Vector{Float64}) at optimisers.jl:179
update!(opt::ADAM, x::Vector{Float32}, x̄::Vector{Float64}) at train.jl:23
update!(opt::ADAM, xs::Params, gs::Zygote.Grads) at train.jl:29
__solve(prob::OptimizationProblem{true, OptimizationFunction{true, GalacticOptim.AutoReverseDiff, typeof(loss), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Vector{Float32}, Matrix{Float64}, Nothing, Nothing, Nothing, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, opt::ADAM, data::Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) at solve.jl:106
__solve at solve.jl:66 [inlined]
__solve at solve.jl:66 [inlined]
#solve#468 at solve.jl:3 [inlined]
(::CommonSolve.var"#solve##kw")(::NamedTuple{(:maxiters,), Tuple{Int64}}, ::typeof(solve), ::OptimizationProblem{true, OptimizationFunction{true, GalacticOptim.AutoReverseDiff, typeof(loss), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Vector{Float32}, Matrix{Float64}, Nothing, Nothing, Nothing, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::ADAM) at solve.jl:3
top-level scope at MWE.jl:27
eval at boot.jl:360 [inlined]