Flux PINN 1D Burgers

Hi! I am trying to construct PINN to solve 1D Burgers equation in Flux.jl without using NeuralPDE.jl. The velocity field u(t,x) is defined by the net_u([t,x]) and the batch for training consists of N_u=32 points corresponding to the initial and boundary conditions and N_f=32 points inside the calculation domain t \in [0,1], \; x \in [-1,1], which are stacked together. Accordingly, the loss function loss(x,y) is the sum of two terms, corresponding to initial+boundary points and to the internal points, where the Burgers equation f(t,x) = u_t(t,x) + u(t,x)*u_x(t,x) - nu*u_xx(t,x) = 0 should be fullfilled. Below is my code

using Flux, Zygote, Statistics

net_u = Chain(Dense(2 => 20, tanh), Dense(20 => 20, tanh), Dense(20 => 20, tanh), Dense(20 => 20, tanh),
        Dense(20 => 1))

u(t,x) = net_u([t,x])[1]
u_t(t,x) = Zygote.gradient((x)->u(x[1],x[2]), [t, x])[1][1]
u_x(t,x) = Zygote.gradient((x)->u(x[1],x[2]), [t, x])[1][2]
u_xx(t,x) = Zygote.hessian((x)->u(x[1],x[2]), [t, x])[2,2]

nu = 0.01/pi
f(t,x) = u_t(t,x) + u(t,x)*u_x(t,x) - nu*u_xx(t,x)

# generating data points
N_u = 32
N_f = 32

function get_batch()
    tx_train = zeros(2, N_u+N_f)
    u_train = zeros(N_u+N_f)

    for i in 1:trunc(Int, N_u/4)
        tx_train[2,i]=-rand()
        u_train[i]=-sin(pi*tx_train[2,i])
    end

    for i in trunc(Int, N_u/4):trunc(Int, N_u/2)
        tx_train[2,i]=rand()
        u_train[i]=-sin(pi*tx_train[2,i])
    end

    for i in trunc(Int, N_u/2):trunc(Int, 3*N_u/4)
        tx_train[1,i]=rand()
        tx_train[2,i]=1.0
        u_train[i]=0.0
    end

    for i in trunc(Int, 3*N_u/4):N_u
        tx_train[1,i]=rand()
        tx_train[2,i]=-1.0
        u_train[i]=0.0
    end

    for i in N_u+1:N_u+N_f
        tx_train[1,i]=rand()
        tx_train[2,i]=2.0*rand()-1.0
    end
    
    return (tx_train, u_train)
end

function loss(x,y)
    loss_u = Flux.Losses.mse(y[1:N_u]',net_u(x[:,1:N_u]))
    loss_f = mean(f.(x[1,N_u+1:N_u+N_f],x[2,N_u+1:N_u+N_f]).^2)  
    return loss_u+loss_f
end

# train and return losses
function run_training(network, epoch)
    losses=[]
    pars = Flux.params(network)  # contains references to arrays in model
    opt = Flux.Adam()    # will store optimiser momentum, etc.
    
    for _ in 1:epoch        
        x,y = get_batch()
        data = Flux.DataLoader((x, y), batchsize=N_u+N_f)   
        Flux.train!(loss, pars, data, opt)        
        push!(losses, loss(x, y))
    end
    return losses
end

losses = run_training(net_u, 1)

which results in error:

Mutating arrays is not supported -- called setindex!(Matrix{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations


Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] _throw_mutation_error(f::Function, args::Matrix{Float64})
    @ Zygote C:\Users\parfe\.julia\packages\Zygote\g2w9o\src\lib\array.jl:86
  [3] (::Zygote.var"#391#392"{Matrix{Float64}})(#unused#::Nothing)
    @ Zygote C:\Users\parfe\.julia\packages\Zygote\g2w9o\src\lib\array.jl:98
  [4] (::Zygote.var"#2488#back#393"{Zygote.var"#391#392"{Matrix{Float64}}})(Δ::Nothing)
    @ Zygote C:\Users\parfe\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
  [5] Pullback
    @ C:\Users\parfe\.julia\packages\Zygote\g2w9o\src\lib\forward.jl:31 [inlined]
  [6] (::typeof(∂(forward_jacobian)))(Δ::Tuple{Nothing, Zygote.OneElement{Float64, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}})
    @ Zygote C:\Users\parfe\.julia\packages\Zygote\g2w9o\src\compiler\interface2.jl:0
  [7] Pullback
    @ C:\Users\parfe\.julia\packages\Zygote\g2w9o\src\lib\forward.jl:44 [inlined]
  [8] Pullback
    @ C:\Users\parfe\.julia\packages\Zygote\g2w9o\src\lib\forward.jl:42 [inlined]
  [9] Pullback
    @ C:\Users\parfe\.julia\packages\Zygote\g2w9o\src\lib\grad.jl:64 [inlined]
 [10] (::typeof(∂(hessian_dual)))(Δ::Zygote.OneElement{Float64, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote C:\Users\parfe\.julia\packages\Zygote\g2w9o\src\compiler\interface2.jl:0
 [11] Pullback
    @ C:\Users\parfe\.julia\packages\Zygote\g2w9o\src\lib\grad.jl:62 [inlined]
 [12] Pullback
    @ .\In[4]:9 [inlined]
 [13] (::typeof(∂(u_xx)))(Δ::Float64)
    @ Zygote C:\Users\parfe\.julia\packages\Zygote\g2w9o\src\compiler\interface2.jl:0
 [14] Pullback
    @ .\In[4]:12 [inlined]
 [15] (::typeof(∂(f)))(Δ::Float64)
    @ Zygote C:\Users\parfe\.julia\packages\Zygote\g2w9o\src\compiler\interface2.jl:0
 [16] #938
    @ C:\Users\parfe\.julia\packages\Zygote\g2w9o\src\lib\broadcast.jl:205 [inlined]
 [17] #4
    @ .\generator.jl:36 [inlined]
 [18] iterate
    @ .\generator.jl:47 [inlined]
 [19] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{Vector{Tuple{Float64, typeof(∂(f))}}, Vector{Float64}}}, Base.var"#4#5"{Zygote.var"#938#944"}})
    @ Base .\array.jl:787
 [20] map
    @ .\abstractarray.jl:3055 [inlined]
 [21] (::Zygote.var"#∇broadcasted#943"{Tuple{Vector{Float64}, Vector{Float64}}, Vector{Tuple{Float64, typeof(∂(f))}}, Val{3}})(ȳ::Vector{Float64})
    @ Zygote C:\Users\parfe\.julia\packages\Zygote\g2w9o\src\lib\broadcast.jl:205
 [22] #3885#back
    @ C:\Users\parfe\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [23] #208
    @ C:\Users\parfe\.julia\packages\Zygote\g2w9o\src\lib\lib.jl:206 [inlined]
 [24] #2066#back
    @ C:\Users\parfe\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [25] Pullback
    @ .\broadcast.jl:1304 [inlined]
 [26] Pullback
    @ .\In[5]:41 [inlined]
 [27] (::typeof(∂(loss)))(Δ::Float64)
    @ Zygote C:\Users\parfe\.julia\packages\Zygote\g2w9o\src\compiler\interface2.jl:0
 [28] #208
    @ C:\Users\parfe\.julia\packages\Zygote\g2w9o\src\lib\lib.jl:206 [inlined]
 [29] #2066#back
    @ C:\Users\parfe\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [30] Pullback
    @ C:\Users\parfe\.julia\packages\Flux\v79Am\src\optimise\train.jl:143 [inlined]
 [31] (::typeof(∂(λ)))(Δ::Float64)
    @ Zygote C:\Users\parfe\.julia\packages\Zygote\g2w9o\src\compiler\interface2.jl:0
 [32] (::Zygote.var"#99#100"{Params{Zygote.Buffer{Any, Vector{Any}}}, typeof(∂(λ)), Zygote.Context{true}})(Δ::Float64)
    @ Zygote C:\Users\parfe\.julia\packages\Zygote\g2w9o\src\compiler\interface.jl:389
 [33] withgradient(f::Function, args::Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote C:\Users\parfe\.julia\packages\Zygote\g2w9o\src\compiler\interface.jl:133
 [34] macro expansion
    @ C:\Users\parfe\.julia\packages\Flux\v79Am\src\optimise\train.jl:142 [inlined]
 [35] macro expansion
    @ C:\Users\parfe\.julia\packages\ProgressLogging\6KXlp\src\ProgressLogging.jl:328 [inlined]
 [36] train!(loss::Function, ps::Params{Zygote.Buffer{Any, Vector{Any}}}, data::MLUtils.DataLoader{Tuple{Matrix{Float64}, Vector{Float64}}, Random._GLOBAL_RNG, Val{nothing}}, opt::Adam; cb::Flux.Optimise.var"#38#41")
    @ Flux.Optimise C:\Users\parfe\.julia\packages\Flux\v79Am\src\optimise\train.jl:140
 [37] train!(loss::Function, ps::Params{Zygote.Buffer{Any, Vector{Any}}}, data::MLUtils.DataLoader{Tuple{Matrix{Float64}, Vector{Float64}}, Random._GLOBAL_RNG, Val{nothing}}, opt::Adam)
    @ Flux.Optimise C:\Users\parfe\.julia\packages\Flux\v79Am\src\optimise\train.jl:136
 [38] run_training(network::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, epoch::Int64)
    @ Main .\In[5]:54
 [39] top-level scope
    @ In[6]:1

but works if the loss function contains only the first part loss(x,y)=loss_u. What is wrong here and how can I resolve the problem?