Dear all,
for a sciml problem in quantum mechanics (strongly inspired by what I saw during the @ChrisRackauckas talk’s in Grenoble), I need to differentiate the loss function. Here is a minimal example of the kind of problem I want to solve. In this very crude example, I try to train a network to represent a complex function. The network is composed of two inputs (real and imaginary part of a complex time) and two outputs, the complex value (real and imaginary part) of the function I want to represent.
The loss is a linear combinaison of two parts, a comparison between outputs and the value of the true function, and a comparison between the complex derivative and the true derivative.
The complex derivative is constructed from the jacobian of the network.
Here is the code:
using Revise
using Flux
using CUDA
using Flux.Optimise: ADAM
using Statistics
using Plots
using Zygote
using ProgressMeter
CUDA_VISIBLE_DEVICES="0,1"
BACKEND = "CPU"
Flux.gpu_backend!(BACKEND) # CPU or CUDA -- do not forget to restart julia
if BACKEND == "CUDA"
CUDA.device!(0)
device = Flux.get_device("CUDA", 0)
println("Device selected: ", device)
elseif BACKEND == "CPU"
device = Flux.get_device("CPU")
end
# Generate training data
function generate_data(N)
x = rand(1.0:0.1:10.0,N)
y = rand(.0:0.01:1.0,N) # Inputs in the range [1, 10]
f = sin.(x .+ im.*y) # Outputs ln(x)
fprime = cos.(x .+ im.*y)
return copy(hcat(x, y)'),copy(hcat(real(f),imag(f))'),copy(hcat(real(fprime),imag(fprime))')
end
# Define the model
model = Chain(
Dense(2,2048, leakyrelu),
Dense(2048, 2048, leakyrelu),
# Dense(512, 512, relu),
#Dense(512, 256, relu),
Dense(2048, 2)
) |> device
jac = x -> Zygote.jacobian(model, x)
# Kernel to fill the complex derivatives
function fill_deriv!(deriv, Jacob, NN)
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
i = 2 * (idx - 1) + 1
if i <= NN
real_part = 0.5 * (Jacob[i, i] + Jacob[i+1, i+1])
imag_part = 0.5 * (Jacob[i+1, i] - Jacob[i, i+1])
deriv[div(i, 2) + 1] = ComplexF32(real_part, imag_part)
end
return
end
# Adapt the `complex_derivative` function for GPU
function complex_derivative(z::CuArray{Float32, 2})
Jacob = jac(z)[1]
NN = size(Jacob, 2)
deriv = CUDA.fill(ComplexF32(0.0, 0.0), div(NN, 2))
# Use a loop on the GPU
threads = min(div(NN, 2), 1024)
blocks = cld(div(NN, 2), threads)
CUDA.@sync @cuda threads=threads blocks=blocks fill_deriv!(deriv, Jacob, NN)
return deriv
end
function complex_derivative(z::Array)
Jacob = jac(z)[1]
NN = size(Jacob, 2)
deriv = Vector{ComplexF32}(undef, div(NN, 2))
for i in 1:div(NN, 2)
real_part = 0.5 * (Jacob[2*i-1, 2*i-1] + Jacob[2*i, 2*i])
imag_part = 0.5 * (Jacob[2*i, 2*i-1] - Jacob[2*i-1, 2*i])
deriv[i] = ComplexF32(real_part, imag_part)
end
return deriv
end
# Define the loss function
function myloss(model,x,f,fprime)
y = model(x)
yprime = complex_derivative(x)
return Flux.mse(y,f) + Flux.mse([real.(yprime) imag.(yprime)]',fprime)
end
# Define the optimizer
optimizer = Flux.AdaMax()
# Generate training data
N = 3000
x_train, f_train,fprime_train = generate_data(N)
x_train = x_train |> device
f_train = f_train |> device
fprime_train = fprime_train |> device
# To check that myloss function is working :
println("test myloss :", myloss(model,x_train[:,1:10],f_train[:,1:10],fprime_train[:,1:10]))
# Training loop
epochs = 1000
batch_size = 64
P = Flux.DataLoader((x_train, f_train,fprime_train), batchsize=batch_size, shuffle=true)
# Function to train on a batch
function train_epoch!(P)
for (x,f,fprime) in P
loss, grads = Flux.withgradient(model) do m
myloss(m,x,f,fprime)
end
Flux.Optimise.update!(optimizer, model, grads[1])
return loss
end
end
# Training process
global min_loss = Inf
@showprogress color=:blue for epoch in 1:epochs
train_loss = train_epoch!(P)
if train_loss < min_loss
global min_loss = train_loss
global best_model = deepcopy(model)
end
println("Epoch $epoch: Training loss = $train_loss")
end
println("Best Training loss = $min_loss")
When running on the GPU, I get the following error:
ERROR: LoadError:
llvmcall
must be compiled to be called
This message is a bit obscure. But when I use CPU, I get the following error which corresponds to what I read about Zygote.jacobian in the loss function:
ERROR: LoadError: Mutating arrays is not supported – called setindex!(Vector{ComplexF32}, …)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= …)
Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
Limitations · Zygote
Is there a solution to fix it? Indeed, for my real probelm, I really need to get the derivative of the network wrt the complex input time. And I am now really stuck!
I read some posts talking about using destructure but I have to admit that I do not succeed.
If someone has any suggestion, I would greatly appreciate. Thaks in advance.
Baptiste