Differentiating implicit parameters using Zygote in complex hierarchical models

I would like to differentiate all parameters belonging to a model using Zygote without having them as explicit inputs to the function I’m differentiating. A MWE taken more or less directly from Flux’s documentation is this:

``````using Zygote

struct Affine
W
b
end

Affine(in::Int, out::Int) = Affine(randn(out, in), randn(out))

(a::Affine)(x) = a.W*x .+ a.b

# Simple model
m = Affine(3, 2)
m(randn(3))
g[m.W], g[m.b]
``````

which outputs

([-1.12455 -0.111599 -1.06465; -1.12455 -0.111599 -1.06465], [1.0, 1.0])

So this works fine for our simple model. Now say I want to differentiate this function using Zygote like this:

``````using Zygote

struct Affine
W
b
end

Affine(in::Int, out::Int) = Affine(randn(out, in), randn(out))

(a::Affine)(x) = a.W*x .+ a.b

# Chained model
layers = [Affine(3, 2), Affine(2, 1)]
model(x) = foldl((xx,m)->m(xx), layers, init=x)
``````

Where the model is slightly more complex. Is there a good way from only having access to the model(x) function to extract the parameters? Or would I have to create a Model struct carrying both the layers definition and the model(x) function?

Further the last gradient try in the complex model fails with a long error like this.

ERROR: MethodError: no method matching zero(::Type{Affine})
Closest candidates are:
zero(::Type{LibGit2.GitHash}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.0/LibGit2/src/oid.jl:220
zero(::Type{Pkg.Resolve.VersionWeights.VersionWeight}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.0/Pkg/src/resolve/VersionWeights.jl:19
zero(::Type{Pkg.Resolve.MaxSum.FieldValues.FieldValue}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.0/Pkg/src/resolve/FieldValues.jl:44

Stacktrace:

Not including the stacktrace. I fail to see why this latter approach is different than the first since I’m referencing my parameters through the layers. Should be equivalent right? I’m missing something.

3 Likes