I’m trying to implement a Universal Differential Equation (UDE) based on a 2D diffusion PDE with an integrated neural network. I’m computing the diffusivity of a 2D matrix H with a forward scheme, and once I reach the the final state of H I compute the loss function and try to backpropagate the whole thing using Zygote.pullback() in order to update the weights of the neural network.
Since I am updating (overwriting) the H matrix at each time step, Zygote rightfully complains that “Mutating arrays is not supported”. I have tried to work around this by using Zygote.Buffer, but I still keep getting the error. I have also read most posts related to this issue, but I haven’t been able to find a solution to this. I’m using Julia 1.6.
Here is an overview of what my code does (I have omitted many parts, as the original code does many more things that I judge irrelevant to this issue):
# Train the UDE
# We get the model parameters to be trained
ps_UA = Flux.params(UA) # get the parameters of the NN
#data = Flux.Data.DataLoader((X, Y), batchsize=hyparams.batchsize, (X, Y), shuffle=false)
data = Dict("X"=>X, "Y"=>Y) # Get the input data
# Train the UDE for a given number of epochs
@epochs hyparams.epochs hybrid_train!(loss, ps_UA, data, opt, H, p, t, t₁)
Which calls
function hybrid_train!(loss, ps_UA, data, opt, H, p, t, t₁)
    # Some previous code here
    # Retrieve model parameters
    ps_UA = Params(ps_UA)
    # back is a method that computes the product of the gradient so far with its argument.
    train_loss_UA, back_UA = Zygote.pullback(() -> loss(data, H, p, t, t₁), ps_UA)
    # Callback to track the training
    callback(train_loss_UA)
    # Apply back() to the correct type of 1.0 to get the gradient of loss.
    gs_UA = back_UA(one(train_loss_UA))
    # Insert what ever code you want here that needs gradient.
    Flux.update!(opt, ps_UA, gs_UA)
end
Based on a loss function which contains the forward scheme
function loss(data, H, p, t, t₁)
    l_H = 0.0f0
    # Compute l_H as the difference between the simulated H with UA(x) and H_ref
    iceflow!(H, UA, p,t,t₁)
    l_H = sqrt(Flux.Losses.mse(H, H_ref["H"][end]; agg=mean))
    return l_H
Where the forward scheme is defined here (this is a simplified version of the code):
function iceflow!(H, UA, p,t,t₁)
    # Manual explicit forward scheme implementation
    while t < t₁
        # Long chunk of code which computes the flux F in a staggered grid
        # Δt is defined to be stable
        dHdt = (F .+ inn(MB[:,:,year])) .* Δt  # dH for time step t
        # Use Zygote.Buffer in order to mutate matrix while being able to compute gradients
        H_buff = Zygote.Buffer(H)
        H_buff[2:end - 1,2:end - 1] .= max.(0.0, inn(H_buff) .+ dHdt)  # CODE FAILS HERE
        H = copy(H_Zygote)
        
        t += Δt
    end
end
So my problem is that I’m overwriting the H matrix at each time step, but Zygote cannot perform the pullback on that. Is there any clean workaround to this? Zygote.Buffer doesn’t seem to help here. Thanks in advance!