I am following the example of solving an ODE using PINN. However, on my machine, it’s been 2+ hours, and I am only on iteration 5 of 200… Does anyone know what could be going wrong here? Is it the matter of Flux vs Lux?
# Implementing PINNs for simple dynamical systems
using Flux, NeuralPDE, OptimizationOptimisers
using DifferentialEquations
# Lets solve the simple ODE
function define_ode_prob()
linear(u, p, t) = cos(2pi * t)
tspan = (0.0, 1.0) ## <- IN THE TUTORIAL, THIS IS FLOAT32 ... is that the problem?
u0 = 0.0
prob = ODEProblem(linear, u0, tspan)
sol = solve(prob, Tsit5())
return prob, sol
end
prob, sol = define_ode_prob()
# Solve the ODE using a traditional solver
function solve_using_NN(prob)
chain = Flux.Chain(Dense(1, 5, σ), Dense(5, 1))
opt = OptimizationOptimisers.Adam(0.1)
#alg = NeuralPDE.NNODE(chain, opt)
sol = solve(prob, NeuralPDE.NNODE(chain, opt), verbose = true, abstol = 1.0f-6,
maxiters = 200)
end
solve_using_NN(prob)
The output after 2+ hours has been so far
┌ Warning: Layer with Float32 parameters got Float64 input.
│ The input will be converted, but any earlier layers may be very slow.
│ layer = Dense(1 => 5, σ)
│ summary(x) = 1-element Vector{Float64}
└ @ Flux /home/affans/.julia/packages/Flux/jgpVj/src/layers/stateless.jl:60
Current loss is: 1.7668873007986452, Iteration: 1
Current loss is: 0.47748592920298255, Iteration: 2
Current loss is: 0.5351499668927292, Iteration: 3
Current loss is: 0.9421460439110038, Iteration: 4
Current loss is: 0.9918245187749526, Iteration: 5
Did you try using Lux like the tutorials show? Flux has a lot of weird issues, and its throwing you warnings in the output there about being slow. The easiest thing to do is normally to just switch to Lux. I think we will internally do the conversion automatically soon (we set that up with DiffEqFlux already to fix a few issues), but right now the auto-conversion isn’t done.
Thanks Chris. It even says in the tutorial to use Lux but I thought I could just stick with Flux given that it’s the go-to pacackage for neural networks. Is that not the case anymore? Is Lux the preferred package?
It’s just easier to do things right. I think technically you need to add a |> f64 or something to the chain definition or something. But there’s this and some other places where weird type promotions come up. We couldn’t even figure out how to get DiffEqFlux.jl working on latest versions of Flux.jl so we resorted to just converting any Flux model to a Lux.jl model for the user in order to preserve backwards compatibility. For NeuralPDE.jl we haven’t setup those conversions yet but we hope to get around to it in the near future as that would then allow someone to give a Flux NN but then we just secretly convert it to Lux under the hood: it would simplify the internals a bunch but also make this kind of issue go away.
Thanks!
I was able to switch out Flux to Lux and things seem to be working now.
# Solve the ODE using a traditional solver
function solve_using_NN(prob)
chain = Chain(Dense(1, 5, σ), Dense(5, 1))
opt = OptimizationOptimisers.Adam(0.1)
alg = NeuralPDE.NNODE(chain, opt)
sol = solve(prob, alg, verbose = true, abstol = 1.0f-6,
maxiters = 5)
end
solve_using_NN(prob)
where prob is an ODEProblem. I did only 5 iterations which takes about 3.5 minutes on my CPU. I am just diving into using deep learning models and so I am not sure if this is fast or not (for just 5 iterations).
EDIT: nevermind, I guess there was initial compilation latency because repeated calls to solve_using_NN is now near instantaneous.
@time solve_using_NN(prob, 200)
0.222050 seconds (575.38 k allocations: 128.602 MiB, 23.56% gc time, 7.97% compilation time)
I wonder why the function is taking so long to compile.