Neural Nets training with multiple Chains Lux.jl and CUDA.jl

Hi, I faced a similar problem. Somebody else also created an issue on GitHub.

Here is how I fixed it:

  1. Define all the initial parameters as a single ComponentArray.
using ComponentArrays, CUDA, Lux, Random

# [...]

@parameters x y
@variables f1(..) f2(..) f3(..) f4(..)

# [...]

chain = [chain1 , chain2, chain3, chain4]
names = :f1, :f2, :f3, :f4  # same as the variables from the beginning

init_params = Lux.initialparameters.(Random.default_rng(),
                                     chain)
init_params = NamedTuple{names}(init_params)
init_params = ComponentArray(init_params)
Edit: Step 2 is no longer necessary as of ComponentArrays@v0.13.3.
  1. Redefine the conversion that happen at NeuralPDE.jl/src/discretize.jl#L480 (necessary for me as of NeuralPDE@v5.3.0 and ComponentArrays@v0.13.2).
using ComponentArrays: GPUComponentArray

function ComponentArray(nt::NamedTuple{(:depvar,),
                                       <:Tuple{GPUComponentArray{T}}}) where {T}
    depvar = cpu(nt.depvar)
    A = ComponentArray(; depvar)
    A = T.(gpu(A))
    return A
end
  1. Move the initial parameters to the GPU.
init_params = Float64.(gpu(init_params))

Let me know if this works for you. I will try to submit a PR to NeuralPDE with the fix.

2 Likes