Hello everyone,
I have been always used ForwardDiff for my automatic differentiation needs, but recently I thought I should try out Zygote instead. My main motivation for doing so is the ability of Zygote to work with callable structs that can be used for implementing ML models.
The first obstacle I have encountered is the following: besides model parameters, a Struct may also hold additional information, e.g. random seed, various weights, some text description, that should not be included in the automatic differentiation. To make my question more concrete, I post below some code (basically a modified version of an example from the Zygote documentation):
using Zygote
struct Linear
W
b
C
end
(l::Linear)(x) = l.W * x .+ l.b
model = Linear(rand(2, 5), rand(2), rand(3))
N = 3
X = [randn(5) for i in 1:N]
Y = [randn(2) for i in 1:N]
function loss(model)
loss = 0.0
for n in 1:N
loss += sum(model.C[n]*(model(X[n]) .- Y[n]).^2)
end
loss
end
Basically, we define a linear model, N=3 data item pairs of inputs X and targets Y and a loss function.
In addition to the original example, the struct here also holds some weighing coefficients C which are constant and are not free model parameters that need to be optimised.
However, if I call dmodel = gradient(loss, model)[1]
, Zygote will also offer the gradient wrt C:
(W = [4.109999051379186 12.539463966870912 … 5.154617937809002 3.818367921164797; 2.7677477951659277 8.485557353981928 … 0.33008326109979613 5.871635193000793], b = [10.270604809342977, 6.932605807555453], C = [5.114710562186542, 8.589047944559583, 15.909538432009729])
Is there perhaps a way of telling Zygote not to differentiate with respect to C? Many thanks.