Autodiff with Zygote: issues with setting seeds

Hello all,

I would like to differentiate a function with fixed seed but obtain a “can’t differentiate foreign call” error when using Zygote. Any advise would be appreciated.

The following is a minimal working example and one possible, but for me somewhat limiting, work around.

using Zygote
using Random

# Random.seed!( ) does not work with Zygote. It produces a "can't differentiate foreight call" error
function simulator(x, id::Int64)
    Random.seed!(id)
    return simulator(x)
end

"This is a work around. The simulator needs to take a rng as input."
function simulator(x, rng::AbstractRNG)
    noise1 = randn(rng)
    noise2 = randn(rng)
    @show noise1
    @show noise2
    return x+noise1+noise2
end

function simulator(x)
    noise1 = randn()
    noise2 = randn()
    @show noise1
    @show noise2
    return x+noise1+noise2
end

function distance(sim, obs)
    return sum((sim-obs).^2)
end

"This will work"
function loss(x, obsdata, id::Int64)
    rng = Xoshiro(id) 
    sim = simulator(x, rng)
    return distance(sim, obsdata)
end

"This won't work"
function loss_with_issue(x, obsdata, id::Int64)
    sim = simulator(x, id)
    return distance(sim, obsdata)
end

# data
myobs = 2.0;

# to fix the seed
id = 123

# test point
xtest = 3.0

# This works
Zygote.gradient(x->loss(x, myobs, id), xtest)
2*(simulator(xtest, Xoshiro(id))-myobs)

# This throws an error: can't differentiate foreigncall expression"
Zygote.gradient(x->loss_with_issue(x, myobs, id), xtest)

Arguably simulator(x, rng::AbstractRNG) is cleaner code and may be preferred anyway, but I needed to be able to differentiate my loss also for simulators such as simulator(x) that do not work with an explicit RNG instance.

Would someone know how to make Zygote work without having to pass around a RNG instance, i.e. for the loss_with_issue case?

Many thanks!

I think you can just tell Zygote not to look inside that function, like so:

julia> Zygote.gradient(x->loss_with_issue(x, myobs, id), xtest)
noise1 = -0.6457306721039767
noise2 = -1.4632513788889214
ERROR: Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_get_current_task), Ref{Task}, svec(), 0, :(:ccall))).
Stacktrace:
...
  [4] setstate!
    @ /Applications/Julia-1.10.app/Contents/Resources/julia/share/julia/stdlib/v1.10/Random/src/Xoshiro.jl:132 [inlined]

julia> function simulator(x, id::Int64)
           Zygote.@ignore Random.seed!(id)
           return simulator(x)
       end
simulator (generic function with 3 methods)

julia> Zygote.gradient(x->loss_with_issue(x, myobs, id), xtest)
noise1 = -0.6457306721039767
noise2 = -1.4632513788889214
(-2.2179641019857965,)

I believe that could be made permanent by a one-line PR here.

2 Likes

Tangential remark: why do you want to differentate a function that returns random values? Autodiff engines are not designed to deal with such situations by default, so you might obtain unexpected (and backend-dependent) results

Thank you very much. That indeed resolves the issue.

The motivation for this is the implementation of a statistical inference procedure that works by fixing the seed of the stochastic generative model (the simulator). Details about the method would be here.