How to properly generate a function based on a condition?

Hello!

I have some pseudo code from a package I am working on:

    function GenerateOutputKernelValues(SimMetaData)
        flag = SimMetaData.FlagOutputKernelValues
        
        if flag
            # Return a function that performs the calculations
            return function(αD, q, ∇ᵢWᵢⱼ, ichunk, SimThreadedArrays, i, j)
                Wᵢⱼ = @fastpow αD*(1-q/2)^4*(2*q + 1)
                SimThreadedArrays.KernelThreaded[ichunk][i] += Wᵢⱼ
                SimThreadedArrays.KernelThreaded[ichunk][j] += Wᵢⱼ
                SimThreadedArrays.KernelGradientThreaded[ichunk][i] += ∇ᵢWᵢⱼ
                SimThreadedArrays.KernelGradientThreaded[ichunk][j] += -∇ᵢWᵢⱼ
                return Wᵢⱼ
            end
        else
            # Return a function that does nothing
            return function(αD, q, ∇ᵢWᵢⱼ, ichunk, SimThreadedArrays, i, j)
                return nothing
            end
        end
    end

Where the main idea is, that to avoid checking an if statement inside of a hot -loop 1 million times, I just design the function exactly how I need it for this simulation load case and let it be “constant” throughout the whole simulation.

What I lose with doing it this way is some transparency in regards to function inputs etc., so I wondered if there was a smarter way,

Thanks!

1 Like

Can’t you put that flag into the type system ? That would allow optimisations at compile time (on top of readability improvements if you do this pattern regularly)

Checking a bool at runtime is really fast, especially if the value is constant/predictable/varies rarely. Then branch prediction may almost eliminate the cost. Returning a function like this, on the other hand, leads to a type instability which could be costly.

Have you tried benchmarking to confirm your suspicion?

1 Like

Yes, I have. In my case, since it is a particle simulation with N-body interactions, with the amount of interactions happening I can shave off in total ~5-10% in my small testcase in regards to neighbor looping with having no branches at all.

This might not sound as much, but due to the sheer volume of calculations and my goal of trying to make it faster and faster, I want to see if I can have it faster in a nice way too!

That could be an idea!

So currently I have:

@with_kw mutable struct SimulationMetaData{Dimensions, FloatType <: AbstractFloat}
    SimulationName::String
    SaveLocation::String
    HourGlass::TimerOutput                  = TimerOutput()
    Iteration::Int                          = 0
    OutputEach::FloatType                   = 0.02 #seconds
    OutputIterationCounter::Int             = 0
    StepsTakenForLastOutput::Int            = 0
    CurrentTimeStep::FloatType              = 0
    TotalTime::FloatType                    = 0
    SimulationTime::FloatType               = 0
    IndexCounter::Int                       = 0
    ProgressSpecification::ProgressUnknown  =  ProgressUnknown(desc="Simulation time per output each:", spinner=true, showspeed=true) 
    FlagViscosityTreatment::Symbol          = :ArtificialViscosity; @assert in(FlagViscosityTreatment, Set((:None, :ArtificialViscosity, :Laminar, :LaminarSPS))) == true "ViscosityTreatment must be either :None, :ArtificialViscosity, :Laminar, :LaminarSPS"
    VisualizeInParaview::Bool               = true
    ExportSingleVTKHDF::Bool                = true
    ExportGridCells::Bool                   = false    
    OpenLogFile::Bool                       = true
    FlagDensityDiffusion::Bool              = false
    FlagLinearizedDDT::Bool                 = false
    FlagOutputKernelValues::Bool            = false     
    FlagLog::Bool                           = false
    FlagShifting::Bool                      = false
    FlagSingleStepTimeStepping::Bool        = false
end

With all kinds of flags etc. Should I just make some kind of super-type or? I might need a small example here, thank you

Are you trying to generate a function with the same call signature because that’s how it’s already referenced downstream? Especially since one of the cases is just a constant, it sounds like you need metaprogramming! That being said, I’m still new myself at building Julia macros, so I guess take this with a grain of salt.

We can use a macro to preprocess the source code at parse-time (even before compile time!) and, if the flag is false, manifestly replace the function call with 0 so there’s not even any function calls to compile let alone run (who even needs inlining?). Then the compiler should also optimize out any arithmetic on those 0 values for even more speed.

macro tozero(expr)
    if global flag
        # When flag is true, the macro just passes through the expression unchanged
        return esc(expr)
    else
        # When flag is false, the macro replaces the expression with 0.0
        return :(0.0)
    end
end

function foo(args...) # foo_implementation
    w = 1.0 # calculate w
    return w
end

flag = true
@tozero foo(1, 2, 3)  # executes foo_impl(1, 2, 3)
# evaluates to 1.0

flag = false
@tozero foo(1, 2, 3)  # replaces foo_impl(1, 2, 3) with 0.0
# evaluates to 0.0

So in your case all you would need to do is find & replace foo(...) with @tozero foo(...). If this was C (which is never a sentence I thought I would say), you would just need to simply use its preprocessor directives as

// at the top of the file
bool flag = ; // true or false

#if flag
// define the function as normal
float foo(ARGS) {
    float w = 1.0; // calculate w
    return w
}
#else
// manifestly replace source foo calls with 0.0
    #define foo(ARGS) (0.0)
#endif

As long as the runtime dispatch is outside the hot loop, this should be fine.

Edit: never mind, looks like I confused Int with Integer

Let me just also point out that your mutable struct is an abstract container which is generally advised against. You may want to try replacing the Int’s by Int64’s, or try parametrizing them for some potential gains. (Similar to what you did for the FloatTypes). Is TimerOutput an abstract type as well?

It should be enough to add a type parameter onbyour struct here (choose a better name than I do :wink: )

@with_kw mutable struct SimulationMetaData{Flag, Dimensions, FloatType <: AbstractFloat}

where Flag is either Val(true) or Val(false)
And then use it:

function GenerateOutputKernelValues(SimMetaData::SimulationMetaData{Flag}) where Flag
        flag = SimMetaData.FlagOutputKernelValues isa Val(true)
       if flag
              # ...
end

The compiler will remove the check. You can also add @inline to the functions (both GenerateOutputKernelValues and the functions it returns) to further encourage the compiler, but I think the functions are small enough to be inlined anyways.

No the struct is fine type-wise.
Int is just a short hand for ints of the native size and the FloatType is in the type definition. You are probably confusing Int with Integer which is indeed an abstract type.

1 Like

If I put in all the flags as types, so I will have for example, SimMetaData{A,B,C,D,E}, then I do not need to generate functions outside any more right and I can use if statements inside of the main function again or how to understand?

Generally, the compiler operates on types. So putting information into types means giving the compiler more information. If the compiler can statically decide that some condition is always true/false, it removes the check.
So yes in principle you can put all the flags you have in the type (as long as they are constant throughout your simulation).

If you have more than 2 flags I would recommend creating some convenience methods around that, e.g. make struct like

struct MyFlags{A,B,C,D,E}
    function MyFlags(a,b,c,d,e) #this is type unstable which may or may not be an issue
        # depends on where you create this
        # probably fine I think
        return new{Val(a), Val(b),...}
    end
end

function isAset(::MyFlags{A}) where A
    return A isa Val(true)
end 
...

And then you put this struct into the type of SimulationMetaData

Got it will try to implement that, I have done a full manual implementation to understand initially.

How would I additionally include some of the more non-trivial flags:

FlagViscosityTreatment::Symbol          = :ArtificialViscosity; @assert in(FlagViscosityTreatment, Set((:None, :ArtificialViscosity, :Laminar, :LaminarSPS))) == true "ViscosityTreatment must be either :None, :ArtificialViscosity, :Laminar, :LaminarSPS"

And would I be able to keep the helper message?

Sure you can use Val to put whatever you want into the type domain :slight_smile:
And yes you can keep that error message. Just move the check to wherever you construct your type I suppose.