Trying to understand the "gradient()" function from the Flux.jl/Zygote.jl package

Hi there,

need some help to dig into https://github.com/FluxML/Zygote.jl/blob/master/src/compiler/interface.jl.

So it start like this:

function gradient(f, args...)
  y, back = pullback(f, args...)
  return back(sensitivity(y)) #sensitivity(y::Number) = one(y)
end
function pullback(f, args...)
  y, back = _pullback(f, args...)
  y, Δ -> tailmemaybe(back(Δ)) #tailmemaybe(::Nothing) = nothing, tailmemaybe(x::Tuple) = Base.tail(x)
end
"""
    tail(x::Tuple)::Tuple

Return a `Tuple` consisting of all but the first component of `x`.
exemple:
julia> Base.tail((1,2,3,4))
(2, 3, 4)
_pullback(f, args...) = _pullback(Context(), f, args...)
Context() = Context(nothing, nothing)

mutable struct Context <: AContext
  cache::Union{IdDict{Any,Any},Nothing}
  globals::Union{Dict{GlobalRef,Any},Nothing}
end

So yeah , I just can’t understand what is going on. I don’t understand the “mutable struct Context” and how/why this work.Thanks you in advance!

I am not sure what level your question is at (the details, or the overview), but you might want to look at my lecture notes on differentiable programming and adjoint methods:

https://mitmath.github.io/18337/lecture11/adjoints

The required background is an understanding of reverse-mode AD:

https://mitmath.github.io/18337/lecture10/estimation_identification

which itself has a prereq of forward-mode AD in some sense:

https://mitmath.github.io/18337/lecture9/autodiff_dimensions

With that in mind, you can see the function it’s calling is just doing the nested pullbacks. Then yeah… then there’s some nasty details.

For understanding the context idea, you might want to look at Cassette.jl:

https://jrevels.github.io/Cassette.jl/latest/overdub.html

Zygote doesn’t use Cassette but the overdubbing idea is very similar.

8 Likes

Context is there to allow Zygote to track things to do with the derivative of a function with respect to global variables (rather than it’s inputs).

It’s not normal needed that’s why it defaults to empty

1 Like

Hi Chris thanks you for the extensive technical paper, I will need some time to understand those.

I am easily confuse with all the different method name and interactions.

Could you make a minimalist self-sufficient working example in Julia?Without context and unnecessary tests.

I am not sure if my demand is reasonable or even possible but I am sure with basic code and less math I will feel a bit less “overwhelm” thank you.

I will dig this context Idea later for now I have hard time to understand even the @generated function Idea. Hehe I have a long way :wink:

Mike builds up a small example here: GitHub - MikeInnes/diff-zoo: Differentiation for Hackers

2 Likes