A collegue who was contributing to the internals of Zygote wanted to understand what was going on with
is_kwfunc(sigt...)
Determines if `sigt` is the type signature of a kwfunction.
Each element of `sigt` should be a type.
Either the first 3 types are a kwfunc type, a NamedTuple and the matching base function type,
or the first argument is the base function type and it is not a kwfunction.
the remaining types in `sigt` are the types of the argument.
"""
is_kwfunc(::Vararg{Any}) = false
is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f)
"""
wrap_chainrules_output(x)
Convert `x` from the differentials types ChainRules uses to the format Zygote uses internally.
"""
@inline wrap_chainrules_output(x) = unthunk(x) # For now we are just not going to deal with thunks
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing.
and
gradtuplekw = isclosure ? gradtuple2 : gradtuple3
adj = @q @inline ZygoteRules.adjoint($(fargs...)) where $(Ts...) = $(esc(body))
quote
$adj
@inline function ZygoteRules._pullback($cx, $f::$T, $(args...)) where $(Ts...)
y, _back = adjoint(__context__, $f, $(argnames...))
$(mut ? nothing : :(back(::Nothing) = nothing))
back(Δ) = $gradtuple(_back(Δ))
return y, back
end
@inline function ZygoteRules._pullback($cx, ::$kT, kw, $f::$T, $(args...)) where $(Ts...)
y, _back = adjoint(__context__, $f, $(argnames...); kw...)
$(mut ? nothing : :(back(::Nothing) = nothing))
back(Δ) = $gradtuplekw(_back(Δ))
return y, back
end
nothing
end
end
macro adjoint(ex)
I gave a bit of an explination of f(x, kwargs...)
lowering to f(::Core.kwfunc(typeof(f)), kwargs, x)
and the kwfunc function handling sorting the order of the keywords and filling in missing ones.
But I don’t fully understand it myself.
Where should someone go looking if they wanted to understand what is going on with these?
5 Likes
There is a short section about this in the devdocs:
1 Like