I’m trying to implement a Universal Differential Equation (UDE) based on a 2D diffusion PDE with an integrated neural network. I’m computing the diffusivity of a 2D matrix H
with a forward scheme, and once I reach the the final state of H
I compute the loss function and try to backpropagate the whole thing using Zygote.pullback()
in order to update the weights of the neural network.
Since I am updating (overwriting) the H
matrix at each time step, Zygote rightfully complains that “Mutating arrays is not supported”. I have tried to work around this by using Zygote.Buffer
, but I still keep getting the error. I have also read most posts related to this issue, but I haven’t been able to find a solution to this. I’m using Julia 1.6.
Here is an overview of what my code does (I have omitted many parts, as the original code does many more things that I judge irrelevant to this issue):
# Train the UDE
# We get the model parameters to be trained
ps_UA = Flux.params(UA) # get the parameters of the NN
#data = Flux.Data.DataLoader((X, Y), batchsize=hyparams.batchsize, (X, Y), shuffle=false)
data = Dict("X"=>X, "Y"=>Y) # Get the input data
# Train the UDE for a given number of epochs
@epochs hyparams.epochs hybrid_train!(loss, ps_UA, data, opt, H, p, t, t₁)
Which calls
function hybrid_train!(loss, ps_UA, data, opt, H, p, t, t₁)
# Some previous code here
# Retrieve model parameters
ps_UA = Params(ps_UA)
# back is a method that computes the product of the gradient so far with its argument.
train_loss_UA, back_UA = Zygote.pullback(() -> loss(data, H, p, t, t₁), ps_UA)
# Callback to track the training
callback(train_loss_UA)
# Apply back() to the correct type of 1.0 to get the gradient of loss.
gs_UA = back_UA(one(train_loss_UA))
# Insert what ever code you want here that needs gradient.
Flux.update!(opt, ps_UA, gs_UA)
end
Based on a loss function which contains the forward scheme
function loss(data, H, p, t, t₁)
l_H = 0.0f0
# Compute l_H as the difference between the simulated H with UA(x) and H_ref
iceflow!(H, UA, p,t,t₁)
l_H = sqrt(Flux.Losses.mse(H, H_ref["H"][end]; agg=mean))
return l_H
Where the forward scheme is defined here (this is a simplified version of the code):
function iceflow!(H, UA, p,t,t₁)
# Manual explicit forward scheme implementation
while t < t₁
# Long chunk of code which computes the flux F in a staggered grid
# Δt is defined to be stable
dHdt = (F .+ inn(MB[:,:,year])) .* Δt # dH for time step t
# Use Zygote.Buffer in order to mutate matrix while being able to compute gradients
H_buff = Zygote.Buffer(H)
H_buff[2:end - 1,2:end - 1] .= max.(0.0, inn(H_buff) .+ dHdt) # CODE FAILS HERE
H = copy(H_Zygote)
t += Δt
end
end
So my problem is that I’m overwriting the H
matrix at each time step, but Zygote cannot perform the pullback on that. Is there any clean workaround to this? Zygote.Buffer
doesn’t seem to help here. Thanks in advance!