How to make Zygote avoid differentiating with respect to some fields in struct

A workaround is to define an auxiliary function blocking the gradient:

julia> nograd(x) = x
nograd (generic function with 1 method)

julia> Zygote.@nograd nograd

julia> function loss(model)
         loss = 0.0
         for n in 1:N
           c = nograd(model.C[n]) 
           loss += sum(c*(model(X[n]) .- Y[n]).^2)
         end
         loss
       end
loss (generic function with 1 method)

julia> dmodel = gradient(loss, model)[1]
(W = [-2.6711034296853375 -0.2254551144087391 … 1.4033779589003235 3.1219721905976305; 2.5271537498909424 -0.8264030561248071 … -1.663924873993493 -1.3721983219005442], b = [1.8739815151086066, -0.8248439951554849], C = nothing)

3 Likes