xiaodai
November 20, 2018, 1:57pm
1
I wanted to create a cumsum
of tracked values. How do I do that?
I tried the below which failed
W = param(rand(12,1))
cumsum(Tracker.data(W), dims = 1) # works
cumsum(W, dims = 1) # this fails
I got it to work using matrix multiplication but that’s not ideal as it’s more cumbersome
mw = Matrix{Float64}(undef, 12, 12)
mw .= 1
for i = 1:12
for j = i+1:12
mw[i,j]=0
end
end
mw
mw*W # this works
For the case of a vector, I think this is right:
using Flux
using Flux.Tracker: TrackedVector, @grad, track
Base.cumsum(x::TrackedVector) = track(cumsum, x)
@grad function cumsum(x::TrackedVector)
cumsum(x.data), Δ -> ( reverse(cumsum(reverse(Δ))) ,)
end
Tracker.gradcheck(x -> sum(sin, cumsum(x)), randn(3))
1 Like
xiaodai
November 20, 2018, 9:09pm
3
I think I am starting to get it.
But I can’t seem to make it work for the case of an array. See cod ebelow
Base.cumsum(x::TrackedArray; dims=1) = track(x->cumsum(x, dims = dims), x)
@grad function cumsum(x::TrackedArray; dims=1)
cumsum(x.data, dims=dims), Δ -> ( reverse(cumsum(reverse(Δ), dims=dims)) ,)
end
Tracker.gradcheck(x -> sum(cumsum(x, dims=1)), randn(3,1))
There’s a trick to get around keyword arguments, and you need to tell reverse
what dimension too. I think this is correct now… if you agree, perhaps worth making a Flux PR?
using Flux
using Flux.Tracker: TrackedArray, @grad, track
Base.cumsum(x::TrackedArray; dims=1) = track(cumsum, x, dims)
@grad cumsum(x::TrackedArray, dims) =
cumsum(x.data, dims=dims), Δ -> ( reverse(cumsum(reverse(Δ, dims=dims), dims=dims), dims=dims) , nothing)
Tracker.gradcheck(x -> sum(sin, cumsum(x)), randn(3))
Tracker.gradcheck(x -> sum(sin, cumsum(x, dims=1)), randn(3))
Tracker.gradcheck(x -> sum(sin, cumsum(x, dims=1)), randn(3,4))
Tracker.gradcheck(x -> sum(sin, cumsum(x, dims=2)), randn(3,4))
1 Like
xiaodai
November 21, 2018, 8:59am
5
How did you figure out this? Did you look at the source code?
I don’t remember, maybe? Or perhaps from one of @MikeInnes ’s posts on here?
Perhaps the docs could use a PR too, to explain this.
Tracking now supports keyword arguments just fine, e.g. . Let me know if you have any issues getting this together (there’s actually also a PR on it that I need to get to).
1 Like
xiaodai
November 21, 2018, 10:14pm
8
I wad doing it on Juliabox so not using latest version. Mayeb thats why?
Very possibly. You should be able to add Flux#master
on JuliaBox if you want the latest.
hey @MikeInnes stupid question: what is the :
in
Base.sum(xs::TrackedArray; dims = :)
?
where to read up on that? thanks.
@floswald , I think it basically means “all dimensions” here.
https://docs.julialang.org/en/v1/base/punctuation/index.html
As a general tip, when googling for this kind of stuff, I find it useful to include the word “julialang”
https://www.google.com/search?q=julialang+colon+symbol
1 Like
oh yeah, of course. i mean I know of course what colon()
means, but i never saw it as an argument of a function. but then again, why not?
Also, in general you can find the source with
using Flux; methods(sum, (Flux.TrackedArray, ))
or, even better,
edit(first(methods(sum, (Flux.TrackedArray, ))))
will open it in your editor directly.