How to support `view` in Flux.Tracker?

I have a batched gemm and batched tr:

Now I want to implement the gradient as:

using Flux.Tracker
using Flux.Tracker: TrackedArray, track, @grad, data
using LinearAlgebra
import Batched

Batched.bgemm(A::TrackedArray, B::TrackedArray) = track(Batched.bgemm, A, B)

@grad Batched.bgemm(A, B) = Batched.bgemm(data(A), data(B)), grad -> (Batched.bgemm(data(grad), transpose(data(B))), Batched.bgemm(transpose(data(B)), data(grad)))

Batched.btr(A::TrackedArray) = track(Batched.btr, A)
@grad Batched.btr(X) = Batched.btr(data(X)), grad->(Batched.ScalarIdentity{size(A, 3), size(A, 1)}(grad), )

And it works fine for TrackedArray, but when I use view on a TrackedArray, I get a SubArray of this TrackedArray rather than a SubArray of Array, and then I can’t use BLAS.gemm!

Any idea how can I fix this?

Well, one possible solution might be:

define a gemm! function which converts those SubArray of TrackedArray to SubArray of Array, but is there any more elegant approach?