How to make Zygote avoid differentiating with respect to some fields in struct

Hello everyone,

I have been always used ForwardDiff for my automatic differentiation needs, but recently I thought I should try out Zygote instead. My main motivation for doing so is the ability of Zygote to work with callable structs that can be used for implementing ML models.

The first obstacle I have encountered is the following: besides model parameters, a Struct may also hold additional information, e.g. random seed, various weights, some text description, that should not be included in the automatic differentiation. To make my question more concrete, I post below some code (basically a modified version of an example from the Zygote documentation):

using Zygote

struct Linear
         W
         b
         C
       end

(l::Linear)(x) = l.W * x .+ l.b

model = Linear(rand(2, 5), rand(2), rand(3))

N = 3

X = [randn(5) for i in 1:N]
Y = [randn(2) for i in 1:N]

function loss(model)
  loss = 0.0
  for n in 1:N
    loss += sum(model.C[n]*(model(X[n]) .- Y[n]).^2)
  end
  loss
end

Basically, we define a linear model, N=3 data item pairs of inputs X and targets Y and a loss function.
In addition to the original example, the struct here also holds some weighing coefficients C which are constant and are not free model parameters that need to be optimised.

However, if I call dmodel = gradient(loss, model)[1], Zygote will also offer the gradient wrt C:

(W = [4.109999051379186 12.539463966870912 … 5.154617937809002 3.818367921164797; 2.7677477951659277 8.485557353981928 … 0.33008326109979613 5.871635193000793], b = [10.270604809342977, 6.932605807555453], C = [5.114710562186542, 8.589047944559583, 15.909538432009729])

Is there perhaps a way of telling Zygote not to differentiate with respect to C? Many thanks.

3 Likes

In Flux, this is typically accomplished using @functor. From the documentation:

To include only certain fields of a struct, one can pass a tuple of field names to @functor :

julia> struct Baz
         x
         y
       end

julia> @functor Baz (x,)

julia> model = Baz(1, 2)
Baz(1, 2)

julia> fmap(float, model)
Baz(1.0, 2)

Apparently, this does not interact perfectly with Zygote, though; see the relevant issue here.

Per the discussion there, it seems that it is sometimes the case that correctly computing the gradient of one field will implicitly depend on the gradient of other fields, and therefore ignoring the other gradients would silently give incorrect results.

2 Likes

A workaround is to define an auxiliary function blocking the gradient:

julia> nograd(x) = x
nograd (generic function with 1 method)

julia> Zygote.@nograd nograd

julia> function loss(model)
         loss = 0.0
         for n in 1:N
           c = nograd(model.C[n]) 
           loss += sum(c*(model(X[n]) .- Y[n]).^2)
         end
         loss
       end
loss (generic function with 1 method)

julia> dmodel = gradient(loss, model)[1]
(W = [-2.6711034296853375 -0.2254551144087391 … 1.4033779589003235 3.1219721905976305; 2.5271537498909424 -0.8264030561248071 … -1.663924873993493 -1.3721983219005442], b = [1.8739815151086066, -0.8248439951554849], C = nothing)

3 Likes

Nice

You can also use the built-in Zygote.dropgrad in place of nograd: Utilities · Zygote. For a higher-level API, the linked Flux issue is definitely the one to follow.

1 Like

I had never come across functors, thanks very much for informing me about it. I thought I had looked thoroughly in the Zygote documentation, but apparently I didn’t. Looks like this could be useful to me in other contexts too…

1 Like

Thanks for your reply. This is easier for me to understand than the functor solution.