Implicit gradients with mutable struct returns an error

Hi,

When I try to compute the gradients of a mutable struct via implicit gradients, this code fails, but the same exact code works when the struct is not mutable:

using Zygote

mutable struct Linear_
    W::Array{Float64, 2}
    b::Array{Float64, 2}
    
    Linear_(ni::Int,no::Int) = new(rand(ni,no),rand(1,no))
end

function (l::Linear_)(x)
    return x * l.W .+ l.b
end

function mse(input::Array{Float64, 2}, target::Array{Float64, 2})::Float64
    return sum((input.-target).^2)/length(input)
end

n_samples = 7
n_features = 5
n_outputs = 3

input = rand(n_samples, n_features)
target = rand(n_samples, n_outputs)

layer1 = Linear_(n_features,10)
layer2 = Linear_(10,n_outputs)

g = gradient(() -> mse(layer2(layer1(input)), target), Params([layer1, layer2]))
g.grads

Output:

MethodError: no method matching getindex(::Nothing)

Stacktrace:
  [1] (::Zygote.var"#back#212"{:b, Zygote.Context, Linear_, Matrix{Float64}})(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/lib/lib.jl:233
  [2] (::Zygote.var"#1744#back#213"{Zygote.var"#back#212"{:b, Zygote.Context, Linear_, Matrix{Float64}}})(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [3] Pullback
    @ ./In[4]:11 [inlined]
  [4] (::typeof(∂(λ)))(Δ::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
  [5] Pullback
    @ ./In[4]:28 [inlined]
  [6] (::typeof(∂(#7)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface2.jl:0
  [7] (::Zygote.var"#84#85"{Params, typeof(∂(#7)), Zygote.Context})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface.jl:343
  [8] gradient(f::Function, args::Params)
    @ Zygote ~/.julia/packages/Zygote/rv6db/src/compiler/interface.jl:76
  [9] top-level scope
    @ In[4]:28
 [10] eval
    @ ./boot.jl:360 [inlined]
 [11] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
    @ Base ./loading.jl:1116

I can however get the gradients via explicit parameters:

gradient((l1,l2) -> mse(l2(l1(input)), target), layer1, layer2)

((W = [10.300679088087453 10.568342281725188 … 3.649421030153132 5.493540298981077; 7.077079146811202 7.255902898063759 … 2.4956256731239326 3.786621582675167; … ; 6.169735183400968 6.33463893073102 … 2.2021692558068606 3.2655913400304204; 10.387690691720055 10.65536098102742 … 3.6648346302631385 5.569931540137668], b = [15.15247655498283 15.543308168010066 … 5.354689243962132 8.104862332066384]), (W = [20.2302939176642 20.59477956756408 10.960969614757374; 11.540421752151698 11.643524500040572 6.293879532698054; … ; 13.778773349620447 14.029981443678864 7.46339545056888; 10.509508900591154 10.668660992058987 5.682118298408729], b = [7.810980389487534 7.916217478882832 4.248374194822758]))

And the same code than in the first snippet, using implicit gradients, but now with a non mutable struct:

using Zygote
​
struct Linear_
    W::Array{Float64, 2}
    b::Array{Float64, 2}
    
    Linear_(ni::Int,no::Int) = new(rand(ni,no),rand(1,no))
end
​
function (l::Linear_)(x)
    return x * l.W .+ l.b
end
​
function mse(input::Array{Float64, 2}, target::Array{Float64, 2})::Float64
    return sum((input.-target).^2)/length(input)
end
​
n_samples = 7
n_features = 5
n_outputs = 3
​
input = rand(n_samples, n_features)
target = rand(n_samples, n_outputs)
​
layer1 = Linear_(n_features,10)
layer2 = Linear_(10,n_outputs)
​
g = gradient(() -> mse(layer2(layer1(input)), target), Params([layer1, layer2]))
g.grads

Output:

IdDict{Any, Any} with 6 entries:
  Linear_([0.996399 0.3126… => (W = [8.39223 9.37895 8.61237; 6.8855 7.70569 7.…
  :(Main.layer1)            => (W = [7.58618 6.41697 … 4.74191 7.71558; 7.78701…
  :(Main.target)            => [-0.615714 -0.765777 -0.712478; -0.835387 -0.894…
  :(Main.layer2)            => (W = [8.39223 9.37895 8.61237; 6.8855 7.70569 7.…
  :(Main.input)             => [5.58015 5.44244 … 6.14441 5.18824; 6.76961 6.63…
  Linear_([0.127018 0.2413… => (W = [7.58618 6.41697 … 4.74191 7.71558; 7.78701…

Is this some known behaviour of Zygote? I have been reading the docs and the repo but I cannot find anything related.

Thanks!

1 Like

Interestingly, omitting layer1/2 from the param list does give the correct gradients and you can even index them as desired (best workaround for now), so this is definitely a bug. I could also repro with just a single layer like sum(layer1(input)). Could you open an issue so we can get some eyes on it?

1 Like