Using Reactant with Lux and Enzyme to speed up training in physics context

I made it work on cpu :

begin
    using Lux
    using Enzyme
    using ComponentArrays
    using Random
    using Reactant
    using Statistics
    Reactant.set_default_backend("cpu")
    Random.seed!(1234)
end

begin # Functions
    function build_model(n_in, n_out, n_layers, n_nodes;
                        act_fun=leakyrelu, last_fun=relu)
        # Input layer
        first_layer = Lux.Dense(n_in, n_nodes, act_fun)

        # Hidden block
        hidden = (Lux.Dense(n_nodes => n_nodes, act_fun) for _ in 1:n_layers)

        # Output layer
        last_layer = Lux.Dense(n_nodes => n_out, last_fun)

        return Chain(first_layer, hidden..., last_layer)
    end

    function eval_model(m, params, st, x)
        M = first(m(x, params, st))
        return M
    end


    function recursive_convert(T, x)
        if x isa AbstractArray
            return convert.(T, x)  # elementwise convert
        elseif x isa NamedTuple
            return NamedTuple{keys(x)}(recursive_convert(T, v) for v in values(x))
        else
            return x
        end
    end

    function lossDiff(p, args)
        y = args[2]
        dsigs = calculateMultiDiffCrossSections(p, args)
        return mean((y - dsigs).^2)
    end

    function calculateDifferentialCrossSection(U, theta)
        return sum(U)*cos.(theta)
    end

    function calculateMultiDiffCrossSections(p, args)
        X = args[1]
        thetas = args[3]
        M = args[4]
        st = args[5]
        nlen = 2
        rlen = 100
        datalen = sum(length.(@view thetas[1:nlen]))
        dσ = [zero(eltype(X)) for _ in 1:datalen] 
        j = 1
        for i in 1:nlen 
            exp_len = length(thetas[i])
            j_next = j + exp_len
            U = eval_model(M, p, st, @view X[:,(i-1)*rlen+1 : i*rlen])
            sig = sum(U)*cos.(thetas[i])
            dσ[j:j_next-1] .= sig
            j = j_next
        end
        return dσ
    end
end


const xdev = reactant_device()
data_type = Float32

# Generate dummy data
XSdiff_train = rand(data_type, 9)
theta_train = [
    [1.0, 10.0, 20.0, 45.0, 60],
    [5.0, 15.0, 25.0, 60]
]
X_train = rand(data_type, 4,200)

# Load a model
nlayers = 2
nnodes = 16
model = build_model(4, 3, nlayers, nnodes)
ps, st = f32(Lux.setup(Random.default_rng(), model))
p = xdev(ComponentArray(recursive_convert(data_type, ps)))
const _st = st

args = (xdev(X_train), xdev(XSdiff_train), xdev(theta_train), model, xdev(_st))

# Test loss function evaluation
# losstest = lossDiff(p, args)
# losstest

display(@allowscalar @jit lossDiff(p, args))
dl_dp(p,args) = Enzyme.gradient(Reverse, lossDiff,p,Const(args))
@allowscalar dldp_comp = @compile dl_dp(p,args)
@allowscalar res = dldp_comp(p,args)
display(res)

it should also work on gpu however its a lot of scalar indexing going on so you should vectorize it a bit more first or write KernelAbstraction.jl kernels.

PS : if someone from Reactant.jl team (@wsmoses for instance) comes by there is a method ambiguity with fill!(::Vector{Reactant.TracedRNumber{Float32}}, ::Reactant.TracedRNumber{Float32}) which appears when doing
zeros(eltype(X), datalen) which forced me to do [zero(eltype(X)) for _ in 1:datalen] . Maybe thats wanted but then if I just make a f64/f32 array and send it to xla its fill with ConcreteNumbers and can’t be filled with TracedNumbers leading to

LoadError: MethodError: no method matching _copyto!(::SubArray{Float32, 1, ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{UnitRange{Int64}}, false}, ::Base.Broadcast.Broadcasted{Reactant.TracedRArrayOverrides.AbstractReactantArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(identity), Tuple{Reactant.TracedRArray{Float64, 1}}})

. So Is there a preffered way to make a temp array that will be fill with traced data for now ?
The method I used if fine but it is a julia array of TracedNumber which is not ideal I think

1 Like