Currying and constant propagation, Base.@pure

Dear Julians,

I have a generic model function with many parameters. I want to hold subsets of these parameters constant, and obtain well-specialized methods on a reduced parameter set.
Unfortunately, in many cases, no constant propagation happens without Base.@pure. This macro seems like a last resort to me, since it may produce bugs without warning, if used incorrectly.

So, my question: Is there a better way to do the following?

using BenchmarkTools

function doWork(x::T,y::T,z::T) where {T<:Float64}
    Base.@pure fa(x::T)= mapreduce(i->x+sin(i),+,1:1000; init=zero(T))
    Base.@pure fb(y::T)= mapreduce(i->y+sin(i),+,1:1000; init=zero(T))
    Base.@pure fc(z::T)= mapreduce(i->z+sin(i),+,1:1000; init=zero(T))
    a = fa(x)
    b = fb(y)
    c = fb(z)
    return a+b+c
end

@btime doWork(rand(),rand(),rand())

# Removing Base.@pure triples timings
@btime (x->doWork(x,0.2,0.3)).([0.11,0.12,0.13])
@btime (y->doWork(0.1,y,0.3)).([0.21,0.22,0.23])
@btime (z->doWork(0.1,0.2,z)).([0.31,0.32,0.33])

Thanks a lot.
Andreas

If you mean that particular example:

julia> function doWork(x::T,y::T,z::T) where {T<:Float64}
           Base.@pure fa(x::T)= mapreduce(i->x+sin(i),+,1:1000; init=zero(T))
           Base.@pure fb(y::T)= mapreduce(i->y+sin(i),+,1:1000; init=zero(T))
           Base.@pure fc(z::T)= mapreduce(i->z+sin(i),+,1:1000; init=zero(T))
           a = fa(x)
           b = fb(y)
           c = fb(z)
           return a+b+c
       end
doWork (generic function with 1 method)

julia> doWork(0.3,0.4,0.34)
1042.44190890222

julia> 3sum(sin,1:1000) + 1000*(0.3+0.4+0.34)
1042.4419089022194

But in general, why not just use loops?

2 Likes

Using GitHub - perrutquist/StaticNumbers.jl: Static numbers in Julia (maybe behind a function barrier) seems to be a useful approach here.

1 Like

@Elrod: Thanks for your suggestion. This was an example to demonstrate the issue. Iā€™d change coding style to loops if that helped. I tried a simple summation loop, it does not constant propagate either.

@tkf: This is an interesting package. I could not get it to propagate in the mapreduce case, I will play with it and see if I can make it work. Thanks!

Even with the very basic loops/recursive structures below, I cannot get this to work without Base.@pure. Am I doing something wrong there?

It would be really convenient to have this working in the future. It can be worked around of course, but it will take a few more lines.

using BenchmarkTools
using StaticNumbers

# Calculate 1.01^1000
# Resursive
function recPow0(x,i,j)
    i>j && return x
    recPow0(1.01*x,i+1,j)
end
# Loop
function loopPow0(s,i,j)
    for i0=i:j
        s= 1.01*s
    end
    s
end
# Resursive static
function recPow1(x,i,j)
    i>j && return x
    recPow1(static(1.01)*x,i+static(1),j)
end
# Loop static
function loopPow1(s,i,j)
    for i0=i:j
        s= static(1.01)*s
    end
    s
end

testRecPow0() = recPow0( 1.0, 1, 1000)
testLoopPow0() = loopPow0( 1.0, 1, 1000)
testRecPow1() = recPow1( static(1.0), static(1), static(1000))
testLoopPow1() = loopPow1( static(1.0), static(1), static(1000))


@btime testRecPow0()
@btime testLoopPow0()

@btime testRecPow1()
@btime testLoopPow1()