Flux.destructure a Flux.Chain that takes an argument?

Hello all, I have a use case where I need to use some custom layers in a Flux.Chain that depend on the initial vector passed to the chain (which changes). I can construct the chain easily as per below (with some dummy layers just for illustrative purposes), but can’t Flux.destructure on the chain:

import Flux
n = 10
L1 = Flux.Dense(n, n, tanh)
L2 = Flux.Dense(n, 1, tanh)

cust_chain(v) = Flux.Chain(L1, x -> v .* x, L2)
### test the code
test_in  = rand(n,)
cust_chain(test_in)(test_in)
### so far so good

Flux.destructure(cust_chain)

ERROR: ArgumentError: reducing over an empty collection is not allowed                                                     │   _       _ _(_)_     |  Documentation: https:/
Stacktrace:                                                                                                                │  (_)     | (_) (_)    |
  [1] _empty_reduce_error()                                                                                                │   _ _   _| |_  __ _   |  Type "?" for help, "]?
    @ Base ./reduce.jl:301                                                                                                 │  | | | | | | |/ _` |  |
  [2] mapreduce_empty(f::Function, op::Function, T::Type)                                                                  │  | | |_| | | | (_| |  |  Version 1.7.1 (2021-12
    @ Base ./reduce.jl:344                                                                                                 │ _/ |\__'_|_|_|\__'_|  |  Official https://julia
  [3] reduce_empty(op::Base.MappingRF{typeof(eltype), typeof(promote_type)}, #unused#::Type{AbstractVector})               │|__/                   |
    @ Base ./reduce.jl:331                                                                                                 │
  [4] reduce_empty_iter                                                                                                    │julia> int(5.1)
    @ ./reduce.jl:357 [inlined]                                                                                            │ERROR: UndefVarError: int not defined
  [5] mapreduce_empty_iter(f::Function, op::Function, itr::Vector{AbstractVector}, ItrEltype::Base.HasEltype)              │Stacktrace:
    @ Base ./reduce.jl:353                                                                                                 │ [1] top-level scope
  [6] _mapreduce(f::typeof(eltype), op::typeof(promote_type), #unused#::IndexLinear, A::Vector{AbstractVector})            │   @ REPL[1]:1
    @ Base ./reduce.jl:402                                                                                                 │
  [7] _mapreduce_dim                                                                                                       │julia> integer(5.1)
    @ ./reducedim.jl:330 [inlined]                                                                                         │ERROR: UndefVarError: integer not defined
  [8] #mapreduce#725                                                                                                       │Stacktrace:
    @ ./reducedim.jl:322 [inlined]                                                                                         │ [1] top-level scope
  [9] mapreduce                                                                                                            │   @ REPL[2]:1
    @ ./reducedim.jl:322 [inlined]                                                                                         │
 [10] reduce(#unused#::typeof(vcat), A::Vector{AbstractVector})                                                            │julia> Int(5.1)
    @ Base ./abstractarray.jl:1621                                                                                         │ERROR: InexactError: Int64(5.1)
 [11] _flatten                                                                                                             │Stacktrace:
    @ ~/.julia/packages/Optimisers/cLLV1/src/destructure.jl:67 [inlined]                                                   │ [1] Int64(x::Float64)
 [12] destructure(x::Function)                                                                                             │   @ Base ./float.jl:812
    @ Optimisers ~/.julia/packages/Optimisers/cLLV1/src/destructure.jl:22                                                  │ [2] top-level scope
 [13] top-level scope                                                                                                      │   @ REPL[3]:1
    @ REPL[68]:1                                                                                                           │
 [14] top-level scope                                                                                                      │julia> Int64(5.1)
    @ ~/.julia/packages/CUDA/Uurn4/src/initialization.jl:52

Is there any way I can either bypass this and code it differently or destructure the chain in some other way? I want to use the chain and optimize its parameters via DiffEqFlux.sciml_train.

The error on an empty cases should be fixed on latest version. But it won’t see parameters in cust_chain, it needs structs which have been marked with @functor. Nor will it see the parameters closed over by x -> v .* x, thus Flux.destructure(Flux.Chain(L1, L2))[1] == Flux.destructure(cust_chain(rand(n)))[1].

1 Like

That’s okay, I don’t need it to to see parameters of x -> v .* x (this layer is intended to not have trainable parameters). The only things that have trainable parameters are L1 and L2. v is a vector that changes in a dynamical system, it’s not a parameter to be learned.

EDIT: Re: “The error on an empty cases should be fixed on latest version”, I am using v0.13.1 which from what I gather is the latest version. Do you mean the next update?

Ok. Then calling destructure (etc.) after calling the function which returns a Chain ought to do what you want, I think.

By latest I mean Optimisers v0.2.4, which ] up should fetch.

1 Like

Checked the version of Optimisers and it was indeed v0.2.4, but it still resulted in the ArgumentError shown in the original post. Just to check that we’re on the same page - I am wanting/trying to call p, re = Flux.destructure(cust_chain), not p, re = Flux.destructure(cust_chain(test_in)). What I would want to do then is something like re(p)(u, u) in an ODEProblem where u is the current state. The first calls Flux.destructure on a function and errors out, the second on a chain and works just fine (but doesn’t do what I want).

Calling destructure(function) is what won’t work. Maybe can’t: it has no way to know what this function returns, and even if it managed to see what arrays are closed over, it can’t make a re which constructs another instance of the function using different parameters.

For it to be able to look inside an object, that object has to be a struct whose type has been marked with @functor.

julia> Flux.destructure(cust_chain)  # empty
(Bool[], Restructure(#cust_chain, ..., 0))

julia> methods(typeof(cust_chain)) # this is why Functors cannot reconstruct this
# 0 methods for type constructor

julia> struct Maker; pre; post; end

julia> (m::Maker)(v) = Flux.Chain(m.pre, Flux.Scale(v, false), m.post)

julia> m = Maker(Flux.Dense(10 => 10, tanh), Flux.Dense(10 => 1));

julia> m_im = m(fill(im, 10))  # here Scale can be explored, hence 131
Chain(
  Dense(10 => 10, tanh),                # 110 parameters
  Scale(10; bias=false),                # 10 parameters
  Dense(10 => 1),                       # 11 parameters
)                   # Total: 5 arrays, 131 parameters, 848 bytes.

julia> m_im(rand(10))
1-element Vector{ComplexF64}:
 0.0 - 0.47595892328291134im

julia> Flux.destructure(m)  # still empty
(Bool[], Restructure(Maker, ..., 0))
 
julia> Flux.@functor Maker  # only now will destructure be able to look inside

julia> p, re = Flux.destructure(m);  # 121 parameters, v does not exist yet

julia> re(rand(121))
Maker(Dense(10 => 10, tanh), Dense(10 => 1))

julia> re(ones(121))(pi)(ones(10,3))
1×3 Matrix{Float64}:
 32.4159  32.4159  32.4159

julia> (m::Maker)(v, x) = Flux.Chain(m.pre, x -> x.*v, m.post)(x)  # two steps in one, tidier?

julia> re(ones(121))(pi, ones(10,3))
1×3 Matrix{Float64}:
 32.4159  32.4159  32.4159
2 Likes

Thank you so much, this seems to work and is the solution!