In some situations, you have to restructure a lot if you use Flux and want to run your batches as seperate solves, not as one big solve, for instance using an
EnsembleProblem. You have to use something like a
ComponentArray to pass the parameters through the solver and to let the adjoint methods do their work in differentiating the solve. But restructuring using a
ComponentArray is unreasonably(?) slow in Flux. Switching to Lux eliminates those problems, but it seems like something that could be implemented better in Flux or ComponentArrays.
This occurs in particular if you always use the same parameters. When differentiating an ODE / SDE solve, restructuring outside of the drift / diffusion functions is not allowed.
Example (not using an ODE solve but just a loop for simplicity):
using Flux, Profile, PProf, Random, ComponentArrays layer_size = 256 layer = Flux.Dense(layer_size => layer_size) params_, re = Flux.destructure(layer) params = ComponentArray(params_) function eval(steps, input) vec = input for i in 1:steps vec = re(params)(vec) end end Profile.clear() @profile eval(100000, rand(Float32, layer_size)) pprof(; web=true)
Here, we spend 10% of the time in sgemv matrix multiplication, another 10% in the rest of the Dense call and about 75% in the Restructure. This gets worse if the networks are smaller of course. As far as I can read the flame graph, the restructure seems to spend a lot of time in the GC:
Could there be a way to mitigate this specific problem? In particular if you use the same parameters. I think this would make some example code a lot faster too.