I want to build two functions for replacing values in a vector and an array, as shown below:
function mutate_vec(vec::AbstractVector{T}, new_value::T, idx::Int) where {T}
vec[idx] = new_value
vec
end
function mutate_arr(arr::AbstractArray{T}, new_value::AbstractArray{T}, idx::Tuple) where {T}
arr[idx..., :] .= new_value
arr
end
But it is known that Zygote usually does not support mutating arrays, so I wrote rrule rules for these two functions by following the guidelines from Which functions need rules? · ChainRules, as shown below:
function ChainRules.rrule(::typeof(mutate_vec), vec::AbstractVector{T}, new_value::T, idx::Int) where {T}
vec = mutate_vec(vec, new_value, idx)
function mutate_vec_pullback(ȳ)
return NoTangent(), ones(T, size(vec)), T(1.0), NoTangent()
end
return vec, mutate_vec_pullback
end
function ChainRules.rrule(::typeof(mutate_arr), arr::AbstractArray{T}, new_value::AbstractArray{T}, idx::Tuple) where {T}
arr = mutate_arr(arr, new_value, idx)
function mutate_arr_pullback(ȳ)
return NoTangent(), ones(T, size(arr)), ones(T, size(new_value)), NoTangent()
end
return arr, mutate_arr_pullback
end
Due to my insufficient understanding of gradients, I am not sure if the rules I wrote are correct, so I hope someone can give me advice.