When building a model using custom structs that are subtypes of AbstractArray, Flux.update! cannot update the model. For my use case, it is particularly convenient to do computations using these custom structs.
Is there a simple way to get this to work such that I can maintain the functionality that comes from the AbstractArray supertype?
Here is a minimal example.
using Flux
struct PositiveVector{T} <: AbstractArray{T,1}
x::Vector{T}
end
(v::PositiveVector)() = exp.(v.x)
Base.size(v::PositiveVector) = size(v.x)
Base.IndexStyle(::Type{<:PositiveVector}) = IndexLinear()
Base.getindex(v::PositiveVector,i::Int) = exp(v.x[i])
@Flux.functor PositiveVector (x,)
m = PositiveVector([.1,.2])
loss(m) = sum(([1.,2.] - m) .^ 2.)
opt_state = Flux.setup(Adam(), m)
grads = Flux.gradient(loss,m)
Flux.update!(opt_state,m,grads[1])
When running this code in Julia v1.9.4 and Flux v0.14.7, this returns:
ERROR: type Array has no field x
Stacktrace:
[1] getproperty
@ ./Base.jl:37 [inlined]
[2] functor(#unused#::Type{PositiveVector{Float64}}, x::Vector{Float64})
@ Main ~/.julia/packages/Functors/rlD70/src/functor.jl:38
[3] (::Optimisers.var"#13#15"{PositiveVector{Float64}})(x̄::Vector{Float64})
@ Optimisers ~/.julia/packages/Optimisers/NnLqJ/src/interface.jl:116
[4] map
@ ./tuple.jl:273 [inlined]
[5] _grads!(dict::IdDict{Optimisers.Leaf, Any}, tree::NamedTuple{(:x,), Tuple{Optimisers.Leaf{Optimisers.Adam, Tuple{Vector{Float64}, Vector{Float64}, Tuple{Float64, Float64}}}}}, x::PositiveVector{Float64}, x̄s::Vector{Float64})
@ Optimisers ~/.julia/packages/Optimisers/NnLqJ/src/interface.jl:116
[6] update!(::NamedTuple{(:x,), Tuple{Optimisers.Leaf{Optimisers.Adam, Tuple{Vector{Float64}, Vector{Float64}, Tuple{Float64, Float64}}}}}, ::PositiveVector{Float64}, ::Vector{Float64})
@ Optimisers ~/.julia/packages/Optimisers/NnLqJ/src/interface.jl:74
[7] top-level scope
If I don’t use the AbstractArray functionality and call the PositiveVector in the loss function (so it maps to a Vector{T} before computing the loss), then things work:
loss(m) = sum(([1.,2.] - m()) .^ 2.)
However, then I am effectively losing the functionality of defining a custom type of AbstractArray.