Attempted to port an RL algorithm from PyTorch to Flux and it's 10x slower

I ported an algorithm described in the book Grokking Deep Reinforcement Learning (which I recommend if you’re a beginner to RL) from PyTorch to Julia. This algorithm is described in chapter 8 of the book, and it is meant to be a very basic RL algorithm which future chapters build on. The algorithm is called “neural fitted Q-iteration”.

The PyTorch code can be found here: gdrl/chapter-08.ipynb at master · mimoralea/gdrl · GitHub

My Julia / Flux implementation can be found here: gdrl-with-flux/Ch8.jl at 5611b4216ef941f2d89636e8783c56236d70da41 · DevJac/gdrl-with-flux · GitHub

I’ve run that notebook on my computer and it solves the RL environment in about 3000 episodes or less, which takes about 5 minutes. The Julia version takes about an hour to run 3000 episodes, almost all of the time being spend on line 57: qs = q.network(all_s). Both implementations were run on the CPU. I do not have a GPU.

Again, profiling showed all the time being spend on line 57: qs = q.network(all_s); network is a Flux model, 3 Dense layers. This is the only forward pass through the model which tracks the gradients, which I use a few lines later to update the model parameters. Other forward passes through the model are much quicker, which is to be expected because they aren’t tracking the gradients.

I believe my algorithm is exactly the same as the linked PyTorch implementation. Unless there is some difference in similar library concepts, like for example, if the RMSProp optimizers are slightly different in PyTorch vs Flux? I’ve left out some logging and skipped a few “design patterns”, but every other detail is the same as best I can tell.

Any ideas why this code is so much slower in Julia / Flux?

5 Likes

One thing is you aren’t very consistent with using Float32 vs Float64, like:

are all Float64 constants. Doing it all as Float32, i.e. 1f0 is probably a good idea. With all of this switching around, you might be going backwards through the network in Float64 which would be one hit. Other than that, the issue might be your BLAS. Have you tried it with MKL.jl to see if it’s an OpenBLAS vs MKL issue?

2 Likes

Thank you. I’ll have to check out MKL.jl. Does it make a permanent change to my Julia install? (Wouldn’t be a problem, I can reinstall easy enough.) Should I use the 64-bit MKL, since I’m not using numpy in this Julia program?

I’ll try using Float32 everywhere and see if that changes anything. My profiling showed 96% of the time spent on line 57 specifically, so I was just ignoring all performance problems that didn’t directly effect that line. Still, I’ve seen in Python that profiling data sometimes gets ascribed to the wrong line. Perhaps the same is happening here in Julia? The profiling samples for that gradient lambda might all be getting assigned to the first line, even though it’s actually the other lines using the time? Maybe? I hadn’t considered this.

That makes sense it’s that line. Making sure it’s done in Float32 and using a fast BLAS is really what you need here.

My optimization loop was taking about 10 seconds with OpenBLAS. After installing MKL.jl it was only taking about 8 seconds. A good improvement, I’m happy to learn about MKL.jl.

However, I timed the PyTorch optimization loop and it takes only 0.2 seconds. So I’m still unable to get anything near PyTorch performance out of Julia / Flux. For right now Julia is performing about 40x to 50x slower.

I changed my update loop as follows, and then got similar performance from Flux as I did from PyTorch. In other words, this change sped up my code about 40x.

I changed:

function update(q, sars, k=40)
    all_s = convert(Array{Float32,2}, reduce(hcat, map(x -> copy(x.s), sars)))
    all_s′ = reduce(hcat, map(x -> copy(x.s′), sars))
    all_a = reduce(hcat, map(x -> x.a == 0 ? [1.0, 0.0] : [0.0, 1.0], sars))
    all_r = reduce(hcat, map(x -> x.r, sars))
    all_f = reduce(hcat, map(x -> x.f ? 0.0 : 1.0, sars))
    prms = params(q)
    for k in 1:k
        target = all_r + 1.0 * (maximum(q.network(all_s′), dims=1)) .* all_f
        grads = gradient(prms) do
            qs = q.network(all_s)
            predicted = sum(qs .* all_a, dims=1)
            mean((predicted .- target).^2)
        end
        update!(opt, prms, grads)
    end
end

to:

function do_update(q, s, a, r, s′, f, k=40)
    for _ in 1:k
        target = r + 1.0 * (maximum(q.network(s′), dims=1)) .* f
        grads = gradient(params(q)) do
            qs = q.network(s)
            predicted = sum(qs .* a, dims=1)
            mean((predicted .- target).^2)
        end
        update!(opt, params(q), grads)
    end
end

function update(q, sars, k=40)
    all_s = convert(Array{Float32,2}, reduce(hcat, map(x -> copy(x.s), sars)))
    all_s′ = convert(Array{Float32,2}, reduce(hcat, map(x -> copy(x.s′), sars)))
    all_a = convert(Array{Float32,2}, reduce(hcat, map(x -> x.a == 0 ? [1.0, 0.0] : [0.0, 1.0], sars)))
    all_r = convert(Array{Float32,2}, reduce(hcat, map(x -> x.r, sars)))
    all_f = convert(Array{Float32,2}, reduce(hcat, map(x -> x.f ? 0.0 : 1.0, sars)))
    @assert typeof(all_s) == Array{Float32, 2}
    @assert typeof(all_s′) == Array{Float32, 2}
    @assert typeof(all_a) == Array{Float32, 2}
    @assert typeof(all_r) == Array{Float32, 2}
    @assert typeof(all_f) == Array{Float32, 2}
    do_update(q, all_s, all_a, all_r, all_s′, all_f)
end
5 Likes

As far as I can see the main difference is the conversion to Float32 across the board. I can totally see why that would improve but why is it 40 times? That just seems weird to me. Thanks for sharing the two code bases to compare with. I’ll definitely read them carefully. :pray:t2:

These lines look really inefficient. Is this really necessary? What is sars here? It seems to be a vector with element SARSF, is that right?

Assuming I don’t misunderstand, this seems like a mistake:

sars = SARSF[]

It creates a vector of abstract element type. You should specify, if you can, the type parameters S and A. Here’s an example timing, where I’m just wildly guessing at the types you are working with:

sars = SARSF[]
for i in 1:100
    push!(sars, SARSF(rand(-10:10), rand(Bool), rand(Float32), rand(-100:100), rand(Bool)))
end

julia> @btime convert(Array{Float32,2}, reduce(hcat, map(x -> copy(x.s), $sars)));
  480.201 μs (3435 allocations: 180.52 KiB)

Alternative. Here we get concrete element type:

 sars_ = [SARSF(rand(-10:10), rand(Bool), rand(Float32), rand(-100:100), rand(Bool)) for i in 1:100];

julia> @btime Float32.(getproperty.($sars_, :s));
  2.889 μs (101 allocations: 5.17 KiB)

Maybe the kind of work being done is different, but you get an idea about a different approach, I hope. The performance difference is not all about the concrete element type, but it also avoids a huge amount of unnecessary allocations of intermediate arrays.

1 Like

Good going. This eltype thing is, unfortunately, a frequent problem. Would be nice if it could warn you.

Note that target is still going to be Float64 here, as it has a 1.0 literal, so possibly the gradients will be too:

You can probably also write conversions more compactly something like Float32.(reduce(hcat, map(x -> x.s, sars)), no need to copy there I believe. (As DNF says, too, I see.)

These asserts seem meaningless. If you convert, the types will be correct, or an error will be thrown, there is no reason to do assert. Not that it hurts, it’s just meaningless and noisy.

3 Likes

Why not use a StructArray for sars? That should obviate any need for complicated conversions.

I certainly don’t understand all the code here, but to me it looks like you should be able to get a hundred-fold, or even much more, speedup of the update function. Depending, of course, on what’s going on inside do_update, which I cannot grok.

(In fact, by using an explicit loop, with no multithreading, I am able to speed up this line

convert(Array{Float32,2}, reduce(hcat, map(x -> copy(x.s), S)))

by a factor of 2500).

I am not sure I understand what is really going on, but it seems to me that, as you did not pass the “k” parameter to the do_update function call, then inside this function k is always equal to its default value 40, whatever the value of the “k” passed to the update function. Is this a potential problem?

Guess I’ll explain the algorithm since it will help cement what I’m trying to learn.

In reinforcement learning an agent attempts to maximize the rewards which are given by the environment. There is the concept of a Q function which takes a state and action and returns the expected future returns for taking the action in that state. Q is what I’m trying to approximate with Flux.

For example, in Poker your state might be 2 aces in your hand, and your actions are either to fold or call, etc. Q is a conceptual thing (although it can sometimes be calculated exactly), but if we can approximate it well enough, then Q can tell us which action maximizes our rewards in any situation, and that’s the entire purpose of reinforcement learning.

In my code Q is Q. Q is trying to approximate the expected future rewards for a state and action.

In a RL environment everything that happens can be encoded as a sequence of (s, a, r, s’) tuples. Where s is the state, a is the action taken in state s, r is the rewards immediately received, and s' is the next state.

Update is implementing an algorithm in which, for each (s, a, r, s’) tuple, you update Q(s, a) to more closely match r + Q(s’, best_a_for_s’). In words: we update our function approximation so that the value of a state and action more closely matches the value of the immediate reward plus the expected future rewards.

It’s complicated. One cool thing is the function approximator is “bootstrapping” itself. The neural network generates it own training data! Training cycles like this seem common in reinforcement learning. Although most algorithms are more sophisticated with a larger cycle, and for good reason: as you can imagine a neural network generating it’s own training data isn’t always stable.


And yes, I know there are a variety of sub-optimal practices in this code. It’s been through a lot of revision while I tried to get it to show even the smallest signs of working. I’ve given no effort to optimizing code unless profiling shows it is a hot spot.

That sounds pretty cool.

There’s no intention to criticize the code, it’s just that since you are specifically trying to compare performance to pytorch, it seems reasonable to point out that there are potentially enormous speed-ups to found, still.

1 Like

Yeah. Now that I’ve found this 40x speedup, I have reason to improve other areas of the code. As it was, whatever was causing that huge slowdown was dominating all other concerns, but now I think these other suggestions are worth considering.

I’m still scratching my head a bit about my profiling results. It was showing all time being spent on one line, and I haven’t changed that line or the data/types of the variables used on that line.

    grads = gradient(prms) do
        qs = q.network(all_s)
        predicted = sum(qs .* all_a, dims=1)
        mean((predicted .- target).^2)
    end

Profiling showed all time being spent on qs = q.network(all_s). I wonder if Julia inlined or otherwise rewrote the 3 lines of that lambda and the profiling was ascribing all time to the first line? The changes I made definitely effect the latter 2 lines, but not the first.

I would have guessed that most of the time is in the gradient of q.network(all_s) alone, that has lots of big matrices. Changing the datatype of target and all_a may change the types in the gradient there, they flow backwards via qs.

How well profiling can disentangle this I don’t know. Zygote certainly re-writes this. You could try timing just one gradient call, and one forwards evaluation, to see?

To wrap up this thread: In the end I’m very happy with the performance of Julia and Flux. The core learning loop, the forward and backward pass through the network, performs similarly in both PyTorch and Flux.

However, overall Julia is much faster. I suspect this is because all the in-between that takes place in either plain Python or in Julia is much faster in Julia. I guess it’s no surprise that plain Julia is faster than plain Python, but in this case there is a lot of “book-keeping” going on outside of the actual learning loop and Julia seems to really be shining.

Start to end, my laptop can train a Python / PyTorch solution in about 5 minutes. It takes about 90 seconds in Julia / Flux, and that’s even using foreign calls to the OpenAIGym Python library.

15 Likes