Hi! I am trying to train a Universal Differential Equation (UDE) using implicit automatic differentiation with Zygote. I am doing this manually, without using DifferentialEquations. A large part of my code (that I am not including here) consists of solving numerically a differential equation. Some parameters in the differential equations are neural networks and I want to take gradients with respect to the parameters of the neural network. The pullback of the loss function is calculated as

``````loss_UA, back_UA = Zygote.pullback(() -> loss(H, UA, p, t, t₁), ps_UA)
``````

where `loss` computes the mean square error between the real and simulated solutions of the differential equation, and the differentiation is with respect to the parameters of the neural network, `ps_UA`. Then, I compute the gradient simply as

``````∇_UA = back_UA(1)
``````

The problem I am having is that this object seems to be empty or it is some kind of data structure I don’t understand. When I access this object, I get this:

``````>  ∇_UA
``````

I have no way of analyzing what Zygote is returning from the gradient. Furthermore, I can run

``````opt = ADAM(0.1)
Flux.update!(opt, ps_UA, ∇_UA)
``````

without any error but this does not change the value of `ps_UA`.

Any idea of what could be the problem here? What kind of data structure are the gradients?

Thank you in advance for any help!

1 Like

A little extra context on this issue (we are working on it together). Applying the obtained gradients using `Flux.update!(opt, ps_UA, ∇_UA)` returns in fact an error, but it is a different one depending if the parameters are implicit or explicit:

With implicit parameters, using `loss_UA, back_UA = Zygote.pullback(() -> loss(H, UA, p, t, t₁), ps_UA)` as stated by @facusapienza above, we get the following error:

`Flux.update!(opt, ps_UA, ∇_UA)`

``````ERROR: DimensionMismatch("cannot broadcast array to have fewer dimensions")
Stacktrace:
[5] instantiate
[6] materialize!
[7] materialize!
@ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/optimisers.jl:181
@ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/train.jl:23
@ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/train.jl:29
[11] top-level scope
@ none:1
[12] eval
@ ./boot.jl:360 [inlined]
[13] interpret(command::String, mod::Module, locals::Dict{Symbol, Any})
``````

I’m not sure to understand the error, but it looks like the gradient is somehow empty?

On the other hand, if we compute the pullback with explicit parameters, passing a full Flux Chain (`UA`) as:

``````    # Leaky ReLu as activation function
leakyrelu(x, a=0.01) = max(a*x, x)

# Constrains A within physically plausible values
relu_A(x) = min(max(1.58e-17, x), 1.58e-16)

# Define the networks 1->10->5->1
UA = Chain(
Dense(1,10,initb = Flux.zeros),
BatchNorm(10, leakyrelu),
Dense(10,5,initb = Flux.zeros),
BatchNorm(5, leakyrelu),
Dense(5,1, relu_A, initb = Flux.zeros)
)

loss_UA, back_UA = Zygote.pullback(UA -> loss(H, UA, p, t, t₁), UA)
``````

When applying this gradient to the NN parameters we get a different error:

`Flux.update!(opt, ps_UA, ∇_UA)`

``````ERROR: MethodError: no method matching getindex(::Tuple{NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Vector{Float64}, Vector{Float64}, Nothing}}, Base.RefValue{Any}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, Nothing}}, Base.RefValue{Any}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, Nothing}}}}}}, ::Matrix{Float32})
Closest candidates are:
getindex(::Tuple, ::Int64) at tuple.jl:29
getindex(::Tuple, ::Real) at tuple.jl:30
getindex(::Tuple, ::Colon) at tuple.jl:33
...
Stacktrace:
[1] update!(opt::ADAM, xs::Params, gs::Tuple{NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Vector{Float64}, Vector{Float64}, Nothing}}, Base.RefValue{Any}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, Nothing}}, Base.RefValue{Any}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, Nothing}}}}}})
@ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/train.jl:28
[2] top-level scope
@ none:1
[3] eval
@ ./boot.jl:360 [inlined]
[4] interpret(command::String, mod::Module, locals::Dict{Symbol, Any})
``````

Explicitly passing the Flux Chain to the pullback returns a different gradient structure, which does look empty:

``````@show ∇_UA

((layers = ((weight = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], bias = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], σ = nothing), Base.RefValue{Any}((λ = nothing, β = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], γ = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], μ = nothing, σ² = nothing, ϵ = 0.0, momentum = nothing, affine = nothing, track_stats = nothing, active = nothing, chs = nothing)), (weight = [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], bias = [0.0, 0.0, 0.0, 0.0, 0.0], σ = nothing), Base.RefValue{Any}((λ = nothing, β = [0.0, 0.0, 0.0, 0.0, 0.0], γ = [0.0, 0.0, 0.0, 0.0, 0.0], μ = nothing, σ² = nothing, ϵ = 0.0, momentum = nothing, affine = nothing, track_stats = nothing, active = nothing, chs = nothing)), (weight = [0.0 0.0 … 0.0 0.0], bias = [0.0], σ = nothing)),),)
``````

There’s hardly any information on the documentation about gradient structure and how to debug these errors. The documentation recommends debugging with `_pullback()` and applying `back()` to trace errors, but this doesn’t return any error in our case. We obtain gradients but they just seem to be empty and the code crashes once we try to update the NN parameters. Any ideas on where the issue might come from?

1 Like

The implicit `Grads` is not necessarily empty. It just has an awful `show` method. You can see the actual content with `∇_UA.grads`. I would run the following for the implicit case:

``````for p in ps_UA
@show p, ∇_UA[p]
end
``````

Make sure that at each print the shape of the parameter and its gradient are aligned.

Just based on Line 8 of the implicit stack trace, it looks like you probably have a row/column vector misalignment. We just need to do a little digging to figure out from where.

The explicit form will not work with Flux’s built-in optimizers. But you can try Optimisers.jl (under development right now) which supports them.

2 Likes

Thanks a lot for the reply. This is interesting, showing the parameters and their gradients only seems to work with implicit parameters, as you said. It does fail when passing explicit parameters. The documentation recommends not using implicit parameters unless it is needed, and it seems to offer two alternatives passing explicit parameters. How come it doesn’t work for our case?

Concerning the size of the parameters and the gradients, funny enough when I display them like this:

``````    for ps in ps_UA
@show ps, ∇_UA[ps]
println("size ps: ", size(ps))
println("size ∇_UA[p]: ", size(∇_UA[ps]))

println("type ps: ", typeof(ps))
println("type ∇_UA[p]: ", typeof(∇_UA[ps]))
end
``````

All the sizes seem to match, except for the first parameter which has shape `(10,1)`and its gradient `(10,)`. Nonetheless, the gradient seems to work on the parameters when doing `∇_UA[p]`. There seems to be a misalignment from the parameters using `Float32` and the gradients `Float64` though:

``````(ps, ∇_UA[ps]) = (Float32[0.30798444; 0.52923024; -0.38974404; -0.44488695; -0.007705779; -0.46781763; 0.30756167; -0.73545146; -0.60995233; -0.12095741], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
size ps: (10, 1)
size ∇_UA[p]: (10,)
type ps: Matrix{Float32}
type ∇_UA[p]: Vector{Float64}
(ps, ∇_UA[ps]) = (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
size ps: (10,)
size ∇_UA[p]: (10,)
type ps: Vector{Float32}
type ∇_UA[p]: Vector{Float64}
(ps, ∇_UA[ps]) = (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
size ps: (10,)
size ∇_UA[p]: (10,)
type ps: Vector{Float32}
type ∇_UA[p]: Vector{Float64}
(ps, ∇_UA[ps]) = (Float32[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
size ps: (10,)
size ∇_UA[p]: (10,)
type ps: Vector{Float32}
type ∇_UA[p]: Vector{Float64}
(ps, ∇_UA[ps]) = (Float32[-0.49900624 0.1739867 -0.50695115 -0.012244531 0.47494888 0.3235831 0.021464685 -0.48237133 -0.586465 -0.38168398; -0.40725645 -0.26300266 -0.14521688 0.020233944 -0.07136398 -0.56981426 -0.05533645 0.16115816 -0.4485389 -0.56794554; -0.37900218 -0.08815088 0.10154217 0.558363 -0.22744176 0.12258495 0.18857977 -0.16126387 -0.45260283 -0.54091734; -0.47956002 -0.27310026 -0.43743765 0.032916818 0.095131814 -0.6059501 -0.40490097 0.43668085 -0.31058735 -0.21437271; -0.031416014 0.21674222 0.485597 -0.3657828 -0.24838457 0.52909964 0.44705272 0.16652822 0.5047817 -0.5061942], [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0])
size ps: (5, 10)
size ∇_UA[p]: (5, 10)
type ps: Matrix{Float32}
type ∇_UA[p]: Matrix{Float64}
(ps, ∇_UA[ps]) = (Float32[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0])
size ps: (5,)
size ∇_UA[p]: (5,)
type ps: Vector{Float32}
type ∇_UA[p]: Vector{Float64}
(ps, ∇_UA[ps]) = (Float32[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0])
size ps: (5,)
size ∇_UA[p]: (5,)
type ps: Vector{Float32}
type ∇_UA[p]: Vector{Float64}
(ps, ∇_UA[ps]) = (Float32[1.0, 1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0, 0.0])
size ps: (5,)
size ∇_UA[p]: (5,)
type ps: Vector{Float32}
type ∇_UA[p]: Vector{Float64}
(ps, ∇_UA[ps]) = (Float32[-0.65684795 0.9604726 -0.8283665 0.7575054 0.6725607], [0.0 0.0 0.0 0.0 0.0])
size ps: (1, 5)
size ∇_UA[p]: (1, 5)
type ps: Matrix{Float32}
type ∇_UA[p]: Matrix{Float64}
(ps, ∇_UA[ps]) = (Float32[0.0], [0.0])
size ps: (1,)
size ∇_UA[p]: (1,)
type ps: Vector{Float32}
type ∇_UA[p]: Vector{Float64}
Hit `@infiltrate` in hybrid_train!(loss::typeof(loss), UA::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, BatchNorm{var"#leakyrelu#180", Vector{Float32}, Float32, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, BatchNorm{var"#leakyrelu#180", Vector{Float32}, Float32, Vector{Float32}}, Dense{var"#relu_A#181", Matrix{Float32}, Vector{Float32}}}}, opt::ADAM, H::Matrix{Float32}, p::Tuple{Int64, Int64, Float64, Float64, Matrix{Float32}, Array{Float64, 3}, Array{Float32, 3}, Vector{Any}, Float64, Int64}, t::Int64, t₁::Float64) at iceflow.jl:79:
``````

Still, when I do `Flux.update!(opt, ps_UA, ∇_UA)` I get the same error as mentioned previously:

``````ERROR: DimensionMismatch("cannot broadcast array to have fewer dimensions")
Stacktrace:
[5] instantiate
[6] materialize!
[7] materialize!
@ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/optimisers.jl:181
@ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/train.jl:23
@ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/train.jl:29
``````

How come I seem to be able to apply the gradients to the parameters in the loop but they don’t work with `Flux.update!()`?

I have also tried applying the gradients inside the for loop doing `Flux.update!(opt, ps, ∇_UA)`, but I get another error:

``````ERROR: MethodError: no method matching +(::Float64, ::Vector{Float64})
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...) at operators.jl:560
+(::Union{Float16, Float32, Float64}, ::BigFloat) at mpfr.jl:392
+(::FillArrays.Zeros{T, N, Axes} where Axes, ::AbstractArray{V, N}) where {T, V, N} at /Users/Bolib001/.julia/packages/FillArrays/cVkp8/src/fillalgebra.jl:180
...
Stacktrace:
[3] getindex
[4] macro expansion
[5] macro expansion
@ ./simdloop.jl:77 [inlined]
[6] copyto!
[7] copyto!
[8] materialize!
[9] materialize!
@ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/optimisers.jl:179
@ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/train.jl:23
``````

Don’t upconvert to Float64 here. `min(max(1.58f-17, x), 1.58f-16)` if you want Float32, etc.

1 Like

MWE?

As Chris already mentioned, your code implicitly promotes the output of `leakyrelu` etc. to `Float64` because you use double precision constants. Do something like `1f-2` etc.

This is the source of your error. If you look at the error message, it says your dimensions mismatch. `(10,1)` has 2 dimensions but `(10,)` only has one. This is happening because your first layer is actually just a vector weight, but Flux creates a matrix. I would characterize this as a Zygote bug. Can you file an issue on Zygote.jl and link this discourse?

That snippet should be updated. Flux’s optimizers don’t understand structured explicit gradients like the kind you get when you differentiate the model explicitly. For now, if you want to use explicit gradients, then you need to use the optimizers in Optimisers.jl.

Explicit and implicit parameters have totally different structures, so the for loop that I wrote will not work to inspect the explicit structure. The explicit structure will be nested tuples, but `Grads` is a single dictionary like object.

2 Likes

I have opened an issue for this problem. For now I cannot provide a MWE, as I’m working with a rather big model and it would require a lot of work to get down to a simple standalone example. I hope this will still be enough.

I have tried flattening the parameters in order to get the same size as for the gradients, but then I cannot apply the gradients, as I’m getting an error. @darsnack Is there any way to manually patch the obtained gradients in order to get a workaround for this issue? I’d like to try to get this working without having to wait for a Zygote fix.

I have made a PR with a small update on the documentation with these details. I hope this can help new users to avoid wasting time as I did.

Thanks a lot again for your help!

I haven’t tried this, but you might be able to patch one of the functions in the above stacktrace to reshape, such as (`update!` has fewer methods than `apply!`, so less chance of ambiguities):

``````Flux.Optimise.update!(opt, x::AbstractMatrix, Δ::AbstractVector) = update!(opt, x, reshape(Δ, size(x)))
``````

The problem might also go away in Julia 1.7, since Base was fixed to allow e.g. `ones(3) .= rand(3,1)` without error. Although it’s a little confusing that this is `apply!(o, x::Matrix, Δ::Vector)` and not `x::Vector, Δ::Matrix`.

3 Likes

Great, as you suggested, your patch did the trick! I just had to specify the full path for both calls to `update!`:

`Flux.Optimise.update!(opt, x::AbstractMatrix, Δ::AbstractVector) = Flux.Optimise.update!(opt, x, reshape(Δ, size(x)))`

Thanks again for your help Michael! I will post this as well in the PR to confirm the source of the issue.

The gradients are still all 0s though, but that’s another issue (and probably my fault).

Make sure you’re actually using the parameter vectors and not some copy of them.

Thanks @ChrisRackauckas for your suggestion. I checked this and it doesn’t seem to be the problem.

Something I observed is that if I use the `relu_A(x)` activation function,

``````relu_A(x) = min(max(minA, x), maxA),
``````

where `minA` and `maxA` are the lower and upper bounds of the activation function, the gradients are zero. This makes sense in the regions where the output of the neural network takes values `minA` or `minA` (ie, the flat regions of the activation function). There, the values of the NN and then the loss function is locally constant. However, if I use any other activation function with non-vanishing gradients, such as

``````sigmoid_A(x) = minA + (maxA - minA) / ( 1 + exp(-x) )
``````

then the gradient​ is filled with `NaN`s.

At which value is the derivative a NaN and why?

When I compute the gradient of the loss function with respect to the parameters of the NN using `loss_UA, back_UA = Zygote.pullback(() -> loss(H, UA, p, t, t₁), ps_UA)` with a simple NN with the `sigmoid_A()` activation function,

``````    sigmoid_A(x) = minA + (maxA - minA) / ( 1 + exp(-x) )

UA = Chain(
Dense(1,1, sigmoid_A, initb = Flux.zeros)
)
``````

I always obtain a gradient with just `NaN`s:

``````(ps, ∇_UA[ps]) = (Float32[0.77069396], [NaN])
size ps: (1, 1)
size ∇_UA[p]: (1, 1)
type ps: Matrix{Float32}
type ∇_UA[p]: Matrix{Float64}
(ps, ∇_UA[ps]) = (Float32[0.0], [NaN])
size ps: (1,)
size ∇_UA[p]: (1,)
type ps: Vector{Float32}
type ∇_UA[p]: Vector{Float64}
``````

If I add layers with more parameters, the gradient with respect to these is also `NaN`s. This behavior is not observed if I use `relu_A()`, which I find a little bit weird since the derivative of the sigmoid is easy to compute and the rest of the chain rule should run without caring about the activation function I am using for the NN.

That looks like a bug. @oxinabox could it be the auto forward stuff? Try changing NaN-safe mode of ForwardDiff: Advanced Usage Guide · ForwardDiff

1 Like

We are yet to actually deploy any auto-forward stuff in ChainRules just made and tested that it was possible.

But concurrently with that work @mcabbott did add some auto-forward stuff to Zygote itself for broadcast only.
Which probably is being hit in this case for applying the activation function.
They might know.

Well, together with @facusapienza we’ve been working on creating a simpler MWE based on the heat equation to allow the reproduction of this issue. We have managed to reproduce the size mismatch error for the Flux model parameters. As I said in my previous post, once I apply the patch by @mcabbott, this issue goes away, but there are still remaining problems with the propagation of the parameters and the model outside the `Zygote.pullback()` function.

In this MWE based on a 2D heat equation, the neural network is correctly working inside `Zygote.pullback()`, but somehow when the neural network (`UA`) is used outside the pullback it always returns `NaN`s. I have added some logs in the MWE to illustrate this. To make this even weirder, the model parameters outside the pullback seem to be fine, and they can be correctly updated, but the Flux model is somehow broken. It looks like the model outside the pullback is no longer linked to the implicit parameters.

Here is the MWE:

``````using LinearAlgebra
using Statistics
using Zygote
using Flux
using Flux: @epochs
using Tullio

#### Parameters
nx, ny = 100, 100 # Size of the grid
Δx, Δy = 1, 1
Δt = 0.01
t₁ = 1

D₀ = 1
tolnl = 1e-4
itMax = 100
damp = 0.85
dτsc   = 1.0/3.0
ϵ     = 1e-4            # small number
cfl  = max(Δx^2,Δy^2)/4.1

A₀ = 1
ρ = 9
g = 9.81
n = 3
p = (Δx, Δy, Δt, t₁, ρ, g, n)  # we add extra parameters for the nonlinear diffusivity

### Reference dataset for the heat Equations
T₀ = [ 250 * exp( - ( (i - nx/2)^2 + (j - ny/2)^2 ) / 300 ) for i in 1:nx, j in 1:ny ];
T₁ = copy(T₀);

#######   FUNCTIONS   ############

# Utility functions
@views avg(A) = 0.25 * ( A[1:end-1,1:end-1] .+ A[2:end,1:end-1] .+ A[1:end-1,2:end] .+ A[2:end,2:end] )

@views avg_x(A) = 0.5 .* ( A[1:end-1,:] .+ A[2:end,:] )

@views avg_y(A) = 0.5 .* ( A[:,1:end-1] .+ A[:,2:end] )

### Functions to generate reference dataset to train UDE

function Heat_nonlinear(T, A, p)

Δx, Δy, Δt, t₁, ρ, g, n = p

#### NEW CODE TO BREAK
dTdx = diff(T, dims=1) / Δx
dTdy = diff(T, dims=2) / Δy
∇T = sqrt.(avg_y(dTdx).^2 .+ avg_x(dTdy).^2)

D = A .* avg(T) .* ∇T

dTdx_edges = diff(T[:,2:end - 1], dims=1) / Δx
dTdy_edges = diff(T[2:end - 1,:], dims=2) / Δy

Fx = -avg_y(D) .* dTdx_edges
Fy = -avg_x(D) .* dTdy_edges

F = .-(diff(Fx, dims=1) / Δx .+ diff(Fy, dims=2) / Δy)

dτ = dτsc * min.( 10.0 , 1.0./(1.0/Δt .+ 1.0./(cfl./(ϵ .+ avg(D)))))

return F, dτ

end

# Fake law to create reference dataset and to be learnt by the NN
fakeA(t) = A₀ * exp(2t)

### Heat equation based on a fake A parameter function to compute the diffusivity
function heatflow_nonlinear(T, fA, p, fake, tol=Inf)

Δx, Δy, Δt, t₁, ρ, g, n = p

total_iter = 0
t = 0

while t < t₁

iter = 1
err = 2 * tolnl
Hold = copy(T)
dTdt = zeros(nx, ny)
err = Inf

if fake
A = fA(t)  # compute the fake A value involved in the nonlinear diffusivity
else
# Compute A with the NN once per time step
A = fA([t]')[1]  # compute A parameter involved in the diffusivity
end

while iter < itMax+1 && tol <= err

Err = copy(T)

F, dτ = Heat_nonlinear(T, A, p)

dTdt_ = copy(dTdt)
@tullio dTdt[i,j] := dTdt_[i,j]*damp + ResT[i,j]

T_ = copy(T)

Zygote.ignore() do
Err .= Err .- T
err = maximum(Err)
end

iter += 1
total_iter += 1

end

t += Δt

end

if(!fake)
println("Values of UA in heatflow_nonlinear: ", fA([0., .5, 1.]')) # Simulations here are correct
end

return T

end

# Patch suggested by Michael Abbott needed in order to correctly retrieve gradients
Flux.Optimise.update!(opt, x::AbstractMatrix, Δ::AbstractVector) = Flux.Optimise.update!(opt, x, reshape(Δ, size(x)))

function train(loss, p)

leakyrelu(x, a=0.01) = max(a*x, x)
relu(x) = max(0, x)

UA = Chain(
Dense(1,10,initb = Flux.glorot_normal),
BatchNorm(10, leakyrelu),
Dense(10,5,initb = Flux.glorot_normal),
BatchNorm(5, leakyrelu),
Dense(5,1, relu, initb = Flux.glorot_normal)
)

opt = RMSProp()
losses = []
@epochs 10 hybrid_train_NN!(loss, UA, p, opt, losses)

println("Values of UA in train(): ", UA([0., .5, 1.]'))

return UA, losses

end

function hybrid_train_NN!(loss, UA, p, opt, losses)

T = T₀
θ = Flux.params(UA)
println("Values of UA in hybrid_train BEFORE: ", UA([0., .5, 1.]'))
loss_UA, back_UA = Zygote.pullback(() -> loss(T, UA, p), θ)
push!(losses, loss_UA)

∇_UA = back_UA(one(loss_UA))

for ps in θ
end

println("θ: ", θ) # parameters are NOT NaNs
println("Values of UA in hybrid_train AFTER: ", UA([0., .5, 1.]')) # Simulations here are all NaNs

Flux.Optimise.update!(opt, θ, ∇_UA)

end

function loss_NN(T, UA, p, λ=1)

T = heatflow_nonlinear(T, UA, p, false)
l_cost = sqrt(Flux.Losses.mse(T, T_ref; agg=mean))

return l_cost
end

#######################

########################################
#####  TRAIN 2D HEAT EQUATION PDE  #####
########################################

T₂ = copy(T₀)
# Reference temperature dataset
T_ref = heatflow_nonlinear(T₂, fakeA, p, true, 1e-1)

# Train heat equation UDE
UA_trained, losses = train(loss_NN, p)
``````

Is this a bug? If so, I’ll also open an issue for this. The use of implicit parameters for this is pretty confusing.