You should order the chain a bit differently here.
ps = Lux.setup(rng, chain)[1] |> gpud |> ComponentArray
the GPU device will anyways cast your elements to Float32.
Regarding the final error, I am not certain NeuralPDE supports (with proper testing) AMD GPUs but that is more of a @ChrisRackauckas question.