Flux.destructure a Flux.Chain that takes an argument?

Calling destructure(function) is what won’t work. Maybe can’t: it has no way to know what this function returns, and even if it managed to see what arrays are closed over, it can’t make a re which constructs another instance of the function using different parameters.

For it to be able to look inside an object, that object has to be a struct whose type has been marked with @functor.

julia> Flux.destructure(cust_chain)  # empty
(Bool[], Restructure(#cust_chain, ..., 0))

julia> methods(typeof(cust_chain)) # this is why Functors cannot reconstruct this
# 0 methods for type constructor

julia> struct Maker; pre; post; end

julia> (m::Maker)(v) = Flux.Chain(m.pre, Flux.Scale(v, false), m.post)

julia> m = Maker(Flux.Dense(10 => 10, tanh), Flux.Dense(10 => 1));

julia> m_im = m(fill(im, 10))  # here Scale can be explored, hence 131
Chain(
  Dense(10 => 10, tanh),                # 110 parameters
  Scale(10; bias=false),                # 10 parameters
  Dense(10 => 1),                       # 11 parameters
)                   # Total: 5 arrays, 131 parameters, 848 bytes.

julia> m_im(rand(10))
1-element Vector{ComplexF64}:
 0.0 - 0.47595892328291134im

julia> Flux.destructure(m)  # still empty
(Bool[], Restructure(Maker, ..., 0))
 
julia> Flux.@functor Maker  # only now will destructure be able to look inside

julia> p, re = Flux.destructure(m);  # 121 parameters, v does not exist yet

julia> re(rand(121))
Maker(Dense(10 => 10, tanh), Dense(10 => 1))

julia> re(ones(121))(pi)(ones(10,3))
1×3 Matrix{Float64}:
 32.4159  32.4159  32.4159

julia> (m::Maker)(v, x) = Flux.Chain(m.pre, x -> x.*v, m.post)(x)  # two steps in one, tidier?

julia> re(ones(121))(pi, ones(10,3))
1×3 Matrix{Float64}:
 32.4159  32.4159  32.4159
2 Likes