Discrepancy between complex gradients calculate with Zygote.jl and Python's Jax

When I translated another python project code into julia code, there was a problem. After unremitting efforts, I finally got the following two sample codes

Julia code
using Zygote
using FiniteDifferences

function test(x)
    y = x * exp( im * 1.5 )
    return real(y)
end

x0 = 2.3 + 4.5 * im
f, g = withgradient(test,x0)

println("f = ", f)
println("g = ", g)

The result of this julia code running is

f = -4.326031875882528
g = (0.0707372016677029 - 0.9974949866040544im,)

Python code
import jax.numpy as jnp
from jax import value_and_grad

def test(x):
    y = x * jnp.exp(1j * 1.5)
    return y.real

x0 = 2.3 + 4.5 * 1j
f, g = value_and_grad(test)(x0)

print("f = ", f)
print("g = ", g)

The result of this python code running is

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
f = -4.326032
g = (0.0707372+0.997495j)

These two examples I think the code logic is exactly the same, but the derivative is different from the complex conjugate. I donā€™t know where other than the problem.

The python project code as a whole turned out fine, but there were problems with the julia code. I have compared my julia code with the rest of the python code and it is the same. My debug results tell me that it seems that the above julia code example is wrong.

I donā€™t understand what happened to the julia code in this example.

Iā€™m not sure, but it seems to me that the Zygote gradient is correct there. If you take the constant to be a + bi and expand the function, we get

f(x\equiv x_r + ix_i) = real[ (a + bi) (x_r + i x_i) ] = real[ax_r + i^2 (bx_i) + i(bx_r + ax_i)] = ax_r -bx_i

where the imaginary part of the constant, b, appears negative in the real part of the result. Thus the derivative of f relative to the imaginary part of x seems to be indeed -b.

I think this is the first time I differentiate a complex function, so take it with a grain of saltā€¦

2 Likes

The solutions are conjugates of each other.
If I remember rightly, whether a gradient should return a jacobian transpose, or a jacobian adjoint (i.e. conjugate transpose) is a matter of debate/convention.

in 2020 @Mason did a deep dive on the topic:

including Wirtinger Derivatives.
Our convention come out of that discussion.

2 Likes

Thatā€™s an interesting discussion, but in this particular case I donā€™t see much space for an ambiguity. It is just a fact that increasing the imaginary part of x0 decreases the function value (which is is real here). Thus, the derivative must be negative.

julia> test(x0 + (0.0 + im*0.01)) - test(x0)
-0.009974949866040639

julia> test(x0 + (0.0 - im*0.01)) - test(x0)
0.00997494986603975
1 Like

Well, one can always choose a convention such that the gradient points towards decreases of the function values instead of increases for complex numbers (but please donā€™t).

But normally the discussion around conventions is instead what the pullback should compute: either J\cdot u or u \cdot J^\dagger (Jacobian-vector product or vector-Jacobian product). Typically, there are various practical and theoretical reasons for why a reverse-mode AD system should define the pullback \mathcal{B} of a function f at a point v should take a vector u to be

\mathcal{B}_{v}(f)(u) = u \cdot \Big(J(f)(v)\Big)^\dagger

where J(f)(v) is the Jacobian of f at v.

The fact that this is done, means that if you calculate a gradient or Jacobian from the pullback, itā€™s easy to accidentally generate incorrect transposed gradients if you donā€™t think hard about complex number support.

Not sure if thatā€™s what went wrong with the above Jax code or whatever, or if they decided to take some perverse convention where gradients point down complex functions though.


Edit: Looks like the Jax people are aware of this, and consider it to be a feature: grad returns complex conjugate of the gradient Ā· Issue #9110 Ā· google/jax Ā· GitHub, to which I must say ā€œyikesā€

5 Likes

Itā€™s not intuitively obvious, but letā€™s derive it

\frac{\partial f}{\partial x} = \frac{\partial f}{\partial x_r}\frac{\partial x_r}{\partial x}+\frac{\partial f}{\partial x_i}\frac{\partial x_i}{\partial x} \\ \ \ \ \ \ = \frac{\partial f}{\partial x_r}\frac{\partial x_r}{\partial (x_r+ix_i)}+\frac{\partial f}{\partial x_i}\frac{\partial x_i}{\partial (x_r+ix_i)}\\ \ \ \ \ \ =\frac{\partial f}{\partial x_r}\times 1+\frac{\partial f}{\partial x_i}\times\frac{1}{i}\\ \ \ \ \ \ =\frac{\partial f}{\partial x_r}\times 1+\frac{\partial f}{\partial x_i}\times\frac{1}{i} \\ \ \ \ \ \ =\frac{\partial f}{\partial x_r} - i \frac{\partial f}{\partial x_i}\\ \ \ \ \ \ =a+ib

If the above derivation is correct, then it seems obvious that julia took the wrong derivative.

I donā€™t know what the convention is, and why not use the same, separate use will make people very confused. It cannot be said that the derivation of complex numbers does not have a definite definition.

My derivation tells me that zygoteā€™s derivative was wrong and jax was right.

I donā€™t know why the link says that the derivative calculated by jax is conjugate. My derivation and actual calculation tell me that the derivative calculated by jax is its original derivative.

Itā€™s a little too technical. I cannā€™t understand it. :disappointed_relieved:

Enzyme has some good docs on this subject. Iā€™d recommend giving them a quick read: FAQ Ā· Enzyme.jl

Okay, then this is something you should take up in the Jax forums, not the julia forums, because Jax is the odd one out here. PyTorch, Zygote, Tapir, ForwardDiff, and Enzyme all agree with eachother, itā€™s Jax thatā€™s returning the conjugate of the gradient.

Your derivation is incorrect, you can check @lmiqā€™s post to get the right gradient.

Hereā€™s the gradients given by Zygote:

image

and hereā€™s the gradients given by Jax:

image

which you can see are pointing towards decreases in the functionā€™s value along the imaginary axis, not increases.

1 Like

Anyway, it seems that the problem in the Julia code must be somewhere else. The gradients there are ā€œcorrectā€, meaning they follow what Julia or Jax expect in each case. What was the kind of problem you had with the Julia vs the Jax code?

1 Like