Unrecognized gradient using Zygote for AD with Universal Differential Equations

Well, together with @facusapienza we’ve been working on creating a simpler MWE based on the heat equation to allow the reproduction of this issue. We have managed to reproduce the size mismatch error for the Flux model parameters. As I said in my previous post, once I apply the patch by @mcabbott, this issue goes away, but there are still remaining problems with the propagation of the parameters and the model outside the Zygote.pullback() function.

In this MWE based on a 2D heat equation, the neural network is correctly working inside Zygote.pullback(), but somehow when the neural network (UA) is used outside the pullback it always returns NaNs. I have added some logs in the MWE to illustrate this. To make this even weirder, the model parameters outside the pullback seem to be fine, and they can be correctly updated, but the Flux model is somehow broken. It looks like the model outside the pullback is no longer linked to the implicit parameters.

Here is the MWE:

using LinearAlgebra
using Statistics
using Zygote
using PaddedViews
using Flux
using Flux: @epochs
using Tullio

#### Parameters
nx, ny = 100, 100 # Size of the grid
Δx, Δy = 1, 1
Δt = 0.01
t₁ = 1

D₀ = 1
tolnl = 1e-4
itMax = 100
damp = 0.85
dτsc   = 1.0/3.0
ϵ     = 1e-4            # small number
cfl  = max(Δx^2,Δy^2)/4.1

A₀ = 1
ρ = 9
g = 9.81
n = 3
p = (Δx, Δy, Δt, t₁, ρ, g, n)  # we add extra parameters for the nonlinear diffusivity

### Reference dataset for the heat Equations
T₀ = [ 250 * exp( - ( (i - nx/2)^2 + (j - ny/2)^2 ) / 300 ) for i in 1:nx, j in 1:ny ];
T₁ = copy(T₀);

#######   FUNCTIONS   ############

# Utility functions
@views avg(A) = 0.25 * ( A[1:end-1,1:end-1] .+ A[2:end,1:end-1] .+ A[1:end-1,2:end] .+ A[2:end,2:end] )

@views avg_x(A) = 0.5 .* ( A[1:end-1,:] .+ A[2:end,:] )

@views avg_y(A) = 0.5 .* ( A[:,1:end-1] .+ A[:,2:end] )

### Functions to generate reference dataset to train UDE

function Heat_nonlinear(T, A, p)
   
    Δx, Δy, Δt, t₁, ρ, g, n = p
    
    #### NEW CODE TO BREAK
    dTdx = diff(T, dims=1) / Δx
    dTdy = diff(T, dims=2) / Δy
    ∇T = sqrt.(avg_y(dTdx).^2 .+ avg_x(dTdy).^2)

    D = A .* avg(T) .* ∇T

    dTdx_edges = diff(T[:,2:end - 1], dims=1) / Δx
    dTdy_edges = diff(T[2:end - 1,:], dims=2) / Δy
   
    Fx = -avg_y(D) .* dTdx_edges
    Fy = -avg_x(D) .* dTdy_edges   
    
    F = .-(diff(Fx, dims=1) / Δx .+ diff(Fy, dims=2) / Δy) 

    dτ = dτsc * min.( 10.0 , 1.0./(1.0/Δt .+ 1.0./(cfl./(ϵ .+ avg(D)))))
    
    return F, dτ
 
end

# Fake law to create reference dataset and to be learnt by the NN
fakeA(t) = A₀ * exp(2t)

### Heat equation based on a fake A parameter function to compute the diffusivity
function heatflow_nonlinear(T, fA, p, fake, tol=Inf)
   
    Δx, Δy, Δt, t₁, ρ, g, n = p
    
    total_iter = 0
    t = 0
    
    while t < t₁
        
        iter = 1
        err = 2 * tolnl
        Hold = copy(T)
        dTdt = zeros(nx, ny)
        err = Inf 

        if fake
            A = fA(t)  # compute the fake A value involved in the nonlinear diffusivity
        else
            # Compute A with the NN once per time step
            A = fA([t]')[1]  # compute A parameter involved in the diffusivity
        end

        
        while iter < itMax+1 && tol <= err
            
            Err = copy(T)
            
            F, dτ = Heat_nonlinear(T, A, p)

            @tullio ResT[i,j] := -(T[i,j] - Hold[i,j])/Δt + F[pad(i-1,1,1),pad(j-1,1,1)] 
            
            dTdt_ = copy(dTdt)
            @tullio dTdt[i,j] := dTdt_[i,j]*damp + ResT[i,j]

            T_ = copy(T)
            #@tullio T[i,j] := max(0.0, T_[i,j] + dTdt[i,j]*dτ[pad(i-1,1,1),pad(j-1,1,1)]) 
            @tullio T[i,j] := max(0.0, T_[i,j] + dTdt[i,j]*dτ[pad(i-1,1,1),pad(j-1,1,1)])
            
            Zygote.ignore() do
                Err .= Err .- T
                err = maximum(Err)
            end 
            
            iter += 1
            total_iter += 1
            
        end
        
        t += Δt
        
    end

    if(!fake)
        println("Values of UA in heatflow_nonlinear: ", fA([0., .5, 1.]')) # Simulations here are correct
    end
    
    return T
    
end

# Patch suggested by Michael Abbott needed in order to correctly retrieve gradients
Flux.Optimise.update!(opt, x::AbstractMatrix, Δ::AbstractVector) = Flux.Optimise.update!(opt, x, reshape(Δ, size(x)))

function train(loss, p)
    
    leakyrelu(x, a=0.01) = max(a*x, x)
    relu(x) = max(0, x)

    UA = Chain(
        Dense(1,10,initb = Flux.glorot_normal), 
        BatchNorm(10, leakyrelu),
        Dense(10,5,initb = Flux.glorot_normal), 
        BatchNorm(5, leakyrelu),
        Dense(5,1, relu, initb = Flux.glorot_normal) 
    )

    opt = RMSProp()
    losses = []
    @epochs 10 hybrid_train_NN!(loss, UA, p, opt, losses)
    
    println("Values of UA in train(): ", UA([0., .5, 1.]'))
    
    return UA, losses
    
end

function hybrid_train_NN!(loss, UA, p, opt, losses)
    
    T = T₀
    θ = Flux.params(UA)
    println("Values of UA in hybrid_train BEFORE: ", UA([0., .5, 1.]'))
    loss_UA, back_UA = Zygote.pullback(() -> loss(T, UA, p), θ)
    push!(losses, loss_UA)
   
    ∇_UA = back_UA(one(loss_UA))

    for ps in θ
       println("Gradients ∇_UA[ps]: ", ∇_UA[ps])
    end
    
    println("θ: ", θ) # parameters are NOT NaNs
    println("Values of UA in hybrid_train AFTER: ", UA([0., .5, 1.]')) # Simulations here are all NaNs
    
    Flux.Optimise.update!(opt, θ, ∇_UA)
    
end


function loss_NN(T, UA, p, λ=1)

    T = heatflow_nonlinear(T, UA, p, false)
    l_cost = sqrt(Flux.Losses.mse(T, T_ref; agg=mean))

    return l_cost 
end

#######################

########################################
#####  TRAIN 2D HEAT EQUATION PDE  #####
########################################

T₂ = copy(T₀)
# Reference temperature dataset
T_ref = heatflow_nonlinear(T₂, fakeA, p, true, 1e-1)

# Train heat equation UDE
UA_trained, losses = train(loss_NN, p)

Is this a bug? If so, I’ll also open an issue for this. The use of implicit parameters for this is pretty confusing.

Thanks again in advance!