I would like to differentiate all parameters belonging to a model using Zygote without having them as explicit inputs to the function I’m differentiating. A MWE taken more or less directly from Flux’s documentation is this:
using Zygote
struct Affine
W
b
end
Affine(in::Int, out::Int) = Affine(randn(out, in), randn(out))
(a::Affine)(x) = a.W*x .+ a.b
# Simple model
m = Affine(3, 2)
m(randn(3))
g = Zygote.gradient(()->sum(m(randn(3))), Zygote.Params([m.W, m.b]))
g[m.W], g[m.b]
which outputs
([-1.12455 -0.111599 -1.06465; -1.12455 -0.111599 -1.06465], [1.0, 1.0])
So this works fine for our simple model. Now say I want to differentiate this function using Zygote like this:
using Zygote
struct Affine
W
b
end
Affine(in::Int, out::Int) = Affine(randn(out, in), randn(out))
(a::Affine)(x) = a.W*x .+ a.b
# Chained model
layers = [Affine(3, 2), Affine(2, 1)]
model(x) = foldl((xx,m)->m(xx), layers, init=x)
g = Zygote.gradient(()->sum(model(randn(3))), Zygote.Params([layers[1].W, layers[1].b]))
Where the model is slightly more complex. Is there a good way from only having access to the model(x) function to extract the parameters? Or would I have to create a Model struct carrying both the layers definition and the model(x) function?
Further the last gradient try in the complex model fails with a long error like this.
ERROR: MethodError: no method matching zero(::Type{Affine})
Closest candidates are:
zero(::Type{LibGit2.GitHash}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.0/LibGit2/src/oid.jl:220
zero(::Type{Pkg.Resolve.VersionWeights.VersionWeight}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.0/Pkg/src/resolve/VersionWeights.jl:19
zero(::Type{Pkg.Resolve.MaxSum.FieldValues.FieldValue}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.0/Pkg/src/resolve/FieldValues.jl:44
…
Stacktrace:
Not including the stacktrace. I fail to see why this latter approach is different than the first since I’m referencing my parameters through the layers. Should be equivalent right? I’m missing something.