Is it hard to support Julia UDFs in DuckDB?

@asbisen so, @tqml was the mind who made the great macro, which is a big UI improvement imo for the UDFs in duckdb.

you can see an example in this discussion. the issue we both ran into (for the version i merged and this macro), was segmentation fault when one of the arguments has a data type other than Int/Float

the code is available in duckdb fork on his repo

in short it looks like this

y_sum = (a,b) -> a+b
fun = @create_scalar_function "my_sum" [Int64, Int64] Int64 my_sum
register_scalar_function(db, fun)
execute(db, "SELECT my_sum(1,2)")
3 Likes

Hey @drizk1 thank you for bringing it up again. I developed a PoC that could work but it has some limitations but it is not-quite ready for merging. Also did not have much time to work on it in the last months.

  • I made a macro and non-macro version. The non macro wrapper has very poor performance because it uses generic types and the dispatch happens at runtime
  • The macro versions need to use eval() on ARM processors (or non-x86 architectures) which seems wrong? But I’m quite new to macro programming on Julia and would love some input here.
  • I could solve at least the string variant. DuckDB has a function assign_string_element() that does exactly what we need.

Link to the implementation: duckdb/tools/juliapkg/src/scalar_function.jl at julia/scalar-function · tqml/duckdb · GitHub

3 Likes

Update: I opened a thread to gather ideas how this can be implemented properly. The main problem seems to be that @cfunction only accepts objects in global scope (at leat to my understanding).

Not sure if a simple API like

fun = @create_scalar_function my_sum(::Int64, ::Int64)::Int64

is possible due to this limitation (which might be a bummer :confused: )

macro _create_scalar_function_new(func_expr)

    func, parameters, return_type = _udf_parse_function_expr(func_expr)
    func_esc = esc(func)
    func_name = string(func)
    parameter_names = [p[1] for p in parameters]
    parameter_types = [p[2] for p in parameters]
    parameter_types_vec = Expr(:vect, parameter_types...) # create a vector expression, e.g. [Int, Int]
    wrapper_expr = _udf_generate_wrapper(func_expr)

    quote
        local wrapper = $(wrapper_expr)
        local fun = ScalarFunction($func_name, $parameter_types_vec, $return_type, $func_esc, wrapper)

        wrapper(C_NULL, C_NULL, C_NULL) # works
        fun.wrapper(C_NULL, C_NULL, C_NULL) # works

        # Everything below only works in GLOBAL scope in the REPL
        # ptr = @cfunction(fun.wrapper, Cvoid, (duckdb_function_info, duckdb_data_chunk, duckdb_vector))
        # ccall(ptr, Cvoid, (duckdb_function_info, duckdb_data_chunk, duckdb_vector), C_NULL, C_NULL, C_NULL) # doesnt work
        # duckdb_scalar_function_set_function(fun.handle, ptr)
        
        fun
    end

end

3 Likes

Another update from my side :smiley:

I opened a pull request at the DuckDB repo that implements UDF like it was suggested in the Github Discussion.

Implementing the creation of the function pointer for C was more difficult than I thought, so there might be a better way to do this. If I have any suggestions, improvement ideas, I’m happy for any “pointers” :slight_smile:

Usage will look like this:

using DuckDB, DataFrames

f_add = (a, b) -> a + b
db = DuckDB.DB()
con = DuckDB.connect(db)

# Create the scalar function 
# the second argument can be omitted if the function name is identical to a global symbol
fun = DuckDB.@create_scalar_function f_add(a::Int, b::Int)::Int f_add
DuckDB.register_scalar_function(con, fun)

df = DataFrame(a = [1, 2, 3], b = [1, 2, 3])
DuckDB.register_table(con, df, "test1")

result = DuckDB.execute(con, "SELECT f_add(a, b) as result FROM test1") |> DataFrame

Feedback of course welcome :slight_smile:

7 Likes

This is awesome work.

2 Likes

I can’t wait for this PR to get merged. This would provide powerful capability to push down Julia functions to process larger than memory datasets. An excellent companion to DataFrames.jl. Thanks so much for the hard work.

4 Likes

Thank you for the kind words :slight_smile:
Also, good news, the PR got merged :slight_smile:
Let me know in case you run into any issues

3 Likes

DuckDB has another kind of UDF: Aggregation Functions

Probably not as useful as scalar functions but they could still enable some nice processing on large datasets. It basically needs 5 functions to work:

  • get_state_size(fun) → should return the size of the state in bytes
  • init_state(fun, state_ptr) → should initialize the state to the memory pointed at state_ptr
  • update(fun, chunk, state) → read input from the chunks and update the state
  • combine(fun, state1, state2) → combine states (in multi-threading environments)
  • finalize(fun, states, result_vec) → write the result to result_vec
  • cleanup(states) → function that can perform cleanup tasks (optional)

Because directly working the DuckDB data chunks and vectors is a bit cumbersome, i build a proof of concept that abstracts away most of the boilerplate and lets you work with the correct Julia types (so Julia Arrays, Strings, etc.)

Usage looks like this:

fun = DuckDB.AggregateFunction("my_avg", U, Y, X, init_function, update_function!, combine_function!, finalize_function!)
DuckDB.register_aggregate_function(con, fun)

where U, Y, X are the data types of the input, the output and the state respectively.

The functions should look like this

init_function = () -> return initial_state
update_function! = (x, u1, u2, ...) -> x[1] = some_calc(x[1], u1,u2, ...)
combine_function!(x1,x2) -> x1[1] = merge_states(x1[1],x2[1])
finalize_function!(y,x) -> y[1] = get_result(x[1])

The state x is currently always passed as an array, (similiar to the inplace form in DifferentialEquaions.jl). The state size function is just sizeof(X).

Feedback welcome :slight_smile:
Let me know what you think and if you have any suggestions or improvement ideas regarding the Julia Interface/API. Below are two usage examples.

Current Issues:
Memory management might be a problem. If someone uses reference types they could get garbage collected, while DuckDB still has a state pointer to it. I currently circumvented it by creating a global dictionary that keeps a (pointer, Ref{state}) to prevent it from getting garbage collected. Maybe there is a better way to do this

Examples

Here is a basic example that calculates the average of a column:

# Calculate the average of a column

    U = (Int,)          # Input
    X = Tuple{Int, Int} # State = (sum, count)
    Y = Float64         # Output

    init_function = () -> (0, 0)
    update_function = function (x, u)
        _sum, _count = x[1]
        _sum += sum(u)
        _count += length(u)
        x[1] = (_sum, _count)
        return nothing
    end
    combine_function = (x1, x2) -> x1 .+= x2 # Hope this works, not tested
    finalize_function = (y, x) -> (y .= [xi[1] / xi[2] for xi in x]; nothing)

    fun =
        DuckDB.AggregateFunction("my_avg", U, Y, X, init_function, update_function, combine_function, finalize_function)

    N = 10_000_000
    df = DataFrame(a = 1:N, b = 1:N, c = rand(N), d = rand(N))
    db = DuckDB.DB()
    con = DuckDB.connect(db)
    DuckDB.register_table(con, df, "test1")
    DuckDB.register_aggregate_function(con, fun)

    result = DuckDB.execute(con, "SELECT my_avg(a) as result FROM test1") |> DataFrame
    @show result

    @test result.result[1] == sum(df.a) / length(df.a)
    @info "Agg Func Stats" fun.stats

Another example that computes the least squares parameter estimate over a dataset in DuckDB:

    U = (Float64, Float64, Float64)                # Input
    X = Tuple{Matrix{Float64}, Vector{Float64}}    # State
    Y = Tuple{Float64, Float64}                    # Output

    # Covariance Matrix, parameter vector
    λ = 0.999 # Forgetting factor
    P0 = 1e6 .* [1.0 0.0; 0.0 1.0]
    x0 = (P0, zeros(2)) # TODO Decide how to track the state, so it doesnt get garbage collected
    rls_init_function = () -> (copy(P0), zeros(2))
    rls_update_function = function (x, y, u1, u2)
        P, θ = x[1]
        for (yi, ui1, ui2) in zip(y, u1, u2)
            ϕ = [ui1, ui2]
            k = P * ϕ / (λ + ϕ' * P * ϕ)
            P .= (P - k * ϕ' * P) / λ
            θ .= θ + k * (yi - dot(ϕ, θ))
        end
    end
    combine_function = (x1, x2) -> throw(Exception("This is not supported!")) # Needed for multi-threading?
    finalize_function = function (y, x)
        @info "Finalize RLS" y x
        P, θ = x[1]
        y[1] = (θ[1], θ[2]) # we only support scalars for know
        return
    end

    fun = DuckDB.AggregateFunction(
        "r_least_squares",
        U,
        Y,
        X,
        rls_init_function,
        rls_update_function,
        combine_function,
        finalize_function
    )

    # Dummy Data
    N = 1_000
    theta_true = [0.4, 0.2]
    u = randn(N, 2)
    y = u * theta_true
    df = DataFrame(y = y, u1 = u[:, 1], u2 = u[:, 2])

    db = DuckDB.DB()
    con = DuckDB.connect(db)
    DuckDB.register_table(con, df, "test1")
    DuckDB.register_aggregate_function(con, fun)

    result = DuckDB.execute(con, "SELECT r_least_squares(y,u1,u2) as result FROM test1") |> DataFrame
    @info "Agg Func Stats" fun.stats

    # Check if the result is close to the true value
    @test  result.result[1] ≈ (theta_true[1], theta_true[2]) atol=1e-2 broken=true

    @show result
    # Prints something like:
    #     Row │ result
    #         │ Tuple…
    #    ─────┼────────────
    #       1 │ (0.4, 0.2)

6 Likes

Thank you for all the work on Julia support on DuckDB, I use it heavily in my work.

Maybe this might help, instead of the global dictionary?

help?> GC.@preserve
  GC.@preserve x1 x2 ... xn expr

  Mark the objects x1, x2, ... as being in use during the evaluation of the expression expr. This is only required in unsafe code where expr implicitly uses memory or other resources owned by one of the xs.

See for example What am I doing wrong with `unsafe_wrap()`? - #2 by simeonschaub on how to use @preserve with unsafe_wrap. I imagine this use case should be similar?

It’s fantastic that this PR got merged, I am very excited to use DuckDB for a backed in more Julia projects. Is there any plan to add documentation regarding this feature to this page Julia Package – DuckDB? I would not have been aware of this awesome new feature had I not stumbled onto this Discourse thread, I fear many other users will also be unaware.

1 Like

Might be an option, however this requires the user to be aware that the state could be garbage collected and that an additional reference needs to be kept.
If the users misses it, it could lead to potential segmentation faults.

I would prefer something that lets the user write normal Julia code and everything “just works” without having to worry about these things.

1 Like

Totally agree, documenting it is on my todo list :slight_smile:

3 Likes

Yes, I agree, that would be problematic. What I had in mind was something much closer to what you are doing, except that instead of storing Ref(state) in global dict, the Julia DuckDB library would wrap the user defined functions inside a GC.@preserve r begin ... end block, with r = Ref(state).

I see now that this might be a little bit tricky to implement because, for as long as DuckDB has a state pointer, code should run within the GC.@preserve block. I’m not an expert on C interop / duckdb internals so I confess I’m not sure how feasible that is.

I don’t think I can wrap everything in a CG.@preserve block.

You basically create the AggregateFunction and then Julia doesn’t know when it will be called. DuckDB calls the function ptr to the C-Wrapper, which in turn calls the user provided functions. When the C-Wrapper returns, we would also leave any CG.@preserve blocks that we define there. The state however must persist across multiple invocations. So the only way I can think of to use GC.@preserve would be on the top-level wrapping the creation of the AggregateFunction and all DuckDB.execute calls that use the aggregate function.

Some other ways that I could think of that might work:

  • First one would be to estimate the true byte size of the object (so recursively for all nested elements), tell this to DuckDB so it can allocate enough memory and then serialize the object to the memory location (and deserialize it afterwards). This feels not very performant and difficult to implement.
  • Another one would be to move the global reference dictionary to a local dictionary inside the AggregateFunction and use the cleanup handler from duckdb. For this we would need store the pointer to the AggregateFunction in the state object as well (but not expose it the user). Then the cleanup method gets an array of states, retrieve the pointer to the function object and start deleting all elements in the local dictionary.
  • Somehow telling julia to disable garbage collection for a specific object and then reenable gargabe collection once the aggregate function is done.

I was wondering if x = unsafe_wrap(...; own = false) already did that to some extent (tell Julia that the underlying memory of x is managed by an external program and shouldn’t be garbage collected when x goes out of scope), but I’ve never used it in the case of C interoperability.

There are also Base.@_gc_preserve_begin / end (as in here), but I think those should only be used internally in Base Julia.