Dear All,
I would like to ask you for advice regarding how to parallelize the loss function in Flux. Context: I am solving a discounted infinite-horizon optimization problem (in economics). It has three controls (Cs, Ci, Cr), which I parameterize with deep neural networks (four hidden layers with 32 neurons). To approximate the infinite horizon problem, I am finite-horizon approximation (lets say 150 periods).
The problem is, that I need to evaluate this loss function over a large grid of points (few thousand), and that takes a lot of time. My networks have relatively small layers, so GPU doesnโt help there. Evaluation at one point is a serial thing that canโt be parallelized, but I think, that maybe I can multithread/distribute those operations to multiple cpu cores (6 cores on my laptop) or on the cluster of my university.
So, I would like to ask for some advice about optimal way, how to parallelize this type of loss function on CPUs. Should I try something like Floops? Also, I would like to ask, whether I should try FastChain from DiffEqFlux instead of regular Chain. Does it provides significant speed-up, or are my networks too large for it (around 8700 parameters for each network)?
Example code.
#(5) Build law of motion
@unpack ฮฑ,โฌ,ฮธ,ฯ,๐,ฯ1,ฯ2,ฯ3,ฯr,ฯd,ฮฝ = Mod1
function ๐(๐,Cs,Ci,Ns,Ni)
๐ข = ๐[1,:]' - ฯ1.*Cs.*Ci.*๐[1,:]'.*๐[2,:]' - ฯ2.*Ns.*Ni.*๐[1,:]'.*๐[2,:]' - ฯ3.*๐[1,:]'.*๐[2,:]'
๐ = (1-ฯr-ฯd).*๐[2,:]' + ฯ1.*Cs.*Ci.*๐[1,:]'.*๐[2,:]' + ฯ2.*Ns.*Ni.*๐[1,:]'.*๐[2,:]' + ฯ3.*๐[1,:]'.*๐[2,:]'
๐ก = ๐[3,:]' + ฯr.*๐[2,:]'
๐จ = [๐ข;๐;๐ก]
return ๐จ
end
#(6) Utility function
u(c,n) = log(c) - ฮธ/2*n^2
#(7) Build objective function
function ๐ (x)
๐ = x
๐ค = zeros(Float32,1,size(๐)[2])
๐ฎ1 = zeros(Float32,1,size(๐)[2])
๐ฎ2 = zeros(Float32,1,size(๐)[2])
๐ฎ3 = zeros(Float32,1,size(๐)[2])
for i in 1:๐ฃ
cs_u = Cs(๐)
ci_u = Ci(๐)
cr_u = Cr(๐)
#Non-Negativity
cs = max.(cs_u,ฮฝ)
ci = max.(ci_u,ฮฝ)
cr = max.(cr_u,ฮฝ)
๐ฎ1 += max.(-cs_u.+ฮฝ,0.0000010f0)
๐ฎ2 += max.(-ci_u.+ฮฝ,0.0000010f0)
๐ฎ3 += max.(-cr_u.+ฮฝ,0.0000010f0)
#Cumulate reward
๐ค += โฌ^(i-1)*(๐[1,:]'.*u.(cs,cs./ฮฑ) + ๐[2,:]'.*u.(ci,ci./(ฮฑ*ฯ))
+ ๐[3,:]'.*u.(cr,cr./ฮฑ) + (1 .- ๐[1,:]'-๐[2,:]'-๐[3,:]').*๐)
๐ = ๐(๐,cs,ci,cs./ฮฑ,ci./(ฮฑ*ฯ))
end
return -sum(๐ค) + sum(๐ฎ1.^2) + sum(๐ฎ2.^2) + sum(๐ฎ3.^2)
end
Any advice/guidance would be welcomed!
Best,
Honza
EDIT:
Of course, I need to parallelize in a way, that would allow for automatic differention w.r.t. network parameters.