Discrepancy between complex gradients calculate with Zygote.jl and Python's Jax

I’m not sure, but it seems to me that the Zygote gradient is correct there. If you take the constant to be a + bi and expand the function, we get

f(x\equiv x_r + ix_i) = real[ (a + bi) (x_r + i x_i) ] = real[ax_r + i^2 (bx_i) + i(bx_r + ax_i)] = ax_r -bx_i

where the imaginary part of the constant, b, appears negative in the real part of the result. Thus the derivative of f relative to the imaginary part of x seems to be indeed -b.

I think this is the first time I differentiate a complex function, so take it with a grain of salt…

2 Likes