Flux on gpu and inference optimization

I’m currently running an Alphazero replica with Flux. As you may know most of the time is spent during selfplay and essentially doing neural network inferences. I’m currently using a batchsize of 512 and when mesuring time spent for inferences i get close to the “theoretic timing” ie mean time

for k in 1:n_sim

where batch is rand(Float32,7,7,3,512)|>gpu). So far so good. Problem is the same model running on another implementation in rust is 10 times faster. If i just count the inference time it already is at least 5 time slower than the whole loop in rust.
so I wonder is there any penalty calling cudnn from flux? Would it possibly be faster to use directly CUDA.jl/CUDNN or is there something dumb I would be doing making the forward pass slow ?
for the record here is the model struct:

mutable struct ResnetB{T}
Flux.@functor ResnetB

function ResNetBlock(n::Int)
    layers = Chain(
        Conv((3, 3), n => n, pad=1, stride=1, bias=false),
        BatchNorm(n, relu),
        Conv((3, 3), n => n, pad=1, stride=1, bias=false),
    return ResnetB(layers)

function (m::ResnetB)(x)
    return relu.(m.layers(x) .+ x)

mutable struct resnetwork_2H <: Network

Flux.@functor resnetwork_2H

function (m::resnetwork_2H)(x)
    b = m.base(x)
    b = m.res(b)

    return m.policy(b), m.value(b)
function resnetb_2H(n_filter, n_tower, dense)

    return resnetwork_2H(Chain(Conv((3, 3), 3 => n_filter, stride=1, pad=1, bias=false), BatchNorm(n_filter, relu)),
        Chain([ResnetB(Chain(ResNetBlock(n_filter), ResNetBlock(n_filter))) for k in 1:n_tower]...),
        Chain(Conv((1, 1), n_filter => 32, stride=1, pad=0, bias=false), BatchNorm(32, relu), Flux.flatten, Dense(32 * 49, dense, relu), Dense(dense, 1, tanh)),
        Chain(Conv((1, 1), n_filter => 32, stride=1, pad=0, bias=false), BatchNorm(32, relu), Flux.flatten, Dense(32 * 49, 539), softmax)) |> gpu

Thanks in advance

At this model size it’s unlikely you’ll see anything but increased overhead from using CUDA. If you want to try though, the CUDA.CUDNN module should have most everything you need. I would recommend having a look at what AlphaZero.jl does to efficiently parallelize training (e.g. discussion in Questions on parallelization · Issue #71 · jonathan-laurent/AlphaZero.jl · GitHub). A related library to try is GitHub - PumasAI/SimpleChains.jl: Simple chains, which I imagine should be closer to your rust implementation because it runs on CPU.

Thanks my implementation already use gpu, is as parallel as Alphazero. jl, and so does rust imple(which by the way is not mine). My problem is that it is way slower. Usually you can match c++, or be in the x2 range, here timings are x10…