Hello all,
I would like to differentiate a function with fixed seed but obtain a “can’t differentiate foreign call” error when using Zygote
. Any advise would be appreciated.
The following is a minimal working example and one possible, but for me somewhat limiting, work around.
using Zygote
using Random
# Random.seed!( ) does not work with Zygote. It produces a "can't differentiate foreight call" error
function simulator(x, id::Int64)
Random.seed!(id)
return simulator(x)
end
"This is a work around. The simulator needs to take a rng as input."
function simulator(x, rng::AbstractRNG)
noise1 = randn(rng)
noise2 = randn(rng)
@show noise1
@show noise2
return x+noise1+noise2
end
function simulator(x)
noise1 = randn()
noise2 = randn()
@show noise1
@show noise2
return x+noise1+noise2
end
function distance(sim, obs)
return sum((sim-obs).^2)
end
"This will work"
function loss(x, obsdata, id::Int64)
rng = Xoshiro(id)
sim = simulator(x, rng)
return distance(sim, obsdata)
end
"This won't work"
function loss_with_issue(x, obsdata, id::Int64)
sim = simulator(x, id)
return distance(sim, obsdata)
end
# data
myobs = 2.0;
# to fix the seed
id = 123
# test point
xtest = 3.0
# This works
Zygote.gradient(x->loss(x, myobs, id), xtest)
2*(simulator(xtest, Xoshiro(id))-myobs)
# This throws an error: can't differentiate foreigncall expression"
Zygote.gradient(x->loss_with_issue(x, myobs, id), xtest)
Arguably simulator(x, rng::AbstractRNG)
is cleaner code and may be preferred anyway, but I needed to be able to differentiate my loss also for simulators such as simulator(x)
that do not work with an explicit RNG instance.
Would someone know how to make Zygote
work without having to pass around a RNG instance, i.e. for the loss_with_issue
case?
Many thanks!