I am interested in wavelets W_n(t,p) that are generated by convolving a Chebyshev polynomial of the first kind T_n(x) with a complex sinusoid \exp(pt):
Here n = 0,1,2,\ldots, t \in \mathbb{R} and p \in \mathbb{C}. Note that W_n(t,p) is a holomorphic function in p.
For example,
It turns out that for n \geq 2 these wavelets satisfy the following recursion:
I would like to implement this recursion in Julia by using automatic differentiation (AD) to calculate the \frac{\partial W_{n-1}(t,p)}{\partial p} term, but I am not sure this is practically feasible. I would like to evaluate W_n(t,p) up to n \sim 100.
Hereās a naive implementation using ForwardDiff.jl
which only accepts real p for simplicity.
using ForwardDiff: derivative
function W(n, t, p)
n == 0 && return Wā(t, p)
n == 1 && return Wā(t, p)
2t*W(n-1, t, p) - 2āW(n-1, t, p) - W(n-2, t, p)
end
function āW(n, t, p)
f(q) = W(n, t, q)
derivative(f, p)
end
function Wā(t, p) # Implements the example equation $W_0$ above
t < -1 && return complex(0.)
t ā¤ 1 && return (exp(p+p*t)-1)/p
2*exp(p*t)*sinh(p)/p
end
function Wā(t, p)
t < -1 && return complex(0.)
t ā¤ 1 && return -(1+exp(p+p*t)*(p-1)+p*t)/p^2
-2*exp(p*t)*(p*cosh(p)-sinh(p))/p^2
end
The problem is that this implementation loses accuracy very fast to the point of producing O(1) errors at n=8. At n=9 it actually returns -Inf
(the correct answer is 32.0746
).
julia> t = .5; p = -.1;
julia> W(2, t, p), W(3, t, p), W(4, t, p), W(5, t, p), W(6, t, p), W(7, t, p), W(8, t, p), W(9, t, p)
(-0.7221760345050515, 0.1500347620713356, 0.15994672950485722, 0.18399848059162338, -0.036349924464325056, -0.09898354246626184, 1.7540793533883656, -Inf)
In addition, the compilation times increase very rapidly with n. I think this is caused by ForwardDiff.jl
calculating partial derivatives of very high orders due to the recursion, and Julia compiling a specialized version of W(n, t, p)
for each order.
I also tried Zygote.jl
but this gives me a missing adjoint error. This post is already very long, so I decided to only illustrate the problem with ForwardDiff.jl
.
In any case I suspect that the underlying problem might be independent of the AD engine used. These higher order partial derivatives do indeed appear when one differentiates the recursive relation above with respect to p, but I had a hunch that AD could somehow avoid calculating them. Am I wrong?