For loop performance vs "functional" performance

The part 2 of the problem of day 19 of the AOC 2024 requires calculating the number of partitions of a string that are subsets of a given set of substrings.
This is repeated for a list of strings and to obtain the total number of admissible partitions.
I understood that, as with other problems of this type, brute force solutions do not work, but a “smart” use of recursion is better. In this case, taking advantage of the fact that even if the cases to be treated increase in the various steps, you can group those that have the same “evolution” and “carry forward” only the “representative” ones so that the number of cases to be managed does not increase explosively.
I compared two forms of one of the solutions of this specific case, finding differences in performance that I cannot explain intuitively.
I wonder why between the functions f1(…) and f1a(…) the first that uses a for loop is slower, even though it produces less allocation.

From what I can gather the form with comprehension and broadcasting

        t=length.(ts[findall(endswith.(dsi[1:end-ss],ts))])
        cm[t.+ss].+=cm[ss]

makes more passes through the vectors than the for loop does

     for e in ts
            if endswith(dsi[1:end-ss],e) 
                cm[length(e)+ss]+=cm[ss]
            end
        end

so I would have expected this to be faster as well as allocating less.

f1(...)
st,sd=open("input_aoc24_d19.txt") do f 
    split(read(f, String), "\n\n")
end

ts=sort(split(st,", "),by=length)
ds=split(sd,'\n')


function f1(ds,ts,i)
    dsi=ds[i]
    l=length(dsi)
    cm=zeros(Int,l)
    t=length.(ts[findall(endswith.(dsi[1:end],ts))])
    cm[t].=1 # cm[t].=1 funziona anche  con t vuoto (perchè?)
    isempty(t) && return 0
    ss=t[1]
    while true
        ss=findnext(!=(0),cm[1:end-1],ss) 
        (isnothing(ss)||ss==l)  && return cm[end]
        t=length.(ts[findall(endswith.(dsi[1:end-ss],ts))])
        cm[t.+ss].+=cm[ss]
        cm[ss]=0
    end
end


using BenchmarkTools
@btime cm=f1(ds,ts,1)
@btime sum(f1(ds,ts,i) for i in  1:length($ds);init=0) # 666491493769758
f1a(...)
function f1a(ds,ts,i)
    dsi=ds[i]
    l=length(dsi)
    cm=zeros(Int,l)
    t=length.(ts[findall(endswith.(dsi[1:end],ts))])
    cm[t].=1 # cm[t].=1 funziona anche  con t vuoto (perchè?)
    isempty(t) && return 0
    ss=t[1]
    while true
        ss=findnext(!=(0),cm[1:end-1],ss) 
        (isnothing(ss)||ss==l)  && return cm[end]
        for e in ts
            if endswith(dsi[1:end-ss],e) 
                cm[length(e)+ss]+=cm[ss]
            end
        end
        cm[ss]=0
    end
end

using BenchmarkTools
@btime cm=f1a(ds,ts,1)
@btime sum(f1a(ds,ts,i) for i in  1:length($ds);init=0) # 666491493769758
julia> @btime cm=f1(ds,ts,1)
  161.400 μs (468 allocations: 273.02 KiB)
178604257966

julia> @btime sum(f1(ds,ts,i) for i in  1:length($ds);init=0) # 666491493769758
  69.592 ms (140919 allocations: 79.66 MiB)
666491493769758

julia> @btime cm=f1a(ds,ts,1)
  359.400 μs (60 allocations: 30.28 KiB)
178604257966

julia> @btime sum(f1a(ds,ts,i) for i in  1:length($ds);init=0) # 666491493769758
  130.744 ms (19415 allocations: 8.94 MiB)
666491493769758

PS don’t overlook the possibility that I measured badly … :grinning:

1 Like

The difference is that in the broadcasting version you create dsi[1:end-ss] once and use it for all elements in ts, whereas in the loop version you create it for every element.

With AoC2024/day19/input.txt at main · asnoyman/AoC2024 · GitHub as input file (which does not seem to be the same one as you’re using as the output is not the same), and using f1, f1a as above, and f1b similar to f1a but with

dsi_end = dsi[1:end-ss]
for e in ts
    if endswith(dsi_end,e) 
        cm[length(e)+ss]+=cm[ss]
    end
end

I get

julia> @btime sum(f1($ds, $ts, i) for i in 1:length($ds); init=0)  # broadcasted version
  65.037 ms (250735 allocations: 17.22 MiB)
758890600222015

julia> @btime sum(f1a($ds, $ts, i) for i in 1:length($ds); init=0)  # dsi[1:end-ss] inside the loop
  101.799 ms (37804 allocations: 7.92 MiB)
758890600222015

julia> @btime sum(f1b($ds, $ts, i) for i in 1:length($ds); init=0)  # dsi[1:end-ss] outside of the loop
  61.943 ms (37804 allocations: 7.92 MiB)
758890600222015
1 Like

Without looking at both codes, at least, I can tell that you are massively allocating temporary arrays for no obvious reason, all throughout f1.

This one is particularly egregious

The above line creates five arrays, afaict.

You are apparently returning a single value. Can’t you do this without allocations, or at most maintain a single work buffer?

2 Likes

introducing the same change suggested by @eldee , the situation doesn’t change much.
Were you thinking of something different?

julia> function f1(ds,ts,i)
           dsi=ds[i]
           l=length(dsi)
           cm=zeros(Int,l)
           t=length.(ts[findall(endswith.(dsi[1:end],ts))])
           cm[t].=1 # cm[t].=1 funziona anche  con t vuoto (perchè?)
           isempty(t) && return 0
           ss=t[1]
           while true
               ss=findnext(!=(0),cm[1:end-1],ss)
               (isnothing(ss)||ss==l)  && return cm[end]
               dsie=dsi[1:end-ss]
               t=length.(ts[findall(endswith.(dsie,ts))])
               cm[t.+ss].+=cm[ss]
               cm[ss]=0
           end
       end
f1 (generic function with 1 method)

julia> @btime sum(f1(ds,ts,i) for i in  1:length($ds);init=0) # 666491493769758
  69.217 ms (140919 allocations: 79.66 MiB)
666491493769758

shouldn’t the compiler understand that it is a constant with respect to the loop variable and avoid repeating the calculation?
furthermore the negative difference is in the execution time not in the allocation

Why do you only interpoloate ds inside length but neither ds or ts inside f1? (Both are non-const globals, at least in your original code, which could degrade performance.)

this is why i looked for an improvement using a for loop. My question is why this tangled sequence of nested functions performs better than the for loop

There shouldn’t be any change, as in f1, dsi[1:end-ss] is only run once (per iteration of the while true loop). To reduce allocations, it is indeed better to loop manually, as in f1a. I guess @DNF should have also looked at that version :slight_smile:

I don’t know too much about compilers, but I wouldn’t say it’s immediately obvious that you can just move dsi[1:end-ss] outside of the loop. If you were to replace endswith with some mutating function, the versions with the indexing copy inside of the loop and outside of it, would have different behaviour.

It’s indeed interesting that there’s no difference in allocations between using dsi[1:end-ss] outside of the loop, or inside it. Also, endswith(@view(dsi[1:end-ss])) results in the same execution time and allocations as without the view. Perhaps the compiler can figure out that it’s okay to always reuse the memory allocated for dsi[1:end-ss]? Then again, I would expect the @view approach to have one fewer allocation for each iteration of the while true loop…


As a side-note, in general it’s not always the case that fewer allocations correspond to faster performance, because of e.g. memory access patterns / caching. Additionally, when reducing the allocations does help, it need not help much. E.g.

function f1b_noallocs(dsi, ts, c)  # c is a Vector{Int} of length >= length(dsi)
    l=length(dsi)
    cm = @view(c[1:l])
    cm .= 0
    first_t = -1
    for e in ts
        if endswith(dsi, e)
            if first_t <= 0
                first_t = length(e)
            end
            cm[length(e)] = 1
        end
    end
    first_t >= 1 || return 0
    ss = first_t
    while true
        ss=findnext(!=(0), @view(cm[1:end-1]), ss)
        (isnothing(ss)||ss==l)  && return cm[end]
        dsi_end = @view dsi[1:end-ss]
        for e in ts
            if endswith(dsi_end, e)
                cm[length(e) + ss] += cm[ss]
            end
        end
        cm[ss]=0
    end
end

function sumf1b_minallocs(ds, ts)
    c = zeros(Int, maximum(length(dsi) for dsi in ds))
    return sum(f1b_noallocs(dsi, ts, c) for dsi in ds)
end
julia> @btime sum(f1b($ds, $ts, i) for i in 1:length($ds); init=0)
  63.070 ms (37804 allocations: 7.92 MiB)
758890600222015

julia> @btime sumf1b_minallocs($ds, $ts)
  56.673 ms (2 allocations: 576 bytes)
758890600222015
1 Like

Another method, not as fast (5x slower), but the code looks simpler:

function countways(s,ts)
    ( L = length(s) ) == 0 && return 0
    M = zeros(Int, L+1, L+1)
    for i in 1:maximum(length, ts)
        for m in eachmatch(Regex(join(filter(==(i)∘length, ts), '|')), s; overlap=true)
            M[m.offset, m.offset+length(m.match)] = 1
        end
    end
    M[L+1,L+1] = 1
    return (M^nextpow(2,L))[1,L+1]
end

This can be optimized a bit further (e.g. constructing Regex objects only once for a batch of query s).

Testing using:

sum(countways.(ds,Ref(ts)))

reprises the result from f1.

P.S. In a code-golfing spirit, I’ve whittled the function down to 3 lines:

using IterTools
function countways(s,ts)
    ( M = zeros(Int, length(s)+1, length(s)+1) )[end, end]=1
    M[[CartesianIndex(m.offset, m.offset+length(m.match)) 
      for P in IterTools.groupby(length, sort(ts; by=length)) 
      for m in eachmatch(Regex(join(P, '|')), s; overlap=true)]] .= 1
    return (M^nextpow(2, length(s)))[1, end]
end

with the long-ish middle line lifting a lot.

1 Like

I would say that this observation is the one that comes closest to a complete solution to my question.
Although there is a doubt about the fact that a difference in results between f1() and f1a() could also depend on the fact that a control expression if … end inside the for loop prevents the parallelization that the compiler could do while it can be done with the use of findall.().

Nice solution. Two observations:

  1. does it make sense to put cm[ss] out of the loop, in this way?
cmss=cm[ss]
for e in ts
if endswith(dsi_end, e)
cm[length(e) + ss] += cmss
end
end
  1. the maximum() function also has the method maximum(f, itr). So in our case also

function sumf1b_minallocs(ds, ts,)
   #c = zeros(Int, maximum(length(dsi) for dsi in ds))
   c=zeros(Int,maximum(length, ds))
   return sum(f1b_noallocs(dsi, ts, c) for dsi in ds)
end

works

1 Like

I took the help of regex to interactively solve part 1 of d19

Now I have looked more carefully at your solution and I would like to have a confirmation and an explanation.
the use of the nextpow(2) function is used to speed up the calculation of M^L, right?
Can you explain with some details (minimal example?) the role of M in the calculation of partitions? Or simply indicate some reference where to find clarifications on the matter.

Yeah, that makes sense. In all likelihood it won’t actually improve performance because cm[ss] is cached or maybe even remains stored in a register. But it certainly should never hurt.

In general I would advise to just test both versions and stick with the one with the ‘best’ combination of speed / legibility. (In this case both versions will be fine.)

Yes.

The idea is to make a graph with nodes in each letter and edges when the underlying string matches one in the allowed ones. With this graph, the goal is to count the paths from the first character to the last. In order to allow varying length paths, the last character has a self edge by M[end, end] = 1.

Matrix multiplication is used to count the paths in the usual correspondence.

Here’s another short approach, which in my eyes is a bit easier and should be roughly in the same ballpark of execution time:

using Memoization

function count_combinations(d, ts)  
    # ts is a list of base Strings. We count in how many ways we can combine some of them to create d.
    length(d) == 0 && return 1
    return sum(@memoize(count_combinations(@view(d[1:end-length(t)]), ts)) for t in Iterators.filter(t -> endswith(d, t), ts); init=0)
end
@btime sum(count_combinations(d, $ts) for d in ds) setup=Memoization.empty_cache!(count_combinations)
  322.416 ms (198049 allocations: 11.03 MiB)
758890600222015

Explicitly caching intermediate counts instead of relying on @memoize is quite a bit faster though, on a similar level as f1/f1b (which are also conceptually similar in approach):

function count_combinations(d, ts, cache::Dict{UInt, Int}=Dict{UInt, Int}())
    # ts is a list of base Strings. We count in how many ways we can combine some of them to create d. 
    # cache stores intermediate results for reuse. Using UInt hashes as keys is a bit faster than something like Union{String, SubString{String}}.
    length(d) == 0 && return 1  # (Alternatively, could add "" => 1 to cache)
    return get!(cache, hash(d)) do  # if hash(d) already exists as a key, look up the value, otherwise add the return value from the anonymous function below
        count = 0
        for t in ts
            if endswith(d, t)
                sub = @view d[1:end-length(t)]
                count += get!(cache, hash(sub)) do
                    count_combinations(sub, ts, cache)
                end
            end
        end
        return count
    end
end
julia> @btime begin
           sum(count_combinations(d, $ts) for d in ds)  # New cache for every d
       end
  68.000 ms (4286 allocations: 1.97 MiB)
758890600222015

julia> @btime begin
           cache = Dict{UInt, Int}()  # Reuse for all d in ds
           sum(count_combinations(d, $ts, cache) for d in ds)
       end
  65.276 ms (38 allocations: 1.42 MiB)
758890600222015
1 Like

This seems a problem full of inspirations.
I tried to better understand the use of matrix algebra in Dan’s solution and I found (confirming my belief that this problem was structurally the same as problem day 10) a solution (without any claim of efficiency) with the technique brilliantly explained here of problem 10, which I post here for future reference.

sol part2 AOC d10
inp=readlines("input_aoc24_d10.txt")
m=first.(stack(split.(inp,""),dims=1))
cim=CartesianIndices(m)

am=zeros(Int, length(cim), length(cim))
nr,nc=size(m)
for (i,ci) in enumerate(cim)
    for (r,c) in ((-1,0),(1,0),(0,-1),(0,1))
        nb=CartesianIndex(ci.I[1]+r,ci.I[2]+c) 
        nb in CartesianIndices(m)  || continue
        j=(nb.I[2]-1)*nr+nb.I[1]
        am[i,j]=(m[ci]-m[nb])==-1 ? 1 : 0
    end
end

am9=((am^2)^2)^2*am
sum(sum((am9)[i,j] for j in findall(==('9'),m[:])) for i in findall(==('0'),m[:]))

you could limit the search to

dsf=filter(x->!endswith(x,"brg"),ds)

even if, even removing 86 of the 400 ds, it doesn’t seem to bring a big improvement