I’m trying to pass a gradient through a remote call with Zygote. The server seems to be passing the gradient back properly, the server decodes and returns it from the adjoint, but the final gradient being returned is coming back empty.
using Zygote
using Zygote: @adjoint
using HTTP
using JSON3
using Sockets
using NamedTupleTools
mutable struct Data
x::Float64
y::Float64
Data() = new()
Data(x, y) = new(x,y)
end
struct RemoteEnv
url::String
end
recursive_namedtuple(x::Any) = x
recursive_namedtuple(d::Dict) = namedtuple(Dict(k => recursive_namedtuple(v) for (k, v) in d))
JSON3.StructType(::Type{Data}) = JSON3.Mutable()
JSON3.StructType(::Type{Base.RefValue{Any}}) = JSON3.Mutable()
function RemoteGradient(r::RemoteEnv, d::Data)
url = string(r.url, "/gradient")
resp = HTTP.request("POST", url, [], JSON3.write(d))
g = JSON3.read(resp.body, Tuple{Base.RefValue{Any}})
println(g)
return g
end
RemoteScore(r::RemoteEnv, d::Data)::Float64 = 0.
@adjoint function RemoteScore(remote::RemoteEnv, d::Data)
return 0., (v) -> (nothing, nothing, RemoteGradient(remote, d)...)
end
env = RemoteEnv("http://score-server2:3000")
println(gradient((v) -> RemoteScore(env, Data(v, v)), 1.))
^^ Returns (nothing,)
. The remote gradient data seems to return properly.
using JSON3
using HTTP
using Sockets
using Zygote
mutable struct Data
x::Float64
y::Float64
Data() = new()
end
JSON3.StructType(::Type{Data}) = JSON3.Mutable()
JSON3.StructType(::Type{Base.RefValue{Any}}) = JSON3.Mutable()
corsHeaders = [
"Access-Control-Allow-Origin" => "*",
"Access-Control-Allow-Methods" => "POST, GET, OPTIONS, DELETE, PUT",
"Access-Control-Allow-Headers" => "*"
]
function gradientEndpoint(req::HTTP.Request)
@show req
d = JSON3.read(IOBuffer(HTTP.payload(req)), Data)
gs = gradient((d) -> d.x+d.y, d)
return HTTP.Response(200, corsHeaders, body=JSON3.write(gs))
end
function corsOptions(req::HTTP.Request)
return HTTP.Response(200, corsHeaders)
end
const OPTIMIZE_ROUTER = HTTP.Router()
HTTP.@register(OPTIMIZE_ROUTER, "OPTIONS", "/gradient", corsOptions)
HTTP.@register(OPTIMIZE_ROUTER, "POST", "/gradient", gradientEndpoint)
HTTP.serve(OPTIMIZE_ROUTER, ip"0.0.0.0", parse(Int, ENV["PORT"]))
I’m also open to alternative approaches, but would like 2 separate applications where the gradient flows from one to the other.