Autograd through loss function with derivatives in

Hi All, I am relatively new to Julia, coming from a C/C++ & Python background. I am absolutely loving Julia, but I am struggling a bit to build a simple PINN model using Zygote/Flux (and I tried ForwardDiff). I am using one of the standard examples from Raissi et al. Simple PINN NLS, but I simplified the loss function for testing purposes. I would like to define a loss function which has derivatives with respect to input features, and then use Zygote.gradient to find the gradients of the loss function with respect to the model parameters.
Please see below the definition of the network:

#define layers
fc1 = Dense(2, 20, tanh);
fc2 = Dense(20, 2, identity);

#create chain and set optimizer
model = Chain(fc1, fc2);
pars = Flux.params(model);
opt = ADAM();

Here is my loss function code:

function loss_fn(model, X)
    u(x) = model(x)[1,:]
    v(x) = model(x)[2,:]
    ∂u∂x = Zygote.gradient(a -> sum(u(a)), X)[1][1,:]
    ∂u∂t = Zygote.gradient(a -> sum(u(a)), X)[1][2,:]
    ∂v∂x = Zygote.gradient(a -> sum(v(a)), X)[1][1,:]
    ∂v∂t = Zygote.gradient(a -> sum(v(a)), X)[1][2,:]
    loss = mean(u(X).^2) + mean(v(X).^2) + mean(∂u∂x.^2) +
            mean(∂u∂t.^2) + mean(∂v∂x.^2) + mean(∂v∂t.^2)
    return loss

Here is the line I am trying to use to find the gradients of the loss function to model parameters:

grads = Flux.gradient(() -> loss_fn(model, x0_t0), pars)

But when I run the above code I get a “mutating” array error. I know other colleagues have proposed using a combination of ReverseDiff and Zygote.gradient to fix this. But honestly I am not sure what that would look like? Any help would be appreciated.

Looks like using ReverseDiff over Zygote works, see this discussion post for the discussion. Here is my loss function with using Zygote to get the gradients:

layer(W, b, x) = tanh.(W * x .+ b);
function loss(W, b, x)
    u(x) = layer(W, b, x)[1,:]
    v(x) = layer(W, b, x)[2,:]
    ∂u∂x = Zygote.gradient(a -> sum(u(a)), x)[1][1,:]
    ∂v∂x = Zygote.gradient(a -> sum(v(a)), x)[1][1,:]
    mean(∂u∂x.^2) + mean(∂v∂x.^2) #just a simple loss function, not physical

Here is the gradient of the loss function with respect to the model weights and biases:

const W, b = rand(2,2), rand(2);
const input = (W, b);
const diff = ReverseDiff.compile(ReverseDiff.GradientTape((a,b) -> loss(a,b,x0_t0), input));
a = ReverseDiff.gradient!(map(copy, input), diff, input);
2×2 Array{Float64,2}:
 0.118311  0.0
 0.126966  0.0