Maybe a noob question.
I’ve an old code based on Flux and Tracker, which solves the derivatives of a matrix
using Flux
ann = Chain(Dense(2, 10, tanh), Dense(10, 1))
u(x) = ann(x)
ux(x) = Tracker.forward(u, x)[2](1)[1]
X = zeros(2, 1000)
u(X)
ux(X)
When I try to turn to Zygote.jl, I’m confused about how to use AD solver when the function output isn’t scalar.
How could I rewrite the above code in Zygote framework?
What you write looks to be just ux(x) = gradient(x, x)[1], which is the same on either. But if you need the value too, the equivalent of Tracker.forward is now Zygote.pullback.
I figured it is about the compatibility between Julia v1.3 and v1.2.
When I turned back to v1.2, it worked.
But I’m still confused about the detailed usage of pullback function.
Could you show me the equivalent codes using pullback() as the above case?
Many thanks.
using Flux, Zygote
ann = Chain(Dense(2, 10, tanh), Dense(10, 1))
X = rand(2,1000)
value, backpropagator = Zygote.pullback(ann, X)
sensitivity = ones(size(X)) #Some sensitivity (note: in this case it's not a scalar!)
backpropagator(sensitivity)[1]
? Or, in case you want both the value and gradient of a scalar valued function:
function value_and_gradient(f, x...)
value, back = Zygote.pullback(f, x...)
grad = back(1)[1]
return value, grad
end
using Statistics: mean
ann2 = Chain(Dense(2, 10, tanh), Dense(10, 1), mean) #note: scalar output
value_gradient(ann2, X) # ann2(X), ∇ann2(X)