Is there anything i can do to speed up this loss calculation ? (Trajectory Balanced loss for Flow Network from [2310.02779] Expected flow networks in stochastic environments and two-player zero-sum games)
function gpu_train(policy, opt, datas, batch_size=512, λ=15)
# Sample a batch on CPU
idxs = rand(1:length(datas), batch_size)
batch = datas[idxs]
# Build CPU arrays
states_cpu = zeros(Float32, 2*NN, batch_size, NN)
actions_cpu=zeros(Float32,NN,batch_size,NN)
rewards_cpu = zeros(Float32, 1, batch_size)
dones_cpu = zeros(Float32, 1, batch_size, NN)
masks_cpu = ones(Float32, NN, batch_size, NN)
for (j, traj) in enumerate(batch)
rewards_cpu[1, j] = λ * traj[end].winner
for (k, d) in enumerate(traj)
states_cpu[:, j, k] .= d.state
actions_cpu[d.action,j, k] = 1
dones_cpu[1, j, k] = d.done ? 0 : 1
masks_cpu[:, j, k] = d.mask
end
end
# Move entire arrays to GPU
states = cu(states_cpu)
actions = cu(actions_cpu)
den = cu(rewards_cpu)
dones = cu(dones_cpu)
masks = cu(masks_cpu)
num=ones(Float32,1,batch_size)|>gpu
gs = gradient(policy) do m
num=exp.(num.*m.z)
nactions=log.(sum(masks, dims=1)).*dones
for t in 1:NN
if sum(dones[:,:,t])==0
break
end
logits = m(states[:,:,t]) .* masks[:,:,t] .- (1 .- masks[:,:,t]) * 1.0f10
probs=sum(logsoftmax(logits, dims=1).*actions[:,:,t],dims=1)
if t%2==1
num += probs.*dones[:,:,t]
num += nactions[:,:,t]
else
den += probs.*dones[:,:,t]
den += nactions[:,:,t]
end
end
return Flux.mse(num,den)
end
Flux.update!(opt, policy, gs[1])
end