I have implemented a multithreaded version of a scientific ML problem I’m solving. Each thread computes a loss and gradient via Enzyme, and these are recombined by weighted averaging. I’ve verified I’m getting the correct gradient and loss, as well as a modest speedup using up to 16 threads with my full problem.
Now I want to integrate these multithreaded functions into my training loop, which uses Optimization.jl. I’m having a hard time figuring out the right way to define these function calls. In particular, I want to compute the loss and its gradient in a single stroke with Enzyme.ReverseWithPrimal. I’m trying to make this work with Optimization.solve() by caching the gradient when Optimization requests the forward result, and then passing this cached gradient without recomputing when the gradient is requested.
Here is my MWE:
# MWE of multithreaded gradient with Optimization
# import packages
begin
using Base.Threads
using Lux
using Plots
using JLD2
using DelimitedFiles
using Optimization
using OptimizationOptimisers
using BenchmarkTools
using Enzyme
using ComponentArrays
using Random
using Printf
using Dates
using Profile
Random.seed!(1234)
end
# Single-threaded functions
begin # Functions
function build_model(n_in, n_out, n_layers, n_nodes;
act_fun=leakyrelu, last_fun=relu)
# Input layer
first_layer = Lux.Dense(n_in, n_nodes, act_fun)
# Hidden block
hidden = (Lux.Dense(n_nodes => n_nodes, act_fun) for _ in 1:n_layers)
# Output layer
last_layer = Lux.Dense(n_nodes => n_out, last_fun)
return Chain(first_layer, hidden..., last_layer)
end
function eval_model(m, params, st, x, tx, ty)
x_eval = standardize_to_existing(x, tx)
M = first(Lux.apply(m, x_eval, params, st))
Z = denormalize_data(M, ty)
return Z
end
function standardize_to_existing(X, tx)
# Normalize data or results to the norm of a previous dataset
# Data points divided columnwise (each row a feature)
# Make a loop to do this with a buffer
Xnorm = zeros(eltype(X), size(X))
means = tx[1]
stdevs = tx[2]
num_feats = size(X,1)
for i in 1:num_feats
Xnorm[i,:] = (X[i,:] .- means[i]) / stdevs[i]
end
return Xnorm
end
function denormalize_data(Yn, ty)
# Data points divided columnwise (each row a feature)
# Find the absolute maximum of each feature and save these
maxes = ty
# Make a loop to do this with a buffer
Y = zeros(eltype(Yn), size(Yn))
num_feats = size(Y,1)
for i in 1:num_feats
Y[i,:] = Yn[i,:] * maxes[i]
end
return Y
end
function recursive_convert(T, x)
if x isa AbstractArray
return convert.(T, x) # elementwise convert
elseif x isa NamedTuple
return NamedTuple{keys(x)}(recursive_convert(T, v) for v in values(x))
else
return x
end
end
function lossDiff(p, args)
y = args[2]
dsigs = calculateMultiDiffCrossSections(p, args)
return mean((y - dsigs).^2)
end
function calculateDifferentialCrossSection(U, theta)
return sum(U)*cos.(theta)
end
function calculateMultiDiffCrossSections(p, args)
X = args[1]
thetas = args[3]
M = args[4]
st = args[5]
tx = args[6]
ty = args[7]
nlen = size(thetas,1)
rlen = 100
datalen = 0
for i in 1:nlen
datalen += length(thetas[i])
end
dσ = zeros(eltype(X), datalen)
j = 1
for i in range(1,nlen)
exp_len = length(thetas[i])
sig = zeros(eltype(X), exp_len)
j_next = j + exp_len
U = eval_model(M, p, st, X[:,(i-1)*rlen+1 : i*rlen], tx, ty)
sig = calculateDifferentialCrossSection(U, thetas[i])
dσ[j:j_next-1] = sig
j = j_next
end
return dσ
end
end
# Functions for multithreading
begin
function build_args_subset(args, i_start::Int, i_end::Int)
X, y, thetas, tx, ty, M, st = args
# ---- 2) X_sub (features × (n_samples * rlen)) ------------------------
rlen = 100
col_start = (i_start - 1) * rlen + 1
col_end = i_end * rlen
X_sub = @view X[:, col_start:col_end] # preserves 4×K orientation
# ---- 3) thetas_sub (arrays indexed by sample) -----------
thetas_sub = thetas[i_start:i_end]
# ---- 4) y_sub (flattened vector of targets for these experiments) ----
lens = length.(thetas) # lengths per experiment (full dataset)
# handle edge cases when i_start == 1
start_idx = (i_start == 1) ? 1 : sum(lens[1:i_start-1]) + 1
end_idx = sum(lens[1:i_end])
y_sub = copy(@view y[start_idx:end_idx]) # copy to avoid aliasing with global y
# Defensive check: y_sub length should equal sum(length.(thetas_sub))
expected = sum(length.(thetas_sub))
@assert length(y_sub) == expected "y_sub length mismatch: got $(length(y_sub)) expected $expected for i_start=$i_start i_end=$i_end"
# ---- 5) return args tuple ------------
return (X_sub, y_sub, thetas_sub, tx, ty, M, st)
end
function loss_and_grad_threaded(p, args; nbatches = 2)
theta = args[3] # vector of vectors (per-experiment angle data)
nlen = size(theta, 1)
# ---- Partition experiments into batches ----
base = div(nlen, nbatches)
rems = rem(nlen, nbatches)
starts = Vector{Int}(undef, nbatches)
ends = Vector{Int}(undef, nbatches)
cur = 1
for b in 1:nbatches
sz = base + (b <= rems ? 1 : 0)
starts[b] = cur
ends[b] = cur + max(sz - 1, 0)
cur += sz
end
# ---- Storage ----
grads = Vector{typeof(p)}(undef, nbatches)
losses = zeros(Float32, nbatches)
weights = zeros(Int, nbatches) # number of datapoints in each batch
# ---- Threaded batch computation ----
@threads for b in 1:nbatches
i1, i2 = starts[b], ends[b]
if i2 < i1
g = similar(p)
fill!(g, zero(eltype(p)))
grads[b] = g
losses[b] = 0f0
weights[b] = 0
continue
end
# compute weight = number of datapoints (angle measurements)
nb = 0
for j in i1:i2
nb += length(theta[j])
end
weights[b] = nb
# build local sliced args
argsb = build_args_subset(args, i1, i2)
# gradient of mean loss
GandP = Enzyme.gradient(
Enzyme.ReverseWithPrimal,
Enzyme.Const(ps -> lossDiff(ps, argsb)),
p
)
grads[b] = GandP[1][1]
losses[b] = float(GandP[2])
end
# ---- Weighted reduction ----
total_points = sum(weights)
println("weights: $(weights ./ total_points)\n")
# weighted loss:
total_loss = sum(losses[b] * weights[b] for b in 1:nbatches) / total_points
# weighted gradient:
g_total = similar(p)
fill!(g_total, zero(eltype(p)))
for b in 1:nbatches
if weights[b] == 0
continue
end
# g_total += g_b * weight
add_inplace_scaled!(g_total, grads[b], weights[b])
end
g_total ./= total_points
return total_loss, g_total
end
function add_inplace_scaled!(accum, g, w)
@inbounds for k in eachindex(accum)
accum[k] += g[k] * w
end
return accum
end
mutable struct LGCache
p_last::Vector{Float32}
loss::Float32
grad::Vector{Float32}
valid::Bool
end
function LGCache(p)
LGCache(copy(p), 0f0, similar(p), false)
end
function cached_loss_grad!(cache::LGCache, p, args)
if cache.valid && p == cache.p_last
return cache.loss, cache.grad
end
# recompute
loss, grad = loss_and_grad_threaded(p, args)
cache.loss = loss
copyto!(cache.grad, grad)
copyto!(cache.p_last, p)
cache.valid = true
return loss, grad
end
loss_only(p, args, cache) = cached_loss_grad!(cache, p, args)[1]
function grad_only!(G, p, args, cache)
_, grad = cached_loss_grad!(cache, p, args)
copyto!(G, grad) # fill G with the cached gradient
end
end
data_type = Float32
# Generate dummy data
XSdiff_train = rand(data_type, 9)
theta_train = [
[1.0, 10.0, 20.0, 45.0, 60],
[5.0, 15.0, 25.0, 60]
]
X_train = rand(data_type, 4,200)
tx = (rand(data_type,4,1),rand(data_type,4,1))
ty = rand(data_type, 3, 1)
# Load a model
nlayers = 2
nnodes = 16
model = build_model(4, 3, nlayers, nnodes)
ps, st = f32(Lux.setup(Random.default_rng(), model))
p = ComponentArray(recursive_convert(data_type, ps))
const _st = st
args = (X_train, XSdiff_train, theta_train, model, _st, tx, ty)
# Test loss function evaluation
# losstest = lossDiff(p, args)
# gradtest = Enzyme.gradient(Enzyme.Reverse, ps -> lossDiff(ps, args),p)[1]
# println("Single threaded loss: $losstest")
# println("Single threaded grad 1:5: $(gradtest[1:5])")
# # Test multithreaded versions
# lossmulti, gradmulti = loss_and_grad_threaded(p, args; nbatches=2)
# println("Multithreaded loss: $lossmulti")
# println("Multithreaded grad 1:5: $(gradmulti[1:5])")
# Controls for training loop and saving/printing
max_iters = 10
learning_rate = 0.0001
global time = Dates.now()
# Callback function: prints and saves during training
callback = function (state,l)
iter = state.iter
time_elapsed = Dates.seconds(Dates.now() - time)
print("Iteration: $iter || Loss: $l || Time: $time_elapsed s")
global time = Dates.now()
println("")
retcode=false
flush(stdout)
return retcode
end
# 1. Create a structure or NamedTuple to hold your static data/cache
cache = LGCache(p)
# 2. Redefine the functions to accept only the optimization variables (u)
# and the static parameters (p)
gradfun(G, u, args) = grad_only!(G, u, args, cache)
# 3. Create the OptimizationFunction
optf = OptimizationFunction(
(p, args) -> loss_only(p, args, cache);
grad=gradfun
)
prob = OptimizationProblem(optf, p, args)
res = Optimization.solve(prob, OptimizationOptimisers.Adam(learning_rate),
callback = callback, maxiters = max_iters)
And the error:
LoadError: MethodError: no method matching (::var"#199#200")(::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}})
The function `#199` exists, but no method is defined for this combination of argument types.
Closest candidates are:
(::var"#199#200")(::Any, ::Any)
@ Main ~/arvind/cachedgrad_mwe.jl:338
Stacktrace:
[1] (::OptimizationFunction{…})(args::ComponentVector{…})
@ SciMLBase ~/.julia/packages/SciMLBase/9sEkh/src/scimlfunctions.jl:4253
[2] macro expansion
@ ~/.julia/packages/OptimizationOptimisers/onzuO/src/OptimizationOptimisers.jl:117 [inlined]
[3] macro expansion
@ ~/.julia/packages/Optimization/PwRNQ/src/utils.jl:32 [inlined]
[4] __solve(cache::OptimizationCache{…})
@ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/onzuO/src/OptimizationOptimisers.jl:93
[5] solve!
@ ~/.julia/packages/SciMLBase/9sEkh/src/solve.jl:234 [inlined]
[6] #solve#749
@ ~/.julia/packages/SciMLBase/9sEkh/src/solve.jl:131 [inlined]
[7] top-level scope
@ ~/arvind/cachedgrad_mwe.jl:345
[8] include(fname::String)
@ Main ./sysimg.jl:38
[9] top-level scope
@ REPL[4]:1
in expression starting at /home/daningburg/arvind/cachedgrad_mwe.jl:345
Some type information was truncated. Use `show(err)` to see complete types.
Keep in mind I’m using an older version of Enzyme due to an LLVM error that cropped up for me a few months ago. Version info:
Julia v1.11.6
Enzyme v0.13.68
Optimization v4.6.0