How to create tracked `cumsum` using Flux.jl?

I wanted to create a cumsum of tracked values. How do I do that?

I tried the below which failed

W = param(rand(12,1))
cumsum(, dims = 1) # works
cumsum(W, dims = 1) # this fails

I got it to work using matrix multiplication but that’s not ideal as it’s more cumbersome

mw = Matrix{Float64}(undef, 12, 12)
mw .= 1
for i = 1:12
    for j = i+1:12
mw*W # this works

For the case of a vector, I think this is right:

using Flux
using Flux.Tracker: TrackedVector, @grad, track

Base.cumsum(x::TrackedVector) = track(cumsum, x)

@grad function cumsum(x::TrackedVector)
    cumsum(, Δ -> ( reverse(cumsum(reverse(Δ))) ,)
Tracker.gradcheck(x -> sum(sin, cumsum(x)), randn(3))
1 Like

I think I am starting to get it.

But I can’t seem to make it work for the case of an array. See cod ebelow

Base.cumsum(x::TrackedArray; dims=1) = track(x->cumsum(x, dims = dims), x)

@grad function cumsum(x::TrackedArray; dims=1)
    cumsum(, dims=dims), Δ -> ( reverse(cumsum(reverse(Δ), dims=dims)) ,)
Tracker.gradcheck(x -> sum(cumsum(x,  dims=1)), randn(3,1))

There’s a trick to get around keyword arguments, and you need to tell reverse what dimension too. I think this is correct now… if you agree, perhaps worth making a Flux PR?

using Flux
using Flux.Tracker: TrackedArray, @grad, track

Base.cumsum(x::TrackedArray; dims=1) = track(cumsum, x, dims)

@grad cumsum(x::TrackedArray, dims) = 
    cumsum(, dims=dims), Δ -> ( reverse(cumsum(reverse(Δ, dims=dims), dims=dims), dims=dims) , nothing)
Tracker.gradcheck(x -> sum(sin, cumsum(x)), randn(3))
Tracker.gradcheck(x -> sum(sin, cumsum(x, dims=1)), randn(3))

Tracker.gradcheck(x -> sum(sin, cumsum(x,  dims=1)), randn(3,4))
Tracker.gradcheck(x -> sum(sin, cumsum(x,  dims=2)), randn(3,4))
1 Like

How did you figure out this? Did you look at the source code?

I don’t remember, maybe? Or perhaps from one of @MikeInnes’s posts on here?

Perhaps the docs could use a PR too, to explain this.

Tracking now supports keyword arguments just fine, e.g.. Let me know if you have any issues getting this together (there’s actually also a PR on it that I need to get to).

1 Like

I wad doing it on Juliabox so not using latest version. Mayeb thats why?

Very possibly. You should be able to add Flux#master on JuliaBox if you want the latest.

hey @MikeInnes stupid question: what is the : in

Base.sum(xs::TrackedArray; dims = :) 

where to read up on that? thanks.

@floswald, I think it basically means “all dimensions” here.

As a general tip, when googling for this kind of stuff, I find it useful to include the word “julialang”

1 Like

oh yeah, of course. i mean I know of course what colon() means, but i never saw it as an argument of a function. but then again, why not? :slight_smile:

Also, in general you can find the source with

using Flux; methods(sum, (Flux.TrackedArray, ))

or, even better,

edit(first(methods(sum, (Flux.TrackedArray, ))))

will open it in your editor directly.