This is an example of an Aggregate UDF to compute a mean. I think there is a lot of code which could benefit from a macro, similar to the on @tqml wrote for a scalar udf.
I do not think there is a need to hold a reference to the State object to avoid the Julia gc, since duckdb will allocate that memory and pass a pointer. Instead you need to define the function to return the sizeof the State structure and then unsafe_load the memory at that pointer.
When updating the State, I am not sure how to best modify the state from the C pointer and avoid the extra allocation and unsafe_store!() calls. Perhaps there is a suggestion on how that can be done more efficiently.
code
import DuckDB
duckdb_aggregate_state = Ptr{Cvoid} # TODO: define in DuckDB ctypes.jl
mutable struct State
sum::Float64
count::Int64
function State()
new(0.0, 0)
end
end
function state_size(info::DuckDB.duckdb_function_info)::DuckDB.idx_t
sizeof(State)
end
state_init = function (
info::DuckDB.duckdb_function_info, x::duckdb_aggregate_state)
state = State()
Base.unsafe_store!(convert(Ptr{State}, x), state)
return nothing
end
update_function = function (info::DuckDB.duckdb_function_info,
input::DuckDB.duckdb_data_chunk, states::Ptr{duckdb_aggregate_state})
input = DuckDB.DataChunk(input, false)
row_count = DuckDB.get_size(input)
input_vector = DuckDB.get_vector(input, 1)
input_data = DuckDB.get_array(input_vector, Float64, row_count)
input_validity = DuckDB.get_validity(input_vector, row_count)
states = Base.unsafe_convert(Ptr{Ptr{State}}, states)
for i in 1:row_count
if DuckDB.isvalid(input_validity, i) == 1
p = Base.unsafe_load(states, i)
state = Base.unsafe_load(p)
state.sum += input_data[i]
state.count += 1
Base.unsafe_store!(p, state)
end
end
return nothing
end
combine_function = function (
info::DuckDB.duckdb_function_info, source::Ptr{duckdb_aggregate_state},
target::Ptr{duckdb_aggregate_state}, count::DuckDB.idx_t)
source = Base.unsafe_load(Base.unsafe_load(Base.unsafe_convert(
Ptr{Ptr{State}}, source)))
p = Base.unsafe_load(Base.unsafe_convert(Ptr{Ptr{State}}, target))
target = Base.unsafe_load(p)
target.sum += source.sum
target.count += source.count
Base.unsafe_store!(p, target)
return nothing
end
finalize_function = function (info::DuckDB.duckdb_function_info,
source::Ptr{duckdb_aggregate_state}, result::DuckDB.duckdb_vector,
count::DuckDB.idx_t, offset::DuckDB.idx_t)
states = Base.unsafe_wrap(
Array, Base.unsafe_convert(Ptr{Ptr{State}}, source), count)
result = DuckDB.get_array(DuckDB.Vec(result), Float64, count)
for i in 1:count
state = Base.unsafe_load(states[i])
result[offset + i] = state.sum / state.count
end
return nothing
end
#
db = DuckDB.DB()
state_size_cfunction = @cfunction(state_size, DuckDB.idx_t,
(DuckDB.duckdb_function_info,))
state_init_cfunction = @cfunction(state_init, Cvoid,
(DuckDB.duckdb_function_info, duckdb_aggregate_state))
update_cfunction = @cfunction(update_function, Cvoid,
(DuckDB.duckdb_function_info, DuckDB.duckdb_data_chunk,
Ptr{duckdb_aggregate_state}))
combine_cfunction = @cfunction(combine_function, Cvoid,
(DuckDB.duckdb_function_info, Ptr{duckdb_aggregate_state},
Ptr{duckdb_aggregate_state}, DuckDB.idx_t))
finalize_cfunction = @cfunction(finalize_function, Cvoid,
(DuckDB.duckdb_function_info, Ptr{duckdb_aggregate_state},
DuckDB.duckdb_vector, DuckDB.idx_t, DuckDB.idx_t))
# create an aggregate function
f = DuckDB.duckdb_create_aggregate_function()
DuckDB.duckdb_aggregate_function_set_name(f, "my_mean")
# add parameter and return type
type = DuckDB.duckdb_create_logical_type(DuckDB.DUCKDB_TYPE_DOUBLE)
DuckDB.duckdb_aggregate_function_add_parameter(f, type)
DuckDB.duckdb_aggregate_function_set_return_type(f, type)
DuckDB.duckdb_destroy_logical_type(type)
# set aggregate functions
DuckDB.duckdb_aggregate_function_set_functions(
f, state_size_cfunction, state_init_cfunction,
update_cfunction, combine_cfunction, finalize_cfunction)
# register and cleanup
DuckDB.duckdb_register_aggregate_function(db.main_connection.handle, f)
DuckDB.duckdb_destroy_aggregate_function(f)
# aggregate
DuckDB.query(db, "select mean(i) from generate_series(100000) t(i)")
DuckDB.query(db, "select my_mean(i) from generate_series(100000) t(i)")
# window aggregate
DuckDB.query(db,
"select i, mean(i) over (order by i ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) from generate_series(100000) t(i)")
DuckDB.query(db,
"select i, my_mean(i) over (order by i ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) from generate_series(100000) t(i)")
# list_aggregate
DuckDB.query(db,
"select i, list_aggregate([i, lag(i, 1) over w, lag(i, 2) over w], 'mean') m3 from generate_series(1000000) t(i) window w as (order by i)")
DuckDB.query(db,
"select i, list_aggregate([i, lag(i, 1) over w, lag(i, 2) over w], 'my_mean') m3 from generate_series(1000000) t(i) window w as (order by i)")