I’m trying to implement a Universal Differential Equation (UDE) based on a 2D diffusion PDE with an integrated neural network. I’m computing the diffusivity of a 2D matrix
H with a forward scheme, and once I reach the the final state of
H I compute the loss function and try to backpropagate the whole thing using
Zygote.pullback() in order to update the weights of the neural network.
Since I am updating (overwriting) the
H matrix at each time step, Zygote rightfully complains that “Mutating arrays is not supported”. I have tried to work around this by using
Zygote.Buffer, but I still keep getting the error. I have also read most posts related to this issue, but I haven’t been able to find a solution to this. I’m using Julia 1.6.
Here is an overview of what my code does (I have omitted many parts, as the original code does many more things that I judge irrelevant to this issue):
# Train the UDE # We get the model parameters to be trained ps_UA = Flux.params(UA) # get the parameters of the NN #data = Flux.Data.DataLoader((X, Y), batchsize=hyparams.batchsize, (X, Y), shuffle=false) data = Dict("X"=>X, "Y"=>Y) # Get the input data # Train the UDE for a given number of epochs @epochs hyparams.epochs hybrid_train!(loss, ps_UA, data, opt, H, p, t, t₁)
function hybrid_train!(loss, ps_UA, data, opt, H, p, t, t₁) # Some previous code here # Retrieve model parameters ps_UA = Params(ps_UA) # back is a method that computes the product of the gradient so far with its argument. train_loss_UA, back_UA = Zygote.pullback(() -> loss(data, H, p, t, t₁), ps_UA) # Callback to track the training callback(train_loss_UA) # Apply back() to the correct type of 1.0 to get the gradient of loss. gs_UA = back_UA(one(train_loss_UA)) # Insert what ever code you want here that needs gradient. Flux.update!(opt, ps_UA, gs_UA) end
Based on a loss function which contains the forward scheme
function loss(data, H, p, t, t₁) l_H = 0.0f0 # Compute l_H as the difference between the simulated H with UA(x) and H_ref iceflow!(H, UA, p,t,t₁) l_H = sqrt(Flux.Losses.mse(H, H_ref["H"][end]; agg=mean)) return l_H
Where the forward scheme is defined here (this is a simplified version of the code):
function iceflow!(H, UA, p,t,t₁) # Manual explicit forward scheme implementation while t < t₁ # Long chunk of code which computes the flux F in a staggered grid # Δt is defined to be stable dHdt = (F .+ inn(MB[:,:,year])) .* Δt # dH for time step t # Use Zygote.Buffer in order to mutate matrix while being able to compute gradients H_buff = Zygote.Buffer(H) H_buff[2:end - 1,2:end - 1] .= max.(0.0, inn(H_buff) .+ dHdt) # CODE FAILS HERE H = copy(H_Zygote) t += Δt end end
So my problem is that I’m overwriting the
H matrix at each time step, but Zygote cannot perform the pullback on that. Is there any clean workaround to this?
Zygote.Buffer doesn’t seem to help here. Thanks in advance!