Questions with building a chainrules for mutating array function

I want to build two functions for replacing values in a vector and an array, as shown below:

function mutate_vec(vec::AbstractVector{T}, new_value::T, idx::Int) where {T}
    vec[idx] = new_value

function mutate_arr(arr::AbstractArray{T}, new_value::AbstractArray{T}, idx::Tuple) where {T}
    arr[idx..., :] .= new_value

But it is known that Zygote usually does not support mutating arrays, so I wrote rrule rules for these two functions by following the guidelines from Which functions need rules? · ChainRules, as shown below:

function ChainRules.rrule(::typeof(mutate_vec), vec::AbstractVector{T}, new_value::T, idx::Int) where {T}
    vec = mutate_vec(vec, new_value, idx)
    function mutate_vec_pullback(ȳ)
        return NoTangent(), ones(T, size(vec)), T(1.0), NoTangent()
    return vec, mutate_vec_pullback

function ChainRules.rrule(::typeof(mutate_arr), arr::AbstractArray{T}, new_value::AbstractArray{T}, idx::Tuple) where {T}
    arr = mutate_arr(arr, new_value, idx)
    function mutate_arr_pullback(ȳ)
        return NoTangent(), ones(T, size(arr)), ones(T, size(new_value)), NoTangent()
    return arr, mutate_arr_pullback

Due to my insufficient understanding of gradients, I am not sure if the rules I wrote are correct, so I hope someone can give me advice.

Hi @chooron!

This could be more explicit in the ChainRules docs, but you must distinguish between two kinds of mutation:

  1. mutation of objects created inside the function
  2. mutation of objects passed as arguments to the function

The first kind of mutation is exactly what ChainRules allows you to solve with custom rules. However, the second kind is still experimental, and you need to take a look at this documentation page to use it.

Sorry, I didn’t fully read this document, so I missed this part. Thank you very much for your suggestion. I will revise my code according to the document. :blush:

1 Like