Workaround to "Mutating arrays is not supported" with Zygote and UDEs

I’m trying to implement a Universal Differential Equation (UDE) based on a 2D diffusion PDE with an integrated neural network. I’m computing the diffusivity of a 2D matrix H with a forward scheme, and once I reach the the final state of H I compute the loss function and try to backpropagate the whole thing using Zygote.pullback() in order to update the weights of the neural network.

Since I am updating (overwriting) the H matrix at each time step, Zygote rightfully complains that “Mutating arrays is not supported”. I have tried to work around this by using Zygote.Buffer, but I still keep getting the error. I have also read most posts related to this issue, but I haven’t been able to find a solution to this. I’m using Julia 1.6.

Here is an overview of what my code does (I have omitted many parts, as the original code does many more things that I judge irrelevant to this issue):

# Train the UDE
# We get the model parameters to be trained
ps_UA = Flux.params(UA) # get the parameters of the NN
#data = Flux.Data.DataLoader((X, Y), batchsize=hyparams.batchsize, (X, Y), shuffle=false)
data = Dict("X"=>X, "Y"=>Y) # Get the input data
# Train the UDE for a given number of epochs
@epochs hyparams.epochs hybrid_train!(loss, ps_UA, data, opt, H, p, t, t₁)

Which calls

function hybrid_train!(loss, ps_UA, data, opt, H, p, t, t₁)
    # Some previous code here
    # Retrieve model parameters
    ps_UA = Params(ps_UA)

    # back is a method that computes the product of the gradient so far with its argument.
    train_loss_UA, back_UA = Zygote.pullback(() -> loss(data, H, p, t, t₁), ps_UA)
    # Callback to track the training
    callback(train_loss_UA)
    # Apply back() to the correct type of 1.0 to get the gradient of loss.
    gs_UA = back_UA(one(train_loss_UA))
    # Insert what ever code you want here that needs gradient.
    Flux.update!(opt, ps_UA, gs_UA)
end

Based on a loss function which contains the forward scheme

function loss(data, H, p, t, t₁)
    l_H = 0.0f0

    # Compute l_H as the difference between the simulated H with UA(x) and H_ref
    iceflow!(H, UA, p,t,t₁)
    l_H = sqrt(Flux.Losses.mse(H, H_ref["H"][end]; agg=mean))
    return l_H

Where the forward scheme is defined here (this is a simplified version of the code):

function iceflow!(H, UA, p,t,t₁)
    # Manual explicit forward scheme implementation
    while t < t₁
        # Long chunk of code which computes the flux F in a staggered grid
        # Δt is defined to be stable
        dHdt = (F .+ inn(MB[:,:,year])) .* Δt  # dH for time step t

        # Use Zygote.Buffer in order to mutate matrix while being able to compute gradients
        H_buff = Zygote.Buffer(H)
        H_buff[2:end - 1,2:end - 1] .= max.(0.0, inn(H_buff) .+ dHdt)  # CODE FAILS HERE
        H = copy(H_Zygote)
        
        t += Δt
    end
end

So my problem is that I’m overwriting the H matrix at each time step, but Zygote cannot perform the pullback on that. Is there any clean workaround to this? Zygote.Buffer doesn’t seem to help here. Thanks in advance!

Zygote.Buffer only supports scalar indexing IIRC and no broadcast expressions.

Note that it’s more stable to only solve for the interior of the PDE, so you might want to just change that part of the discretization approach which would make the issue go away.

2 Likes

Hi @ChrisRackauckas , thank you for your answer! I am also following this issue. What do you mean by “only solve for the interior of the PDE” and why that is going to solve this issue? Thank you!

It’s another way of stating ghost points. See the definition here:

http://diffeqoperators.sciml.ai/dev/operators/derivative_operators/#Derivative-Operators

what it amounts to is that you only need the H_buff[2:end - 1,2:end - 1] .= max.(0.0, inn(H_buff) .+ dHdt) equation, in which case you don’t need the buffer because you can instead just return max.(0.0, H .+ dHdt)

1 Like

Thanks Chris, that’s a good point. However, I don’t quite see how to get away with keeping only the interior of the PDE. For each iteration (time step) I lose one row and column in the matrix due to gradient computations. So at the end of each iteration I somehow need to move back to the original matrix size in order to avoid my matrix shrinking at each time step.

DiffEqOperators.jl seems to be able to handle ghost points within its functions, but I would like to keep my code structure as is. Is there any way to avoid this issue or am I missing something?

Boundary conditions are extrapolation conditions by which you make the matrix an operator of the interior. Take a read of that page again with that in mind.

OK, so I’m currently trying to re-write my 2D staggered grid diffusion PDE using DiffEqOperators.jl in hopes that it will work smoother with Zygote.pullback. I’ve seen some features are very recent, so the documentation is quite tough to follow.

Here’s what I have now:

∂x = CenteredDifference{1}(1, 2, Δx, size(H,1))
∂y = CenteredDifference{2}(1, 2, Δx, size(H,2))
∂²x = CenteredDifference{1}(2, 2, Δx, size(H,1))
∂²y = CenteredDifference{2}(2, 2, Δx, size(H,2))

Qx, Qy = MultiDimBC(Dirichlet0BC(eltype(H)), size(H))

I’ve seen it is possible to combine the operators for multiple dimensions like ∂ = ∂x + ∂y and Q = compose(Qx, Qy), but that seems to only work for squared matrices (size(H,1) == size(H,2)). That’s not the case for my matrix, is there a workaround for that? I’m currently getting an error with ∂ = ∂x + ∂y .

And besides that, I’m still getting another error when working with separate dimensions:

dHdx = ∂x*H*Qx

Fails during H*Qx with:

ERROR: MethodError: no method matching *(::Matrix{Float32}, ::MultiDimDirectionalBC{Float32, RobinBC{Float32, Vector{Float32}}, 1, 2, 1})
Closest candidates are:
  *(::Any, ::Any, ::Any, ::Any...) at operators.jl:560
  *(::StridedMatrix{T}, ::StridedVector{S}) where {T<:Union{Float32, Float64, ComplexF32, ComplexF64}, S<:Real} at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:44
  *(::StridedMatrix{var"#s832"} where var"#s832"<:Union{Float32, Float64}, ::StridedMatrix{var"#s831"} where var"#s831"<:Union{Float32, Float64, ComplexF32, ComplexF64}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:158

@ChrisRackauckas What is the correct way to work with boundary conditions in non-square (size(H,1) != size(H,2)) 2D Arrays? Thanks again!

PS: Once I get this working I’ll try to update the documentation with a basic example in order to make it more intuitive for new users.

The right way to do PDEs is to mutate and use EnzymeVJP, which is a now supported in DiffEqSensitivity v6.52 so I’d call that the resolution to this.