How to create tracked `cumsum` using Flux.jl?

I wanted to create a `cumsum` of tracked values. How do I do that?

I tried the below which failed

``````W = param(rand(12,1))
cumsum(Tracker.data(W), dims = 1) # works
cumsum(W, dims = 1) # this fails
``````

I got it to work using matrix multiplication but that’s not ideal as it’s more cumbersome

``````mw = Matrix{Float64}(undef, 12, 12)
mw .= 1
for i = 1:12
for j = i+1:12
mw[i,j]=0
end
end
mw
mw*W # this works
``````

For the case of a vector, I think this is right:

``````using Flux

Base.cumsum(x::TrackedVector) = track(cumsum, x)

cumsum(x.data), Δ -> ( reverse(cumsum(reverse(Δ))) ,)
end

``````
1 Like

I think I am starting to get it.

But I can’t seem to make it work for the case of an array. See cod ebelow

``````Base.cumsum(x::TrackedArray; dims=1) = track(x->cumsum(x, dims = dims), x)

cumsum(x.data, dims=dims), Δ -> ( reverse(cumsum(reverse(Δ), dims=dims)) ,)
end

``````

There’s a trick to get around keyword arguments, and you need to tell `reverse` what dimension too. I think this is correct now… if you agree, perhaps worth making a Flux PR?

``````using Flux

Base.cumsum(x::TrackedArray; dims=1) = track(cumsum, x, dims)

cumsum(x.data, dims=dims), Δ -> ( reverse(cumsum(reverse(Δ, dims=dims), dims=dims), dims=dims) , nothing)

Tracker.gradcheck(x -> sum(sin, cumsum(x, dims=1)), randn(3))

Tracker.gradcheck(x -> sum(sin, cumsum(x,  dims=1)), randn(3,4))
Tracker.gradcheck(x -> sum(sin, cumsum(x,  dims=2)), randn(3,4))
``````
1 Like

How did you figure out this? Did you look at the source code?

I don’t remember, maybe? Or perhaps from one of @MikeInnes’s posts on here?

Perhaps the docs could use a PR too, to explain this.

Tracking now supports keyword arguments just fine, e.g.. Let me know if you have any issues getting this together (there’s actually also a PR on it that I need to get to).

1 Like

Very possibly. You should be able to `add Flux#master` on JuliaBox if you want the latest.

hey @MikeInnes stupid question: what is the `:` in

``````Base.sum(xs::TrackedArray; dims = :)
``````

?
where to read up on that? thanks.

@floswald, I think it basically means “all dimensions” here.

https://docs.julialang.org/en/v1/base/punctuation/index.html

As a general tip, when googling for this kind of stuff, I find it useful to include the word “julialang”

1 Like

oh yeah, of course. i mean I know of course what `colon()` means, but i never saw it as an argument of a function. but then again, why not?

Also, in general you can find the source with

``````using Flux; methods(sum, (Flux.TrackedArray, ))
``````

or, even better,

``````edit(first(methods(sum, (Flux.TrackedArray, ))))
``````

will open it in your editor directly.