Using custom gradient function with Optimization.jl

I have implemented a multithreaded version of a scientific ML problem I’m solving. Each thread computes a loss and gradient via Enzyme, and these are recombined by weighted averaging. I’ve verified I’m getting the correct gradient and loss, as well as a modest speedup using up to 16 threads with my full problem.

Now I want to integrate these multithreaded functions into my training loop, which uses Optimization.jl. I’m having a hard time figuring out the right way to define these function calls. In particular, I want to compute the loss and its gradient in a single stroke with Enzyme.ReverseWithPrimal. I’m trying to make this work with Optimization.solve() by caching the gradient when Optimization requests the forward result, and then passing this cached gradient without recomputing when the gradient is requested.

Here is my MWE:

# MWE of multithreaded gradient with Optimization

# import packages
begin
    using Base.Threads
    using Lux
    using Plots
    using JLD2
    using DelimitedFiles
    using Optimization
    using OptimizationOptimisers
    using BenchmarkTools
    using Enzyme
    using ComponentArrays
    using Random
    using Printf
    using Dates
    using Profile
    Random.seed!(1234)
end

# Single-threaded functions
begin # Functions
    function build_model(n_in, n_out, n_layers, n_nodes;
                        act_fun=leakyrelu, last_fun=relu)
        # Input layer
        first_layer = Lux.Dense(n_in, n_nodes, act_fun)

        # Hidden block
        hidden = (Lux.Dense(n_nodes => n_nodes, act_fun) for _ in 1:n_layers)

        # Output layer
        last_layer = Lux.Dense(n_nodes => n_out, last_fun)

        return Chain(first_layer, hidden..., last_layer)
    end

    function eval_model(m, params, st, x, tx, ty)
        x_eval = standardize_to_existing(x, tx)
        M = first(Lux.apply(m, x_eval, params, st))
        Z = denormalize_data(M, ty)
        return Z
    end

    function standardize_to_existing(X, tx)
        # Normalize data or results to the norm of a previous dataset
        # Data points divided columnwise (each row a feature)

        # Make a loop to do this with a buffer
        Xnorm = zeros(eltype(X), size(X))
        means = tx[1]
        stdevs = tx[2]
        num_feats = size(X,1)
        for i in 1:num_feats
            Xnorm[i,:] = (X[i,:] .- means[i]) / stdevs[i]
        end
        return Xnorm
    end

    function denormalize_data(Yn, ty)
        # Data points divided columnwise (each row a feature)
        # Find the absolute maximum of each feature and save these
        maxes = ty

        # Make a loop to do this with a buffer
        Y = zeros(eltype(Yn), size(Yn))
        num_feats = size(Y,1)
        for i in 1:num_feats
            Y[i,:] = Yn[i,:] * maxes[i]
        end
        return Y
    end


    function recursive_convert(T, x)
        if x isa AbstractArray
            return convert.(T, x)  # elementwise convert
        elseif x isa NamedTuple
            return NamedTuple{keys(x)}(recursive_convert(T, v) for v in values(x))
        else
            return x
        end
    end

    function lossDiff(p, args)
        y = args[2]
        dsigs = calculateMultiDiffCrossSections(p, args)
        return mean((y - dsigs).^2)
    end

    function calculateDifferentialCrossSection(U, theta)
        return sum(U)*cos.(theta)
    end

    function calculateMultiDiffCrossSections(p, args)
        X = args[1]
        thetas = args[3]
        M = args[4]
        st = args[5]
        tx = args[6]
        ty = args[7]
        nlen = size(thetas,1)
        rlen = 100
        datalen = 0
        for i in 1:nlen
            datalen += length(thetas[i])
        end
        dσ = zeros(eltype(X), datalen)
        j = 1
        for i in range(1,nlen)
            exp_len = length(thetas[i])
            sig = zeros(eltype(X), exp_len)
            j_next = j + exp_len
            U = eval_model(M, p, st, X[:,(i-1)*rlen+1 : i*rlen], tx, ty)
            sig = calculateDifferentialCrossSection(U, thetas[i])
            dσ[j:j_next-1] = sig
            j = j_next
        end
        return dσ
    end
end


# Functions for multithreading
begin
    function build_args_subset(args, i_start::Int, i_end::Int)
        X, y, thetas, tx, ty, M, st = args

        # ---- 2) X_sub (features × (n_samples * rlen)) ------------------------
        rlen = 100
        col_start = (i_start - 1) * rlen + 1
        col_end   = i_end * rlen
        X_sub = @view X[:, col_start:col_end]   # preserves 4×K orientation

        # ---- 3) thetas_sub (arrays indexed by sample) -----------
        thetas_sub = thetas[i_start:i_end]

        # ---- 4) y_sub (flattened vector of targets for these experiments) ----
        lens = length.(thetas)  # lengths per experiment (full dataset)
        # handle edge cases when i_start == 1
        start_idx = (i_start == 1) ? 1 : sum(lens[1:i_start-1]) + 1
        end_idx   = sum(lens[1:i_end])
        y_sub = copy(@view y[start_idx:end_idx])   # copy to avoid aliasing with global y

        # Defensive check: y_sub length should equal sum(length.(thetas_sub))
        expected = sum(length.(thetas_sub))
        @assert length(y_sub) == expected "y_sub length mismatch: got $(length(y_sub)) expected $expected for i_start=$i_start i_end=$i_end"

        # ---- 5) return args tuple  ------------
        return (X_sub, y_sub, thetas_sub, tx, ty,  M, st)
    end

    function loss_and_grad_threaded(p, args; nbatches = 2)
        theta = args[3]             # vector of vectors (per-experiment angle data)
        nlen = size(theta, 1)

        # ---- Partition experiments into batches ----
        base = div(nlen, nbatches)
        rems = rem(nlen, nbatches)

        starts = Vector{Int}(undef, nbatches)
        ends   = Vector{Int}(undef, nbatches)

        cur = 1
        for b in 1:nbatches
            sz = base + (b <= rems ? 1 : 0)
            starts[b] = cur
            ends[b]   = cur + max(sz - 1, 0)
            cur += sz
        end

        # ---- Storage ----
        grads   = Vector{typeof(p)}(undef, nbatches)
        losses  = zeros(Float32, nbatches)
        weights = zeros(Int, nbatches)  # number of datapoints in each batch

        # ---- Threaded batch computation ----
        @threads for b in 1:nbatches
            i1, i2 = starts[b], ends[b]

            if i2 < i1
                g = similar(p)
                fill!(g, zero(eltype(p)))
                grads[b] = g
                losses[b]  = 0f0
                weights[b] = 0
                continue
            end

            # compute weight = number of datapoints (angle measurements)
            nb = 0
            for j in i1:i2
                nb += length(theta[j])
            end
            weights[b] = nb

            # build local sliced args
            argsb = build_args_subset(args, i1, i2)

            # gradient of mean loss
            GandP = Enzyme.gradient(
                Enzyme.ReverseWithPrimal,
                Enzyme.Const(ps -> lossDiff(ps, argsb)),
                p
            )

            grads[b]  = GandP[1][1]
            losses[b] = float(GandP[2])
        end

        # ---- Weighted reduction ----
        total_points = sum(weights)
        println("weights: $(weights ./ total_points)\n")

        # weighted loss:
        total_loss = sum(losses[b] * weights[b] for b in 1:nbatches) / total_points

        # weighted gradient:
        g_total = similar(p)
        fill!(g_total, zero(eltype(p)))

        for b in 1:nbatches
            if weights[b] == 0
                continue
            end
            # g_total += g_b * weight
            add_inplace_scaled!(g_total, grads[b], weights[b])
        end

        g_total ./= total_points

        return total_loss, g_total
    end

    function add_inplace_scaled!(accum, g, w)
        @inbounds for k in eachindex(accum)
            accum[k] += g[k] * w
        end
        return accum
    end

    mutable struct LGCache
        p_last::Vector{Float32}
        loss::Float32
        grad::Vector{Float32}
        valid::Bool
    end

    function LGCache(p)
        LGCache(copy(p), 0f0, similar(p), false)
    end

    function cached_loss_grad!(cache::LGCache, p, args)
        if cache.valid && p == cache.p_last
            return cache.loss, cache.grad
        end
        # recompute
        loss, grad = loss_and_grad_threaded(p, args)
        cache.loss = loss
        copyto!(cache.grad, grad)
        copyto!(cache.p_last, p)
        cache.valid = true
        return loss, grad
    end

    loss_only(p, args, cache) = cached_loss_grad!(cache, p, args)[1]

    function grad_only!(G, p, args, cache)
        _, grad = cached_loss_grad!(cache, p, args)
        copyto!(G, grad)     # fill G with the cached gradient
    end

end 

data_type = Float32

# Generate dummy data
XSdiff_train = rand(data_type, 9)
theta_train = [
    [1.0, 10.0, 20.0, 45.0, 60],
    [5.0, 15.0, 25.0, 60]
]
X_train = rand(data_type, 4,200)
tx = (rand(data_type,4,1),rand(data_type,4,1))
ty = rand(data_type, 3, 1)

# Load a model
nlayers = 2
nnodes = 16
model = build_model(4, 3, nlayers, nnodes)
ps, st = f32(Lux.setup(Random.default_rng(), model))
p = ComponentArray(recursive_convert(data_type, ps))
const _st = st

args = (X_train, XSdiff_train, theta_train, model, _st, tx, ty)

# Test loss function evaluation
# losstest = lossDiff(p, args)
# gradtest = Enzyme.gradient(Enzyme.Reverse, ps -> lossDiff(ps, args),p)[1]
# println("Single threaded loss: $losstest")
# println("Single threaded grad 1:5: $(gradtest[1:5])")

# # Test multithreaded versions
# lossmulti, gradmulti = loss_and_grad_threaded(p, args; nbatches=2)
# println("Multithreaded loss: $lossmulti")
# println("Multithreaded grad 1:5: $(gradmulti[1:5])")




# Controls for training loop and saving/printing
max_iters = 10
learning_rate = 0.0001
global time = Dates.now()

# Callback function: prints and saves during training
callback = function (state,l)
    iter = state.iter
    time_elapsed = Dates.seconds(Dates.now() - time)
    print("Iteration: $iter || Loss: $l || Time: $time_elapsed s")
    global time = Dates.now()
    println("")
    retcode=false
    flush(stdout)
    return retcode
end


# 1. Create a structure or NamedTuple to hold your static data/cache
cache = LGCache(p)

# 2. Redefine the functions to accept only the optimization variables (u)
# and the static parameters (p)
gradfun(G, u, args) = grad_only!(G, u, args, cache)

# 3. Create the OptimizationFunction
optf = OptimizationFunction(
    (p, args) -> loss_only(p, args, cache);
    grad=gradfun
)


prob = OptimizationProblem(optf, p, args)

res = Optimization.solve(prob, OptimizationOptimisers.Adam(learning_rate),
                         callback = callback, maxiters = max_iters)

And the error:

LoadError: MethodError: no method matching (::var"#199#200")(::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}})
The function `#199` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  (::var"#199#200")(::Any, ::Any)
   @ Main ~/arvind/cachedgrad_mwe.jl:338

Stacktrace:
 [1] (::OptimizationFunction{…})(args::ComponentVector{…})
   @ SciMLBase ~/.julia/packages/SciMLBase/9sEkh/src/scimlfunctions.jl:4253
 [2] macro expansion
   @ ~/.julia/packages/OptimizationOptimisers/onzuO/src/OptimizationOptimisers.jl:117 [inlined]
 [3] macro expansion
   @ ~/.julia/packages/Optimization/PwRNQ/src/utils.jl:32 [inlined]
 [4] __solve(cache::OptimizationCache{…})
   @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/onzuO/src/OptimizationOptimisers.jl:93
 [5] solve!
   @ ~/.julia/packages/SciMLBase/9sEkh/src/solve.jl:234 [inlined]
 [6] #solve#749
   @ ~/.julia/packages/SciMLBase/9sEkh/src/solve.jl:131 [inlined]
 [7] top-level scope
   @ ~/arvind/cachedgrad_mwe.jl:345
 [8] include(fname::String)
   @ Main ./sysimg.jl:38
 [9] top-level scope
   @ REPL[4]:1
in expression starting at /home/daningburg/arvind/cachedgrad_mwe.jl:345
Some type information was truncated. Use `show(err)` to see complete types.

Keep in mind I’m using an older version of Enzyme due to an LLVM error that cropped up for me a few months ago. Version info:
Julia v1.11.6
Enzyme v0.13.68
Optimization v4.6.0