[Optimization] How would you speed the RNN Flux / Zygote code up?

Hey all,
I worked with tensorflow1.x/2.x/pytorch/jax for years and now I found Julia language where I find there is a lot more opportunity.
I would like to ask everyone so we can learn from each other, how would you optimise a code like this. I want to ask, what is the maximum speed you guys can you reach with this code?

using Flux
using Distributions

batchsize, timesteps = 2000, 200
data = (randn(Float32, timesteps, batchsize, 2),
		randn(Float32, timesteps, batchsize, 1))

init_norm() = (shape) -> randn(Float32, shape)
mse_fn(Y_pred, Y_true) = mean((Y_pred .- Y_true).^2)

function model(prms, data)
  X, Y = data
  X_t = [X[t, :, :] for t in 1:timesteps]::Array{Array{Float32,2},1}
  initial_state = ones(Float32, 2000, 1)
  ylist, state = Array{Array{Float32,2},1}(), [initial_state]
  function predict_cell(state, inputs)
	  predict_fn_const(prms, inputs, state)
  cell = Flux.Recur(predict_cell, state)
  ylist = cell.(X_t)::Array{Array{Float32,2},1}
  # concat result back.
  predict = cat(dims=2, ylist...)'
  predict = reshape(predict, size(predict)..., 1)
  loss = mse_fn(predict, Y)
function init(data)
  prms = [init_norm()((2000, 1))]
  loss = model(prms, data)
  opt = ADAM(0.06, (0.3, 0.7))
  loss, prms, opt
function gradient_pro(f, args...)
  @time y, back = Flux.pullback(f, args...)  # forwardprop time
  return y, back(one(y))
function step(data, prms, opt)
  loss, grads = gradient_pro((p) -> model(p, data), prms)
  return loss, prms, opt
function predict_fn_const(prms, inputs, state)
  # SOME random multiplications...
  i1, i2 = inputs[:, 1:1], inputs[:, 2:2]
  v1 = prms[1]
  o14 = i1 + i2
  o104 = o14 .* v1
  r206 = state[1]
  out, next_states = o104 + r206, [o104]
  return next_states, out
function test()
  loss, params, opt = init(data)
  for i in 1:10
    @time loss, params, opt = step(data, params, opt)   # forward & backwardprop time
@time test()

My run results:

0.838855 seconds (1.48 M allocations: 89.775 MiB)
  1.736759 seconds (2.47 M allocations: 1.347 GiB, 2.03% gc time)
  0.012596 seconds (36.88 k allocations: 16.770 MiB)
  0.531656 seconds (53.78 k allocations: 1.228 GiB, 25.55% gc time)
  0.011113 seconds (36.88 k allocations: 16.770 MiB)
  0.430338 seconds (53.78 k allocations: 1.228 GiB, 52.53% gc time)
  0.008225 seconds (36.88 k allocations: 16.770 MiB)
  0.353133 seconds (53.78 k allocations: 1.228 GiB, 2.43% gc time)
  0.011098 seconds (36.88 k allocations: 16.770 MiB)
  0.282289 seconds (53.78 k allocations: 1.228 GiB, 1.17% gc time)
  0.007328 seconds (36.88 k allocations: 16.770 MiB)
  0.223698 seconds (53.78 k allocations: 1.228 GiB, 15.76% gc time)
  0.006849 seconds (36.88 k allocations: 16.770 MiB)
  0.406923 seconds (53.78 k allocations: 1.228 GiB)
  0.007938 seconds (36.88 k allocations: 16.770 MiB)
  0.410816 seconds (53.78 k allocations: 1.228 GiB, 37.27% gc time)
  0.006129 seconds (36.88 k allocations: 16.770 MiB)
  0.218867 seconds (53.78 k allocations: 1.228 GiB)
  0.007709 seconds (36.88 k allocations: 16.770 MiB)
  0.397147 seconds (53.78 k allocations: 1.228 GiB, 11.62% gc time)
  5.268901 seconds (2.95 M allocations: 12.403 GiB, 17.46% gc time)
  5.495894 seconds (3.26 M allocations: 12.430 GiB, 16.74% gc time)

I have a strong confidence about that a 10x speed up should be possible.
Also the 50x diff between the backwardprop and forwardprop sounds a little bit huge. Can you guys help me?

p.s. I also made a reverse diff solution, which is 5x faster but I didn’t see if GPU is viable there so I have to stick with Flux + Zygote.

be aware that your code will hit https://github.com/FluxML/Flux.jl/issues/1209, you should avoid using broadcast in ylist = cell.(X_t)

With one of me college we tried to improve.
I don’t know if I changed the output of the final code looks like this:

using Flux
using Distributions

batchsize, timesteps = Int(2000), Int(200)
data = (randn(Float32, batchsize, timesteps, 2),
				randn(Float32, batchsize, timesteps, 1))

init_norm() = (shape) -> randn(Float32, shape)
mse_fn(Y_pred, Y_true) = mean((Y_pred .- Y_true).^2)

function model(prms, data)
	X, Y = data
	state = ones(Float32, 2000)
	ylist = Flux.Zygote.Buffer(Array{Array{Float32,1},1}(), timesteps)
  @inbounds for t in 1:timesteps::Int
    ylist[t], state = predict_fn_const(prms, view(X,:,t,:), state)
	ypred = cat(dims=2, ylist...)
	loss = mse_fn(ypred, Y)
function init(data)
  prms = init_norm()(2000)
  loss = model(prms, data)
  opt = ADAM(0.06, (0.3, 0.7))
  loss, prms, opt
function gradient_pro(f, args...)
  @time y, back = Flux.pullback(f, args...)  # forwardprop time
  y, back(one(y))
function step(data, prms, opt)
	# @code_warntype(model(prms, data))
  loss, grads = gradient_pro((p) -> model(p, data), prms)
  loss, prms, opt
function predict_fn_const(prms, inputs, state)
  # SOME random multiplications...
  i1, i2 = view(inputs, :, 1), view(inputs, :, 2)
  o14 = i1 .+ i2
  o104 = o14 .* prms
  r206 = state
	out, next_states = o104 .+ r206, o104
function test()
  loss, params, opt = init(data)
  for i in 1:10
    @time loss, params, opt = step(data, params, opt)   # forward & backwardprop time
@time test()

The results:

  0.352301 seconds (841.94 k allocations: 48.380 MiB)
  0.480803 seconds (1.04 M allocations: 64.940 MiB)
  0.002801 seconds (25.09 k allocations: 9.946 MiB)
  0.004730 seconds (28.00 k allocations: 17.698 MiB)
  0.003337 seconds (25.09 k allocations: 9.946 MiB)
  0.006421 seconds (28.00 k allocations: 17.697 MiB)
  0.019834 seconds (25.09 k allocations: 9.946 MiB, 79.98% gc time)
  0.022040 seconds (28.01 k allocations: 17.698 MiB, 71.97% gc time)
  0.003194 seconds (25.09 k allocations: 9.946 MiB)
  0.005274 seconds (28.00 k allocations: 17.697 MiB)
  0.003315 seconds (25.09 k allocations: 9.946 MiB)
  0.006852 seconds (28.00 k allocations: 17.697 MiB)
  0.004084 seconds (25.09 k allocations: 9.946 MiB)
  0.007252 seconds (28.00 k allocations: 17.697 MiB)
  0.011192 seconds (25.09 k allocations: 9.946 MiB, 64.36% gc time)
  0.013348 seconds (28.01 k allocations: 17.698 MiB, 53.97% gc time)
  0.003219 seconds (25.09 k allocations: 9.946 MiB)
  0.006179 seconds (28.00 k allocations: 17.698 MiB)
  0.003185 seconds (25.09 k allocations: 9.946 MiB)
  0.005193 seconds (28.00 k allocations: 17.698 MiB)
  0.796614 seconds (1.81 M allocations: 255.765 MiB, 2.90% gc time)

This is 10x speed up for sure.
The buffer speeded added 10x speedup. And that hcat works on it pretty well. Also column based indexing we tried which adds some more speed I think, like 2x speedup.

But still not 0.001sec so I don’t know what could be better. Maybe gradient tape?