Dense Matrix sparse binary vector product

Hi,
I need to train very shallow networks(2 dense layers) whose inputs are sparse binary vector (i.e. only 0 and 1 entries, batched, it is for NNUE). I tried to write custom kernels for forward and backward pass but they are at best on par with simple matrix products. I know that some specific trainer (using pytorch+specific cuda kernels, or rust +cuda) can be much faster using sparsity and the fact that you only do addition of columns. But currently I’m stuck. What i tried was passing a batch in form of a matrix whose entries are index of the ones, along with a counter of ones. This is very naive quite slow. Also the adjoint allocates a lot.
Any hints ?

function matvec_sparse_forward_kernel!(
    Y::CuDeviceMatrix{Float32},          # m × batch
    A::CuDeviceMatrix{Float32},          # m × n
    v_idx::CuDeviceMatrix{Int32},        # max_nnz × batch (column j are indices for vector j, padded with 0)
    nnz_per_vec::CuDeviceVector{Int32}   # batch-length vector of nnz counts
)
    row = (blockIdx().x - 1) * blockDim().x + threadIdx().x
    vec = (blockIdx().y - 1) * blockDim().y + threadIdx().y

    m, n = size(A)
    batch = size(Y, 2)
    if row <= m && vec <= batch
        acc = 0.0f0
        nnz = nnz_per_vec[vec]
        @inbounds for k in 1:nnz
            j = v_idx[k, vec]
            acc += A[row, j]
        end
        Y[row, vec] = acc
    end
    return
end



# ------------------------
# Backward kernel (atomic)
# ------------------------
function matvec_sparse_backward_kernel!(dW::CuDeviceMatrix{Float32}, dZ::CuDeviceMatrix{Float32},
    X::CuDeviceMatrix{Int32}, nnz_per_vec::CuDeviceVector{Int32})
    row = (blockIdx().x - 1) * blockDim().x + threadIdx().x  # 1-based
    vec = blockIdx().y 
    m,n = size(dW)
    batch = size(X,2)                               # 1-based
    if row > m || vec > batch
        return
    end
   
   
    dout = dZ[row, vec]
    nnz = nnz_per_vec[vec]
    for i in 1:nnz
        j = X[i, vec]
        CUDA.@atomic dW[row, j] += dout
    end
    return
end ```