I have a batched gemm and batched tr: https://github.com/Roger-luo/Batched.jl/blob/master/src/gemm.jl
Now I want to implement the gradient as:
using Flux.Tracker
using Flux.Tracker: TrackedArray, track, @grad, data
using LinearAlgebra
import Batched
Batched.bgemm(A::TrackedArray, B::TrackedArray) = track(Batched.bgemm, A, B)
@grad Batched.bgemm(A, B) = Batched.bgemm(data(A), data(B)), grad -> (Batched.bgemm(data(grad), transpose(data(B))), Batched.bgemm(transpose(data(B)), data(grad)))
Batched.btr(A::TrackedArray) = track(Batched.btr, A)
@grad Batched.btr(X) = Batched.btr(data(X)), grad->(Batched.ScalarIdentity{size(A, 3), size(A, 1)}(grad), )
And it works fine for TrackedArray
, but when I use view
on a TrackedArray
, I get a SubArray
of this TrackedArray
rather than a SubArray
of Array
, and then I can’t use BLAS.gemm!
Any idea how can I fix this?