Hi there,
I would like to add an rrule for a function that mutates an input array that I want to differentiate with.
The following functions can be used with Forward/ReverseDiff, but gives me an Mutation error with Zygote:
using ChainRulesCore, Zygote, ForwardDiff, ReverseDiff
function vec_to_mat(x_vec::AbstractVector{T}) where {T<:Real}
K = size(x_vec, 1)
x_mat = zeros(eltype(x_vec), K, K)
@inbounds @simd for iter in eachindex(x_vec)
x_mat[iter, iter] = x_vec[iter]
end
return x_mat
end
x_vec = [1., 2., 3.]
x_m = vec_to_mat(x_vec) #3*3 Matrix with diagonal elements == x_vec
function my_AD_function(x_vec::AbstractVector{T}) where {T<:Real}
x_mat = vec_to_mat(x_vec)
return sum(x_mat)
end
my_AD_function(x_vec) #6
ForwardDiff.gradient(my_AD_function, x_vec) # Vector{Float64} with 3 elements
ReverseDiff.gradient(my_AD_function, x_vec) # Vector{Float64} with 3 elements
Zygote.gradient(my_AD_function, x_vec)[1] # Mutating arrays is not supported -- called setindex!...
I would like the gradients only wrt to the 3 input parameter, and the function should not change the gradients at all. Here is my non-working example, but Zygote does not seem to work with this one:
function ChainRulesCore.rrule(::typeof(vec_to_mat), v::AbstractVector{T}) where {T<:Real}
L = vec_to_mat(v)
pullback_vec_to_mat(x) = NoTangent(), vec_to_mat(x)
return L, pullback_vec_to_mat
end
Zygote.gradient(my_AD_function, x_vec) # Mutating arrays is not supported -- called setindex!...
I can get the example working by relaxing type constrictions for the mutating function (which I do not really want to), but then I get the gradients wrt to the (larger) matrix, not wrt to the input vector:
function vec_to_mat(x_vec)
K = size(x_vec, 1)
x_mat = zeros(eltype(x_vec), K, K)
@inbounds @simd for iter in eachindex(x_vec)
x_mat[iter, iter] = x_vec[iter]
end
return x_mat
end
function my_AD_function(x_vec::AbstractVector{T}) where {T<:Real}
x_mat = vec_to_mat(x_vec)
return sum(x_mat)
end
Zygote.gradient(my_AD_function, x_vec) #3Ă—3 Matrix{Float64} ,not Vector{Float64} with 3 elements
Question: How can I add the rrule, such that I only get the gradient wrt to the 3-dimensional input vector, not wrt to the enlarged 3*3 matrix?