NeuralPDE.jl to find critical points of a geometric flow equation on the four-torus. This ends up being a fairly big system, with 18 variables in total. I think the size of the system makes the CUDA compiler generate kernels that exceed the maximal parameter space, i.e. the issue described here: Guard against exceeding maximum kernel parameter size · Issue #32 · JuliaGPU/CUDA.jl · GitHub. Is there a way to avoid this, maybe a parameter for the CUDA compilation? Or is this a bug in the way the kernel gets assembled in the first place from the
There’s a suggestion in that issue: Pass the large argument as a one-element array.
Thanks for the answer @maleadt ! I did see the suggestion, but I don’t know how I can control the generated CUDA code. Is there a way I can mangle with it after the automatic generation? Or are you suggesting finding whatever is generating these huge tuple arguments and patching that library?
Yes. You can’t just patch up the generated code, as changing the argument from a value to reference type changes the calling convention (i.e. requires you to pass a pointer to memory, e.g., by using an Array).
That makes sense. I can see that
Lux.jl is creating these named tuples, but it could also be that
ComponentArrays.jl introduces the problem. The offending line of code is
[Lux.setup(Random.default_rng(), c) |> ComponentArray |> gpu .|> Float64 for c in chains]
I’ll try to pinpoint it. I’m a bit surprised that I’m the first one to hit this problem, as this seems to be a fairly common way of using
Lux.jl parameters on the GPU.
This was caused by a more complicated PDE equation, which made NeuralPDE.jl compile a fairly complicated loss function. I simplified the constraints by reformulating them into several smaller ones. This solves the problem for me. I’m not sure yet if I’m paying a performance penalty for this though.
worth opening an issue to track
@ChrisRackauckas I think this is essentially a CUDA.jl and not so much a NeuralPDE.jl problem. The issue is tracked here Guard against exceeding maximum kernel parameter size · Issue #32 · JuliaGPU/CUDA.jl · GitHub.