Flux/DiffEqFlux: error when a loss function does not use all the weights of the NN

Hi everyone, I have a question that is related to training a Neural Network using the sciml_train function of DiffEqFlux.jl.

I want to train a Neural Stochastic Differential Equation where the loss explicitly depends on the Neural Network inside the SDE. In a particular part of the loss, a particular weight of the Neural Network is not used at all. When running my code, I believe this gives an error. Below I present a MWE that gives a similar error.


# Importing the used packages
using Flux, DiffEqFlux, LinearAlgebra

# Defining some constants for creating the Neural Network
input_size = 2
output_size = 2

# Creating a simple neural network consisting of two linear layers
nn_initial = Chain(Dense(input_size, output_size))

# Using destructure
Θ, nn = Flux.destructure(nn_initial)

# Defining the loss function
function loss_intermediate(NN, data_x, data_y)
    #diff = (NN(data_x) - data_y)
    loss_1 = 0 #0.5*dot(diff, diff)
    bias = NN.layers[1].bias
    loss_2 = 0.5*dot(bias, bias)
    return loss_1 + 0.1*loss_2 #+ 0.5*dot(NN.layers[1].W, NN.layers[1].W)

function loss(nn, Θ, data_x, data_y)
    NN = nn(Θ)
    return loss_intermediate(NN, data_x, data_y)

# Create a function that generates data from a linear function and adds noise
function data_generator(n)
    data_x = randn(input_size,n)
    data_y = data_x + 0.01*randn(input_size,n)
    return data_x, data_y

# Getting the data
data_x, data_y = data_generator(50)

# Testing the loss function
testLossVal = loss(nn, Θ, data_x, data_y)
println("The value of the loss function at the start of training is: $testLossVal \n")

# Training everything using sciml_train
result = DiffEqFlux.sciml_train((p) -> loss(nn, p, data_x, data_y), Θ,
              ADAM(0.01), maxiters = 10)

println("The final weights are: \n")


LoadError: DimensionMismatch("array could not be broadcast to match destination")
 [1] check_broadcast_shape at .\broadcast.jl:520 [inlined]
 [2] check_broadcast_axes at .\broadcast.jl:523 [inlined]
 [3] check_broadcast_axes at .\broadcast.jl:527 [inlined]
 [4] instantiate at .\broadcast.jl:269 [inlined]
 [5] materialize! at .\broadcast.jl:848 [inlined]
 [6] materialize! at .\broadcast.jl:845 [inlined]
 [7] apply!(::ADAM, ::Array{Float32,1}, ::Array{Float64,1}) at C:\Users\Sven\.julia\packages\Flux\qp1gc\src\optimise\optimisers.jl:179
 [8] update!(::ADAM, ::Array{Float32,1}, ::Array{Float64,1}) at C:\Users\Sven\.julia\packages\Flux\qp1gc\src\optimise\train.jl:23
 [9] update!(::ADAM, ::Zygote.Params, ::Zygote.Grads) at C:\Users\Sven\.julia\packages\Flux\qp1gc\src\optimise\train.jl:29
 [10] __solve(::SciMLBase.OptimizationProblem{false,SciMLBase.OptimizationFunction{false,GalacticOptim.AutoZygote,SciMLBase.OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{var"#29#30"},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},GalacticOptim.var"#148#158"{GalacticOptim.var"#147#157"{SciMLBase.OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{var"#29#30"},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#151#161"{GalacticOptim.var"#147#157"{SciMLBase.OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{var"#29#30"},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#156#166",Nothing,Nothing,Nothing},Array{Float32,1},SciMLBase.NullParameters,Nothing,Nothing,Nothing,Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:maxiters,),Tuple{Int64}}}}, ::ADAM, ::Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at C:\Users\Sven\.julia\packages\GalacticOptim\Ts0Bu\src\solve.jl:103
 [11] #solve#468 at C:\Users\Sven\.julia\packages\SciMLBase\Z1NtH\src\solve.jl:3 [inlined]
 [12] sciml_train(::var"#29#30", ::Array{Float32,1}, ::ADAM, ::GalacticOptim.AutoZygote; lower_bounds::Nothing, upper_bounds::Nothing, kwargs::Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:maxiters,),Tuple{Int64}}}) at C:\Users\Sven\.julia\packages\DiffEqFlux\GQl0U\src\train.jl:6
 [13] top-level scope at C:\Users\Sven\Documents\Master\Master_thesis\Julia\Code\TestRepository\TestFileDiscourse_question3.jl:43
 [14] include_string(::Function, ::Module, ::String, ::String) at .\loading.jl:1088
in expression starting at C:\Users\Sven\Documents\Master\Master_thesis\Julia\Code\TestRepository\TestFileDiscourse_question3.jl:43

Possible cause of the error:
The code presented earlier does not work. However, with several adjustments the code works. For instance, if one adds + 0.5*dot(NN.layers[1].W, NN.layers[1].W) to the loss function (and do not change anything else), the code runs. Moreover, when changing loss_1 from 0 to 0.5*dot(diff, diff) where diff = (NN(data_x) - data_y), the code works as well. This makes me believe that the error is related to the fact that not ALL the parameters of the NN are present in the loss function.

How to resolve the above presented issue?


Zygote uses reverse-mode AD to take gradients, meaning it essentially starts from the return statement in loss and works its way backwards through the expression tree while calculating the appropriate pullbacks (see the Zygote documentation for details). At least that is the understanding I have settled on :slight_smile:

Bottom line, it only ever encounters the last two entries of \Theta which your function accesses by referencing NN.layers[1].bias (by the way, in the version Flux v0.11.3 I am trying this on, the field is called b not bias).

You could either use forward-mode AD by wrapping the loss function

function loss(nn, Θ, data_x, data_y)
           Zygote.forwarddiff(Θ) do Θ; NN = nn(Θ);
             loss_intermediate(NN, data_x, data_y)

or making sure that loss references all parameters of the network, e.g. by explicitly destructuring again

 function loss_intermediate(NN, data_x, data_y)
           bias = Flux.destructure(NN)[1][end-1:end]
           loss_2 = 0.5*dot(bias, bias)
           return 0.1*loss_2 #+ 0.5*dot(NN.layers[1].W, NN.layers[1].W)

The former is actually the more performant approach when the number of parameters is small.

Edit: The latter approach doesn’t work anymore with Flux>=0.12. Instead, it yields an empty gradient

gradient(p->loss_intermediate(nn(p),data_x,data_y), Θ)

Thank you for your reply :smiley:. I think your explanation makes sense. My functions only sees a part of the parameters and can not cope with this in backpropogation. In forward-mode you do see all the parameters at the beginning and this will not give any trouble.

The forwarddiff approach works with both your and my version of loss_intermediate. When not using forwarddiff and using your loss_intermediate version, I get the same error as before. I also tried the following:

function loss(nn, Θ, data_x, data_y)
    NN = nn(Θ)
    bias = NN.layers[1].bias
    loss_1 = 0.0
    loss_2 = 0.5*dot(bias.-1, bias.-1)
    return loss_1 + 0.1*loss_2

My argumentation was that in this case the function knows the specific shape of Θ and it might perhaps be able to Backpropagate appropriately. Unfortunately, I get the same error again. Following your arguments from earlier, it makes sense that again the backpropagation only ‘sees’ the bias parameters. As your loss_intermediate function does not work on my setup (my Flux version is 0.12.1), do you have any other ideas to solve the problem using Reverse-mode differentation?

Did @frankschae take a look at this? I forget if I messaged him on it :sweat_smile:

No, I haven’t seen this one, yet. I’ll have a look.

I think something like this:

# Defining the loss function
function loss_intermediate(bias, data_x, data_y)
    loss_1 = 0.0
    loss_2 = 0.5*dot(bias, bias)
    return loss_1 + 0.1*loss_2

function loss(nn, Θ, data_x, data_y)
    NN = nn(Θ)
    bias = @view Θ[end-1:end]
    return loss_intermediate(bias, data_x, data_y)

would work. With a non-zero bias initialization:

nn_initial = Chain(Dense(input_size, output_size, bias=randn(2)))

The gradient for the MWE loss looks good:

Zygote.gradient((p) -> loss(nn, p, data_x, data_y), Θ)
i.e., [0, 0, 0, 0, bias1, bias2]
1 Like

Thank you for your reply. Using your loss function my code ran. Only the bias = randn(2) part did not work (my Flux version is v.0.12.1). I think I understand why your code works. Now it sees bias as a part of the parameters we want to optimize and the function loss_intermediate only cares about bias which makes the differentiation deal with it properly.

However, why did the following not work?

I have the feeling it should work as the function knows (in a similar fashion to your code) that bias is part of the parameters.

Are you still using destructure to flatten out the parameters before differentiating? The code above would work with implicit params (i.e. params(...)), but I’m not sure referencing parameters works with the parameter vector unless you reconstruct the model (using the second returned value from destructure).

I still use destructure: Θ, nn = Flux.destructure(nn_initial).

How I optimize:
result = DiffEqFlux.sciml_train((p) -> loss(nn, p, data_x, data_y), Θ, ADAM(0.01), maxiters = 100).

Hence, I indeed use the vector form of the weights (and not the params(...) form). What do you mean with: