Here is a minimum working example of solving a 3x3 system using the StaticArrays.jl
module. The solver runs fine with Tsit5
but crashes with Rodas
. Not clear how that is possible since Rodas4
is a stiff solver, and Tsit5
is not. The crash has to do with the variable t
getting corrupted. Here is the code:
3106 3/1/2023 13:37 ls -l giesekus_optimized_MWE.jl.36534.mem
# ME to optimize Giesekus
using DifferentialEquations
using StaticArrays
using InteractiveUtils
function dudt_giesekus(u, p, t, gradv)
# Destructure the parameters
η0, τ, α = p
# Governing equations are for components of the 3x3 stress tensor
σ = SA_F32[u[1] u[4] 0.; u[4] 0. 0.; 0. 0. u[3]]
# Rate-of-strain (symmetric) and vorticity (antisymmetric) tensors
∇v = SA_F32[0. 0. 0. ; gradv(t) 0. 0. ; 0. 0. 0.]
D = 0.5 .* (∇v + transpose(∇v))
T1 = (η0/τ) .* D
T2 = (transpose(∇v) * σ) + (σ * ∇v)
coef = α / (τ * η0)
F = coef * (σ * σ)
du = -σ / τ + T1 + T2 - F # 9 equations (static matrix)
end
function run()
du = [0. 0. 0.; 0. 0. 0.; 0. 0. 0.]
u = [0. 0. 0.; 0. 0. 0.; 0. 0. 0.]
u0 = SA[0. 0. 0.; 0. 0. 0.; 0. 0. 0.]
t = 0.
p = SA[1.,1.,1.]
dct = Dict{Any, Any}()
dct[:γ_protoc] = convert(Vector{Float32}, [1, 2, 1, 2, 1, 2, 1, 2])
dct[:ω_protoc] = convert(Vector{Float32}, [1, 1, 0.5, 0.5, 2., 2., 1/3., 1/3.])
γ_protoc = dct[:γ_protoc] # The type should now be correctly inferred on the LHS
ω_protoc = dct[:ω_protoc]
dudt_giesekus(u, p, t, cos)
tspan = (0., 5.)
σ0 = SA[0. 0. 0.; 0. 0. 0.; 0. 0. 0.]
# Memory allocation is 12k per call to solve(). WHY?
v21_protoc = (t) -> γ_protoc[1] * cos(ω_protoc[1]*t)
dudt(u,p,t) = dudt_giesekus(u, p, t, v21_protoc)
prob_giesekus = ODEProblem(dudt, σ0, tspan, p)
sol_giesekus = solve(prob_giesekus, Tsit5())
#sol_giesekus = solve(prob_giesekus, Rodas4())
end
run()
The error message when I run with Rodas4
is as follows: (any help is appreciated).
RROR: LoadError: TypeError: in new, expected NTuple{9, Float32}, got a value of type Tuple{Float32, Float32, Float32, ForwardDiff.Dual{Nothing, Float32, 1}, Float32, Float32, Float32, Float32, Float32}
Stacktrace:
[1] SArray at /Users/erlebach/.julia/packages/StaticArraysCore/U2Z1K/src/StaticArraysCore.jl:113
[2] SArray at /Users/erlebach/.julia/packages/StaticArraysCore/U2Z1K/src/StaticArraysCore.jl:117
[3] Type at /Users/erlebach/.julia/packages/StaticArrays/pTgFe/src/convert.jl:163
[4] _SA_typed_hvcat at /Users/erlebach/.julia/packages/StaticArrays/pTgFe/src/initializers.jl:56
[5] typed_hvcat at /Users/erlebach/.julia/packages/StaticArrays/pTgFe/src/initializers.jl:60
[6] dudt_giesekus at /Users/erlebach/src/2022/rude/giesekus/GE_rude.jl/optimized_code/MWE.jl:14
[7] dudt at /Users/erlebach/src/2022/rude/giesekus/GE_rude.jl/optimized_code/MWE.jl:47
[8] ODEFunction at /Users/erlebach/.julia/packages/SciMLBase/gTrkJ/src/scimlfunctions.jl:2404
[9] TimeDerivativeWrapper at /Users/erlebach/.julia/packages/SciMLBase/gTrkJ/src/function_wrappers.jl:23
[10] derivative at /Users/erlebach/.julia/packages/ForwardDiff/vXysl/src/derivative.jl:14
[11] perform_step! at /Users/erlebach/.julia/packages/OrdinaryDiffEq/W3SVv/src/perform_step/rosenbrock_perform_step.jl:929
[12] perform_step! at /Users/erlebach/.julia/packages/OrdinaryDiffEq/W3SVv/src/perform_step/rosenbrock_perform_step.jl:898
[13] solve! at /Users/erlebach/.julia/packages/OrdinaryDiffEq/W3SVv/src/solve.jl:520
[14] #__solve#626 at /Users/erlebach/.julia/packages/OrdinaryDiffEq/W3SVv/src/solve.jl:6
[15] __solve at /Users/erlebach/.julia/packages/OrdinaryDiffEq/W3SVv/src/solve.jl:1
[16] #solve_call#22 at /Users/erlebach/.julia/packages/DiffEqBase/JH4gt/src/solve.jl:509
[17] solve_call at /Users/erlebach/.julia/packages/DiffEqBase/JH4gt/src/solve.jl:479
[18] #solve_up#29 at /Users/erlebach/.julia/packages/DiffEqBase/JH4gt/src/solve.jl:932
[19] solve_up at /Users/erlebach/.julia/packages/DiffEqBase/JH4gt/src/solve.jl:905
[20] #solve#27 at /Users/erlebach/.julia/packages/DiffEqBase/JH4gt/src/solve.jl:842
[21] solve at /Users/erlebach/.julia/packages/DiffEqBase/JH4gt/src/solve.jl:832
[22] run at /Users/erlebach/src/2022/rude/giesekus/GE_rude.jl/optimized_code/MWE.jl:51
[23] top-level scope at /Users/erlebach/src/2022/rude/giesekus/GE_rude.jl/optimized_code/MWE.jl:54
[24] eval at ./boot.jl:368
Time is supposed to be a float. Here is what happens when I print time inside the function dudt_giesekus
:
t = 0.0
t = 0.0
t = 1.0e-6
t = Dual{ForwardDiff.Tag{SciMLBase.TimeDerivativeWrapper{ODEFunction{false,SciMLBase.AutoSpecialize,…}, StaticArraysCore.SMatrix{3, 3, Float64, 9}, StaticArraysCore.SVector{3, Float64}}, Float64}}(0.0,1.0)
ERROR: LoadError: TypeError: in new, expected NTuple{9, Float32}, got a value of type Tuple{Float32, Float32, Float32, ForwardDiff.Dual{Nothing, Float32, 1}, Float32, Float32, Float32, Float32, Float32}
How is it possible that the Float64
is transformed to a Dual
number?
Project.toml file, Julia 1.8.5:
ArgMacros = "dbc42088-9de8-42a0-8ec8-2cd114e1ea3e"
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DaemonMode = "d749ddd5-2b29-4920-8305-6ff5a704e36e"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Thanks!