I am training a neural network in a physical context on observables; crucially, the network is producing values which are passed through some physics calculations before being compared to the training data. My code works, but has become quite slow as I’ve acquired more training data (on the order of 30,000 points). I would like to try moving it to Reactant to make use of its optimization and also HPC GPUs.
I’m running into a lot of errors that are either “Constant memory is stored (or returned) to a differentiable variable” or “No augmented forward pass found for ejlstr$UnsafeBufferPointer$…”.
I’d like some help dealing with this, or ideas on refactoring my code to be more friendly to Reactant. I’ve created an MWE that matches the high levels of my code quite closely, but leaves out the physics underneath. Here is its current form and the error I’m getting:
begin
using Lux
using Enzyme
using ComponentArrays
using Random
using Reactant
using Statistics
Random.seed!(1234)
end
begin # Functions
function build_model(n_in, n_out, n_layers, n_nodes;
act_fun=leakyrelu, last_fun=relu)
# Input layer
first_layer = Lux.Dense(n_in, n_nodes, act_fun)
# Hidden block
hidden = (Lux.Dense(n_nodes => n_nodes, act_fun) for _ in 1:n_layers)
# Output layer
last_layer = Lux.Dense(n_nodes => n_out, last_fun)
return Chain(first_layer, hidden..., last_layer)
end
function eval_model(m, params, st, x)
M = first(m(x, params, st))
return M
end
function combex(x::AbstractMatrix{T}, r::AbstractVector{T}) where T<:AbstractFloat
xlen, nx = size(x)
rlen = length(r)
# Preallocate output of the same type T
X = Array{T}(undef, xlen*rlen, nx+1)
# Fill output manually
for i in 1:xlen
row_start = (i-1)*rlen + 1
row_end = i*rlen
for j in 1:rlen
X[row_start + j - 1, 1] = r[j] # first column
@inbounds X[row_start + j - 1, 2:end] = x[i, :] # rest of columns
end
end
return X'
end
function recursive_convert(T, x)
if x isa AbstractArray
return convert.(T, x) # elementwise convert
elseif x isa NamedTuple
return NamedTuple{keys(x)}(recursive_convert(T, v) for v in values(x))
else
return x
end
end
function lossDiff(p, args)
y = args[3]
dsigs = calculateMultiDiffCrossSections(p, args)
return ln_loss(y, dsigs)
end
function ln_loss(y_data, y_model)
ln_data = log.(y_data)
ln_model = log.(y_model)
return mean((ln_data - ln_model).^2)
end
function calculateDifferentialCrossSection(A, Z, E, U, r, dr, theta, Lrange)
return rand(Float32, size(theta))
end
function calculateMultiDiffCrossSections(p, args)
x = args[1]
X = args[2]
r = args[4]
dr = args[5]
thetas = args[6]
Lrange = args[7]
M = args[8]
st = args[9]
nlen = size(x, 1)
rlen = size(r,1)
datalen = 0
for i in 1:nlen
datalen += length(thetas[i])
end
dσ = zeros(eltype(x), datalen)
j = 1
for i in 1:nlen
exp_len = length(thetas[i])
sig = zeros(eltype(x), exp_len)
j_next = j + exp_len
if x[i,2] > 0
U = eval_model(M, p, st, X[:,(i-1)*rlen+1 : i*rlen])
# U = eval_model(M, p, st, view(X,:,(i-1)*rlen+1 : i*rlen))
sig = calculateDifferentialCrossSection(x[i,2], x[i,3], x[i,1], U, r, dr, thetas[i], Lrange)
else
# Routine for adding isotopes of a natural element
nat_inds = findall(x -> x == x[:,3], nat_zs)
nat_abund = nat_abunds[nat_inds]
nat_As = mod.(nat_zaids[nat_inds], 1000)
k = 1
for nat_A in nat_As
X[3,(i-1)*rlen+1 : i*rlen] .= nat_A
U = eval_model(M, p, st, X[:,(i-1)*rlen+1 : i*rlen])
sig += nat_abund[k]/100*calculateDifferentialCrossSection(nat_A, x[i,3], x[i,1], U, r, dr, thetas[i], Lrange)
k+=1
end
end
dσ[j:j_next-1] = sig
j = j_next
end
return dσ
end
end
const cdev = cpu_device()
const xdev = reactant_device()
# Set up physics parameters
data_type = Float32
Lmax = 15
dr = 0.1
rmin = dr/100
rmax = 12.5
r = Vector{data_type}(rmin:dr:rmax)
Lrange = collect(0:Lmax)
# Abundances
data_abund = [12024 78.99;
12025 10;
12026 11.01;
14028 92.223;
14029 4.685;
14030 3.092]
const global nat_zaids = Vector{Int}(data_abund[:,1])
const global nat_abunds = Vector{data_type}(data_abund[:,2])
const global nat_zs = nat_zaids .÷ 1000
# Generate dummy data
x_train = rand(data_type, 2,3)
X_train = combex(x_train, r)
theta_train = [
[1.0, 10.0, 20.0, 45.0, 60],
[5.0, 15.0, 25.0, 60]
]
XSdiff_train = rand(data_type, 9)
# Load a model
nlayers = 2
nnodes = 16
model = build_model(4, 3, nlayers, nnodes)
ps, st = f32(Lux.setup(Random.default_rng(), model))
p = xdev(ComponentArray(recursive_convert(data_type, ps)))
const _st = st
args = (xdev(x_train), xdev(X_train), xdev(XSdiff_train), r, dr, theta_train, Lrange, model, _st)
# Test loss function evaluation
losstest = lossDiff(p, args)
Enzyme.jacobian(Reverse, p -> lossDiff(p, args),p)
losstest
The current error:
No augmented forward pass found for ejlstr$UnsafeBufferPointer$/Users/daningburg/.julia/artifacts/2c69783b22c1072452c6b137cf11806ce31f9f67/lib/libReactantExtra.dylib
at context: %39 = call i64 @"ejlstr$UnsafeBufferPointer$/Users/daningburg/.julia/artifacts/2c69783b22c1072452c6b137cf11806ce31f9f67/lib/libReactantExtra.dylib"(i64 %38) #271, !dbg !325
Stacktrace:
[1] wait
@ ~/.julia/packages/Reactant/QTNFa/src/Types.jl:195
[2] setindex!
@ ~/.julia/packages/Reactant/QTNFa/src/ConcreteRArray.jl:359
[3] macro expansion
@ ./cartesian.jl:62 [inlined]
[4] _unsafe_getindex!
@ ./multidimensional.jl:938 [inlined]
[5] _unsafe_getindex
@ ./multidimensional.jl:929
[6] Array
@ ./boot.jl:579 [inlined]
[7] Array
@ ./boot.jl:591 [inlined]
[8] zeros
@ ./array.jl:589 [inlined]
[9] zeros
@ ./array.jl:585 [inlined]
[10] calculateMultiDiffCrossSections
@ ~/Documents/Code/nuclear-diffprog/MWEs/reactant_and_lux.jl:98
[11] augmented_julia_lossDiff_40511_inner_217wrap
@ ~/Documents/Code/nuclear-diffprog/MWEs/reactant_and_lux.jl:0
[12] macro expansion
@ ~/.julia/packages/Enzyme/iosr4/src/compiler.jl:5691 [inlined]
[13] enzyme_call
@ ~/.julia/packages/Enzyme/iosr4/src/compiler.jl:5225 [inlined]
[14] AugmentedForwardThunk
@ ~/.julia/packages/Enzyme/iosr4/src/compiler.jl:5164 [inlined]
[15] macro expansion
@ ~/.julia/packages/Enzyme/iosr4/src/rules/jitrules.jl:447 [inlined]
[16] runtime_generic_augfwd(::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::typeof(lossDiff), ::Nothing, ::ComponentVector{…}, ::ComponentVector{…}, ::Tuple{…}, ::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/iosr4/src/rules/jitrules.jl:574
[17] #57
@ ~/Documents/Code/nuclear-diffprog/MWEs/reactant_and_lux.jl:170 [inlined]
[18] augmented_julia__57_29225_inner_1wrap
@ ~/Documents/Code/nuclear-diffprog/MWEs/reactant_and_lux.jl:0
[19] macro expansion
@ ~/.julia/packages/Enzyme/iosr4/src/compiler.jl:5691 [inlined]
[20] enzyme_call
@ ~/.julia/packages/Enzyme/iosr4/src/compiler.jl:5225 [inlined]
[21] AugmentedForwardThunk
@ ~/.julia/packages/Enzyme/iosr4/src/compiler.jl:5164 [inlined]
[22] autodiff
@ ~/.julia/packages/Enzyme/iosr4/src/Enzyme.jl:408 [inlined]
[23] autodiff
@ ~/.julia/packages/Enzyme/iosr4/src/Enzyme.jl:538 [inlined]
[24] macro expansion
@ ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:324 [inlined]
[25] gradient
@ ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:262 [inlined]
[26] macro expansion
@ ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:861 [inlined]
[27] jacobian_helper
@ ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:785 [inlined]
[28] macro expansion
@ ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:1239 [inlined]
[29] jacobian(mode::ReverseMode{false, false, false, FFIABI, false, false}, f::var"#57#58", xs::ComponentVector{Float32, ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{…}}, Tuple{Axis{…}}})
@ Enzyme ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:1213
[30] top-level scope
@ ~/Documents/Code/nuclear-diffprog/MWEs/reactant_and_lux.jl:170
[31] include(fname::String)
@ Main ./sysimg.jl:38
[32] top-level scope
@ REPL[6]:1
in expression starting at /Users/daningburg/Documents/Code/nuclear-diffprog/MWEs/reactant_and_lux.jl:170