Zygote and StructArrays

It seems Flux.@functor does not get along with StructArrays?

using StructArrays, Flux

struct A{SA}
	s::SA
end

struct Point
	x::Float64
	y::Float64
end

Flux.@functor A
Flux.@functor Point

a = A(StructArray((Point(i,2i) for i=0.0:10.0)))
Flux.params(a) # Params([])

I was expecting Flux.params(a) to detect the two parameter vectors a.s.x and a.s.y contained in the StructArray.

Any suggestions on how I can make this work?

This is a very interesting question.

I just made some tests and found that the problem is the type constraint in the following line in Flux

If we remove the <:Number and define a function like this:

julia> Flux.params!(p::Flux.Params, x::AbstractArray, seen = IdSet()) = push!(p, x)

Then we can get:

julia> Flux.params(a)
Params([Point[Point(0.0, 0.0), Point(1.0, 2.0), Point(2.0, 4.0), Point(3.0, 6.0), Point(4.0, 8.0), Point(5.0, 10.0), Point(6.0, 12.0), Point(7.0, 14.0), Point(8.0, 16.0), Point(9.0, 18.0), Point(10.0, 20.0)]])

If this is what you want, maybe create an issue in Flux.jl?

It doesn’t seem to work very well.

using StructArrays, Flux, Zygote

Flux.params!(p::Flux.Params, x::AbstractArray, seen = Zygote.IdSet()) = push!(p, x)

struct A{SA}
	s::SA
end

struct Point
	x::Float64
	y::Float64
end

Flux.@functor A
Flux.@functor Point

a = A(StructArray((Point(i,2i) for i=0.0:10.0)))

function fun(a::A)
	sum(a.s.x) + 2sum(a.s.y)
end

ps = Flux.params(a)
gs = gradient(ps) do
	fun(a)
end

julia> gs[a.s]
(fieldarrays = nothing, y = [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0])

For some reason a.s.x is being ignored.

Now I see. Sorry that I’m not very familiar with StructArray before.

You have defined the following functor:

Flux.@functor A
Flux.@functor Point

But the one for StructArray is missing. You can define it like this:

StructArray((Point(i,2i) for i=0.0:10.0))
b = StructArray((Point(i,2i) for i=0.0:10.0))
Flux.@functor typeof(b) (x, y)
a = A(b)

Here, we should be able to define a general functor function for SubArrays. You may take it as an exercise. :laughing:

1 Like

No way to define a generic functor for StructArray, containing any data struct?

I did the following

function Flux.trainable(s::StructArray)
    getproperty.(Ref(s), propertynames(s))
end

In this case I get to populate the params correctly, but I still get an error when trying to differentiate. Here is an example:

using Flux, StructArrays
struct Point
    x::Float64
    y::Float64
end
pnts = StructArray([Point(randn(), randn()) for i = 1:10, j = 1:10])
fun(p::Point, v::Real) = (p.x^2 + p.y^2) * v
V = randn(10,10,10)
ps = Flux.params(pnts)
gs = gradient(ps) do
    sum(fun.(pnts, V))
end

This results in an error.

ERROR: MethodError: no method matching reducedim_init(::typeof(identity), ::typeof(Zygote.accum), ::Array{NamedTuple{(:x, :y),Tuple{Float64,Float64}},3}, ::Tuple{Int64,Int64,Int64})
Closest candidates are:
  reducedim_init(::Any, ::Union{typeof(+), typeof(Base.add_sum)}, ::AbstractArray, ::Any) at reducedim.jl:109
  reducedim_init(::Any, ::Union{typeof(*), typeof(Base.mul_prod)}, ::AbstractArray, ::Any) at reducedim.jl:112
  reducedim_init(::Any, ::typeof(min), ::AbstractArray, ::Any) at reducedim.jl:132
  ...
Stacktrace:
 [1] _mapreduce_dim(::Function, ::Function, ::NamedTuple{(),Tuple{}}, ::Array{NamedTuple{(:x, :y),Tuple{Float64,Float64}},3}, ::Tuple{Int64,Int64,Int64}) at ./reducedim.jl:317
 [2] #mapreduce#580 at ./reducedim.jl:307 [inlined]
 [3] #reduce#582 at ./reducedim.jl:352 [inlined]
 [4] accum_sum(::Array{NamedTuple{(:x, :y),Tuple{Float64,Float64}},3}; dims::Tuple{Int64,Int64,Int64}) at /home/cossio/.julia/packages/Zygote/4tJp5/src/lib/broadcast.jl:36
 [5] unbroadcast at /home/cossio/.julia/packages/Zygote/4tJp5/src/lib/broadcast.jl:48 [inlined]
 [6] map(::typeof(Zygote.unbroadcast), ::Tuple{StructArray{Point,2,NamedTuple{(:x, :y),Tuple{Array{Float64,2},Array{Float64,2}}},Int64},Array{Float64,3}}, ::Tuple{Array{NamedTuple{(:x, :y),Tuple{Float64,Float64}},3},Array{Float64,3}}) at ./tuple.jl:177
 [7] (::Zygote.var"#1734#1741"{Tuple{StructArray{Point,2,NamedTuple{(:x, :y),Tuple{Array{Float64,2},Array{Float64,2}}},Int64},Array{Float64,3}},Val{3},Array{typeof(∂(fun)),3}})(::FillArrays.Fill{Float64,3,Tuple{Base.OneTo{Int64},Base.OneTo{Int64},Base.OneTo{Int64}}}) at /home/cossio/.julia/packages/Zygote/4tJp5/src/lib/broadcast.jl:139
 [8] #4367#back at /home/cossio/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [9] (::Zygote.var"#175#176"{Zygote.var"#4367#back#1745"{Zygote.var"#1734#1741"{Tuple{StructArray{Point,2,NamedTuple{(:x, :y),Tuple{Array{Float64,2},Array{Float64,2}}},Int64},Array{Float64,3}},Val{3},Array{typeof(∂(fun)),3}}},Tuple{NTuple{4,Nothing},Tuple{}}})(::FillArrays.Fill{Float64,3,Tuple{Base.OneTo{Int64},Base.OneTo{Int64},Base.OneTo{Int64}}}) at /home/cossio/.julia/packages/Zygote/4tJp5/src/lib/lib.jl:170
 [10] #344#back at /home/cossio/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [11] broadcasted at ./broadcast.jl:1238 [inlined]
 [12] #5 at ./REPL[9]:2 [inlined]
 [13] (::typeof(∂(#5)))(::Float64) at /home/cossio/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
 [14] (::Zygote.var"#48#49"{Zygote.Params,Zygote.Context,typeof(∂(#5))})(::Float64) at /home/cossio/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:108
 [15] gradient(::Function, ::Zygote.Params) at /home/cossio/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:45
 [16] top-level scope at REPL[9]:1

Issue: https://github.com/FluxML/Zygote.jl/issues/594

I have created a small package that helps solves some of these issues: https://github.com/cossio/ZygoteStructArrays.jl

2 Likes