Autodifferentiation of a recursive function

I am interested in wavelets W_n(t,p) that are generated by convolving a Chebyshev polynomial of the first kind T_n(x) with a complex sinusoid \exp(pt):

W_n(t,p) = T_n(t) * \exp(p t)

Here n = 0,1,2,\ldots, t \in \mathbb{R} and p \in \mathbb{C}. Note that W_n(t,p) is a holomorphic function in p.

For example,

W_0(t,p) = \begin{cases} 0, &t < -1\\ \frac{e^{p+pt}-1}{p}, &t \in [-1,1] \\ \frac{2e^{pt} \sinh(p)}{p}, &t > 1 \end{cases}

It turns out that for n \geq 2 these wavelets satisfy the following recursion:

W_n(t,p) = 2t W_{n-1}(t,p) - 2\frac{\partial W_{n-1}(t,p)}{\partial p} - W_{n-2}(t,p)

I would like to implement this recursion in Julia by using automatic differentiation (AD) to calculate the \frac{\partial W_{n-1}(t,p)}{\partial p} term, but I am not sure this is practically feasible. I would like to evaluate W_n(t,p) up to n \sim 100.

Here’s a naive implementation using ForwardDiff.jl which only accepts real p for simplicity.

using ForwardDiff: derivative

function W(n, t, p)
    n == 0 && return W₀(t, p)
    n == 1 && return W₁(t, p)
    2t*W(n-1, t, p) - 2∂W(n-1, t, p) - W(n-2, t, p)
end

function ∂W(n, t, p)
	f(q) = W(n, t, q)
	derivative(f, p)
end

function W₀(t, p) # Implements the example equation $W_0$ above
	t < -1 && return complex(0.)
	t ≤ 1 && return (exp(p+p*t)-1)/p
	2*exp(p*t)*sinh(p)/p
end

function W₁(t, p)
	t < -1 && return complex(0.)
	t ≤ 1 && return -(1+exp(p+p*t)*(p-1)+p*t)/p^2
	-2*exp(p*t)*(p*cosh(p)-sinh(p))/p^2
end

The problem is that this implementation loses accuracy very fast to the point of producing O(1) errors at n=8. At n=9 it actually returns -Inf (the correct answer is 32.0746).

julia> t = .5; p = -.1;
julia> W(2, t, p), W(3, t, p), W(4, t, p), W(5, t, p), W(6, t, p), W(7, t, p), W(8, t, p), W(9, t, p)
(-0.7221760345050515, 0.1500347620713356, 0.15994672950485722, 0.18399848059162338, -0.036349924464325056, -0.09898354246626184, 1.7540793533883656, -Inf)

In addition, the compilation times increase very rapidly with n. I think this is caused by ForwardDiff.jl calculating partial derivatives of very high orders due to the recursion, and Julia compiling a specialized version of W(n, t, p) for each order.

I also tried Zygote.jl but this gives me a missing adjoint error. This post is already very long, so I decided to only illustrate the problem with ForwardDiff.jl.

In any case I suspect that the underlying problem might be independent of the AD engine used. These higher order partial derivatives do indeed appear when one differentiates the recursive relation above with respect to p, but I had a hunch that AD could somehow avoid calculating them. Am I wrong?

What problem do you want us to look at first: accuracy or performance?

1 Like

I think improving performance is the easiest to start out with.

As I said, I think the key is in avoiding calculation of the higher order partial derivatives by reusing derivative information as n increases, but I am not sure at all if this is possible with AD. (It is possible in principle with finite differences.)

The short answer is that this recursion is a really bad method to implement. For n=100, this will require aproximately ϕ^100 = 8*10^20 recursive calls. You would get much better results by just numerically computing the convolution.

2 Likes

So it does not makes sense for me to investigate further?

Profiler shows

    n == 1 && return W₁(t, p)

as one of the main culprits…

Thank you for your answer. The code given above is a naive implementation, so I’m not sure if this recursion is such a bad approach. (But please correct me if I’m wrong.)

This is because in practice I never need to calculate T_n(t,p) for an isolated value of n, but always a list \{T_n(t,p)\}_{n=0}^N of them. The recursive relationship would enable me to construct this list sequentially.

For example, using finite differences to construct this list is possible in O(N), but might be too inaccurate. I’ll implement it and see how it compares.

This is like the classic Fibonacci function and no it doesn’t make sense to investigate further in recursive form. It’d make more sense to compute iteratively starting at W_0 and working your way up, remembering the last 2 W values and the previous derivative so that you can always just use them…

W2prev = ...
W1prev = ...

for i in 0:order
   Wnext = 2*t*Wprev - 2*.... - W2prev
    W2prev = W1prev
   W1prev = Wnext
end

Except you’d have to figure out how to put something in the … slot to deal with the derivative. This is probably doable.

2 Likes

Exactly. My main question to the AD wizards if something like this is indeed possible. :slight_smile:

1 Like

Ah, TCO and CPS for AD. This will be quite entertaining(for some degree of entertaining)…

Well, I think you should be able to come up with a recursive formula for the derivative as well right? Then you just autodiff the base case and work your way up ?

Nope, I guess the problem is that dWn/dp depends on the derivative of dW(n-1)/dp so you’re doing higher order derivatives further and further until it explodes?

Suppose that you have an infinitely fast autodiff… so we don’t care about inefficiency, it’s still a potentially HUGE numerical stability problem right? if this is a wavelet with some oscillations then the d^100/dp^100 will have fourier coefficients that grow like p^100 right?

3 Likes

So, can we help OP or should we let OP figure it out?

That’s a very good point that I will have to clear out first.

My investigations show some things that might be relevant to the numerical stability:

  • The magnitude of W_n(t,p) decreases for increasing n (given that \Re(p) < 0, which is always true)
  • W_n(t,p) always tapers off as e^{pt} for t > 1
  • W_n(t,p) oscillates smoothly in [-1,1] with roughly n zeros

This suggests that the recursion might be well-behaved, but I’m relatively inexperienced with matters of numerical stability. @dlakelan, I would love to hear your thoughts on this, if you have time.

So, can we help OP or should we let OP figure it out?

I’ll think about it some more :slight_smile:. Thank you all very much for your time and help!

It’s been a while since I’ve thought about this kind of math, so let me know if I’m way off base.

You’re taking derivatives with p. p is a complex number and there’s some requirement for Re(p) < 0 but is the magnitude of p less than 1? If so p^n will approach 0 for large n and I think that’s what will get you stability. But if abs(p) can be 1 or larger then it seems it will be very badly behaved due to growth?

if abs(p) == 1 like in a Fourier representation, then if there is any numerical error in the calculations you can get abs(p) > 1 and then explosion, so it will be numerically unstable even though mathematically it is supposed to just orbit at abs(p) = 1

Also if abs(p) << 1 then you can probably treat d^n/dp^n as very small for large n, and so you might do an approximation with just 4 or 5 terms? in which case perhaps a recursion relationship with truncation would be sufficient?

A classic “solution” for the explosion of computation caused by the fibonacci recursion is to Memoize the function. So perhaps Memoization.jl would let you write the recursive version and still have it be efficient. If I were going to try this I’d try to write a wanddwdp function that returns both W(t,p) and dW/dp(t,p). Do the base case and then the recursion, and then memoize the function, and see if you get anything useful (because this should be pretty easy to write) .

1 Like

I’m having fun… trying it out for you… be back in a few mins.

1 Like

Haha, see you here then. I’m doing the math analysis.

1 Like

You give the base case for n == 0, what happens if n == 1?

It is given by the W₁(t, p) function in the code above.

Ah, you want something like the 100’th derivate of something via AD, IIUC? That sounds complicated.

Yep. I’m beginning to think it is not possible. With “it” I mean what @dlakelan said previously:

Except you’d have to figure out how to put something in the … slot to deal with the derivative. This is probably doable.

AFAIU, AD allows you with certain interfaces to reuse computation of your function for the 1st and 2nd derivative, but I have yet to see interfaces for n-th derivatives… But, OTOH derivatives are somehow recursive too…

1 Like