Mooncake.jl with variable sized input

This post is intended as a log for people interested in Mooncake. Similarly, I am interested in feedback.

Background: I am interested in learning from structured data, under which you can think about JSON documents. I am one of the authors of Mill.jl and JsonGrinder.jl packages facilitating training classifiers from data stored in JSONs.

I was following development of Mooncake.jl, because the development of Zygote has stalled and Julia should have an actively developed AD system. One of the issues when processing JSON data is that almost every sample has different sizes and it would be very difficult to pad everything to same size. Therefore I have not considered Reactant.jl (though I play with that).

While Mooncake.jl has the same limitation (when one calls prepare_gradient_cache), I was hoping that since Mooncake.jl is written in Julia, I can lift this requirement.
Yesterday, I decided to give Mooncake.jl try with the help of Claude. First of all, it is important to understand that the cache object returned by prepare_gradient_cache allocates space for arguments of differentiated function and not for intermediate results. From here stems the requirement on fixed arrays. This also means that if these arrays are resized to accomodate new arguments, the cache can be reused. Claude has implemnted function resize_cache which adjust preallocated spaces to new arguments and, it just works. sizes of arguments. The example of use is as follows

f_vec(x) = sum(abs2, x)
cache = prepare_gradient_cache(f_vec, [1.0, 2.0]))
cache = resize_gradient_cache(cache, f_vec, [1.0, 2.0, 3.0])
val, (_, grad) = value_and_gradient!!(cache, f_vec, [1.0, 2.0, 3.0])

The code for resize cache is below

resize_cache

function resize_gradient_cache(cache::Cache, fx…)
specs = getfield(cache, :input_specs)

# Arity check.
length(specs) == length(fx) ||
    _throw_prepared_cache_spec_error(:arity, 0, length(specs), length(fx))

# Type check: the compiled rule is specialised per type; a type mismatch means the
# caller needs a fresh cache from prepare_gradient_cache, not a resize.
for i in 1:length(specs)
    T = typeof(specs[i]).parameters[1]
    typeof(fx[i]) == T || _throw_prepared_cache_spec_error(:type, i, T, typeof(fx[i]))
end

# Fast path: all top-level array sizes already match.
all(1:length(specs)) do i
    !(fx[i] isa AbstractArray) || size(fx[i]) == specs[i].size
end && return cache

# Sizes changed — allocate new gradient buffers.
#
# No warmup pass is needed here. The rule's internal Stacks are at position=0 after
# any completed value_and_gradient!! call (forward pushes, reverse pops balance out).
# prepare_gradient_cache runs a warmup purely to (a) validate the return type and
# (b) pre-size the Stack backing vectors; neither is necessary here because:
#   (a) the return type was already validated when the original cache was built, and
#       prepare_gradient_cache enforces IEEEFloat so output_spec is always scalar.
#   (b) Stack backing vectors grow automatically on the first real call if needed.
tangents = map(zero_tangent, fx)

input_specs = map(fx) do x
    x isa AbstractArray ? PreparedCacheInputSpec(typeof(x), size(x)) :
                          PreparedCacheInputSpec(typeof(x), ())
end
# output_spec is always scalar (IEEEFloat): prepare_gradient_cache rejects anything else.
output_spec = getfield(cache, :output_spec)

dests = isnothing(getfield(cache, :dests)) ? nothing :
        tuple(map(friendly_tangent_cache, fx)...)
return Cache(cache.rule, nothing, tangents, dests, nothing, input_specs, output_spec)

end

This was OK, but then I have asked if there is a way to say that the parameter is constant and the gradient will not be computed. Because this is frequently the type of problem in machine learning, where the gradient is computed with respect to the model, but not with respect to data.

The proposed solution is to implement wrapper with NoTangent. The complete example with unit test is

struct ConstParam{T}
    val::T
end
Mooncake.tangent_type(::Type{<:ConstParam}) = Mooncake.NoTangent

@testset "ConstParam" begin 

    loss(w, x::AbstractArray) = sum(w * x )
    loss(w, x::ConstParam) = loss(w, x.val)

    w = randn(2,2)
    x = randn(2,10)

    cache_10 = prepare_gradient_cache(loss, w, ConstParam(x); config=Mooncake.Config(; friendly_tangents=true))
    x = randn(2,100)
    cache_100 = prepare_gradient_cache(loss, w, x; config=Mooncake.Config(; friendly_tangents=true))
    
    val, (_, ∇w_10, _) = value_and_gradient!!(cache_10, loss, w, ConstParam(x))
    val, (_, ∇w_100, _) = value_and_gradient!!(cache_100, loss, w, x)

    @test ∇w_10 ≈ ∇w_100
end

The above solution is cute, in my view. I have tried it with Mill.jl and it works, though currently Mooncake.jl is slower then Zygote.jl, which I address to the fact that the lib was very optimized from Zygote.jl (by my understanding, there are unnecessary allocation with Mooncake.jl caused by not well written AD rules).

Feedback from Mooncake experts is very appreciated.

1 Like