Flux.update! not working with custom AbstractArray

When building a model using custom structs that are subtypes of AbstractArray, Flux.update! cannot update the model. For my use case, it is particularly convenient to do computations using these custom structs.

Is there a simple way to get this to work such that I can maintain the functionality that comes from the AbstractArray supertype?

Here is a minimal example.

using Flux
struct PositiveVector{T} <: AbstractArray{T,1}
(v::PositiveVector)() = exp.(v.x)
Base.size(v::PositiveVector) = size(v.x)
Base.IndexStyle(::Type{<:PositiveVector}) = IndexLinear()
Base.getindex(v::PositiveVector,i::Int) = exp(v.x[i])
@Flux.functor PositiveVector (x,)

m = PositiveVector([.1,.2])

loss(m) = sum(([1.,2.] - m) .^ 2.)
opt_state = Flux.setup(Adam(), m)
grads = Flux.gradient(loss,m)

When running this code in Julia v1.9.4 and Flux v0.14.7, this returns:

ERROR: type Array has no field x
 [1] getproperty
   @ ./Base.jl:37 [inlined]
 [2] functor(#unused#::Type{PositiveVector{Float64}}, x::Vector{Float64})
   @ Main ~/.julia/packages/Functors/rlD70/src/functor.jl:38
 [3] (::Optimisers.var"#13#15"{PositiveVector{Float64}})(x̄::Vector{Float64})
   @ Optimisers ~/.julia/packages/Optimisers/NnLqJ/src/interface.jl:116
 [4] map
   @ ./tuple.jl:273 [inlined]
 [5] _grads!(dict::IdDict{Optimisers.Leaf, Any}, tree::NamedTuple{(:x,), Tuple{Optimisers.Leaf{Optimisers.Adam, Tuple{Vector{Float64}, Vector{Float64}, Tuple{Float64, Float64}}}}}, x::PositiveVector{Float64}, x̄s::Vector{Float64})
   @ Optimisers ~/.julia/packages/Optimisers/NnLqJ/src/interface.jl:116
 [6] update!(::NamedTuple{(:x,), Tuple{Optimisers.Leaf{Optimisers.Adam, Tuple{Vector{Float64}, Vector{Float64}, Tuple{Float64, Float64}}}}}, ::PositiveVector{Float64}, ::Vector{Float64})
   @ Optimisers ~/.julia/packages/Optimisers/NnLqJ/src/interface.jl:74
 [7] top-level scope

If I don’t use the AbstractArray functionality and call the PositiveVector in the loss function (so it maps to a Vector{T} before computing the loss), then things work:

loss(m) = sum(([1.,2.] - m()) .^ 2.)

However, then I am effectively losing the functionality of defining a custom type of AbstractArray.

This probably isn’t a great idea, but it can be made to work. The problem is a mismatch between how Zygote.jl and really Functors.jl think about this type:

julia> loss(m) = sum(([1.,2.] - m) .^ 2.0)  # relies on m::AbstractArray
loss (generic function with 1 method)

julia> opt_state = Flux.setup(Adam(), m)  # sees field, due to @functor PositiveVector
(x = Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), ([0.0, 0.0], [0.0, 0.0], (0.9, 0.999))),)

julia> grads = Zygote.gradient(loss,m)  # Zygote regards AbstractArray{Float64} as primal, "natural" gradient
([0.21034183615129542, -1.5571944836796603],)

julia> Flux.update!(opt_state,m,grads[1])  # structures don't match
ERROR: type Array has no field x

Here update! expects a “structural” gradient like (x = [...],) as it tries to recursively explore opt_state, but it gets a “natural” one, just an array.

The same problem might arise with wrappers like Adjoint, but in fact this works:

julia> opt_state = Flux.setup(Adam(), [1.0 2.0]')
(parent = Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), ([0.0 0.0], [0.0 0.0], (0.9, 0.999))),)

julia> grads = Flux.gradient(loss,  [1.0 2.0]')
([-0.0; -0.0;;],)

julia> Flux.update!(opt_state, [1.0 2.0]', grads[1])
((parent = Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), ([0.0 0.0], [0.0 0.0], (0.81, 0.998001))),), [1.0; 2.0;;])

This is because these definitions alter how the recursive walk used by update! works. They apply adjoint to the “natural” gradient, to convert the gradient for y=[1.0 2.0]' to the one for y.parent. You could write similar rules for your type.

However, I think it will be simpler to do things like m() (and perhaps don’t subtype AbstractArray at all). For instance, I think this will work with GPU arrays, while the approach of defining getindex for your type will not.

If the idea is to use these arrays inside existing layers, then a third approach would be to move the positivity constraint to an Optimisers.jl rule, which you compose with Adam a bit like ClipGrad.

1 Like