Custom RNN cell. Help needed to improve performance

I am working on an implementation of a sparsely connected RNN cell.
The training of the cell is very slow. My code is not completely type stable and there are for sure some rookie mistakes as well. I hope that this is the only reason for the bad performance.

I would like to use this model as a basis for future work if it is possible to speed up the training.

abstract type Component end
abstract type CurrentComponent <:Component end

struct Mapper{V<:AbstractVector}
  W::V
  b::V
  Mapper(W::V,b::V) where {V<:AbstractVector} = new{typeof(W)}(W,b)
end
Mapper(in::Integer) = Mapper(ones(Float32,in), zeros(Float32,in))
function (m::Mapper{<:Vector{T}})(x::AbstractVecOrMat{T}) where T
  W,b = m.W, m.b
  W .* x .+ b
end
Flux.@functor Mapper
Flux.trainable(m::Mapper) = (m.W, m.b,)
Base.show(io::IO, m::Mapper) = print(io, "Mapper(", length(m.W), ")")


import Base: +, -
struct ZeroSynapse <: CurrentComponent end
ZeroSynapse(args...) = ZeroSynapse()
(m::ZeroSynapse)(h,x) = fill(Flux.Zeros(), size(x,1))#zeros(eltype(x),size(x,1))#Flux.Zeros()#zeros(I,size(h,2))
Flux.@functor ZeroSynapse
Flux.trainable(m::ZeroSynapse) = Float32[]
+(::Flux.Zeros, b::Real) = b
+(a::Real, ::Flux.Zeros) = a

-(::Flux.Zeros, b::Real) = -b
-(a::Real, ::Flux.Zeros) = -a


struct SigmoidSynapse{V} <: CurrentComponent
  μ::V
  σ::V
  G::V
  E::V
end
function SigmoidSynapse(;
             μr=Float32.((0.3, 0.8)), σr=Float32.((3, 8)), Gr=Float32.((0.001, 1)), Er=Float32.((-0.3, 0.3)))
  μ = rand_uniform(eltype(μr),μr..., 1, 1)[1,:]
  σ = rand_uniform(eltype(σr),σr..., 1, 1)[1,:]
  G = rand_uniform(eltype(Gr),Gr..., 1, 1)[1,:]
  E = rand_uniform(eltype(Er),Er..., 1, 1)[1,:]
  SigmoidSynapse(μ,σ,G,E)
end
function (m::SigmoidSynapse{V})(h::AbstractVecOrMat, x::AbstractVecOrMat{T}) where {V,T}
  #@show size(@fastmath @. m.G * sigmoid((x - m.μ) * m.σ) * (h - m.E))    
  @fastmath @. m.G * sigmoid((x - m.μ) * m.σ) * (h - m.E)
end
Flux.@functor SigmoidSynapse
Flux.trainable(m::SigmoidSynapse) = (m.μ, m.σ, m.G, m.E,)
Base.show(io::IO, m::SigmoidSynapse) = print(io, "SigmoidSynapse")


struct LeakChannel{V} <: CurrentComponent
  G::V
  E::V
end
LeakChannel(n::Int; Gr=Float32.((0.001,1)), Er=Float32.((-0.3,0.3))) =
  LeakChannel(rand_uniform(eltype(Gr),Gr..., n), rand_uniform(eltype(Er),Er..., n))
(m::LeakChannel{V})(h::AbstractVecOrMat{T}) where {V,T} = @fastmath m.G .* (h .- m.E)
Flux.@functor LeakChannel
Flux.trainable(m::LeakChannel) = (m.G, m.E,)
Base.show(io::IO, m::LeakChannel) = print(io, "LeakChannel(", size(m.G,1), ")")


struct ComponentContainer{C}
   components::C
 end
(m::ComponentContainer{C})(x::AbstractVecOrMat, I::AbstractVecOrMat) where C = doComponentContainer(m.components, x, I)
doComponentContainer(components, x, I) = reshape(mapreduce(dst -> mapreduce(src -> components[src,dst](x[dst,:],I[src,:]), +, 1:size(I,1)), vcat, 1:size(x,1)), size(x))
Flux.@functor ComponentContainer


struct LTCCell{W<:Wiring,SE<:ComponentContainer,SY<:ComponentContainer,LE<:LeakChannel,CC<:Flux.Parallel,CCP<:AbstractArray,CCRE<:Function,V<:AbstractArray,S<:AbstractMatrix,SOLVER,SENSEALG}
  wiring::W
  sens::SE
  syns::SY
  leaks::LE
  cc::CC
  cc_p::CCP
  cc_re::CCRE
  cm::V
  state0::S
  solver::SOLVER
  sensealg::SENSEALG
end


function LTCCell(wiring, solver, sensealg; cmr=Float32.((1.6,2.5)), state0r=Float32.((0.01)))
  n_in = wiring.n_in
  out = wiring.out
  n_total = wiring.n_total
  cm    = rand_uniform(eltype(cmr),cmr..., n_total)
  state0 = fill(state0r[1], n_total, 1)

  sens = Union{ZeroSynapse,SigmoidSynapse}[]
  syns = Union{ZeroSynapse,SigmoidSynapse}[]
  # sens = SigmoidSynapse[]
  # syns = SigmoidSynapse[]
  # sens = []
  # syns = []

  for dst in 1:size(wiring.sens_mask,2)
    for src in 1:size(wiring.sens_mask,1)
      toadd = wiring.sens_mask[src,dst]
      s = toadd == 0 ? ZeroSynapse() : SigmoidSynapse()
      push!(sens, s)
    end
  end
  for dst in 1:size(wiring.syn_mask,2)
    for src in 1:size(wiring.syn_mask,1)
      toadd = wiring.syn_mask[src,dst]
      s = toadd == 0 ? ZeroSynapse() : SigmoidSynapse()
      push!(syns, s)
    end
  end

  sens = reshape(sens, size(wiring.sens_mask))
  syns = reshape(syns, size(wiring.syn_mask))

  # sens = StructArray(sens)
  # syns = StructArray(syns)

  sensc = ComponentContainer(sens)
  synsc = ComponentContainer(syns)
  leaks = LeakChannel(n_total)

  cc = Flux.Parallel(+, sensc, synsc, leaks)
  cc_p, cc_re = Flux.destructure(cc)

  LTCCell(wiring, sensc, synsc, leaks, cc, cc_p, cc_re, cm, state0, solver, sensealg)
end
Flux.@functor LTCCell
Flux.trainable(m::LTCCell) = (m.cc_p, m.cm, m.state0,)
Base.show(io::IO, m::LTCCell) = print(io, "LTCCell(", m.wiring.n_sensory, ",", m.wiring.n_inter, ",", m.wiring.n_command, ",", m.wiring.n_motor, ")")
# TODO remove in v0.13
function Base.getproperty(m::LTCCell, sym::Symbol)
  if sym === :h
    Zygote.ignore() do
      @warn "LTCCell field :h has been deprecated. Use m::LTCCell.state0 instead."
    end
    return getfield(m, :state0)
  else
    return getfield(m, sym)
  end
end

function (m::LTCCell{W,SE,SY,LE,CC,CCP,CCRE,V,<:AbstractMatrix{T},SOLVER,SENSEALG})(h::AbstractVecOrMat{T}, x::AbstractVecOrMat{T}) where {W,SE,SY,LE,CC,CCP,CCRE,V,T,SOLVER,SENSEALG}
  h = repeat(h, 1, size(x,2)-size(h,2)+1)::Matrix{T}
  h = solve_ode(m,h,x)
  h, h
end
function solve_ode(m,h::Matrix{T},x::Matrix)::Matrix{Float32} where T

  function dltcdt!(dh,h,p,t)
    cc_p = @view p[1 : cc_p_l]
    cm   = @view p[cc_p_l+1 : pl]
    cc = cc_re(cc_p)
    
    argsv = ((h,x), (h,h), (h,))

    I_components = mapreduce(i -> cc.layers[i](argsv[i]...), cc.connection, 1:length(cc.layers))
    @. dh = - cm * I_components
    nothing
  end
  function oop(h,p,t)
    dh = similar(h)
    dltcdt!(dh,h,p,t)
    dh
  end
  
  cc_p_l = size(m.cc_p,1)::Int
  cc_re = m.cc_re

  p = vcat(m.cc_p, m.cm)::Vector{Float32}
  pl = size(p,1)::Int
  prob = ODEProblem{true}(dltcdt!,h,Float32.((0,1)),p)
  # prob = ODEProblem{false}(oop,h,tspan,p)

  # de = ModelingToolkit.modelingtoolkitize(prob)
  # jac = eval(ModelingToolkit.generate_jacobian(de)[2])
  # f = ODEFunction((dh,h,p,t)->dltcdt!(dh,h,p,t,x,cc_p_l), jac=jac)
  # prob_jac = ODEProblem(f,h,tspan,p)

  solve(prob,m.solver; sensealg=m.sensealg, save_everystep=false, save_start=false, abstol=1e-3, reltol=1e-3)[:,:,end]
end




mutable struct LTCNet{MI<:Mapper,MO<:Mapper,T<:LTCCell,S}
  mapin::MI
  mapout::MO
  cell::T
  state::S
  LTCNet(mapin,mapout,cell,state) = new{typeof(mapin),typeof(mapout),typeof(cell),typeof(state)}(mapin,mapout,cell,state)
end
function LTCNet(wiring,solver,sensealg)
  mapin = Mapper(wiring.n_in)
  mapout = Mapper(wiring.n_total)
  cell = LTCCell(wiring,solver,sensealg)
  LTCNet(mapin,mapout,cell,cell.state0)
end

function (m::LTCNet{MI,MO,T,<:AbstractMatrix{T2}})(x::AbstractVecOrMat{T2}) where {MI,MO,T,T2}
  x1 = m.mapin(x)
  m.state, y1 = m.cell(m.state, x1)
  y = m.mapout(y1)
  return y
end

Flux.@functor LTCNet
Flux.trainable(m::LTCNet) = (m.mapin, m.mapout, m.cell,)
Flux.reset!(m::LTCNet) = (m.state = m.cell.state0)
Base.show(io::IO, m::LTCNet) = print(io, "LTCNet(", m.mapin, ",", m.mapout, ",", m.cell, ")")
function Base.getproperty(m::LTCNet, sym::Symbol)
  if sym === :init
    Zygote.ignore() do
      @warn "LTCNet field :init has been deprecated. To access initial state weights, use m::LTCNet.cell.state0 instead."
    end
    return getfield(m.cell, :state0)
  else
    return getfield(m, sym)
  end
end




function get_bouds(m::LeakChannel)
end
function get_bounds(m::SigmoidSynapse)
  [0.1, 1, 0.001, -1] |> f32, [0.9, 10, 1, 1] |> f32
end
function get_bounds(m::Mapper)
  lb = [[-20.1 for i in 1:length(m.W)]...,
        [-20.1 for i in 1:length(m.b)]...] |> f32

  ub = [[20.1 for i in 1:length(m.W)]...,
        [20.1 for i in 1:length(m.b)]...] |> f32
  return lb, ub
end
function my_custom_train!(m, loss, ps, data, opt; data_range=nothing, lower=nothing, upper=nothing, cb=()->nothing)
  local training_loss
  ps = Zygote.Params(ps)
  for d in data
    gs = Zygote.gradient(ps) do
      training_loss = loss(d...)
      return training_loss
    end
    cb(d..., training_loss, m)
    Flux.Optimise.update!(opt, ps, gs)
    for i = 1:length(lower)
      # clamp!(ps[1][i], lower[i], upper[i])
      ps[1][i] = max(lower[i],ps[1][i])
      ps[1][i] = min(upper[i],ps[1][i])
    end
  end
end

function generate_data()
    in_features = 2
    out_features = 1
    N = 48
    data_x = [sin.(range(0,stop=3π,length=N)), cos.(range(0,stop=3π,length=N))]
    data_x = [reshape([Float32(data_x[1][i]),Float32(data_x[2][i])],2,1) for i in 1:N]# |> f32
    data_y = [reshape([Float32(y)],1) for y in sin.(range(0,stop=6π,length=N))]# |> f32

    data_x, data_y
end

function data(iter; data_x=nothing, data_y=nothing, short=false, noisy=false)
    #noisy_data = Vector{Tuple{Vector{Matrix{Float64}}, Vector{Vector{Float64}}}}([])
    if data_y === nothing
      data_x, data_y = generate_data()
    end
    noisy_data = Vector{Tuple{Vector{Matrix{eltype(data_x[1])}}, Vector{Vector{eltype(data_y[1])}}}}([])
    for i in 1:iter
        x = data_x
        y = data_y
        if short isa Array
          x = x[short[1]:short[2]]
          y = y[short[1]:short[2]]
        end
        push!(noisy_data, (x , noisy ? add_gauss.(y,0.02) : y))
    end
    noisy_data
end


function loss(x,y,m::LTCNet{<:Mapper,<:Mapper,<:LTC.LTCCell,<:AbstractMatrix})
  Flux.reset!(m)
  ŷ = map(xi -> m(xi)[end-m.cell.wiring.n_motor+1:end, :], x)
  sum(sum([(ŷ[i][end,:] .- y[i]) .^ 2 for i in 1:length(y)]))/length(y)#, ŷ
end
function cb(x,y,l,m)
  println(l)
  # pred = m.(x)
  # # isnan(l) && return false
  # fig = plot([ŷ[size(ŷ,1),1] for ŷ in pred])
  # plot!(fig, [yi[size(yi,1),1] for yi in y])
  # display(fig)
  return false
end

function traintest(n, solver=VCABM(), sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)))

  x,y = generate_data()
  model = LTC.LTCNet(Wiring(2,1), solver, sensealg)
  lower,upper = [],[]
  θ = Flux.params(model)

  @show sum(length.(θ))
  @show length(lower)

  train_data = data(n)

  opt = Flux.Optimiser(ClipValue(1), ADAM(0.03))


  my_custom_train!(model, (x,y) -> loss(x,y,model), θ, data(3), opt; cb, lower, upper)
  my_custom_train!(model, (x,y) -> loss(x,y,model), θ, train_data, opt; cb, lower, upper)
end

@time traintest(10)
@time traintest(30)

The timing of the 2nd traintest() call: 24.968881 seconds (231.50 M allocations: 12.595 GiB, 8.58% gc time)

github repo

Any help is more than welcome!