How to create tracked `cumsum` using Flux.jl?

first-steps
flux

#1

I wanted to create a cumsum of tracked values. How do I do that?

I tried the below which failed

W = param(rand(12,1))
cumsum(Tracker.data(W), dims = 1) # works
cumsum(W, dims = 1) # this fails

I got it to work using matrix multiplication but that’s not ideal as it’s more cumbersome

mw = Matrix{Float64}(undef, 12, 12)
mw .= 1
for i = 1:12
    for j = i+1:12
        mw[i,j]=0
    end
end
mw
mw*W # this works

#2

For the case of a vector, I think this is right:

using Flux
using Flux.Tracker: TrackedVector, @grad, track

Base.cumsum(x::TrackedVector) = track(cumsum, x)

@grad function cumsum(x::TrackedVector)
    cumsum(x.data), Δ -> ( reverse(cumsum(reverse(Δ))) ,)
end
    
Tracker.gradcheck(x -> sum(sin, cumsum(x)), randn(3))

#3

I think I am starting to get it.

But I can’t seem to make it work for the case of an array. See cod ebelow

Base.cumsum(x::TrackedArray; dims=1) = track(x->cumsum(x, dims = dims), x)

@grad function cumsum(x::TrackedArray; dims=1)
    cumsum(x.data, dims=dims), Δ -> ( reverse(cumsum(reverse(Δ), dims=dims)) ,)
end
    
Tracker.gradcheck(x -> sum(cumsum(x,  dims=1)), randn(3,1))

#4

There’s a trick to get around keyword arguments, and you need to tell reverse what dimension too. I think this is correct now… if you agree, perhaps worth making a Flux PR?

using Flux
using Flux.Tracker: TrackedArray, @grad, track

Base.cumsum(x::TrackedArray; dims=1) = track(cumsum, x, dims)

@grad cumsum(x::TrackedArray, dims) = 
    cumsum(x.data, dims=dims), Δ -> ( reverse(cumsum(reverse(Δ, dims=dims), dims=dims), dims=dims) , nothing)
    
Tracker.gradcheck(x -> sum(sin, cumsum(x)), randn(3))
Tracker.gradcheck(x -> sum(sin, cumsum(x, dims=1)), randn(3))

Tracker.gradcheck(x -> sum(sin, cumsum(x,  dims=1)), randn(3,4))
Tracker.gradcheck(x -> sum(sin, cumsum(x,  dims=2)), randn(3,4))

#5

How did you figure out this? Did you look at the source code?


#6

I don’t remember, maybe? Or perhaps from one of @MikeInnes’s posts on here?

Perhaps the docs could use a PR too, to explain this.


#7

Tracking now supports keyword arguments just fine, e.g.. Let me know if you have any issues getting this together (there’s actually also a PR on it that I need to get to).


#8

I wad doing it on Juliabox so not using latest version. Mayeb thats why?


#9

Very possibly. You should be able to add Flux#master on JuliaBox if you want the latest.


#10

hey @MikeInnes stupid question: what is the : in

Base.sum(xs::TrackedArray; dims = :) 

?
where to read up on that? thanks.


#11

@floswald, I think it basically means “all dimensions” here.

https://docs.julialang.org/en/v1/base/punctuation/index.html

As a general tip, when googling for this kind of stuff, I find it useful to include the word “julialang”

https://www.google.com/search?q=julialang+colon+symbol


#12

oh yeah, of course. i mean I know of course what colon() means, but i never saw it as an argument of a function. but then again, why not? :slight_smile:


#13

Also, in general you can find the source with

using Flux; methods(sum, (Flux.TrackedArray, ))

or, even better,

edit(first(methods(sum, (Flux.TrackedArray, ))))

will open it in your editor directly.