DuckDB has another kind of UDF: Aggregation Functions
Probably not as useful as scalar functions but they could still enable some nice processing on large datasets. It basically needs 5 functions to work:
get_state_size(fun) → should return the size of the state in bytes
init_state(fun, state_ptr) → should initialize the state to the memory pointed at state_ptr
update(fun, chunk, state) → read input from the chunks and update the state
combine(fun, state1, state2) → combine states (in multi-threading environments)
finalize(fun, states, result_vec) → write the result to result_vec
cleanup(states) → function that can perform cleanup tasks (optional)
Because directly working the DuckDB data chunks and vectors is a bit cumbersome, i build a proof of concept that abstracts away most of the boilerplate and lets you work with the correct Julia types (so Julia Arrays, Strings, etc.)
Usage looks like this:
fun = DuckDB.AggregateFunction("my_avg", U, Y, X, init_function, update_function!, combine_function!, finalize_function!)
DuckDB.register_aggregate_function(con, fun)
where U, Y, X are the data types of the input, the output and the state respectively.
The functions should look like this
init_function = () -> return initial_state
update_function! = (x, u1, u2, ...) -> x[1] = some_calc(x[1], u1,u2, ...)
combine_function!(x1,x2) -> x1[1] = merge_states(x1[1],x2[1])
finalize_function!(y,x) -> y[1] = get_result(x[1])
The state x is currently always passed as an array, (similiar to the inplace form in DifferentialEquaions.jl). The state size function is just sizeof(X).
Feedback welcome 
Let me know what you think and if you have any suggestions or improvement ideas regarding the Julia Interface/API. Below are two usage examples.
Current Issues:
Memory management might be a problem. If someone uses reference types they could get garbage collected, while DuckDB still has a state pointer to it. I currently circumvented it by creating a global dictionary that keeps a (pointer, Ref{state}) to prevent it from getting garbage collected. Maybe there is a better way to do this
Examples
Here is a basic example that calculates the average of a column:
# Calculate the average of a column
U = (Int,) # Input
X = Tuple{Int, Int} # State = (sum, count)
Y = Float64 # Output
init_function = () -> (0, 0)
update_function = function (x, u)
_sum, _count = x[1]
_sum += sum(u)
_count += length(u)
x[1] = (_sum, _count)
return nothing
end
combine_function = (x1, x2) -> x1 .+= x2 # Hope this works, not tested
finalize_function = (y, x) -> (y .= [xi[1] / xi[2] for xi in x]; nothing)
fun =
DuckDB.AggregateFunction("my_avg", U, Y, X, init_function, update_function, combine_function, finalize_function)
N = 10_000_000
df = DataFrame(a = 1:N, b = 1:N, c = rand(N), d = rand(N))
db = DuckDB.DB()
con = DuckDB.connect(db)
DuckDB.register_table(con, df, "test1")
DuckDB.register_aggregate_function(con, fun)
result = DuckDB.execute(con, "SELECT my_avg(a) as result FROM test1") |> DataFrame
@show result
@test result.result[1] == sum(df.a) / length(df.a)
@info "Agg Func Stats" fun.stats
Another example that computes the least squares parameter estimate over a dataset in DuckDB:
U = (Float64, Float64, Float64) # Input
X = Tuple{Matrix{Float64}, Vector{Float64}} # State
Y = Tuple{Float64, Float64} # Output
# Covariance Matrix, parameter vector
λ = 0.999 # Forgetting factor
P0 = 1e6 .* [1.0 0.0; 0.0 1.0]
x0 = (P0, zeros(2)) # TODO Decide how to track the state, so it doesnt get garbage collected
rls_init_function = () -> (copy(P0), zeros(2))
rls_update_function = function (x, y, u1, u2)
P, θ = x[1]
for (yi, ui1, ui2) in zip(y, u1, u2)
ϕ = [ui1, ui2]
k = P * ϕ / (λ + ϕ' * P * ϕ)
P .= (P - k * ϕ' * P) / λ
θ .= θ + k * (yi - dot(ϕ, θ))
end
end
combine_function = (x1, x2) -> throw(Exception("This is not supported!")) # Needed for multi-threading?
finalize_function = function (y, x)
@info "Finalize RLS" y x
P, θ = x[1]
y[1] = (θ[1], θ[2]) # we only support scalars for know
return
end
fun = DuckDB.AggregateFunction(
"r_least_squares",
U,
Y,
X,
rls_init_function,
rls_update_function,
combine_function,
finalize_function
)
# Dummy Data
N = 1_000
theta_true = [0.4, 0.2]
u = randn(N, 2)
y = u * theta_true
df = DataFrame(y = y, u1 = u[:, 1], u2 = u[:, 2])
db = DuckDB.DB()
con = DuckDB.connect(db)
DuckDB.register_table(con, df, "test1")
DuckDB.register_aggregate_function(con, fun)
result = DuckDB.execute(con, "SELECT r_least_squares(y,u1,u2) as result FROM test1") |> DataFrame
@info "Agg Func Stats" fun.stats
# Check if the result is close to the true value
@test result.result[1] ≈ (theta_true[1], theta_true[2]) atol=1e-2 broken=true
@show result
# Prints something like:
# Row │ result
# │ Tuple…
# ─────┼────────────
# 1 │ (0.4, 0.2)