ChainRules for functions that "destroy their input arguments

Some functions have the behaviour of using their inputs as workspace, and by doing so their documented behaviour is that they destroy the input. As an example, things like svd! work in such fashion.
While in general, mutating functions are not supported, in this case it might be fair to assume that we can define a rule like (example code):

function ChainRulesCore.rrule(::typeof(svd!), A)
    U, S, V = svd(A)
    pullback_svd!(dUSV) = pullback_svd(dUSV)
   [...]
end

Effectively, this automatically replaces svd! with svd, which should leave the code around it unaffected.

With this setting in mind, I have two questions:
Firstly, is there any reason to refrain from doing this? It does feel prone to silent errors that are hard to debug.
Secondly, for rrules like this, the automatic test_rrule from ChainRulesTestUtils will fail, as this assumes that the primal values are left unchanged. I have a local version of test_mutating_rrule that circumvents this by explicitly copying these, for which I can file a pull request if wanted.

Any thoughts or comments are greatly appreciated.

I am asking this because I am defining rules for TensorKit, and there are similar rules there.

This sounds fine I think. So long as you are sure that no program tries to use the values written into A. Which seems likely if they are officially garbage.

For testing, one way would be to define rrule(::typeof(svd!∘copy), x) = rrule(svd!, x) and test that. ChainRules itself has some examples a bit like this, e.g. collect∘eachslice fits what CRTU expects (meaning in this case, an array output, even though eachslice returned a non-array iterator, on Julia 1.6 or so.)

1 Like