Flux Restructure for ComponentArrays.jl unreasonably slow

In some situations, you have to restructure a lot if you use Flux and want to run your batches as seperate solves, not as one big solve, for instance using an EnsembleProblem. You have to use something like a ComponentArray to pass the parameters through the solver and to let the adjoint methods do their work in differentiating the solve. But restructuring using a ComponentArray is unreasonably(?) slow in Flux. Switching to Lux eliminates those problems, but it seems like something that could be implemented better in Flux or ComponentArrays.

This occurs in particular if you always use the same parameters. When differentiating an ODE / SDE solve, restructuring outside of the drift / diffusion functions is not allowed.

Example (not using an ODE solve but just a loop for simplicity):

using Flux, Profile, PProf, Random, ComponentArrays
layer_size = 256
layer = Flux.Dense(layer_size => layer_size)
params_, re = Flux.destructure(layer)
params = ComponentArray(params_)
function eval(steps, input)
	vec = input
	for i in 1:steps
		vec = re(params)(vec)
	end
end
Profile.clear()
@profile eval(100000, rand(Float32, layer_size))
pprof(; web=true)

Here, we spend 10% of the time in sgemv matrix multiplication, another 10% in the rest of the Dense call and about 75% in the Restructure. This gets worse if the networks are smaller of course. As far as I can read the flame graph, the restructure seems to spend a lot of time in the GC:

Could there be a way to mitigate this specific problem? In particular if you use the same parameters. I think this would make some example code a lot faster too.

This is one of the reasons the documentation recently all changed from Flux to Lux. DiffEqFlux v2 just came out last week and one of its big things was to be Lux-native through and through for better performance and reliability. I recommend checking out the new stable docs which showcase using Lux:

https://docs.sciml.ai/DiffEqFlux/stable/examples/neural_ode/

Lux doesn’t require destructure/restructure so that issue is completely removed.

1 Like

Thanks for answering! Sure, I’m using Lux now too. Still, a lot of people use Flux. I’ll just open an issue in Flux and see what happens over there. Flux issue: Restructure for ComponentArrays.jl unreasonably slow · Issue #2238 · FluxML/Flux.jl · GitHub