Optimize loss calculation on gpu

Is there anything i can do to speed up this loss calculation ? (Trajectory Balanced loss for Flow Network from [2310.02779] Expected flow networks in stochastic environments and two-player zero-sum games)

function gpu_train(policy, opt, datas, batch_size=512, λ=15)
    # Sample a batch on CPU
    idxs = rand(1:length(datas), batch_size)
    batch = datas[idxs]

    # Build CPU arrays
    states_cpu = zeros(Float32, 2*NN, batch_size, NN)
    actions_cpu=zeros(Float32,NN,batch_size,NN)
    rewards_cpu = zeros(Float32, 1, batch_size)
    dones_cpu = zeros(Float32, 1, batch_size, NN)
    masks_cpu = ones(Float32, NN, batch_size, NN)

    for (j, traj) in enumerate(batch)
        rewards_cpu[1, j] = λ * traj[end].winner
        for (k, d) in enumerate(traj)
            states_cpu[:, j, k] .= d.state
            actions_cpu[d.action,j, k] = 1
            dones_cpu[1, j, k] = d.done ? 0 : 1
            masks_cpu[:, j, k] = d.mask
        end
    end

    # Move entire arrays to GPU
    states = cu(states_cpu)
    actions = cu(actions_cpu)
    den = cu(rewards_cpu)
    dones   = cu(dones_cpu)
    masks   = cu(masks_cpu)
    num=ones(Float32,1,batch_size)|>gpu
    gs = gradient(policy) do m
        num=exp.(num.*m.z)
       
        nactions=log.(sum(masks, dims=1)).*dones
        for t in 1:NN
            if sum(dones[:,:,t])==0
                break
            end
            logits = m(states[:,:,t]) .* masks[:,:,t] .- (1 .- masks[:,:,t]) * 1.0f10
            probs=sum(logsoftmax(logits, dims=1).*actions[:,:,t],dims=1)
            if t%2==1
               
                num += probs.*dones[:,:,t]
                num += nactions[:,:,t]
            else
    
                den += probs.*dones[:,:,t]
                den += nactions[:,:,t]
            end
    end

        return Flux.mse(num,den)
    end
    Flux.update!(opt, policy, gs[1])
end