Using ChainRules rrule for mutating function to work with Zygote

Hi there,

I would like to add an rrule for a function that mutates an input array that I want to differentiate with.
The following functions can be used with Forward/ReverseDiff, but gives me an Mutation error with Zygote:


using ChainRulesCore, Zygote, ForwardDiff, ReverseDiff
function vec_to_mat(x_vec::AbstractVector{T}) where {T<:Real}
    K = size(x_vec, 1)
    x_mat = zeros(eltype(x_vec), K, K)
    @inbounds @simd for iter in eachindex(x_vec)
        x_mat[iter, iter] = x_vec[iter]
    end
    return x_mat
end
x_vec = [1., 2., 3.]
x_m = vec_to_mat(x_vec) #3*3 Matrix with diagonal elements == x_vec

function my_AD_function(x_vec::AbstractVector{T}) where {T<:Real}
    x_mat = vec_to_mat(x_vec)
    return sum(x_mat)
end
my_AD_function(x_vec) #6
ForwardDiff.gradient(my_AD_function, x_vec) # Vector{Float64} with 3 elements
ReverseDiff.gradient(my_AD_function, x_vec) # Vector{Float64} with 3 elements
Zygote.gradient(my_AD_function, x_vec)[1] # Mutating arrays is not supported -- called setindex!...

I would like the gradients only wrt to the 3 input parameter, and the function should not change the gradients at all. Here is my non-working example, but Zygote does not seem to work with this one:

function ChainRulesCore.rrule(::typeof(vec_to_mat), v::AbstractVector{T}) where {T<:Real}
    L = vec_to_mat(v)
    pullback_vec_to_mat(x) = NoTangent(), vec_to_mat(x)
    return L, pullback_vec_to_mat
end
Zygote.gradient(my_AD_function, x_vec) # Mutating arrays is not supported -- called setindex!...

I can get the example working by relaxing type constrictions for the mutating function (which I do not really want to), but then I get the gradients wrt to the (larger) matrix, not wrt to the input vector:

function vec_to_mat(x_vec)
    K = size(x_vec, 1)
    x_mat = zeros(eltype(x_vec), K, K)
    @inbounds @simd for iter in eachindex(x_vec)
        x_mat[iter, iter] = x_vec[iter]
    end
    return x_mat
end
function my_AD_function(x_vec::AbstractVector{T}) where {T<:Real}
    x_mat = vec_to_mat(x_vec)
    return sum(x_mat)
end
Zygote.gradient(my_AD_function, x_vec) #3Ă—3 Matrix{Float64} ,not Vector{Float64} with 3 elements

Question: How can I add the rrule, such that I only get the gradient wrt to the 3-dimensional input vector, not wrt to the enlarged 3*3 matrix?

I think this should work, but the gradient shouldn’t call the same function again, it should instead perform the reverse operation:

function ChainRulesCore.rrule(::typeof(vec_to_mat), v::AbstractVector{T}) where {T<:Real}
    L = vec_to_mat(v)  # forward pass: from v to matrix
    pullback_vec_to_mat(x) = NoTangent(), diag(x)  # backward pass: from matrix to vector
    return L, pullback_vec_to_mat
end

If this is really the function you want, then note that there is a built-in version:

julia> gradient(x -> sum(abs2, diagm(x)), [1,2,3])
([2, 4, 6],)
2 Likes

Thanks a lot that worked! Is there also a way to define inplace version that do not return a new matrix at each call? This works:

using ChainRulesCore, Zygote, ForwardDiff, ReverseDiff

function vec_to_mat(x_vec::AbstractVector{T}) where {T<:Real}
    K = size(x_vec, 1)
    x_mat = zeros(eltype(x_vec), K, K)
    @inbounds @simd for iter in eachindex(x_vec)
        x_mat[iter, iter] = x_vec[iter]
    end
    return x_mat
end
function mat_to_vec(x_mat::AbstractMatrix{T}) where {T<:Real}
    K = size(x_mat, 1)
    x_vec = zeros(eltype(x_mat), K)
    @inbounds @simd for iter in eachindex(x_vec)
        x_vec[iter] = x_mat[iter, iter]
    end
    return x_vec
end
function ChainRulesCore.rrule(::typeof(vec_to_mat), v::AbstractVector{T}) where {T<:Real}
    L = vec_to_mat(v)
    pullback_vec_to_mat(mat) = NoTangent(), mat_to_vec(mat)
    return L, pullback_vec_to_mat
end
x_vec = [1., 2., 3.]
x_m = vec_to_mat(x_vec) #3*3 Matrix with diagonal elements == x_vec
mat_to_vec(x_m)
function my_func(x_vec::AbstractVector{T}) where {T<:Real}
    x_mat = vec_to_mat(x_vec)
    return sum(x_mat.^2)
end
my_func(x_vec)
ForwardDiff.gradient(my_func, x_vec) # Vector{Float64} with 3 elements #2., 4., 6.
ReverseDiff.gradient(my_func, x_vec) # Vector{Float64} with 3 elements #2., 4., 6.
Zygote.gradient(my_func, x_vec)[1] # Vector{Float64} with 3 elements #2., 4., 6.

but the inplace version fails unfortunately:

function vec_to_mat!(x_mat::AbstractMatrix{T}, x_vec::AbstractVector{T}) where {T<:Real}
    @inbounds @simd for iter in eachindex(x_vec)
        x_mat[iter, iter] = x_vec[iter]
    end
    return x_mat
end
function mat_to_vec!(x_vec::AbstractVector{T}, x_mat::AbstractMatrix{T}) where {T<:Real}
    @inbounds @simd for iter in eachindex(x_vec)
        x_vec[iter] = x_mat[iter, iter]
    end
    return x_vec
end
function ChainRulesCore.rrule(::typeof(vec_to_mat!), v::AbstractVector{T}, x_mat::AbstractMatrix{T}) where {T<:Real}
    L = vec_to_mat!(x_mat, v)
    pullback_vec_to_mat(mat) = NoTangent(), mat_to_vec!(mat, v)
    return L, pullback_vec_to_mat
end
x_vec = [1., 2., 3.]
x_m = zeros(3, 3)
vec_to_mat!(x_m, x_vec)
mat_to_vec!(x_vec, x_m)

function my_func2(x_vec::AbstractVector{T}) where {T<:Real}
    K = size(x_vec, 1)
    x_mat = zeros(eltype(x_vec), K, K)
    vec_to_mat!(x_mat, x_vec)
    return sum(x_mat.^2)
end
my_func2(x_vec)
ForwardDiff.gradient(my_func2, x_vec) # Vector{Float64} with 3 elements #2., 4., 6.
ReverseDiff.gradient(my_func2, x_vec) # Vector{Float64} with 3 elements
Zygote.gradient(my_func2, x_vec)[1] #Mutating arrays is not supported -- called setindex!(::Matrix{Float64}