# Jacobian of a network in the loss function with Flux

Dear all,
for a sciml problem in quantum mechanics (strongly inspired by what I saw during the @ChrisRackauckas talk’s in Grenoble), I need to differentiate the loss function. Here is a minimal example of the kind of problem I want to solve. In this very crude example, I try to train a network to represent a complex function. The network is composed of two inputs (real and imaginary part of a complex time) and two outputs, the complex value (real and imaginary part) of the function I want to represent.

The loss is a linear combinaison of two parts, a comparison between outputs and the value of the true function, and a comparison between the complex derivative and the true derivative.
The complex derivative is constructed from the jacobian of the network.

Here is the code:

``````using Revise
using Flux
using CUDA
using Statistics
using Plots
using Zygote
using ProgressMeter

CUDA_VISIBLE_DEVICES="0,1"

BACKEND = "CPU"
Flux.gpu_backend!(BACKEND)  # CPU or CUDA -- do not forget to restart julia

if BACKEND == "CUDA"
CUDA.device!(0)
device = Flux.get_device("CUDA", 0)
println("Device selected: ", device)
elseif BACKEND == "CPU"
device = Flux.get_device("CPU")
end

# Generate training data
function generate_data(N)
x = rand(1.0:0.1:10.0,N)
y = rand(.0:0.01:1.0,N) # Inputs in the range [1, 10]
f = sin.(x .+ im.*y)  # Outputs ln(x)
fprime = cos.(x .+ im.*y)
return copy(hcat(x, y)'),copy(hcat(real(f),imag(f))'),copy(hcat(real(fprime),imag(fprime))')
end

# Define the model
model = Chain(
Dense(2,2048, leakyrelu),
Dense(2048, 2048, leakyrelu),
#  Dense(512, 512, relu),
#Dense(512, 256, relu),
Dense(2048, 2)
) |> device

jac = x -> Zygote.jacobian(model, x)

# Kernel to fill the complex derivatives
function fill_deriv!(deriv, Jacob, NN)
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
i = 2 * (idx - 1) + 1
if i <= NN
real_part = 0.5 * (Jacob[i, i] + Jacob[i+1, i+1])
imag_part = 0.5 * (Jacob[i+1, i] - Jacob[i, i+1])
deriv[div(i, 2) + 1] = ComplexF32(real_part, imag_part)
end
return
end
# Adapt the `complex_derivative` function for GPU
function complex_derivative(z::CuArray{Float32, 2})
Jacob = jac(z)[1]
NN = size(Jacob, 2)
deriv = CUDA.fill(ComplexF32(0.0, 0.0), div(NN, 2))

# Use a loop on the GPU

return deriv
end

function complex_derivative(z::Array)
Jacob = jac(z)[1]
NN = size(Jacob, 2)
deriv = Vector{ComplexF32}(undef, div(NN, 2))
for i in 1:div(NN, 2)
real_part = 0.5 * (Jacob[2*i-1, 2*i-1] + Jacob[2*i, 2*i])
imag_part = 0.5 * (Jacob[2*i, 2*i-1] - Jacob[2*i-1, 2*i])
deriv[i] = ComplexF32(real_part, imag_part)
end
return deriv
end

# Define the loss function
function  myloss(model,x,f,fprime)
y = model(x)
yprime = complex_derivative(x)
return Flux.mse(y,f) + Flux.mse([real.(yprime) imag.(yprime)]',fprime)
end

# Define the optimizer

# Generate training data
N = 3000
x_train, f_train,fprime_train = generate_data(N)
x_train = x_train |> device
f_train = f_train |> device
fprime_train = fprime_train |> device

# To check that myloss function is working :
println("test myloss :", myloss(model,x_train[:,1:10],f_train[:,1:10],fprime_train[:,1:10]))

# Training loop
epochs = 1000
batch_size = 64

P = Flux.DataLoader((x_train, f_train,fprime_train), batchsize=batch_size, shuffle=true)

# Function to train on a batch
function train_epoch!(P)
for (x,f,fprime) in P
myloss(m,x,f,fprime)
end
return loss
end
end

# Training process
global min_loss = Inf
@showprogress color=:blue for epoch in 1:epochs
train_loss = train_epoch!(P)
if train_loss < min_loss
global min_loss = train_loss
global best_model = deepcopy(model)
end
println("Epoch \$epoch: Training loss = \$train_loss")
end
println("Best Training loss = \$min_loss")
``````

When running on the GPU, I get the following error:

ERROR: LoadError: `llvmcall` must be compiled to be called

This message is a bit obscure. But when I use CPU, I get the following error which corresponds to what I read about Zygote.jacobian in the loss function:

ERROR: LoadError: Mutating arrays is not supported – called setindex!(Vector{ComplexF32}, …)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= …)

Possible fixes:

• avoid mutating operations (preferred)
• or read the documentation and solutions for this error
Limitations · Zygote

Is there a solution to fix it? Indeed, for my real probelm, I really need to get the derivative of the network wrt the complex input time. And I am now really stuck!

I read some posts talking about using destructure but I have to admit that I do not succeed.

If someone has any suggestion, I would greatly appreciate. Thaks in advance.

Baptiste

1 Like

It’s helpful to share the full stack trace of the error. But from reading just the top level, my guess is that it’s pointing to this line of code:

and what it’s saying is that mutation of this form is not supported in the Zygote autodiff engine. To fix this, simply change that implementation to a map:

``````deriv = map(1:div(NN, 2)) do i
real_part = 0.5 * (Jacob[2*i-1, 2*i-1] + Jacob[2*i, 2*i])
imag_part = 0.5 * (Jacob[2*i, 2*i-1] - Jacob[2*i-1, 2*i])
ComplexF32(real_part, imag_part)
end
``````

I’m sorry I didn’t put the whole error message, I’m new to this kind of forum and haven’t adopted all the best practices yet.

By making the change you suggest, I still get the same type of error.

Here is my complete code (with a modification so that the Jacobian is recalculated each time the loss function is called):

``````using Revise
using Flux
using CUDA
using Statistics
using Plots
using Zygote
using ProgressMeter

CUDA_VISIBLE_DEVICES="0,1"

BACKEND = "CPU"
Flux.gpu_backend!(BACKEND)  # CPU or CUDA -- do not forget to restart julia

if BACKEND == "CUDA"
CUDA.device!(0)
device = Flux.get_device("CUDA", 0)
println("Device selected: ", device)
elseif BACKEND == "CPU"
device = Flux.get_device("CPU")
end

# Generate training data
function generate_data(N)
x = rand(1.0:0.1:10.0,N)
y = rand(.0:0.01:1.0,N) # Inputs in the range [1, 10]
f = sin.(x .+ im.*y)  # Outputs ln(x)
fprime = cos.(x .+ im.*y)
return copy(hcat(x, y)'),copy(hcat(real(f),imag(f))'),copy(hcat(real(fprime),imag(fprime))')
end

# Define the model
model = Chain(
Dense(2,2048, leakyrelu),
Dense(2048, 2048, leakyrelu),
Dense(2048, 2)
) |> device

# Kernel to fill the complex derivatives
function fill_deriv!(deriv, Jacob, NN)
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
i = 2 * (idx - 1) + 1
if i <= NN
real_part = 0.5 * (Jacob[i, i] + Jacob[i+1, i+1])
imag_part = 0.5 * (Jacob[i+1, i] - Jacob[i, i+1])
deriv[div(i, 2) + 1] = ComplexF32(real_part, imag_part)
end
return
end
# Adapt the `complex_derivative` function for GPU
function complex_derivative(model :: Chain, z::CuArray{Float32, 2})
jac = x -> Zygote.jacobian(model, x)
Jacob = jac(z)[1]
NN = size(Jacob, 2)
deriv = CUDA.fill(ComplexF32(0.0, 0.0), div(NN, 2))

# Use a loop on the GPU

return deriv
end

function complex_derivative(model::Chain,z::Array)
jac = x -> Zygote.jacobian(model, x)
Jacob = jac(z)[1]
NN = size(Jacob, 2)
#deriv = Vector{ComplexF32}(undef, div(NN, 2))
deriv = map(1:div(NN, 2)) do i
real_part = 0.5 * (Jacob[2*i-1, 2*i-1] + Jacob[2*i, 2*i])
imag_part = 0.5 * (Jacob[2*i, 2*i-1] - Jacob[2*i-1, 2*i])
ComplexF32(real_part, imag_part)
end
return deriv
end

# Define the loss function
function  myloss(model,x,f,fprime)
y = model(x)
yprime = complex_derivative(model,x)
return Flux.mse(y,f) + Flux.mse([real.(yprime) imag.(yprime)]',fprime)
end

# Define the optimizer

# Generate training data
N = 3000
x_train, f_train,fprime_train = generate_data(N)
x_train = x_train |> device
f_train = f_train |> device
fprime_train = fprime_train |> device

# To check that myloss function is working :
println("test myloss :", myloss(model,x_train[:,1:10],f_train[:,1:10],fprime_train[:,1:10]))

# Training loop
epochs = 1000
batch_size = 64

P = Flux.DataLoader((x_train, f_train,fprime_train), batchsize=batch_size, shuffle=true)

# Function to train on a batch
function train_epoch!(P)
for (x,f,fprime) in P
myloss(m,x,f,fprime)
end
return loss
end
end

# Training process
global min_loss = Inf
@showprogress color=:blue for epoch in 1:epochs
train_loss = train_epoch!(P)
if train_loss < min_loss
global min_loss = train_loss
global best_model = deepcopy(model)
end
println("Epoch \$epoch: Training loss = \$train_loss")
end
println("Best Training loss = \$min_loss")
``````

Here is the associated error when running on CPU:

``````ERROR: LoadError: Mutating arrays is not supported -- called copyto!(SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] _throw_mutation_error(f::Function, args::SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{…}}, true})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:70
[3] (::Zygote.var"#543#544"{SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{…}}, true}})(::Nothing)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:85
[4] (::Zygote.var"#2633#back#545"{Zygote.var"#543#544"{SubArray{Float64, 1, Matrix{…}, Tuple{…}, true}}})(Δ::Nothing)
[6] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[7] withjacobian
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[9] #291
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
[10] #2169#back
[11] jacobian
[12] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Matrix{…}})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[13] #7
@ /mnt/CAPTAIN_HARLOCK/RECHERCHE/JULIA/MY_COUNTER_ROTATING/MyCounterRotating/test_function_fit_derivative_complex.jl:73 [inlined]
[14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Matrix{…}})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[15] complex_derivative
@ /mnt/CAPTAIN_HARLOCK/RECHERCHE/JULIA/MY_COUNTER_ROTATING/MyCounterRotating/test_function_fit_derivative_complex.jl:74 [inlined]
[16] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{ComplexF32})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[17] myloss
@ /mnt/CAPTAIN_HARLOCK/RECHERCHE/JULIA/MY_COUNTER_ROTATING/MyCounterRotating/test_function_fit_derivative_complex.jl:88 [inlined]
[18] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[19] #11
@ /mnt/CAPTAIN_HARLOCK/RECHERCHE/JULIA/MY_COUNTER_ROTATING/MyCounterRotating/test_function_fit_derivative_complex.jl:118 [inlined]
[20] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[21] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:213
[23] train_epoch!(P::MLUtils.DataLoader{Tuple{Matrix{…}, Matrix{…}, Matrix{…}}, Random._GLOBAL_RNG, Val{nothing}})
@ Main /mnt/CAPTAIN_HARLOCK/RECHERCHE/JULIA/MY_COUNTER_ROTATING/MyCounterRotating/test_function_fit_derivative_complex.jl:117
[24] top-level scope
@ /mnt/CAPTAIN_HARLOCK/RECHERCHE/JULIA/MY_COUNTER_ROTATING/MyCounterRotating/test_function_fit_derivative_complex.jl:128
in expression starting at /mnt/CAPTAIN_HARLOCK/RECHERCHE/JULIA/MY_COUNTER_ROTATING/MyCounterRotating/test_function_fit_derivative_complex.jl:127
Some type information was truncated. Use `show(err)` to see complete types.
``````

I hope it could help to find the solution. If necessary, I can add the result of the show(err) command.

My fear is that it won’t be possible to use the Jacobian in the loss, which would be very bad news for me. Indeed, my project is based on this approach.

All suggestions are welcome.

I have changed once more my code in order to avoid mutating arrays. Here is the code:

``````using Revise
using Flux
using Flux: params
using CUDA
using Statistics
using Plots
using Zygote
using ProgressMeter

# Setting up the backend
BACKEND = "CPU"
Flux.gpu_backend!(BACKEND)  # CPU or CUDA -- do not forget to restart julia

if BACKEND == "CUDA"
CUDA.device!(0)
device = Flux.get_device("CUDA", 0)
println("Device selected: ", device)
elseif BACKEND == "CPU"
device = Flux.get_device("CPU")
end

# Generate training data
function generate_data(N)
x = rand(ComplexF32, (1, N))
f = sin.(x)
fprime = cos.(x)
return x, f, fprime
end

model = Chain(
x -> vcat(real(x), imag(x)),
Dense(2, 2048, leakyrelu),
Dense(2048, 2048, leakyrelu),
Dense(2048, 2),
x -> complex.(reshape(x[1, :], (1, size(x, 2))), reshape(x[2, :], (1, size(x, 2))))
) |> device

function complex_derivative(f, x)
return ComplexF32.(conj.(0.5 .* (Zygote.pullback(f, x)[2](1.0)[1] .- im * Zygote.pullback(f, x)[2](im)[1])))
end

# Define the loss function
function myloss(model, x, f, fprime)
y = model(x)
yprime = complex_derivative(x -> model(x), x)
return Flux.mse(y, f) +Flux.mse(yprime,fprime)
end

# Define the optimizer
optim = Flux.setup(optimizer, model)

# Generate training data
N = 3000
x_train, f_train, fprime_train = generate_data(N)
x_train = x_train |> device
f_train = f_train |> device
fprime_train = fprime_train |> device

# To check that myloss function is working
println("test myloss: ", myloss(model, x_train[:, 1:10], f_train[:, 1:10], fprime_train[:, 1:10]))

# Training loop
epochs = 1000
batch_size = 64

P = Flux.DataLoader((x_train, f_train, fprime_train), batchsize=batch_size, shuffle=true)

# Function to train on a batch
function train_epoch!(P)
for (x, f, fprime) in P
myloss(m, x, f, fprime)
end
end
end

# Training process
global min_loss = Inf
@showprogress color=:blue for epoch in 1:epochs
train_epoch!(P)
train_loss = myloss(model, x_train, f_train, fprime_train)
if train_loss < min_loss
global min_loss = train_loss
global best_model = deepcopy(model)
end
println("Epoch \$epoch: Training loss = \$train_loss")
end
println("Best Training loss = \$min_loss")
``````

And the error associated:

``````julia> include("fit_derivative_complex.jl")
┌ Info: GPU backend is already set to: CPU.
└ No need to do anything else.
test myloss: 1.7619199
ERROR: LoadError: DimensionMismatch: dimensions must match: a has dims (Base.OneTo(1), Base.OneTo(64)), must have singleton at dim 2
Stacktrace:
[1] promote_shape
@ ./indices.jl:183 [inlined]
[2] _promote_tuple_shape
@ ./iterators.jl:394 [inlined]
[3] axes
@ ./iterators.jl:391 [inlined]
[4] _tryaxes
@ ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:171 [inlined]
[5] map
@ ./tuple.jl:291 [inlined]
[6] ∇map(cx::Zygote.Context{false}, f::Base.var"#4#5"{Zygote.var"#1366#1372"}, args::Base.Iterators.Zip{Tuple{…}})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:189
[7] _pullback(cx::Zygote.Context{false}, ::typeof(collect), g::Base.Generator{Base.Iterators.Zip{…}, Base.var"#4#5"{…}})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:230
[8] map
@ ./abstractarray.jl:3409 [inlined]
[9] _pullback(::Zygote.Context{false}, ::typeof(map), ::Zygote.var"#1366#1372", ::Matrix{Tuple{…}}, ::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[11] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#∇broadcasted#1371"{Tuple{…}, Matrix{…}, Val{…}}, args::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[12] #4115#back
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[14] #291
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
[15] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#291#292"{Tuple{…}, Zygote.var"#4115#back#1360"{…}}, args::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[16] #2169#back
[17] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{…}}, args::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[18] Pullback
[19] Pullback
@ ~/Documents/RECHERCHE/JULIA/COUNTER_ROTATING/CounterRotating/fit_derivative_complex.jl:36 [inlined]
[20] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[21] Pullback
@ ~/.julia/packages/Flux/CUn7U/src/layers/basic.jl:53 [inlined]
[22] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[23] Pullback
@ ~/.julia/packages/Flux/CUn7U/src/layers/basic.jl:51 [inlined]
[24] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[25] Pullback
@ ~/Documents/RECHERCHE/JULIA/COUNTER_ROTATING/CounterRotating/fit_derivative_complex.jl:52 [inlined]
[26] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[27] #75
@ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91 [inlined]
[28] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}}, args::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[29] complex_derivative
@ ~/Documents/RECHERCHE/JULIA/COUNTER_ROTATING/CounterRotating/fit_derivative_complex.jl:46 [inlined]
[30] _pullback(::Zygote.Context{false}, ::typeof(complex_derivative), ::var"#9#10"{Chain{Tuple{…}}}, ::Matrix{ComplexF32})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[31] myloss
@ ~/Documents/RECHERCHE/JULIA/COUNTER_ROTATING/CounterRotating/fit_derivative_complex.jl:52 [inlined]
[32] _pullback(::Zygote.Context{…}, ::typeof(myloss), ::Chain{…}, ::Matrix{…}, ::Matrix{…}, ::Matrix{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[33] #11
@ ~/Documents/RECHERCHE/JULIA/COUNTER_ROTATING/CounterRotating/fit_derivative_complex.jl:80 [inlined]
[34] _pullback(ctx::Zygote.Context{false}, f::var"#11#12"{Matrix{…}, Matrix{…}, Matrix{…}}, args::Chain{Tuple{…}})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[35] pullback(f::Function, cx::Zygote.Context{false}, args::Chain{Tuple{var"#5#7", Dense{…}, Dense{…}, Dense{…}, var"#6#8"}})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:90
[36] pullback
@ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:88 [inlined]
[37] withgradient(f::Function, args::Chain{Tuple{var"#5#7", Dense{…}, Dense{…}, Dense{…}, var"#6#8"}})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:205
[38] train_epoch!(P::MLUtils.DataLoader{Tuple{Matrix{…}, Matrix{…}, Matrix{…}}, Random._GLOBAL_RNG, Val{nothing}})
@ Main ~/Documents/RECHERCHE/JULIA/COUNTER_ROTATING/CounterRotating/fit_derivative_complex.jl:79
[39] top-level scope
@ ~/Documents/RECHERCHE/JULIA/COUNTER_ROTATING/CounterRotating/fit_derivative_complex.jl:91
in expression starting at /home/baptiste/Documents/RECHERCHE/JULIA/COUNTER_ROTATING/CounterRotating/fit_derivative_complex.jl:90
Some type information was truncated. Use `show(err)` to see complete types.
``````

Any help would be greatly appreciated. Thank you in advance.

Zygote doesn’t handle nested derivatives well. See Nested Automatic Differentiation | Lux.jl Docs on how to handle this with Lux. I haven’t tested complex derivatives with this nested AD but if you encounter any problems, open an issue and I can take a look

I have tried your suggestion. Here is the updated code:

``````using Revise
using Lux
using CUDA
using Statistics
using Plots
using Zygote
using ForwardDiff
using ProgressMeter
using Random

# Generate training data
function generate_data(N)
x = rand(ComplexF32, (1, N))
f = sin.(x)
fprime = cos.(x)
return x, f, fprime
end

model = Chain(
x -> vcat(real(x), imag(x)),
Dense(2 => 2048, leakyrelu),
Dense(2048 => 2048, leakyrelu),
Dense(2048 => 2),
x -> complex.(reshape(x[1, :], (1, size(x, 2),)), reshape(x[2, :], (1, size(x, 2),)))
)

function complex_derivative(f, x)
return ComplexF32.(conj.(0.5 .* (Zygote.pullback(f, x)[2](1.0)[1] .- im * Zygote.pullback(f, x)[2](im)[1])))
end

# Define the loss function
function myloss(model, x, ps, st,f, fprime)
smodel = Lux.StatefulLuxLayer{true}(model, ps, st)
y = smodel(x)
yprime = complex_derivative(x -> smodel(x), x)
loss_emp = sum(abs2, y .- f)
loss_reg = sum(abs2, yprime .- fprime)
return loss_emp + loss_reg
end

ps, st = Lux.setup(Xoshiro(0), model)

# Generate training data
N = 3000
x_train, f_train, fprime_train = generate_data(N)
x_train = x_train
f_train = f_train
fprime_train = fprime_train

# To check that myloss function is working
println("test myloss: ", myloss(model, x_train[:, 1:10],ps,st, f_train[:, 1:10], fprime_train[:, 1:10]))

_, ∂x, ∂ps, _, _ = Zygote.gradient(myloss, model, x_train[:, 1:10],ps,st, f_train[:, 1:10], fprime_train[:, 1:10])
``````

It returns this error:

``````ERROR: LoadError: DimensionMismatch: dimensions must match: a has dims (Base.OneTo(1), Base.OneTo(10)), must have singleton at dim 2
Stacktrace:
[1] promote_shape
@ ./indices.jl:183 [inlined]
[2] _promote_tuple_shape
@ ./iterators.jl:394 [inlined]
[3] axes
@ ./iterators.jl:391 [inlined]
[4] _tryaxes
@ ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:171 [inlined]
[5] map
@ ./tuple.jl:291 [inlined]
[6] ∇map(cx::Zygote.Context{false}, f::Base.var"#4#5"{Zygote.var"#1366#1372"}, args::Base.Iterators.Zip{Tuple{…}})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:189
[7] _pullback(cx::Zygote.Context{false}, ::typeof(collect), g::Base.Generator{Base.Iterators.Zip{…}, Base.var"#4#5"{…}})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:230
[8] map
@ ./abstractarray.jl:3409 [inlined]
[9] _pullback(::Zygote.Context{false}, ::typeof(map), ::Zygote.var"#1366#1372", ::Matrix{Tuple{…}}, ::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[11] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#∇broadcasted#1371"{Tuple{…}, Matrix{…}, Val{…}}, args::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[12] #4115#back
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[14] #291
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
[15] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#291#292"{Tuple{…}, Zygote.var"#4115#back#1360"{…}}, args::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[16] #2169#back
[17] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{…}}, args::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[18] Pullback
[19] Pullback
@ /mnt/CAPTAIN_HARLOCK/RECHERCHE/JULIA/MY_COUNTER_ROTATING/MyCounterRotating/test_function_fit_derivative_complex_lux_autodiff.jl:25 [inlined]
[20] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[21] Pullback
@ ~/.julia/packages/Lux/JXc6P/src/layers/basic.jl:269 [inlined]
[22] Pullback
@ ~/.julia/packages/Lux/JXc6P/src/layers/basic.jl:256 [inlined]
[23] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Tuple{Float64, Nothing})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[24] Pullback
@ ~/.julia/packages/LuxCore/biwfu/src/LuxCore.jl:168 [inlined]
[25] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Tuple{Float64, Nothing})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[26] Pullback
@ ~/.julia/packages/Lux/JXc6P/src/layers/containers.jl:0 [inlined]
[27] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Tuple{Float64, Nothing})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[28] Pullback
@ ~/.julia/packages/Lux/JXc6P/src/layers/containers.jl:510 [inlined]
[29] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Tuple{Float64, Nothing})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[30] Pullback
@ ~/.julia/packages/LuxCore/biwfu/src/LuxCore.jl:168 [inlined]
[31] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Tuple{Float64, Nothing})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[32] Pullback
@ ~/.julia/packages/Lux/JXc6P/src/helpers/stateful.jl:111 [inlined]
[33] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[34] Pullback
@ ~/.julia/packages/Lux/JXc6P/src/helpers/stateful.jl:111 [inlined]
[35] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[36] Pullback
@ /mnt/CAPTAIN_HARLOCK/RECHERCHE/JULIA/MY_COUNTER_ROTATING/MyCounterRotating/test_function_fit_derivative_complex_lux_autodiff.jl:37 [inlined]
[37] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[38] #75
@ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91 [inlined]
[39] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}}, args::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[40] complex_derivative
@ /mnt/CAPTAIN_HARLOCK/RECHERCHE/JULIA/MY_COUNTER_ROTATING/MyCounterRotating/test_function_fit_derivative_complex_lux_autodiff.jl:30 [inlined]
[41] _pullback(::Zygote.Context{…}, ::typeof(complex_derivative), ::var"#9#10"{…}, ::Matrix{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[42] myloss
@ /mnt/CAPTAIN_HARLOCK/RECHERCHE/JULIA/MY_COUNTER_ROTATING/MyCounterRotating/test_function_fit_derivative_complex_lux_autodiff.jl:37 [inlined]
[43] _pullback(::Zygote.Context{…}, ::typeof(myloss), ::Chain{…}, ::Matrix{…}, ::@NamedTuple{…}, ::@NamedTuple{…}, ::Matrix{…}, ::Matrix{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[44] pullback(::Function, ::Zygote.Context{false}, ::Chain{@NamedTuple{…}, Nothing}, ::Vararg{Any})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:90
[45] pullback(::Function, ::Chain{@NamedTuple{…}, Nothing}, ::Matrix{ComplexF32}, ::Vararg{Any})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:88
[46] gradient(::Function, ::Chain{@NamedTuple{…}, Nothing}, ::Matrix{ComplexF32}, ::Vararg{Any})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:147
[47] top-level scope
``````

I think it does not like the fact that my loss function, and in particular the complex_derivative part, include a combinaison of jacobian parts.

``````using Lux, Statistics, Zygote, ForwardDiff, Random

# Generate training data
function generate_data(N)
x = rand(ComplexF32, (1, N))
f = sin.(x)
fprime = cos.(x)
return x, f, fprime
end

model = Chain(x -> vcat(real(x), imag(x)),
Dense(2 => 2048, leakyrelu),
Dense(2048 => 2048, leakyrelu),
Dense(2048 => 2),
x -> complex.(reshape(x[1, :], (1, size(x, 2))), reshape(x[2, :], (1, size(x, 2)))))

function complex_derivative(f, x)
real_part = vector_jacobian_product(f, AutoZygote(), x, eltype(x)(1))
comp_part = vector_jacobian_product(f, AutoZygote(), x, eltype(x)(im))
return ComplexF32.(conj.(0.5f0 .* (real_part .- im .* comp_part)))
end

# Define the loss function
function myloss(model, x, ps, st, f, fprime)
smodel = Lux.StatefulLuxLayer{true}(model, ps, st)
y = smodel(x)
yprime = complex_derivative(smodel, x)
loss_emp = sum(abs2, y .- f)
loss_reg = sum(abs2, yprime .- fprime)
return loss_emp + loss_reg
end

ps, st = Lux.setup(Xoshiro(0), model)

# Generate training data
N = 10
x_train, f_train, fprime_train = generate_data(N)

# To check that myloss function is working
println("test myloss: ",
myloss(model, x_train[:, 1:10], ps, st, f_train[:, 1:10], fprime_train[:, 1:10]))

_, ∂x, ∂ps, _, _ = Zygote.gradient(
myloss, model, x_train[:, 1:10], ps, st, f_train[:, 1:10], fprime_train[:, 1:10])
``````

This is a modified version of your code that allows Lux to use nested AD. Though it very soon hits

``````ERROR: ArgumentError: Cannot create a dual over scalar type ComplexF32. If the type behaves as a scalar, define ForwardDiff.can_dual(::Type{ComplexF32}) = true.
Stacktrace:
⋮ internal @ ForwardDiff
[2] ForwardDiff.Dual{ForwardDiff.Tag{Lux.var"#213#215"{typeof(pullback), StatefulLuxLayer{true, Chain{…}, @NamedTuple{…}, @NamedTuple{…}}, ComplexF32}, ComplexF32}, ComplexF32, 1}(value::ComplexF32, partials::ForwardDiff.Partials{1, ComplexF32})
@ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:18
[3] (::Lux.var"#bfn#233"{ComplexF32, ForwardDiff.Tag{Lux.var"#213#215"{typeof(pullback), StatefulLuxLayer{true, Chain{…}, @NamedTuple{…}, @NamedTuple{…}}, ComplexF32}, ComplexF32}})(xᵢ::ComplexF32, uᵢ::Tuple{ComplexF32})
⋮ internal @ Unknown
[8] materialize
[9] __dualify
[10] __forwarddiff_jvp(f::Lux.var"#213#215"{…}, x::Matrix{…}, Δx::Matrix{…}, y::@NamedTuple{…})
@ Lux /mnt/research/lux/Lux.jl/src/forwarddiff/jvp.jl:5
[11] (::Lux.var"#212#214"{typeof(pullback), StatefulLuxLayer{…}, Matrix{…}, @NamedTuple{…}, ComplexF32, Matrix{…}})(Δ_::Matrix{ComplexF32})
[12] ZBack
@ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
[13] __vector_jacobian_product_impl
[14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Matrix{ComplexF32})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[15] vector_jacobian_product
@ /mnt/research/lux/Lux.jl/src/helpers/autodiff.jl:36 [inlined]
[16] (::Zygote.Pullback{Tuple{typeof(vector_jacobian_product), StatefulLuxLayer{true, Chain{…}, @NamedTuple{…}, @NamedTuple{…}}, AutoZygote, Matrix{ComplexF32}, ComplexF32}, Any})(Δ::Matrix{ComplexF32})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[17] complex_derivative
@ ./REPL[35]:3 [inlined]
[18] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Matrix{ComplexF32})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[19] myloss
@ ./REPL[42]:5 [inlined]
⋮ internal @ Zygote
[22] gradient(::Function, ::Chain{@NamedTuple{layer_1::WrappedFunction{…}, layer_2::Dense{…}, layer_3::Dense{…}, layer_4::Dense{…}, layer_5::WrappedFunction{…}}, Nothing}, ::Matrix{ComplexF32}, ::Vararg{Any})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
[23] top-level scope
@ REPL[47]:1
[24] top-level scope
@ none:1
Use `err` to retrieve the full stack trace.
Some type information was truncated. Use `show(err)` to see complete types.
``````

I am not too familiar with complex differentiation to know how to fix this. Support for real-valued function with complex arguments · Issue #498 · JuliaDiff/ForwardDiff.jl · GitHub is the issue I could find in ForwardDiff