ReverseDiff Incorrect Result

I’m running into an issue using ReverseDiff. It seems to be evaluating my function incorrectly when I prerecord a gradient tape:

using ReverseDiff: GradientTape, GradientConfig, gradient!, compile
using DiffResults: value, GradientResult

# Compute the gradient using a gradient tape and return the value 
function reversediff_tape_eval(func, val)
    f_tape = GradientTape(func, rand(length(val)))
    compiled_f_tape = compile(f_tape)
    results = GradientResult(similar(val))
    gradient!(results, compiled_f_tape, val)
    return DiffResults.value(results)
end

# Compute the gradient without a gradient tape and return the value
function reversediff_eval(func, val)
    results = GradientResult(similar(val))
    gradient!(results, func, val)
    return DiffResults.value(results)
end

# Simple Example
function f1(xs)
    function f2(p)
        s = 0.0
        for x in xs
            if x[1] > p[1]
                s += x[1]
            else
                s += x[2]
            end
        end
        return s
    end
    return f2
end

# Example input
xs = [rand(2) for i=1:5] # e.g. [[0.698, 0.035], [0.828, 0.666], [0.209, 0.384], [0.661, 0.089], [0.394, 0.812]]
a = rand(2) # e.g. [0.705, 0.092]

f = f1(xs)
f(a) # 2.149 Correct
reversediff_eval(f, a) # 2.149 Correct
reversediff_tape_eval(f, a) # 2.966 Incorrect

What’s even more bizarre is that I get a different results every time I evaluate reversediff_tape_eval…

Anyone have any idea what’s going on?

The incorrect result is because the gradient tape only records the operations for a specific branch (if any) determined by its initial input.
In your example, the operations recorded in the tape is totally determined by the initial input in f_tape = GradientTape(func, rand(length(val))), which is a random initial rand(length(val)). The random initial input together with fact that the operation carried out by function f2 has some branching based on the input result the random behavior of reversediff_tape_eval(f, a).

3 Likes

I also made a little example based on your original one to make things more clear:

using ReverseDiff: GradientTape, GradientConfig, gradient!, compile
using DiffResults: value, GradientResult

# Compute the gradient using a gradient tape and return the value
function reversediff_tape_eval(func, val)
    random_val = rand(length(val))
    print(random_val)
    f_tape = GradientTape(func, random_val)
    compiled_f_tape = compile(f_tape)
    results = GradientResult(similar(val))
    gradient!(results, compiled_f_tape, val)
    return DiffResults.value(results)
end

# Compute the gradient without a gradient tape and return the value
function reversediff_eval(func, val)
    results = GradientResult(similar(val))
    gradient!(results, func, val)
    return DiffResults.value(results)
end

# Simple Example
function f1(xs)
    function f2(p)
        s = 0.0
        for x in xs
            if x > p[1]
                s += x
            end
        end
        return s
    end
    return f2
end

# Example input
xs = [0.5, 0.5, 0.5]
a = rand(2)

f = f1(xs)
f(a)
reversediff_eval(f, a)
reversediff_tape_eval(f, a)
## no matter what a is,
## initial input to tape, if first element is smaller than 0.5 like [0.308748, 0.747674] => 1.5
## initial input to tape, if first element is not smaller than 0.5 like [0.804809, 0.21977] => 0.0

So the behavior here is more extreme and it’s not hard to see the connection between the result and the initial input provided to tape.

2 Likes

You can also use the @forward macro from ReverseDiff to forward-differentiate the part of your algorithm which contains the branch. This should result in the tape giving the right answer even when a different branch is taken at run-time: ReverseDiff API - ReverseDiff.jl

3 Likes

Helpful link to the documentation!

The @forward macro seems to only work on functions with scalar input.
I try to make it work on functions with vector input but fail…
Is there any possible way to do this? Thanks!

You can broadcast the function to apply it elementwise. But I’m not sure if you can use it on a function that actually does a vector → vector computation