Idea to make Zygote support mutation in easy cases

Motivation

The appeal of automatic differentiation is that you can write a function using ordinary Julia and the gradient is automatically computed. Disallowing mutation in that function seems reasonable on the surface—it seems like it should be pretty easy to write code that doesn’t mutate arrays.

The problem is that mutation is tucked away in many places throughout the Julia ecosystem, in many cases as an implementation detail that is not visible to the user of that functionality. For example, Float64[f(x) for x in input] mutates. To quote the Zygote documentation “Non-mutating functions may also use mutation under the hood. This can be done for performance reasons or code re-use.”

Idea

The fundamental issue with mutation is that you loose information when you overwrite a value. The idea is to support mutation when the value prior to the mutation is not used in gradient calculation. In my estimation, this would fix most of the hard to spot or unexpected mutations because Zygote would no longer throw for operations that mutate in a manner that is irrelevant to gradient calculation (e.g. use of mutation for code re-use within a library function). Further, in all or almost all cases it should be possible to add a copy operation just before the mutation to avoid the error if performance isn’t a major concern.

Implementation

An array a::T where T<:AbstractArray could be represented as mutable struct WrappedArray{T}; const a::T; valid::Bool; end on the backwards pass with valid starting out as true. Pullback for mutating operations would set valid to false and access to a for gradient computation would be gated by runtime validity checking. The runtime cost of these checks would be present even if there was no mutation, but should be negligible arrays of more than a few elements. Ideally someone more familiar with Zygote.jl would be able to devise a system that has no runtime cost in most cases.

cc @MikeInnes who has previously approached this issue
cc @ToucheSir, this proposal hopes to begin to address your comment here

1 Like

Have you checked Zygote’s buffers? It seems they do exactly what you suggest.

Mutation is confusing but this sounds similar to a proposal in ChainRules#521, where the idea is to make return fill!(similar(x), y) work by giving fill! a rule which poisons the gradient of its first argument.

But Zygote (1) at present doesn’t call the pullback at all when the function’s return is not used (as is common in mutating paths), and (2) since that issue seems to have been taught to ignore ChainRules’s not-implemented mechanism.

It’s possible that this has other problems too.

I want to differentiate generic Julia code that predates or is otherwise unconcerned with compatibility with Zygote. For example, I can’t expect StatsBase’s mean function to use Zygote’s buffers as a way to fix gradient() fails on array mutation for `mean(f, x; dims)` · Issue #1128 · FluxML/Zygote.jl · GitHub.

There is a small but pivotal difference between that proposal and my own. Rather than poison the gradient, I poison the data and set the gradient to zero (ideally a structural zero). A key objection to ChainRules#521 is “that it will cause any other rule which has captured x to give wrong answers” (@mcabbott here). By poisoning the data, any other rule which captures the mutated value prior to mutation will throw.

Yes, after thinking some more, passing something back in the gradient won’t always be enough. That’s what 521 looked at, and what I read this as saying, alter the gradient representation:

Now you seem to be saying that the object with this extra flag is present on the forward pass. In which case you can simulate the effect by having the pullback write NaN into the original a. (Or just restore the original values before mutation.)

What this still won’t solve is that, when the return value of fill!(xs, y) is discarded, its pullback won’t get the gradient for the new xs. In the spirit of trying to make simple cases work you could, modulo (1), (2) above, have the pullback return NotImplemented for both dx and dy. That could perhaps let something like x[1]=0 work, when you don’t want the gradient of x. But won’t help for sum(Float32[x for _ in 1:3]).

I believe this is an example of what you are referring to, and yes, with my original proposal it would return [0, 0, 0]

gradient([1, 2, 3]) do x
    y = x*x
    fill!(x, zero(x))
    y
end

I don’t fully understand your point (2) above, but for point (1), the standard Julia compiler only elides functions whose return values are not used if they are free of side effects, could Zygote do a similar thing in pullbacks? That is, only elide pullbacks of functions whose return values are ignored when the pullback has no side effects. The trivial pullback of functions with @nograd is free of side effects.

Yes, this is closer to what I am suggesting. Unfortunately, writing NaN is not a viable option because not all arrays support NaN and restoring the original values is not ideal because that would require all mutating operations to make a copy. I don’t want sum(Float32[x for _ in 1:3]) to allocate twice: first for the array that is populated and summed and then a copy of that uninitialized array to restore later. Ideally there would be some way to poison these arrays at compile time.

Zygote only sees untyped, pre-inference IR (think @code_lowered or @code_warntype without type information). It simply does not have enough information to do anything beyond the simplest, surface level transformations.

That said, it looks like we can force Zygote to evaluate pullbacks for mutating functions even if the result isn’t used. See rrule for fill! by CarloLucibello · Pull Request #521 · JuliaDiff/ChainRules.jl · GitHub.

One issue with doing this via a wrapper is that there are hundreds of existing ChainRules rrules (spread over numerous packages) out there which take array arguments. Making all of those aware of the wrapper type and able to check for validity/poison would be a massive effort, hence the exploration of less distruptive alternatives like writing NaNs. It’s quite possible this would require major changes to ChainRulesCore’s interface, as discussed in Ability to specify different rules based on what combinations of inputs are actually being used · Issue #452 · JuliaDiff/ChainRulesCore.jl · GitHub, mutating calls · Issue #242 · JuliaDiff/ChainRulesCore.jl · GitHub, etc.

1 Like