Parallelizing Loss Function

Dear All,

I would like to ask you for advice regarding how to parallelize the loss function in Flux. Context: I am solving a discounted infinite-horizon optimization problem (in economics). It has three controls (Cs, Ci, Cr), which I parameterize with deep neural networks (four hidden layers with 32 neurons). To approximate the infinite horizon problem, I am finite-horizon approximation (lets say 150 periods).

The problem is, that I need to evaluate this loss function over a large grid of points (few thousand), and that takes a lot of time. My networks have relatively small layers, so GPU doesnโ€™t help there. Evaluation at one point is a serial thing that canโ€™t be parallelized, but I think, that maybe I can multithread/distribute those operations to multiple cpu cores (6 cores on my laptop) or on the cluster of my university.

So, I would like to ask for some advice about optimal way, how to parallelize this type of loss function on CPUs. Should I try something like Floops? Also, I would like to ask, whether I should try FastChain from DiffEqFlux instead of regular Chain. Does it provides significant speed-up, or are my networks too large for it (around 8700 parameters for each network)?

Example code.

#(5) Build law of motion
@unpack ฮฑ,โ„ฌ,ฮธ,ฯ†,๐–‰,ฯ€1,ฯ€2,ฯ€3,ฯ€r,ฯ€d,ฮฝ = Mod1
function ๐“—(๐›€,Cs,Ci,Ns,Ni)
    ๐“ข = ๐›€[1,:]' - ฯ€1.*Cs.*Ci.*๐›€[1,:]'.*๐›€[2,:]' - ฯ€2.*Ns.*Ni.*๐›€[1,:]'.*๐›€[2,:]' - ฯ€3.*๐›€[1,:]'.*๐›€[2,:]'
    ๐“˜ = (1-ฯ€r-ฯ€d).*๐›€[2,:]' + ฯ€1.*Cs.*Ci.*๐›€[1,:]'.*๐›€[2,:]' + ฯ€2.*Ns.*Ni.*๐›€[1,:]'.*๐›€[2,:]' + ฯ€3.*๐›€[1,:]'.*๐›€[2,:]'
    ๐“ก = ๐›€[3,:]' + ฯ€r.*๐›€[2,:]'
    ๐žจ = [๐“ข;๐“˜;๐“ก]
    return ๐žจ
end

#(6) Utility function
u(c,n) = log(c) - ฮธ/2*n^2

#(7) Build objective function
function ๐“ (x)
    ๐›€ = x
    ๐“ค = zeros(Float32,1,size(๐›€)[2])
    ๐“ฎ1 = zeros(Float32,1,size(๐›€)[2])
    ๐“ฎ2 = zeros(Float32,1,size(๐›€)[2])
    ๐“ฎ3 = zeros(Float32,1,size(๐›€)[2])
    for i in 1:๐“ฃ
        cs_u = Cs(๐›€)
        ci_u = Ci(๐›€)
        cr_u = Cr(๐›€)
        #Non-Negativity
        cs = max.(cs_u,ฮฝ)
        ci = max.(ci_u,ฮฝ)
        cr = max.(cr_u,ฮฝ)
        ๐“ฎ1 += max.(-cs_u.+ฮฝ,0.0000010f0)
        ๐“ฎ2 += max.(-ci_u.+ฮฝ,0.0000010f0)
        ๐“ฎ3 += max.(-cr_u.+ฮฝ,0.0000010f0)
        #Cumulate reward
        ๐“ค += โ„ฌ^(i-1)*(๐›€[1,:]'.*u.(cs,cs./ฮฑ) + ๐›€[2,:]'.*u.(ci,ci./(ฮฑ*ฯ†))
        + ๐›€[3,:]'.*u.(cr,cr./ฮฑ) + (1 .- ๐›€[1,:]'-๐›€[2,:]'-๐›€[3,:]').*๐–‰)
        ๐›€ = ๐“—(๐›€,cs,ci,cs./ฮฑ,ci./(ฮฑ*ฯ†))
    end
    return -sum(๐“ค) + sum(๐“ฎ1.^2) + sum(๐“ฎ2.^2) + sum(๐“ฎ3.^2)
end

Any advice/guidance would be welcomed!

Best,
Honza

EDIT:

Of course, I need to parallelize in a way, that would allow for automatic differention w.r.t. network parameters.

1 Like

Supporting AD in FLoops (both sequential and parallel) or any Transducers.jl-related packages should be straightforward (I think), as long as user-defined loop bodies and functions are also AD-able and the accumulator is type stable (the latter assumption can be removed with some effort). Itโ€™s been in (a long list of) want-to-do list of mine but I havenโ€™t had the time to try it out. Nothing is there yet ATM.

Just FYI, how effective parallelization can be depend on the ratio of the serial and parallelizable things. (ref: Amdahlโ€™s law)

1 Like

@tkf Thank you very much! I will try Floops! My objective function is just a sum/loop over pretty simple functions of neural network, so I think there shouldnโ€™t be problems in this direction.

So, should I simple write an inner function, that would perform the sequential loop (one grid point evaluation), and then Floop this function over the array of grid points?

Best,
Honza

Wait no, please donโ€™t try :slight_smile:. There is no AD support in FLoops. Sorry if my explanation was unclear.

(I was meant to say I can make it work if I tweak FLoops.jl. But users of FLoops.jl canโ€™t.)

2 Likes

@tkf Thank you! Is there some other tool, which would allow for differentiation through multithreading/distributed code, or do I had to code it by hard?