Hi there,
need some help to dig into https://github.com/FluxML/Zygote.jl/blob/master/src/compiler/interface.jl.
So it start like this:
function gradient(f, args...)
y, back = pullback(f, args...)
return back(sensitivity(y)) #sensitivity(y::Number) = one(y)
end
function pullback(f, args...)
y, back = _pullback(f, args...)
y, Δ -> tailmemaybe(back(Δ)) #tailmemaybe(::Nothing) = nothing, tailmemaybe(x::Tuple) = Base.tail(x)
end
"""
tail(x::Tuple)::Tuple
Return a `Tuple` consisting of all but the first component of `x`.
exemple:
julia> Base.tail((1,2,3,4))
(2, 3, 4)
_pullback(f, args...) = _pullback(Context(), f, args...)
Context() = Context(nothing, nothing)
mutable struct Context <: AContext
cache::Union{IdDict{Any,Any},Nothing}
globals::Union{Dict{GlobalRef,Any},Nothing}
end
So yeah , I just can’t understand what is going on. I don’t understand the “mutable struct Context” and how/why this work.Thanks you in advance!
I am not sure what level your question is at (the details, or the overview), but you might want to look at my lecture notes on differentiable programming and adjoint methods:
https://mitmath.github.io/18337/lecture11/adjoints
The required background is an understanding of reverse-mode AD:
https://mitmath.github.io/18337/lecture10/estimation_identification
which itself has a prereq of forward-mode AD in some sense:
https://mitmath.github.io/18337/lecture9/autodiff_dimensions
With that in mind, you can see the function it’s calling is just doing the nested pullbacks. Then yeah… then there’s some nasty details.
For understanding the context idea, you might want to look at Cassette.jl:
https://jrevels.github.io/Cassette.jl/latest/overdub.html
Zygote doesn’t use Cassette but the overdubbing idea is very similar.
8 Likes
Context
is there to allow Zygote to track things to do with the derivative of a function with respect to global variables (rather than it’s inputs).
It’s not normal needed that’s why it defaults to empty
1 Like
Hi Chris thanks you for the extensive technical paper, I will need some time to understand those.
I am easily confuse with all the different method name and interactions.
Could you make a minimalist self-sufficient working example in Julia?Without context and unnecessary tests.
I am not sure if my demand is reasonable or even possible but I am sure with basic code and less math I will feel a bit less “overwhelm” thank you.
I will dig this context Idea later for now I have hard time to understand even the @generated function
Idea. Hehe I have a long way