Given that Zygote does not support mutation, how does Recur gets away with it?

From Flux.jl’s source code here:

mutable struct Recur{T}
  cell::T
  init
  state
end

Recur(m, h = hidden(m)) = Recur(m, h, h)

function (m::Recur)(xs...)
  h, y = m.cell(m.state, xs...)
  m.state = h
  return y
end

it looks like Recur struct, which is used for every recurrent layer, does mutation of its state field in the forward pass. But Zygote.jl does not support mutation, so why is this not throwing something like ERROR: Mutation is not supported! as it usually does in such cases?

For a context of where this came up: I was implementing my custom stateful layer. Initially, I just defined MyCustomRecurrentCell similarly to RNNCell and relied on Recur to handle the MyCustomRecurrentCell's state mutation, just like RNNCell does. But then I discovered that because Recur's fields are not type annotated I was getting Any type outputs for my state which was propagating to all the other layer and I was getting Any everywhere because of this. Then I filed this issue and decided to reimplement Recur with type annotated fields. When I did so, I’m now getting something like Mutation is not supported error. So I was wondering how is Recur able to get away with mutating state?

4 Likes

@Azamat I was wondering if you figured out what was going on with the state mutation?

I’m also still puzzled about figuring how the Recur struct handles its state to be mutated.
More specifically, I’m wondering if the reason could come from the usage of a broadcast for applying the forward pass over the sequence, such as in: rnn.(x). Do you know if there’s any guarantee when applying the broadcast that it is actually applied in a sequence, which is needed for the state to be updated in right order?

@jeremiedb No, I haven’t. Maybe @MikeInnes can shed some light on this.

Besides being type-unstable, there are some other issues with the current implementation of recurrent layers in Flux (see e.g. Flux.jl#1089), which is why I think they need to be redesigned completely with the attention to performance

Zygote doesn’t support mutation of arrays, but you can mutate other objects just fine (eg try code with dictionaries in).

What might be happening when you add a type restraint is that Julia has to convert the array before storing it, which might call a mutating kernel. Whatever the exact reason, the error must be coming from mutation of an array.

Feel free to try it, but I’d be very surprised is adding type restrictions to Recur gave any meaningful performance improvement. I know the Julia manual talks about type inference and globals etc., but ML really has very different performance constraints; there’s plenty of active work here and it’s definitely not the case that these things are written without attention to performance.

1 Like

Although I remain unclear on how Zygote gets away with the apparent array mutation that happens within the Recur, from the test below, it effectively looks like implementing an immutable struct for the RNNCell and Recur along with type annotation doesn’t bring any improvement (both for time and memory allocations):

using Flux
using Flux: Recur, @functor, glorot_uniform, hidden
using BenchmarkTools

# immutable alternative - no hidden state since redundant with one defined in Recur
struct MyRNNCell{F,T}
  σ::F
  Wi::Matrix{T}
  Wh::Matrix{T}
  b::Vector{T}
end

# initializer
MyRNNCell(in::Integer, out::Integer, σ = tanh; init = glorot_uniform) = MyRNNCell(σ, init(out, in), init(out, out), init(out))

# overload
function (m::MyRNNCell{F,T})(h::Matrix{T}, x::Matrix{T}) where {F,T}
  σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
  h = σ.(Wi*x .+ Wh*h .+ b)
  return h, h
end

@functor MyRNNCell Wi, Wh, b

# immutable alternative implementation of Recur
struct MyRecur{F,T}
  cell::MyRNNCell{F,T}
  init::Vector{T}
  state::Matrix{T}
end

# overload
function (m::MyRecur)(xs...)
  h, y = m.cell(m.state, xs...)
  m.state .= h
  return y
end

@functor MyRecur cell, init


# define model based on alternative immutable struct
m = MyRecur(MyRNNCell(128,256), zeros(Float32,256), zeros(Float32,256, 512))
# original / benchmark model
bm = RNN(128,256)

# generate data
x = rand(Float32, 128,512)
xx = [x = rand(Float32, 128,512) for i in 1:100]

# all clean inference
@code_warntype m(x)
# inference issues raised
@code_warntype bm(x)

@btime m(x)
2.562 ms (11 allocations: 1.50 MiB)

@btime bm(x)
2.553 ms (11 allocations: 1.50 MiB)

Any pointers to understand how both approach have same performance despite the expected gain from the type stability (at least from the @code_warntype) would be helpful, as it effectively breaks the rationale I had from the manual performance tips.

1 Like

There is no array mutation. You can see that the Recur struct modifies its own mutable state field here. But it’s a struct that gets modified, not an array.

Performance is always context-dependent. If code is not type inferred, Julia inserts dynamic dispatches to figure out what method should be called at run time; this costs about 100ns - 1μs. That’s a disaster if you have a scalar loop (since scalar operations can take about a nanosecond, so the overhead is significant), but it’s a non-issue if you’re working with big arrays (as in ML) since individual array operations can easily take milliseconds. Adding a microsecond of dispatch to each operation isn’t even noticeable.

Type inference might matter more in future once the compiler does more array optimisations, but for now it’s largely not worth worrying about.

2 Likes

Type inference might matter more in future once the compiler does more array optimisations, but for now it’s largely not worth worrying about.

In my use-case, I was using LSTMs with OMEinsum.jl and because of their current type-unstable implementation the LSTM's output was getting inferred as Any or Array{Any} and because of that it was not getting dispatching to the right GPU–optimized kernel in OMEinsum, but to a generic slow one, which is when I’ve filed the issue above. All of this was happening in the backward pass, so really hard to debug and fix. So type inference was a deal-breaker in my case.

Hm, inference shouldn’t influence dispatch, the actual runtime types are used to determine what method is called. Unless you somehow manually call into inference and do logic based on that, that is.