In many packages, we deal with target functions, either in linear or log-space, that will then be the object of optimization, statistical density sampling (esp. in Bayesian statistics, etc.), machine-learning, etc.
Currently, packages have different conventions/API for users to implement such target functions - often just a simple functions returning a value (trusting the user to know whether that should be a lin or log value), or if there’s a gradient, returning a tuple of value and gradient is also used commonly, but not universally.
In particular, when using plain values, there’s always the danger that a less-experienced user may return a linear value instead of the log-density value that an density sampler expects, etc.
I’d like to propose a very simple convention that packages could support (in addition to their own way of doing things) to let the users/packages express their intentions in a common way, without taking on any dependencies, by using the power of NamedTuple
s.
This isn’t really special in any way in itself - but it might prove quite powerful if it would become a convention that many packages support.
Proposal: Target functions return a NamedTuple
(linval = ...,)
or
(logval = ...,)
to indicate whether the result is in lin- or log-space. If the function computes a custom gradient (no auto-diff), it returns
(linval = ..., grad_linval = [...])
or
(logval = ..., grad_logval = [...])
This is up for discussion, of course - it can be, e.g., logval
, logd
(for log-density), logresult
, etc. - should just be short and clear and a convention we decide to adopt.
The algorithm evaluating the target function can then either complain (“need to give me the log of the density here!” or convert as appropriate).
This would also be implicitly compatible with the (value, gradient)
tuple that is commonly used in many packages, if they do something like value, gradient = result
internally.
The nice thing about using NamedTuple
s would be that it’s easily extensible. For example, in expensive calculations (e.g. complex likelihood functions, etc.), a user/application may often need to keep track of intermediate/auxiliary results - sometimes for debugging and verification, sometimes also as an important part of later analysis.
So for a use case like, e.g. MCMC sampling, we would allow the target function to return
(logval = ..., aux = (my_intermediate_sum = ..., some_other_value = ...))
(logval = ..., grad_logval = [...], aux = (my_intermediate_sum = ..., some_other_value = ...))
or maybe even something like (no requirement to use aux
explicitly)
(logval = ..., my_intermediate_sum = ..., some_other_value = ...)
The sampler, optimizer, etc. would then pass this through to it’s output (assuming the algorithm’s output structure provides a way to store auxiliary information).
I don’t think we should require a specific order in the NamedTuple
, long term. With a bit of generated function magic we can extract what we need, check if there’s a gradient, and then dispatch in a type-stable fashion (and automatically auto-diff if necessary). Probably would run the target function one time first (kinda like broadcast does to determine it’s result type). We could put together a lightweight package “TargetFunctionConventions.jl” (or so) that packages can use internally so we don’t have to duplicate that kind of value-extraction/canonicalization code. This would be for algorithm implementations to use internally, the package providing the target function wouldn’t need to depend on it. If there’s interest in this proposal, I’d volunteer to put something lightweight and (at least almost) dependency-free together.
I’d love to hear your thoughts on this.
CC @Tamas_Papp, @cpfiffer, @Kai_Xu, @yebai, @ChrisRackauckas (and I’ll have forgotten many others here, my apologies!)