When I translated another python project code into julia code, there was a problem. After unremitting efforts, I finally got the following two sample codes
Julia code
using Zygote
using FiniteDifferences
function test(x)
y = x * exp( im * 1.5 )
return real(y)
end
x0 = 2.3 + 4.5 * im
f, g = withgradient(test,x0)
println("f = ", f)
println("g = ", g)
The result of this julia code running is
f = -4.326031875882528
g = (0.0707372016677029 - 0.9974949866040544im,)
Python code
import jax.numpy as jnp
from jax import value_and_grad
def test(x):
y = x * jnp.exp(1j * 1.5)
return y.real
x0 = 2.3 + 4.5 * 1j
f, g = value_and_grad(test)(x0)
print("f = ", f)
print("g = ", g)
The result of this python code running is
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
f = -4.326032
g = (0.0707372+0.997495j)
These two examples I think the code logic is exactly the same, but the derivative is different from the complex conjugate. I donāt know where other than the problem.
The python project code as a whole turned out fine, but there were problems with the julia code. I have compared my julia code with the rest of the python code and it is the same. My debug results tell me that it seems that the above julia code example is wrong.
I donāt understand what happened to the julia code in this example.
where the imaginary part of the constant, b, appears negative in the real part of the result. Thus the derivative of f relative to the imaginary part of x seems to be indeed -b.
I think this is the first time I differentiate a complex function, so take it with a grain of saltā¦
The solutions are conjugates of each other.
If I remember rightly, whether a gradient should return a jacobian transpose, or a jacobian adjoint (i.e. conjugate transpose) is a matter of debate/convention.
Thatās an interesting discussion, but in this particular case I donāt see much space for an ambiguity. It is just a fact that increasing the imaginary part of x0 decreases the function value (which is is real here). Thus, the derivative must be negative.
Well, one can always choose a convention such that the gradient points towards decreases of the function values instead of increases for complex numbers (but please donāt).
But normally the discussion around conventions is instead what the pullback should compute: either J\cdot u or u \cdot J^\dagger (Jacobian-vector product or vector-Jacobian product). Typically, there are various practical and theoretical reasons for why a reverse-mode AD system should define the pullback \mathcal{B} of a function f at a point v should take a vector u to be
\mathcal{B}_{v}(f)(u) = u \cdot \Big(J(f)(v)\Big)^\dagger
where J(f)(v) is the Jacobian of f at v.
The fact that this is done, means that if you calculate a gradient or Jacobian from the pullback, itās easy to accidentally generate incorrect transposed gradients if you donāt think hard about complex number support.
Not sure if thatās what went wrong with the above Jax code or whatever, or if they decided to take some perverse convention where gradients point down complex functions though.
I donāt know what the convention is, and why not use the same, separate use will make people very confused. It cannot be said that the derivation of complex numbers does not have a definite definition.
My derivation tells me that zygoteās derivative was wrong and jax was right.
I donāt know why the link says that the derivative calculated by jax is conjugate. My derivation and actual calculation tell me that the derivative calculated by jax is its original derivative.
Okay, then this is something you should take up in the Jax forums, not the julia forums, because Jax is the odd one out here. PyTorch, Zygote, Tapir, ForwardDiff, and Enzyme all agree with eachother, itās Jax thatās returning the conjugate of the gradient.
Your derivation is incorrect, you can check @lmiqās post to get the right gradient.
Hereās the gradients given by Zygote:
and hereās the gradients given by Jax:
which you can see are pointing towards decreases in the functionās value along the imaginary axis, not increases.
Anyway, it seems that the problem in the Julia code must be somewhere else. The gradients there are ācorrectā, meaning they follow what Julia or Jax expect in each case. What was the kind of problem you had with the Julia vs the Jax code?