Hi,
I have a following problem. I need to define a function, say f(a, b)
, where a
and b
are matrices. For this function, I want to define a gradient for Flux, therefore if one of those arguments are of type TrackedArray
, they needs to be forwarded to Flux.Tracked.track(f, a, b).
Since TrackedArray
is subtype of AbstractArray
, i do not know, how to achieve that if none of those arguments are of type TrackedArray
, the function should be dispatched normally, which if at least one of the is of type TrackedArray
, it is forwarded to track
.
I have ended up defining all four functions (in a loop), but I wonder if there is a neat solution. At the moment, I can also thing about using generated function as well.
Thanks for opinions.