I made it work on cpu :
begin
using Lux
using Enzyme
using ComponentArrays
using Random
using Reactant
using Statistics
Reactant.set_default_backend("cpu")
Random.seed!(1234)
end
begin # Functions
function build_model(n_in, n_out, n_layers, n_nodes;
act_fun=leakyrelu, last_fun=relu)
# Input layer
first_layer = Lux.Dense(n_in, n_nodes, act_fun)
# Hidden block
hidden = (Lux.Dense(n_nodes => n_nodes, act_fun) for _ in 1:n_layers)
# Output layer
last_layer = Lux.Dense(n_nodes => n_out, last_fun)
return Chain(first_layer, hidden..., last_layer)
end
function eval_model(m, params, st, x)
M = first(m(x, params, st))
return M
end
function recursive_convert(T, x)
if x isa AbstractArray
return convert.(T, x) # elementwise convert
elseif x isa NamedTuple
return NamedTuple{keys(x)}(recursive_convert(T, v) for v in values(x))
else
return x
end
end
function lossDiff(p, args)
y = args[2]
dsigs = calculateMultiDiffCrossSections(p, args)
return mean((y - dsigs).^2)
end
function calculateDifferentialCrossSection(U, theta)
return sum(U)*cos.(theta)
end
function calculateMultiDiffCrossSections(p, args)
X = args[1]
thetas = args[3]
M = args[4]
st = args[5]
nlen = 2
rlen = 100
datalen = sum(length.(@view thetas[1:nlen]))
dσ = [zero(eltype(X)) for _ in 1:datalen]
j = 1
for i in 1:nlen
exp_len = length(thetas[i])
j_next = j + exp_len
U = eval_model(M, p, st, @view X[:,(i-1)*rlen+1 : i*rlen])
sig = sum(U)*cos.(thetas[i])
dσ[j:j_next-1] .= sig
j = j_next
end
return dσ
end
end
const xdev = reactant_device()
data_type = Float32
# Generate dummy data
XSdiff_train = rand(data_type, 9)
theta_train = [
[1.0, 10.0, 20.0, 45.0, 60],
[5.0, 15.0, 25.0, 60]
]
X_train = rand(data_type, 4,200)
# Load a model
nlayers = 2
nnodes = 16
model = build_model(4, 3, nlayers, nnodes)
ps, st = f32(Lux.setup(Random.default_rng(), model))
p = xdev(ComponentArray(recursive_convert(data_type, ps)))
const _st = st
args = (xdev(X_train), xdev(XSdiff_train), xdev(theta_train), model, xdev(_st))
# Test loss function evaluation
# losstest = lossDiff(p, args)
# losstest
display(@allowscalar @jit lossDiff(p, args))
dl_dp(p,args) = Enzyme.gradient(Reverse, lossDiff,p,Const(args))
@allowscalar dldp_comp = @compile dl_dp(p,args)
@allowscalar res = dldp_comp(p,args)
display(res)
it should also work on gpu however its a lot of scalar indexing going on so you should vectorize it a bit more first or write KernelAbstraction.jl kernels.
PS : if someone from Reactant.jl team (@wsmoses for instance) comes by there is a method ambiguity with fill!(::Vector{Reactant.TracedRNumber{Float32}}, ::Reactant.TracedRNumber{Float32}) which appears when doing
zeros(eltype(X), datalen) which forced me to do [zero(eltype(X)) for _ in 1:datalen] . Maybe thats wanted but then if I just make a f64/f32 array and send it to xla its fill with ConcreteNumbers and can’t be filled with TracedNumbers leading to
LoadError: MethodError: no method matching _copyto!(::SubArray{Float32, 1, ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{UnitRange{Int64}}, false}, ::Base.Broadcast.Broadcasted{Reactant.TracedRArrayOverrides.AbstractReactantArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(identity), Tuple{Reactant.TracedRArray{Float64, 1}}})
. So Is there a preffered way to make a temp array that will be fill with traced data for now ?
The method I used if fine but it is a julia array of TracedNumber which is not ideal I think