DuckDB has another kind of UDF: Aggregation Functions
Probably not as useful as scalar functions but they could still enable some nice processing on large datasets. It basically needs 5 functions to work:
get_state_size(fun)
→ should return the size of the state in bytes
init_state(fun, state_ptr)
→ should initialize the state to the memory pointed at state_ptr
update(fun, chunk, state)
→ read input from the chunks and update the state
combine(fun, state1, state2)
→ combine states (in multi-threading environments)
finalize(fun, states, result_vec)
→ write the result to result_vec
cleanup(states)
→ function that can perform cleanup tasks (optional)
Because directly working the DuckDB data chunks and vectors is a bit cumbersome, i build a proof of concept that abstracts away most of the boilerplate and lets you work with the correct Julia types (so Julia Arrays, Strings, etc.)
Usage looks like this:
fun = DuckDB.AggregateFunction("my_avg", U, Y, X, init_function, update_function!, combine_function!, finalize_function!)
DuckDB.register_aggregate_function(con, fun)
where U
, Y
, X
are the data types of the input, the output and the state respectively.
The functions should look like this
init_function = () -> return initial_state
update_function! = (x, u1, u2, ...) -> x[1] = some_calc(x[1], u1,u2, ...)
combine_function!(x1,x2) -> x1[1] = merge_states(x1[1],x2[1])
finalize_function!(y,x) -> y[1] = get_result(x[1])
The state x
is currently always passed as an array, (similiar to the inplace form in DifferentialEquaions.jl). The state size function is just sizeof(X)
.
Feedback welcome
Let me know what you think and if you have any suggestions or improvement ideas regarding the Julia Interface/API. Below are two usage examples.
Current Issues:
Memory management might be a problem. If someone uses reference types they could get garbage collected, while DuckDB still has a state pointer to it. I currently circumvented it by creating a global dictionary that keeps a (pointer, Ref{state}) to prevent it from getting garbage collected. Maybe there is a better way to do this
Examples
Here is a basic example that calculates the average of a column:
# Calculate the average of a column
U = (Int,) # Input
X = Tuple{Int, Int} # State = (sum, count)
Y = Float64 # Output
init_function = () -> (0, 0)
update_function = function (x, u)
_sum, _count = x[1]
_sum += sum(u)
_count += length(u)
x[1] = (_sum, _count)
return nothing
end
combine_function = (x1, x2) -> x1 .+= x2 # Hope this works, not tested
finalize_function = (y, x) -> (y .= [xi[1] / xi[2] for xi in x]; nothing)
fun =
DuckDB.AggregateFunction("my_avg", U, Y, X, init_function, update_function, combine_function, finalize_function)
N = 10_000_000
df = DataFrame(a = 1:N, b = 1:N, c = rand(N), d = rand(N))
db = DuckDB.DB()
con = DuckDB.connect(db)
DuckDB.register_table(con, df, "test1")
DuckDB.register_aggregate_function(con, fun)
result = DuckDB.execute(con, "SELECT my_avg(a) as result FROM test1") |> DataFrame
@show result
@test result.result[1] == sum(df.a) / length(df.a)
@info "Agg Func Stats" fun.stats
Another example that computes the least squares parameter estimate over a dataset in DuckDB:
U = (Float64, Float64, Float64) # Input
X = Tuple{Matrix{Float64}, Vector{Float64}} # State
Y = Tuple{Float64, Float64} # Output
# Covariance Matrix, parameter vector
λ = 0.999 # Forgetting factor
P0 = 1e6 .* [1.0 0.0; 0.0 1.0]
x0 = (P0, zeros(2)) # TODO Decide how to track the state, so it doesnt get garbage collected
rls_init_function = () -> (copy(P0), zeros(2))
rls_update_function = function (x, y, u1, u2)
P, θ = x[1]
for (yi, ui1, ui2) in zip(y, u1, u2)
ϕ = [ui1, ui2]
k = P * ϕ / (λ + ϕ' * P * ϕ)
P .= (P - k * ϕ' * P) / λ
θ .= θ + k * (yi - dot(ϕ, θ))
end
end
combine_function = (x1, x2) -> throw(Exception("This is not supported!")) # Needed for multi-threading?
finalize_function = function (y, x)
@info "Finalize RLS" y x
P, θ = x[1]
y[1] = (θ[1], θ[2]) # we only support scalars for know
return
end
fun = DuckDB.AggregateFunction(
"r_least_squares",
U,
Y,
X,
rls_init_function,
rls_update_function,
combine_function,
finalize_function
)
# Dummy Data
N = 1_000
theta_true = [0.4, 0.2]
u = randn(N, 2)
y = u * theta_true
df = DataFrame(y = y, u1 = u[:, 1], u2 = u[:, 2])
db = DuckDB.DB()
con = DuckDB.connect(db)
DuckDB.register_table(con, df, "test1")
DuckDB.register_aggregate_function(con, fun)
result = DuckDB.execute(con, "SELECT r_least_squares(y,u1,u2) as result FROM test1") |> DataFrame
@info "Agg Func Stats" fun.stats
# Check if the result is close to the true value
@test result.result[1] ≈ (theta_true[1], theta_true[2]) atol=1e-2 broken=true
@show result
# Prints something like:
# Row │ result
# │ Tuple…
# ─────┼────────────
# 1 │ (0.4, 0.2)