[ANN] BesselK.jl: an AD-compatible modified second-kind Bessel function

Hey all. Just wanted to share a project that I am just finishing up with some collaborators: BesselK.jl. It is a package for evaluating the second-kind modified Bessel function \mathcal{K}_\nu(x) that has been specifically designed to have accurate ForwardDiff.jl-based derivatives \tfrac{\partial}{\partial \nu} \mathcal{K}_\nu(x). While those derivatives do in some sense exist in closed form, the expressions for them are sufficiently challenging to compute (and would be sufficiently slow to compute) that they don’t really seem like viable options to us. So we have taken a different approach here.

In order to avoid naming conflicts with SpecialFunctions.besselk, we export adbesselk and adbesselkxv, where adbesselkxv gives x^\nu \mathcal{K}_\nu(x), a specific scaling that comes up in the Mat'ern covariance function, which was the primary motivation for this work. Here is a very unexciting demo:

using Besselk, ForwardDiff
(v,x) = (1.1, 1.1)
ForwardDiff.derivative(_v->adbesselk(_v, x), v) # accurate to at least atols of 1e-10

This was enough of a project that we ended up writing a paper about it. Without getting into too many details here, suffice it to say that a lot of expressions for \mathcal{K}_\nu come after limits have been taken, and for some involved reasons are not usable for implementations you want to pass AD through. So there was a lot of work to be done for when \nu + \tfrac{1}{2} \in \mathbb{Z} or \nu \in \mathbb{Z}, for example.

As an added bonus, if you’re okay with giving up the last couple trailing digits, derivatives with respect to argument also are great and are SO much faster than the the two calls to SpecialFunctions.besselk that give the exact derivative. Here is a general summary of the package:

  • fast direct evaluations, sometimes at the cost of the last few digits: the exported function adbesselk actually gives SpecialFunctions.besselk if v isa AbstractFloat, so by default you will always get the most accurate result. But BesselK._besselk(v,x) is significant faster than SpecialFunctions.besselk(v,x) almost everywhere in the domain, and is never slower. So if you’re willing to give up a couple digits, you can see some very significant speed gains.

  • Accurate first and second derivatives with ForwardDiff: This is the whole point of the package, really. The entire implementation of Besselk._besselk, our direct implementation, was designed for this to work well. Higher derivatives, though, are not guaranteed to work, so use at your own risk.

  • Fast derivatives: If you use adaptive finite difference derivatives like from the beautiful FiniteDifferences.jl you’ll get something comparably accurate, but it allocates and also takes 5+ function calls to SpecialFunctions.besselk, which will mean more than an order of magnitude more runtime cost than ForwardDiff+BesselK, and for almost literally no accuracy gain in most cases. All derivatives with ForwardDiff.jl+BesselK.jl are also non-allocating, so they will work a little bit better with many threads and stuff.

As a closer, here is an example implementation of the Matern covariance function that is now fully AD-compatible in all 3 arguments:

function matern(x, y, params)
  (sg, rho, nu) = params
  dist = norm(x-y)
  iszero(dist) && return sg*sg
  arg = sqrt(2*nu)*dist/rho
  (sg*sg*(2^(1-nu))/gamma(nu))*adbesselkxv(nu, arg)
end

For plenty of people reading this who don’t particularly care about implementation details, the real point is that you can add this to your dependency tree and then forget about it and start fitting smoothness parameters in a way that is composable with whatever fancy method you use to fit Gaussian processes.

If you do use this project for something, though, please cite the paper. This was really a lot of work, and the reason it hasn’t been done already is because it is a huge pain. I know it seems silly to cite a paper that exported one little function that is sort of an afterthought in your dependency tree compared to whatever fancy Bayesian package you use to fit stuff or whatever, but it really took some R&D and grinding to actually put it together and make it so trivial to drop in and forget about.

17 Likes

Sorry to double tap this thread, but it occurs to me to mention something: there is an existing limitation that sort of means you need to use ForwardDiff.jl for the AD. But the problem is really that I don’t really know how check via types whether or not a forward evaluation or a derivative is being evaluated other than checking something like v isa Dual{T,V,N} where{T,V,N} or something. But maybe somebody else knows and can help here: is there some more general way of checking at the type-system level whether the function call is for an AD derivative or not?

To explain why the code currently does that, there are some special cases where \mathcal{K}_\nu(x) has nice reductions, like when \nu + \tfrac{1}{2} \in \mathbb{Z}, and I want to keep those in there because they are much faster. But those reduced forms are predictably incorrect when you are taking a derivative, so I have a code pattern of

[...]
is_ad = !(v isa AbstractFloat)
return is_ad ? more_expensive_version(...) : less_expensive_version(...)
[...]

I am sure that this code not necessarily working well with something like Zygote out of the box limits its ability to be seamlessly integrated with fancier GP tools, so if anybody has advice on how to solve that I’d be very interested.

2 Likes

We played with this problem at some point, and my recollection is that the hardest case to handle accurately is when \nu is close (by say 1e-8) to those points. Were you able to deal with that?

We do have a solution for that, although it probably isn’t the most efficient thing possible. Our code has a lot of return is_near_int(v+0.5, tol=0.001) ? one_version(...) : another_version(...) in it, because as it turns out that the method from the Temme 1975 JCP paper works pretty well in those settings (example) so we really relied on that. It does come at the cost of occasionally missing the most specialized branches, like for \nu + \frac{1}{2} \in \mathbb{Z} for second derivatives with respect to \nu, though, which is disappointing.

The biggest problem we had was that in those limiting cases there were a million little things that were defined in limiting forms that didn’t really preserve derivative information with respect to \nu, and so we couldn’t use any of those expressions. But the method in the Temme paper really reduced that problem to just a single coefficient that had to be implemented properly about the origin, and that coefficient just had a handful of univariate functions where we could drop in local expansions around the origin so that all code branches preserve derivative information about \nu and so all the partials are correct. It isn’t pretty, but it did simplify the problem to the point where I was able to grind out the implementation.