Unrecognized gradient using Zygote for AD with Universal Differential Equations

Hi! I am trying to train a Universal Differential Equation (UDE) using implicit automatic differentiation with Zygote. I am doing this manually, without using DifferentialEquations. A large part of my code (that I am not including here) consists of solving numerically a differential equation. Some parameters in the differential equations are neural networks and I want to take gradients with respect to the parameters of the neural network. The pullback of the loss function is calculated as

loss_UA, back_UA = Zygote.pullback(() -> loss(H, UA, p, t, t₁), ps_UA)

where loss computes the mean square error between the real and simulated solutions of the differential equation, and the differentiation is with respect to the parameters of the neural network, ps_UA. Then, I compute the gradient simply as

∇_UA = back_UA(1)

The problem I am having is that this object seems to be empty or it is some kind of data structure I don’t understand. When I access this object, I get this:

>  ∇_UA
Grads(...)

I have no way of analyzing what Zygote is returning from the gradient. Furthermore, I can run

opt = ADAM(0.1)
Flux.update!(opt, ps_UA, ∇_UA)

without any error but this does not change the value of ps_UA.

Any idea of what could be the problem here? What kind of data structure are the gradients?

Thank you in advance for any help!

1 Like

A little extra context on this issue (we are working on it together). Applying the obtained gradients using Flux.update!(opt, ps_UA, ∇_UA) returns in fact an error, but it is a different one depending if the parameters are implicit or explicit:

With implicit parameters, using loss_UA, back_UA = Zygote.pullback(() -> loss(H, UA, p, t, t₁), ps_UA) as stated by @facusapienza above, we get the following error:

Flux.update!(opt, ps_UA, ∇_UA)

ERROR: DimensionMismatch("cannot broadcast array to have fewer dimensions")
Stacktrace:
  [1] check_broadcast_shape(#unused#::Tuple{}, Ashp::Tuple{Base.OneTo{Int64}})
    @ Base.Broadcast ./broadcast.jl:518
  [2] check_broadcast_shape(shp::Tuple{Base.OneTo{Int64}}, Ashp::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}})
    @ Base.Broadcast ./broadcast.jl:521
  [3] check_broadcast_axes
    @ ./broadcast.jl:523 [inlined]
  [4] check_broadcast_axes
    @ ./broadcast.jl:526 [inlined]
  [5] instantiate
    @ ./broadcast.jl:269 [inlined]
  [6] materialize!
    @ ./broadcast.jl:894 [inlined]
  [7] materialize!
    @ ./broadcast.jl:891 [inlined]
  [8] apply!(o::ADAM, x::Matrix{Float32}, Δ::Vector{Float64})
    @ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/optimisers.jl:181
  [9] update!(opt::ADAM, x::Matrix{Float32}, x̄::Vector{Float64})
    @ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/train.jl:23
 [10] update!(opt::ADAM, xs::Params, gs::Zygote.Grads)
    @ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/train.jl:29
 [11] top-level scope
    @ none:1
 [12] eval
    @ ./boot.jl:360 [inlined]
 [13] interpret(command::String, mod::Module, locals::Dict{Symbol, Any})

I’m not sure to understand the error, but it looks like the gradient is somehow empty?

On the other hand, if we compute the pullback with explicit parameters, passing a full Flux Chain (UA) as:

    # Leaky ReLu as activation function
    leakyrelu(x, a=0.01) = max(a*x, x)

    # Constrains A within physically plausible values
    relu_A(x) = min(max(1.58e-17, x), 1.58e-16)

    # Define the networks 1->10->5->1
    UA = Chain(
        Dense(1,10,initb = Flux.zeros), 
        BatchNorm(10, leakyrelu),
        Dense(10,5,initb = Flux.zeros), 
        BatchNorm(5, leakyrelu),
        Dense(5,1, relu_A, initb = Flux.zeros) 
    )

loss_UA, back_UA = Zygote.pullback(UA -> loss(H, UA, p, t, t₁), UA)

When applying this gradient to the NN parameters we get a different error:

Flux.update!(opt, ps_UA, ∇_UA)

ERROR: MethodError: no method matching getindex(::Tuple{NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Vector{Float64}, Vector{Float64}, Nothing}}, Base.RefValue{Any}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, Nothing}}, Base.RefValue{Any}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, Nothing}}}}}}, ::Matrix{Float32})
Closest candidates are:
  getindex(::Tuple, ::Int64) at tuple.jl:29
  getindex(::Tuple, ::Real) at tuple.jl:30
  getindex(::Tuple, ::Colon) at tuple.jl:33
  ...
Stacktrace:
  [1] update!(opt::ADAM, xs::Params, gs::Tuple{NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Vector{Float64}, Vector{Float64}, Nothing}}, Base.RefValue{Any}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, Nothing}}, Base.RefValue{Any}, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, Nothing}}}}}})
    @ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/train.jl:28
  [2] top-level scope
    @ none:1
  [3] eval
    @ ./boot.jl:360 [inlined]
  [4] interpret(command::String, mod::Module, locals::Dict{Symbol, Any})

Explicitly passing the Flux Chain to the pullback returns a different gradient structure, which does look empty:

@show ∇_UA

((layers = ((weight = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], bias = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], σ = nothing), Base.RefValue{Any}((λ = nothing, β = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], γ = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], μ = nothing, σ² = nothing, ϵ = 0.0, momentum = nothing, affine = nothing, track_stats = nothing, active = nothing, chs = nothing)), (weight = [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], bias = [0.0, 0.0, 0.0, 0.0, 0.0], σ = nothing), Base.RefValue{Any}((λ = nothing, β = [0.0, 0.0, 0.0, 0.0, 0.0], γ = [0.0, 0.0, 0.0, 0.0, 0.0], μ = nothing, σ² = nothing, ϵ = 0.0, momentum = nothing, affine = nothing, track_stats = nothing, active = nothing, chs = nothing)), (weight = [0.0 0.0 … 0.0 0.0], bias = [0.0], σ = nothing)),),)

There’s hardly any information on the documentation about gradient structure and how to debug these errors. The documentation recommends debugging with _pullback() and applying back() to trace errors, but this doesn’t return any error in our case. We obtain gradients but they just seem to be empty and the code crashes once we try to update the NN parameters. Any ideas on where the issue might come from?

Thanks a lot in advance!

1 Like

The implicit Grads is not necessarily empty. It just has an awful show method. You can see the actual content with ∇_UA.grads. I would run the following for the implicit case:

for p in ps_UA
    @show p, ∇_UA[p]
end

Make sure that at each print the shape of the parameter and its gradient are aligned.

Just based on Line 8 of the implicit stack trace, it looks like you probably have a row/column vector misalignment. We just need to do a little digging to figure out from where.

The explicit form will not work with Flux’s built-in optimizers. But you can try Optimisers.jl (under development right now) which supports them.

2 Likes

Thanks a lot for the reply. This is interesting, showing the parameters and their gradients only seems to work with implicit parameters, as you said. It does fail when passing explicit parameters. The documentation recommends not using implicit parameters unless it is needed, and it seems to offer two alternatives passing explicit parameters. How come it doesn’t work for our case?

Concerning the size of the parameters and the gradients, funny enough when I display them like this:

    for ps in ps_UA
        @show ps, ∇_UA[ps]
        println("size ps: ", size(ps))
        println("size ∇_UA[p]: ", size(∇_UA[ps]))

        println("type ps: ", typeof(ps))
        println("type ∇_UA[p]: ", typeof(∇_UA[ps]))
    end

All the sizes seem to match, except for the first parameter which has shape (10,1)and its gradient (10,). Nonetheless, the gradient seems to work on the parameters when doing ∇_UA[p]. There seems to be a misalignment from the parameters using Float32 and the gradients Float64 though:

(ps, ∇_UA[ps]) = (Float32[0.30798444; 0.52923024; -0.38974404; -0.44488695; -0.007705779; -0.46781763; 0.30756167; -0.73545146; -0.60995233; -0.12095741], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
size ps: (10, 1)
size ∇_UA[p]: (10,)
type ps: Matrix{Float32}
type ∇_UA[p]: Vector{Float64}
(ps, ∇_UA[ps]) = (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
size ps: (10,)
size ∇_UA[p]: (10,)
type ps: Vector{Float32}
type ∇_UA[p]: Vector{Float64}
(ps, ∇_UA[ps]) = (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
size ps: (10,)
size ∇_UA[p]: (10,)
type ps: Vector{Float32}
type ∇_UA[p]: Vector{Float64}
(ps, ∇_UA[ps]) = (Float32[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
size ps: (10,)
size ∇_UA[p]: (10,)
type ps: Vector{Float32}
type ∇_UA[p]: Vector{Float64}
(ps, ∇_UA[ps]) = (Float32[-0.49900624 0.1739867 -0.50695115 -0.012244531 0.47494888 0.3235831 0.021464685 -0.48237133 -0.586465 -0.38168398; -0.40725645 -0.26300266 -0.14521688 0.020233944 -0.07136398 -0.56981426 -0.05533645 0.16115816 -0.4485389 -0.56794554; -0.37900218 -0.08815088 0.10154217 0.558363 -0.22744176 0.12258495 0.18857977 -0.16126387 -0.45260283 -0.54091734; -0.47956002 -0.27310026 -0.43743765 0.032916818 0.095131814 -0.6059501 -0.40490097 0.43668085 -0.31058735 -0.21437271; -0.031416014 0.21674222 0.485597 -0.3657828 -0.24838457 0.52909964 0.44705272 0.16652822 0.5047817 -0.5061942], [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0])
size ps: (5, 10)
size ∇_UA[p]: (5, 10)
type ps: Matrix{Float32}
type ∇_UA[p]: Matrix{Float64}
(ps, ∇_UA[ps]) = (Float32[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0])
size ps: (5,)
size ∇_UA[p]: (5,)
type ps: Vector{Float32}
type ∇_UA[p]: Vector{Float64}
(ps, ∇_UA[ps]) = (Float32[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0])
size ps: (5,)
size ∇_UA[p]: (5,)
type ps: Vector{Float32}
type ∇_UA[p]: Vector{Float64}
(ps, ∇_UA[ps]) = (Float32[1.0, 1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0, 0.0])
size ps: (5,)
size ∇_UA[p]: (5,)
type ps: Vector{Float32}
type ∇_UA[p]: Vector{Float64}
(ps, ∇_UA[ps]) = (Float32[-0.65684795 0.9604726 -0.8283665 0.7575054 0.6725607], [0.0 0.0 0.0 0.0 0.0])
size ps: (1, 5)
size ∇_UA[p]: (1, 5)
type ps: Matrix{Float32}
type ∇_UA[p]: Matrix{Float64}
(ps, ∇_UA[ps]) = (Float32[0.0], [0.0])
size ps: (1,)
size ∇_UA[p]: (1,)
type ps: Vector{Float32}
type ∇_UA[p]: Vector{Float64}
Hit `@infiltrate` in hybrid_train!(loss::typeof(loss), UA::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, BatchNorm{var"#leakyrelu#180", Vector{Float32}, Float32, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, BatchNorm{var"#leakyrelu#180", Vector{Float32}, Float32, Vector{Float32}}, Dense{var"#relu_A#181", Matrix{Float32}, Vector{Float32}}}}, opt::ADAM, H::Matrix{Float32}, p::Tuple{Int64, Int64, Float64, Float64, Matrix{Float32}, Array{Float64, 3}, Array{Float32, 3}, Vector{Any}, Float64, Int64}, t::Int64, t₁::Float64) at iceflow.jl:79:

Still, when I do Flux.update!(opt, ps_UA, ∇_UA) I get the same error as mentioned previously:

ERROR: DimensionMismatch("cannot broadcast array to have fewer dimensions")
Stacktrace:
  [1] check_broadcast_shape(#unused#::Tuple{}, Ashp::Tuple{Base.OneTo{Int64}})
    @ Base.Broadcast ./broadcast.jl:518
  [2] check_broadcast_shape(shp::Tuple{Base.OneTo{Int64}}, Ashp::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}})
    @ Base.Broadcast ./broadcast.jl:521
  [3] check_broadcast_axes
    @ ./broadcast.jl:523 [inlined]
  [4] check_broadcast_axes
    @ ./broadcast.jl:526 [inlined]
  [5] instantiate
    @ ./broadcast.jl:269 [inlined]
  [6] materialize!
    @ ./broadcast.jl:894 [inlined]
  [7] materialize!
    @ ./broadcast.jl:891 [inlined]
  [8] apply!(o::ADAM, x::Matrix{Float32}, Δ::Vector{Float64})
    @ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/optimisers.jl:181
  [9] update!(opt::ADAM, x::Matrix{Float32}, x̄::Vector{Float64})
    @ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/train.jl:23
 [10] update!(opt::ADAM, xs::Params, gs::Zygote.Grads)
    @ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/train.jl:29

How come I seem to be able to apply the gradients to the parameters in the loop but they don’t work with Flux.update!()?

I have also tried applying the gradients inside the for loop doing Flux.update!(opt, ps, ∇_UA), but I get another error:

ERROR: MethodError: no method matching +(::Float64, ::Vector{Float64})
For element-wise addition, use broadcasting with dot syntax: scalar .+ array
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at operators.jl:560
  +(::Union{Float16, Float32, Float64}, ::BigFloat) at mpfr.jl:392
  +(::FillArrays.Zeros{T, N, Axes} where Axes, ::AbstractArray{V, N}) where {T, V, N} at /Users/Bolib001/.julia/packages/FillArrays/cVkp8/src/fillalgebra.jl:180
  ...
Stacktrace:
  [1] _broadcast_getindex_evalf
    @ ./broadcast.jl:648 [inlined]
  [2] _broadcast_getindex
    @ ./broadcast.jl:621 [inlined]
  [3] getindex
    @ ./broadcast.jl:575 [inlined]
  [4] macro expansion
    @ ./broadcast.jl:984 [inlined]
  [5] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [6] copyto!
    @ ./broadcast.jl:983 [inlined]
  [7] copyto!
    @ ./broadcast.jl:936 [inlined]
  [8] materialize!
    @ ./broadcast.jl:894 [inlined]
  [9] materialize!
    @ ./broadcast.jl:891 [inlined]
 [10] apply!(o::ADAM, x::Matrix{Float32}, Δ::Zygote.Grads)
    @ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/optimisers.jl:179
 [11] update!(opt::ADAM, x::Matrix{Float32}, x̄::Zygote.Grads)
    @ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/train.jl:23

Thanks again for your help!

Don’t upconvert to Float64 here. min(max(1.58f-17, x), 1.58f-16) if you want Float32, etc.

1 Like

MWE?

As Chris already mentioned, your code implicitly promotes the output of leakyrelu etc. to Float64 because you use double precision constants. Do something like 1f-2 etc.

This is the source of your error. If you look at the error message, it says your dimensions mismatch. (10,1) has 2 dimensions but (10,) only has one. This is happening because your first layer is actually just a vector weight, but Flux creates a matrix. I would characterize this as a Zygote bug. Can you file an issue on Zygote.jl and link this discourse?

That snippet should be updated. Flux’s optimizers don’t understand structured explicit gradients like the kind you get when you differentiate the model explicitly. For now, if you want to use explicit gradients, then you need to use the optimizers in Optimisers.jl.

Explicit and implicit parameters have totally different structures, so the for loop that I wrote will not work to inspect the explicit structure. The explicit structure will be nested tuples, but Grads is a single dictionary like object.

2 Likes

I have opened an issue for this problem. For now I cannot provide a MWE, as I’m working with a rather big model and it would require a lot of work to get down to a simple standalone example. I hope this will still be enough.

I have tried flattening the parameters in order to get the same size as for the gradients, but then I cannot apply the gradients, as I’m getting an error. @darsnack Is there any way to manually patch the obtained gradients in order to get a workaround for this issue? I’d like to try to get this working without having to wait for a Zygote fix.

I have made a PR with a small update on the documentation with these details. I hope this can help new users to avoid wasting time as I did.

Thanks a lot again for your help!

I haven’t tried this, but you might be able to patch one of the functions in the above stacktrace to reshape, such as (update! has fewer methods than apply!, so less chance of ambiguities):

Flux.Optimise.update!(opt, x::AbstractMatrix, Δ::AbstractVector) = update!(opt, x, reshape(Δ, size(x)))

The problem might also go away in Julia 1.7, since Base was fixed to allow e.g. ones(3) .= rand(3,1) without error. Although it’s a little confusing that this is apply!(o, x::Matrix, Δ::Vector) and not x::Vector, Δ::Matrix.

3 Likes

Great, as you suggested, your patch did the trick! I just had to specify the full path for both calls to update!:

Flux.Optimise.update!(opt, x::AbstractMatrix, Δ::AbstractVector) = Flux.Optimise.update!(opt, x, reshape(Δ, size(x)))

Thanks again for your help Michael! I will post this as well in the PR to confirm the source of the issue.

The gradients are still all 0s though, but that’s another issue (and probably my fault).

Make sure you’re actually using the parameter vectors and not some copy of them.

Thanks @ChrisRackauckas for your suggestion. I checked this and it doesn’t seem to be the problem.

Something I observed is that if I use the relu_A(x) activation function,

relu_A(x) = min(max(minA, x), maxA),

where minA and maxA are the lower and upper bounds of the activation function, the gradients are zero. This makes sense in the regions where the output of the neural network takes values minA or minA (ie, the flat regions of the activation function). There, the values of the NN and then the loss function is locally constant. However, if I use any other activation function with non-vanishing gradients, such as

sigmoid_A(x) = minA + (maxA - minA) / ( 1 + exp(-x) )

then the gradient​ is filled with NaNs.

At which value is the derivative a NaN and why?

When I compute the gradient of the loss function with respect to the parameters of the NN using loss_UA, back_UA = Zygote.pullback(() -> loss(H, UA, p, t, t₁), ps_UA) with a simple NN with the sigmoid_A() activation function,

    sigmoid_A(x) = minA + (maxA - minA) / ( 1 + exp(-x) )

    UA = Chain(
        Dense(1,1, sigmoid_A, initb = Flux.zeros) 
    )

I always obtain a gradient with just NaNs:

(ps, ∇_UA[ps]) = (Float32[0.77069396], [NaN])
size ps: (1, 1)
size ∇_UA[p]: (1, 1)
type ps: Matrix{Float32}
type ∇_UA[p]: Matrix{Float64}
(ps, ∇_UA[ps]) = (Float32[0.0], [NaN])
size ps: (1,)
size ∇_UA[p]: (1,)
type ps: Vector{Float32}
type ∇_UA[p]: Vector{Float64}

If I add layers with more parameters, the gradient with respect to these is also NaNs. This behavior is not observed if I use relu_A(), which I find a little bit weird since the derivative of the sigmoid is easy to compute and the rest of the chain rule should run without caring about the activation function I am using for the NN.

That looks like a bug. @oxinabox could it be the auto forward stuff? Try changing NaN-safe mode of ForwardDiff: Advanced Usage Guide · ForwardDiff

1 Like

We are yet to actually deploy any auto-forward stuff in ChainRules just made and tested that it was possible.

But concurrently with that work @mcabbott did add some auto-forward stuff to Zygote itself for broadcast only.
Which probably is being hit in this case for applying the activation function.
They might know.

Well, together with @facusapienza we’ve been working on creating a simpler MWE based on the heat equation to allow the reproduction of this issue. We have managed to reproduce the size mismatch error for the Flux model parameters. As I said in my previous post, once I apply the patch by @mcabbott, this issue goes away, but there are still remaining problems with the propagation of the parameters and the model outside the Zygote.pullback() function.

In this MWE based on a 2D heat equation, the neural network is correctly working inside Zygote.pullback(), but somehow when the neural network (UA) is used outside the pullback it always returns NaNs. I have added some logs in the MWE to illustrate this. To make this even weirder, the model parameters outside the pullback seem to be fine, and they can be correctly updated, but the Flux model is somehow broken. It looks like the model outside the pullback is no longer linked to the implicit parameters.

Here is the MWE:

using LinearAlgebra
using Statistics
using Zygote
using PaddedViews
using Flux
using Flux: @epochs
using Tullio

#### Parameters
nx, ny = 100, 100 # Size of the grid
Δx, Δy = 1, 1
Δt = 0.01
t₁ = 1

D₀ = 1
tolnl = 1e-4
itMax = 100
damp = 0.85
dτsc   = 1.0/3.0
ϵ     = 1e-4            # small number
cfl  = max(Δx^2,Δy^2)/4.1

A₀ = 1
ρ = 9
g = 9.81
n = 3
p = (Δx, Δy, Δt, t₁, ρ, g, n)  # we add extra parameters for the nonlinear diffusivity

### Reference dataset for the heat Equations
T₀ = [ 250 * exp( - ( (i - nx/2)^2 + (j - ny/2)^2 ) / 300 ) for i in 1:nx, j in 1:ny ];
T₁ = copy(T₀);

#######   FUNCTIONS   ############

# Utility functions
@views avg(A) = 0.25 * ( A[1:end-1,1:end-1] .+ A[2:end,1:end-1] .+ A[1:end-1,2:end] .+ A[2:end,2:end] )

@views avg_x(A) = 0.5 .* ( A[1:end-1,:] .+ A[2:end,:] )

@views avg_y(A) = 0.5 .* ( A[:,1:end-1] .+ A[:,2:end] )

### Functions to generate reference dataset to train UDE

function Heat_nonlinear(T, A, p)
   
    Δx, Δy, Δt, t₁, ρ, g, n = p
    
    #### NEW CODE TO BREAK
    dTdx = diff(T, dims=1) / Δx
    dTdy = diff(T, dims=2) / Δy
    ∇T = sqrt.(avg_y(dTdx).^2 .+ avg_x(dTdy).^2)

    D = A .* avg(T) .* ∇T

    dTdx_edges = diff(T[:,2:end - 1], dims=1) / Δx
    dTdy_edges = diff(T[2:end - 1,:], dims=2) / Δy
   
    Fx = -avg_y(D) .* dTdx_edges
    Fy = -avg_x(D) .* dTdy_edges   
    
    F = .-(diff(Fx, dims=1) / Δx .+ diff(Fy, dims=2) / Δy) 

    dτ = dτsc * min.( 10.0 , 1.0./(1.0/Δt .+ 1.0./(cfl./(ϵ .+ avg(D)))))
    
    return F, dτ
 
end

# Fake law to create reference dataset and to be learnt by the NN
fakeA(t) = A₀ * exp(2t)

### Heat equation based on a fake A parameter function to compute the diffusivity
function heatflow_nonlinear(T, fA, p, fake, tol=Inf)
   
    Δx, Δy, Δt, t₁, ρ, g, n = p
    
    total_iter = 0
    t = 0
    
    while t < t₁
        
        iter = 1
        err = 2 * tolnl
        Hold = copy(T)
        dTdt = zeros(nx, ny)
        err = Inf 

        if fake
            A = fA(t)  # compute the fake A value involved in the nonlinear diffusivity
        else
            # Compute A with the NN once per time step
            A = fA([t]')[1]  # compute A parameter involved in the diffusivity
        end

        
        while iter < itMax+1 && tol <= err
            
            Err = copy(T)
            
            F, dτ = Heat_nonlinear(T, A, p)

            @tullio ResT[i,j] := -(T[i,j] - Hold[i,j])/Δt + F[pad(i-1,1,1),pad(j-1,1,1)] 
            
            dTdt_ = copy(dTdt)
            @tullio dTdt[i,j] := dTdt_[i,j]*damp + ResT[i,j]

            T_ = copy(T)
            #@tullio T[i,j] := max(0.0, T_[i,j] + dTdt[i,j]*dτ[pad(i-1,1,1),pad(j-1,1,1)]) 
            @tullio T[i,j] := max(0.0, T_[i,j] + dTdt[i,j]*dτ[pad(i-1,1,1),pad(j-1,1,1)])
            
            Zygote.ignore() do
                Err .= Err .- T
                err = maximum(Err)
            end 
            
            iter += 1
            total_iter += 1
            
        end
        
        t += Δt
        
    end

    if(!fake)
        println("Values of UA in heatflow_nonlinear: ", fA([0., .5, 1.]')) # Simulations here are correct
    end
    
    return T
    
end

# Patch suggested by Michael Abbott needed in order to correctly retrieve gradients
Flux.Optimise.update!(opt, x::AbstractMatrix, Δ::AbstractVector) = Flux.Optimise.update!(opt, x, reshape(Δ, size(x)))

function train(loss, p)
    
    leakyrelu(x, a=0.01) = max(a*x, x)
    relu(x) = max(0, x)

    UA = Chain(
        Dense(1,10,initb = Flux.glorot_normal), 
        BatchNorm(10, leakyrelu),
        Dense(10,5,initb = Flux.glorot_normal), 
        BatchNorm(5, leakyrelu),
        Dense(5,1, relu, initb = Flux.glorot_normal) 
    )

    opt = RMSProp()
    losses = []
    @epochs 10 hybrid_train_NN!(loss, UA, p, opt, losses)
    
    println("Values of UA in train(): ", UA([0., .5, 1.]'))
    
    return UA, losses
    
end

function hybrid_train_NN!(loss, UA, p, opt, losses)
    
    T = T₀
    θ = Flux.params(UA)
    println("Values of UA in hybrid_train BEFORE: ", UA([0., .5, 1.]'))
    loss_UA, back_UA = Zygote.pullback(() -> loss(T, UA, p), θ)
    push!(losses, loss_UA)
   
    ∇_UA = back_UA(one(loss_UA))

    for ps in θ
       println("Gradients ∇_UA[ps]: ", ∇_UA[ps])
    end
    
    println("θ: ", θ) # parameters are NOT NaNs
    println("Values of UA in hybrid_train AFTER: ", UA([0., .5, 1.]')) # Simulations here are all NaNs
    
    Flux.Optimise.update!(opt, θ, ∇_UA)
    
end


function loss_NN(T, UA, p, λ=1)

    T = heatflow_nonlinear(T, UA, p, false)
    l_cost = sqrt(Flux.Losses.mse(T, T_ref; agg=mean))

    return l_cost 
end

#######################

########################################
#####  TRAIN 2D HEAT EQUATION PDE  #####
########################################

T₂ = copy(T₀)
# Reference temperature dataset
T_ref = heatflow_nonlinear(T₂, fakeA, p, true, 1e-1)

# Train heat equation UDE
UA_trained, losses = train(loss_NN, p)

Is this a bug? If so, I’ll also open an issue for this. The use of implicit parameters for this is pretty confusing.

Thanks again in advance!

Wrong day to ask: Zygote had an update that broke downstream packages. DiffEqFlux is… in flux. https://github.com/SciML/DiffEqFlux.jl/pull/591 . We’re working with their devs to get that fixed hopefully in a day. But I wouldn’t assume you’ve done anything wrong until this gets cleared up.

1 Like

Hmm, but I’m not using DiffEqFlux here. I’m just doing it manually using Zygote and differentiating with respect to the parameters of a Flux Chain. Is it still the case?

Yes