Flux.train! hangs up after updating packages

Hello people,
I am currently trying to train a neural ODE network using a chemistry network from Catalyst.jl.
My code was working fine until I updated my packages. Unfortunatley I just did Pkg.update() and didnt really pay attention to what was updated, so I cant really go back.

What is happening now, is that Flux.train! seems to get stuck right away (no error, it just keeps running forever). Before the update it took me 90sec for 100 training iterations (on a very weak machine).

I am using Atom/Juno to write my code and this is showing up:
inf
Maybe someone with Atom/Juno experience knows what this means in my case?

I am using Julia 1.61, DiffEqFlux 1.41, Flux 0.12.4, DiffEqSensivity 6.54.0 and DifferentialEquations 6.17.1.
Its hard to figure out what wrong without error messages … my code was working as it is before the updates.
Any ideas whats wrong/happening?

EDIT: maybe I should mention that I batched up my code from the documentation and this paper.

Not sure this is the case, but did you restart Julia after updating the packages?

Try obtaining gradients on just one batch of data and not using train!. That should give you a fast way to troubleshoot any issues.

2 Likes

Yes, show the gradient calculation example as an MWE.

1 Like

Okay so I updated my code like this to look at the gradients and using update!, but now the training somehow works fine. I am happy with that but this doesnt really allow me to understand why train! is not working for me. update! is also showing much better results than train! (when it was working).

I guess this must be because of my model because the example code from the documentation using train! works just fine. But its still strange that it happened after updating my packages.
So I dont know how to make an MWE out of my code/model or whether it makes sense.

You’re restructuring outside of the function:

Also, ForwardDiff.gradient is different from Zygote.gradient, which Flux uses under the hood. You may be able to repro the issue by using the latter instead.

1 Like

Okay, sorry for the late response. So I replaced grad = ForwardDiff.gradient() with @time grad = Zygote.gradient()[1] and got this output:

1.0966585823313262e7
177.862870 seconds (287.27 M allocations: 20.921 GiB, 5.49% gc time, 32.76% compilation time)
β”Œ Warning: First function call produced NaNs. Exiting.
β”” @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/2AoGt/src/initdt.jl:81
β”Œ Warning: Automatic dt set the starting dt as NaN, causing instability.
β”” @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/2AoGt/src/solve.jl:510
β”Œ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
β”” @ SciMLBase ~/.julia/packages/SciMLBase/kCcpg/src/integrator_interface.jl:325
445024.28042782197
β”Œ Warning: First function call produced NaNs. Exiting.
β”” @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/2AoGt/src/initdt.jl:81
β”Œ Warning: Automatic dt set the starting dt as NaN, causing instability.
β”” @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/2AoGt/src/solve.jl:510
β”Œ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
β”” @ SciMLBase ~/.julia/packages/SciMLBase/kCcpg/src/integrator_interface.jl:325

Then the code hangs up again. So like you assumed the problem is calculating the gradient (with Zygote). Zygote.gradient() needs 177s to calculate the gradient, while ForwardDiff.gradient() only needs 0.02s.

Is that a problem and which function are talking about? Im very new to this, so Im not quite sure what you are pointing to. But I am very happy to learn/get more insight how to do these things the best way.
So all the example codes I looked at are doing it in a similar way.
With p, re = Flux.destructure(NN) I obtain the initial network parameters and the re function to reconstruct the network, that I name re_NN in the line your referencing.

I will look deeper into it.

If you don’t restructure in the ODE, the neural network is not a function of the parameters p.

re_NN = re(p)

function dudt!(du, u, p, t) #scale the data
    du .= re(p)(u) .* yscale / t_end
end
2 Likes