Why does my Flux model return in all NaN?

Hi friends here,

I would like to ask a question about the Flux training, and I am try to understand why my model will return NaN value after a failed training step.

for example I have such a model, input shape of it is (120,56,1)

julia> myModel = Chain(
            Conv((7, ), 56 => 32, stride = 3),  
            Flux.flatten,  
            Dense(38 * 32 => 10, identity),  
            BatchNorm(10, relu)
        )
julia> x = rand(Float32,120,56,1)

here is the output of model:

julia> myModel(x) 
10-element Vector{Float32}:
 -0.0013220371
  0.0051790997
  0.042023115
 -0.031484906
 -0.037755977
....

Then I have an objective function but due to dimension mismatch the function will throw an error, as my input is (120,56,1) but the output is (10,1), from my understanding the model will not get train with such a loss function.

julia> loss(x) = logitbinarycrossentropy(myModel(x), x, agg=sum) # it will raise an error

So when I try to train this model it immediately throws an error, which is expected

# this will raise an error
julia> Flux.train!(loss, Flux.params(myModel), [x], Adam(0.0001)) 

However, if I rerun the model it will output all NaN values, which is confusing because I get an error in the training step and the training cannot continue, so why is it affecting my original model?

julia> myModel(x)
10-element Vector{Float32}:
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN

Any comments are appreciated & Thank you for your attention,

Just checking layer by layer it seems to come from the BatchNorm, and looking at the fields of the BatchNorm it seems to be the variance that is suddenly NaN all over.
I’m not exactly sure how the BatchNorm is updated, but I think it might be during the forward pass. So since the error happens in the loss calculation, you might have gotten a forward pass done before hitting the error, which changed the BatchNorm params. I’m not sure why this would create NaNs though.

julia> tmp = myModel[1](x)
38×32×1 Array{Float32, 3}:
[:, :, 1] =
 -0.216846    0.836771   -0.164834    0.0958836   0.538185  …   0.515274   -0.146586    -0.956131    0.628205
 -0.655194   -0.2469      0.753487   -0.246044    0.384564      0.343249   -0.440123    -0.227641    0.34813
  0.261229    0.505907    0.139729    0.0175025   0.411264      0.431131   -0.0311406   -0.480904    0.369305
  0.0517392   0.0546288  -0.128075    0.13172     0.644657     -0.41286     0.165092    -0.574947    0.29692
 -1.13784     0.664485    0.210838   -0.325501    0.703631      0.351165   -0.30469     -0.708115    0.0410308
 -1.21508     0.322345    0.216984   -0.245615    0.282585  …   0.192009    0.121902    -0.316177    0.532268
 -0.310966   -0.213768    0.324011    0.691767    0.322875      0.210432    0.123857    -0.934422    0.146396
 -0.721712    0.431708    0.828699   -0.184289    0.955514      0.685203   -0.256821    -0.484859   -0.221852
 -1.01505     0.675787    0.633824   -0.296752    0.79166       0.147934   -0.351413    -0.048238    0.409103
 -0.0491967  -0.362716    0.569628   -0.132194    0.849446      0.13531    -0.939218    -0.602757    0.336303
 -0.682288    0.683156    0.144198    0.36092     0.348951  …  -0.170675   -0.332979     0.0922967  -0.0652037
 -0.578487    0.726059    0.68295     0.19393     0.419121      0.410701   -0.327695    -1.42942    -0.0726877
 -0.351213    0.24157     0.579865    0.370011    0.370588     -0.220557   -0.180216    -0.289107   -0.116528
 -0.32464     0.258404    0.455044    0.0673226   0.565907      0.515696   -0.245075    -0.283426    0.280826
 -0.560559    0.869204    0.531793    0.472433    0.427659     -0.343646    0.280007    -1.0279      1.06505
 -0.239349    0.45405     0.094598   -0.327334    0.74495   …  -0.0906876  -0.144499    -1.15908     0.348169
  0.118142    0.0938911  -0.0305605   0.180625    0.213752      0.548597   -0.12773     -0.482545    0.266652
 -0.551988   -0.167326    0.138988   -0.0738249   0.494136      0.320206    0.207258    -0.196603    0.126249
 -0.437397    1.01912     0.91061     0.353962    1.11248      -0.0917978   0.69427     -1.01596     0.0453018
 -0.487832    0.851307   -0.01276     0.372752    0.78544       0.103224    0.0258075   -1.24862     0.236722
 -0.371546    1.11405     0.138101   -0.0605819   0.683237  …  -0.0827831  -0.844868    -0.76218     0.162526
 -0.324308    0.412406   -0.54857    -0.370068    0.686435     -0.121307    0.127933    -0.139145    0.170808
  0.190516    0.294839    0.380076    0.840861    0.361272      0.445421    0.628682    -0.882952   -0.035326
 -0.670033    0.356724    0.116671    0.204454    0.918058      0.416338   -0.25526     -0.453965    0.0441106
 -0.631428    1.00532     0.464934    0.348853    1.01267       0.522718   -0.12757     -0.948507    0.359225
 -0.495404    0.374752   -0.194241    0.354911    0.875992  …   0.213015   -0.125746    -0.946267   -0.619776
 -0.131325    0.048355    0.572054    0.224516   -0.125437     -0.155194    0.358692    -0.614002    0.183253
 -0.619126    0.389369    0.367157   -0.266041    0.2691       -0.307645   -0.177976    -0.869357   -0.140114
 -1.03964     0.542192   -0.18858     0.294489    1.30662       0.456471   -0.815276    -0.853876    0.252516
 -1.34283     0.60476     0.267019    0.462968    0.850572      0.716962    0.34762     -0.380839   -0.0589569
 -0.199921    0.0195927   0.517672    0.13763     0.881167  …   0.414399   -0.0953214   -1.17149    -0.06685
 -0.632891    0.490813    0.339849    0.448696    0.810454      0.042796   -0.460004    -0.137017    0.429834
 -0.675497    0.740635    0.342982    0.0972508   0.628799     -0.102711    0.351601    -1.21269    -0.0786517
 -0.488251    0.412987    0.434148    0.0963691   0.76427       0.196123   -0.256912    -0.736296    0.173753
  0.0767523   0.31352    -0.311657    0.468733    1.02594       0.566601    0.00272077  -0.14009     0.0268976
 -0.901064    0.31836     0.260267   -0.394814    0.405657  …   0.0119985  -0.207358    -0.199947    0.690202
  0.064047    0.34082    -0.182415    0.277961    1.11691       0.668527   -0.599959    -0.591749    0.0351928
 -0.536      -0.280775    0.253179    0.548145    0.353008     -0.283707    0.127862    -0.352093    0.698897

julia> tmp = myModel[2](tmp)
1216×1 Matrix{Float32}:
 -0.21684615
 -0.65519375
  0.2612291
  0.051739205
 -1.1378447
 -1.2150772
 -0.3109662
 -0.7217121
 -1.0150516
 -0.049196705
 -0.68228847
 -0.578487
 -0.35121307
 -0.32463968
 -0.5605589
 -0.2393494
  0.11814194
 -0.5519884
 -0.4373966
 -0.48783237
 -0.37154633
 -0.32430825
  0.1905157
 -0.6700327
 -0.63142776
 -0.49540403
 -0.13132462
 -0.6191261
 -1.0396357
 -1.3428266
 -0.19992118
 -0.6328913
 -0.6754972
 -0.48825085
  0.07675231
 -0.90106356
  0.064047046
 -0.536
  0.83677065
 -0.24690045
  0.5059075
  0.054628797
  0.66448516
  0.32234538
 -0.2137678
  ⋮
 -1.2126876
 -0.73629576
 -0.14008987
 -0.19994697
 -0.59174865
 -0.35209286
  0.6282047
  0.3481297
  0.36930525
  0.2969197
  0.041030847
  0.53226835
  0.14639598
 -0.22185217
  0.4091034
  0.33630335
 -0.06520373
 -0.072687656
 -0.11652782
  0.28082615
  1.0650545
  0.34816924
  0.26665246
  0.12624937
  0.04530175
  0.23672172
  0.16252577
  0.17080754
 -0.035326034
  0.04411064
  0.3592252
 -0.61977637
  0.18325284
 -0.14011371
  0.25251567
 -0.05895686
 -0.066849984
  0.4298343
 -0.07865165
  0.17375296
  0.026897624
  0.6902021
  0.035192758
  0.69889677

julia> tmp = myModel[3](tmp)
10×1 Matrix{Float32}:
  1.0776123
 -0.012326896
 -1.5732218
 -0.605291
  0.9401836
 -1.581678
  0.89541256
 -0.8995581
 -0.50464994
 -0.6450879

julia> tmp = myModel[4](tmp)
10×1 Matrix{Float32}:
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN

julia> myModel[4].μ
10-element Vector{Float32}:
  0.10776123
 -0.0012326896
 -0.15732218
 -0.0605291
  0.09401836
 -0.15816781
  0.08954126
 -0.089955814
 -0.050464995
 -0.06450879

julia> myModel[4].σ²
10-element Vector{Float32}:
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
2 Likes

With only one training example, the variance in the BatchNorm becomes zero, causing the NaNs.

You need a larger batch size, i.e x = rand(Float32,120,56,2).

Also, don’t forget to switch between training and inference mode with Flux.trainmode! and Flux.testmode!, otherwise the BN layer will change parameters during inference.

Or maybe InstanceNorm instead of BatchNorm is what you want.

Edit: Actually, the variance is NaN itself already, as seen from Albin’s post above. Doesn’t make a difference to the outcome though.

4 Likes