Problem defining a Straight Through Estimator gradient


#1

Hi all,

I am trying to define a custom STE gradient in Flux. The activation function is the sign function, and its gradient should just be the incoming gradient as is, iff its absolute value is <= 1, or cancelled otherwise. The code I have is below

binarize(x) = x>=0 ? true : false

binarize(x::Flux.Tracker.TrackedReal) = Flux.Tracker.track(binarize, x)

@grad function binarize(x)
    return binarize(Flux.Tracker.data(x)), Δ -> (abs(x) <= 1 ? Δ : 0, )
end

For a random 5x1 array a, i do the following

>> a= param(randn(5))
>> Tracked 5-element Array{Float64,1}:
 -0.3605564089879154
 -0.7853512499733902
  0.8102988051980005
 -0.9715952052917924
 -1.276343849200165 

>> c=binarize.(a)
>> 5-element BitArray{1}:
 false
 false
  true
 false
 false

>> Tracker.back!(c, [1,1,1,1,1])
>> a.grad
>> 5-element Array{Float64,1}:
 0.0
 0.0
 0.0
 0.0
 0.0

There are a couple problems here. First of all I’d expect the gradient to be a TrackedArray.
But more importantly, the returned gradient should not be all zeroes, but ones except for the last element right?

What am I doing wrong here?

Thanks,