Hi all, I am trying to get a gradient of a Flux model with a different package than Zygote or ForwardDiff (Zygote cannot get through my real loss function and ForwardDiff is too slow). Since my loss function returns scalar value, it occurs to me that ReverseDiff could be a way to go. Unfortunately, I am not able to get the gradient with it. MWE:
#create the Flux model
using Flux
model = Chain(Conv((3,3), 4=>2; pad=div(3, 2)), BatchNorm(2), x->relu.(x))
#model input
inp = rand(Float32, 7,7,4,1)
#get explicit parameters of the model
using Optimisers
θ, nn = Optimisers.destructure(model)
#define loss function
using Statistics
lossf(pars, net, inp, data) = mean((net(pars)(inp).-data).^2)
#generate data
D = rand(7,7,2,1)
#get the gradient
using ReverseDiff
g = ReverseDiff.gradient((pars)->lossf(pars, nn, inp, D), θ)
Executing this code returns error
ERROR: TaskFailedException
nested task error: UndefRefError: access to undefined reference
Stacktrace:
[1] getindex
@ ./array.jl:925 [inlined]
[2] getindex
@ ./subarray.jl:282 [inlined]
[3] conv_direct!(y::SubArray{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5, Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, x::SubArray{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5, Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::SubArray{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5, Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, cdims::DenseConvDims{3, 3, 3, 6, 3}, ::Val{(3, 3, 1)}, ::Val{2}, ::Val{(1, 1, 1, 1, 0, 0)}, ::Val{(1, 1, 1)}, ::Val{(1, 1, 1)}, fk::Val{false}; alpha::ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, beta::Bool)
@ NNlib ~/.julia/packages/NNlib/ydqxJ/src/impl/conv_direct.jl:111
[4] conv_direct!(y::SubArray{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5, Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, x::SubArray{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5, Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::SubArray{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5, Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, cdims::DenseConvDims{3, 3, 3, 6, 3}; alpha::ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, beta::Bool)
@ NNlib ~/.julia/packages/NNlib/ydqxJ/src/impl/conv_direct.jl:50
[5] conv_direct!
@ ~/.julia/packages/NNlib/ydqxJ/src/impl/conv_direct.jl:47 [inlined]
[6] (::NNlib.var"#308#312"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, DenseConvDims{3, 3, 3, 6, 3}, SubArray{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5, Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5, Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, SubArray{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5, Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})()
@ NNlib ./threadingconstructs.jl:258
Stacktrace:
[1] sync_end(c::Channel{Any})
@ Base ./task.jl:436
[2] macro expansion
@ ./task.jl:455 [inlined]
[3] conv!(out::Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, in1::Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, in2::Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 5}, cdims::DenseConvDims{3, 3, 3, 6, 3}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ NNlib ~/.julia/packages/NNlib/ydqxJ/src/conv.jl:205
[4] conv!
@ ~/.julia/packages/NNlib/ydqxJ/src/conv.jl:185 [inlined]
[5] conv!(y::Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 4}, x::Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 4}, w::Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 4}, cdims::DenseConvDims{2, 2, 2, 4, 2}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ NNlib ~/.julia/packages/NNlib/ydqxJ/src/conv.jl:145
[6] conv!
@ ~/.julia/packages/NNlib/ydqxJ/src/conv.jl:140 [inlined]
[7] conv(x::Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 4}, w::Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 4}, cdims::DenseConvDims{2, 2, 2, 4, 2}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ NNlib ~/.julia/packages/NNlib/ydqxJ/src/conv.jl:88
[8] conv
@ ~/.julia/packages/NNlib/ydqxJ/src/conv.jl:83 [inlined]
[9] (::Conv{2, 4, typeof(identity), Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 4}, Vector{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}}})(x::Array{Float32, 4})
@ Flux ~/.julia/packages/Flux/uCLgc/src/layers/conv.jl:202
[10] macro expansion
@ ~/.julia/packages/Flux/uCLgc/src/layers/basic.jl:53 [inlined]
[11] _applychain(layers::Tuple{Conv{2, 4, typeof(identity), Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 4}, Vector{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}}}, BatchNorm{typeof(identity), Vector{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}}, Float32, Vector{Float32}}, var"#29#30"}, x::Array{Float32, 4})
@ Flux ~/.julia/packages/Flux/uCLgc/src/layers/basic.jl:53
[12] (::Chain{Tuple{Conv{2, 4, typeof(identity), Array{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 4, Array{Float32, 4}, Array{Float32, 4}}}, 4}, Vector{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}}}, BatchNorm{typeof(identity), Vector{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}}, Float32, Vector{Float32}}, var"#29#30"}})(x::Array{Float32, 4})
@ Flux ~/.julia/packages/Flux/uCLgc/src/layers/basic.jl:51
[13] lossf(pars::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, net::Optimisers.Restructure{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, var"#29#30"}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:σ, :weight, :bias, :stride, :pad, :dilation, :groups), Tuple{Tuple{}, Int64, Int64, Tuple{Tuple{}, Tuple{}}, NTuple{4, Tuple{}}, Tuple{Tuple{}, Tuple{}}, Tuple{}}}, NamedTuple{(:λ, :β, :γ, :μ, :σ², :ϵ, :momentum, :affine, :track_stats, :active, :chs), Tuple{Tuple{}, Int64, Int64, Tuple{}, Tuple{}, Tuple{}, Tuple{}, Tuple{}, Tuple{}, Tuple{}, Tuple{}}}, Tuple{}}}}}, inp::Array{Float32, 4}, data::Array{Float64, 4})
@ Main ~/Documents/blindeblur/RegularizedSelfDeblur/src/UnetR2.jl:287
[14] (::var"#41#42")(pars::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}})
@ Main ~/Documents/blindeblur/RegularizedSelfDeblur/src/UnetR2.jl:290
[15] ReverseDiff.GradientTape(f::var"#41#42", input::Vector{Float32}, cfg::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}})
@ ReverseDiff ~/.julia/packages/ReverseDiff/YkVxM/src/api/tape.jl:199
[16] gradient(f::Function, input::Vector{Float32}, cfg::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}) (repeats 2 times)
@ ReverseDiff ~/.julia/packages/ReverseDiff/YkVxM/src/api/gradients.jl:22
[17] top-level scope
@ ~/Documents/blindeblur/RegularizedSelfDeblur/src/UnetR2.jl:290
Am I doing something wrong or is it just not possible to combine Flux model and ReverseDiff? I found a question where Flux.destructure
was the cause of the problem, but a suggested workaround does not work anymore. If it is not possible, can you suggest any other (probably reverse mode) AD package that could be used with it?