ReverseDiff with Flux model

Hi all, I am trying to get a gradient of a Flux model with a different package than Zygote or ForwardDiff (Zygote cannot get through my real loss function and ForwardDiff is too slow). Since my loss function returns scalar value, it occurs to me that ReverseDiff could be a way to go. Unfortunately, I am not able to get the gradient with it. MWE:

#create the Flux model
using Flux
model = Chain(Conv((3,3), 4=>2; pad=div(3, 2)), BatchNorm(2), x->relu.(x))
#model input
inp = rand(Float32, 7,7,4,1)
#get explicit parameters of the model
using Optimisers
θ, nn = Optimisers.destructure(model)
#define loss function
using Statistics
lossf(pars, net, inp, data) = mean((net(pars)(inp).-data).^2)
#generate data
D = rand(7,7,2,1)
#get the gradient
using ReverseDiff
g = ReverseDiff.gradient((pars)->lossf(pars, nn, inp, D), θ)

Executing this code returns error

ERROR: TaskFailedException

    nested task error: UndefRefError: access to undefined reference
    Stacktrace:
     [1] getindex
       @ ./array.jl:925 [inlined]
     [2] getindex
       @ ./subarray.jl:282 [inlined]
     [3] conv_direct!(y::SubArray{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5, Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, x::SubArray{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5, Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::SubArray{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5, Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, cdims::DenseConvDims{3, 3, 3, 6, 3}, ::Val{(3, 3, 1)}, ::Val{2}, ::Val{(1, 1, 1, 1, 0, 0)}, ::Val{(1, 1, 1)}, ::Val{(1, 1, 1)}, fk::Val{false}; alpha::ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, beta::Bool)
       @ NNlib ~/.julia/packages/NNlib/ydqxJ/src/impl/conv_direct.jl:111
     [4] conv_direct!(y::SubArray{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5, Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, x::SubArray{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5, Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::SubArray{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5, Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, cdims::DenseConvDims{3, 3, 3, 6, 3}; alpha::ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, beta::Bool)
       @ NNlib ~/.julia/packages/NNlib/ydqxJ/src/impl/conv_direct.jl:50
     [5] conv_direct!
       @ ~/.julia/packages/NNlib/ydqxJ/src/impl/conv_direct.jl:47 [inlined]
     [6] (::NNlib.var"#308#312"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, DenseConvDims{3, 3, 3, 6, 3}, SubArray{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5, Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5, Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, SubArray{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5, Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})()
       @ NNlib ./threadingconstructs.jl:258
Stacktrace:
  [1] sync_end(c::Channel{Any})
    @ Base ./task.jl:436
  [2] macro expansion
    @ ./task.jl:455 [inlined]
  [3] conv!(out::Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, in1::Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, in2::Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, cdims::DenseConvDims{3, 3, 3, 6, 3}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/ydqxJ/src/conv.jl:205
  [4] conv!
    @ ~/.julia/packages/NNlib/ydqxJ/src/conv.jl:185 [inlined]
  [5] conv!(y::Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 4}, x::Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 4}, w::Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 4}, cdims::DenseConvDims{2, 2, 2, 4, 2}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/ydqxJ/src/conv.jl:145
  [6] conv!
    @ ~/.julia/packages/NNlib/ydqxJ/src/conv.jl:140 [inlined]
  [7] conv(x::Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 4}, w::Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 4}, cdims::DenseConvDims{2, 2, 2, 4, 2}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/ydqxJ/src/conv.jl:88
  [8] conv
    @ ~/.julia/packages/NNlib/ydqxJ/src/conv.jl:83 [inlined]
  [9] (::Conv{2, 4, typeof(identity), Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 4}, Vector{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}}})(x::Array{Float32, 4})
    @ Flux ~/.julia/packages/Flux/uCLgc/src/layers/conv.jl:202
 [10] macro expansion
    @ ~/.julia/packages/Flux/uCLgc/src/layers/basic.jl:53 [inlined]
 [11] _applychain(layers::Tuple{Conv{2, 4, typeof(identity), Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 4}, Vector{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}}}, BatchNorm{typeof(identity), Vector{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}}, Float32, Vector{Float32}}, var"#29#30"}, x::Array{Float32, 4})
    @ Flux ~/.julia/packages/Flux/uCLgc/src/layers/basic.jl:53
 [12] (::Chain{Tuple{Conv{2, 4, typeof(identity), Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 4}, Vector{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}}}, BatchNorm{typeof(identity), Vector{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}}, Float32, Vector{Float32}}, var"#29#30"}})(x::Array{Float32, 4})
    @ Flux ~/.julia/packages/Flux/uCLgc/src/layers/basic.jl:51
 [13] lossf(pars::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, net::Optimisers.Restructure{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, var"#29#30"}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:σ, :weight, :bias, :stride, :pad, :dilation, :groups), Tuple{Tuple{}, Int64, Int64, Tuple{Tuple{}, Tuple{}}, NTuple{4, Tuple{}}, Tuple{Tuple{}, Tuple{}}, Tuple{}}}, NamedTuple{(:λ, :β, :γ, :μ, :σ², :ϵ, :momentum, :affine, :track_stats, :active, :chs), Tuple{Tuple{}, Int64, Int64, Tuple{}, Tuple{}, Tuple{}, Tuple{}, Tuple{}, Tuple{}, Tuple{}, Tuple{}}}, Tuple{}}}}}, inp::Array{Float32, 4}, data::Array{Float64, 4})
    @ Main ~/Documents/blindeblur/RegularizedSelfDeblur/src/UnetR2.jl:287
 [14] (::var"#41#42")(pars::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}})
    @ Main ~/Documents/blindeblur/RegularizedSelfDeblur/src/UnetR2.jl:290
 [15] ReverseDiff.GradientTape(f::var"#41#42", input::Vector{Float32}, cfg::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/YkVxM/src/api/tape.jl:199
 [16] gradient(f::Function, input::Vector{Float32}, cfg::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}) (repeats 2 times)
    @ ReverseDiff ~/.julia/packages/ReverseDiff/YkVxM/src/api/gradients.jl:22
 [17] top-level scope
    @ ~/Documents/blindeblur/RegularizedSelfDeblur/src/UnetR2.jl:290

Am I doing something wrong or is it just not possible to combine Flux model and ReverseDiff? I found a question where Flux.destructure was the cause of the problem, but a suggested workaround does not work anymore. If it is not possible, can you suggest any other (probably reverse mode) AD package that could be used with it?

Besides not working, anything which ends up with Array{TrackedReal} is likely to be quite slow.

It might be possible to use Tracker instead, which is integrated with NNlib (which is where conv lives) and understands structured gradients (no destructure). But not sure.

Fixing this may be easier. Can you make a small example of the problem?

Just to get a better understanding of the problem, where does the Array{TrackedReal} come from?

Thanks, I will give it a try.

I probably could. The problem is that I need to get a gradient of a function that takes a gradient of a different function (with respect to different parameters than the other gradient) and updates the parameters, which is done through Zygote and Optimisers. There is some try/catch statement inside the code that does not allow Zygote to get through with the gradient, so I was kind of hoping that I could get around it with different AD package.