Hello, I need to perform optimization of many groups of ODE problems. The optimization time required for the first few groups is about 20 minutes, but after completing multiple optimizations, the time required gradually increases to more than 1 hour. I have tried to use GC.gc(true)
to clear some caches, but it does not seem to achieve the expectation of maintaining computational efficiency.
My code is as follows:
using Lux
using DifferentialEquations
using SciMLSensitivity
using StableRNGs
using Zygote
using Statistics
using ComponentArrays
using BenchmarkTools
using Interpolations
using ProgressMeter
using Optimization, OptimizationOptimisers
using JLD2
using CSV, DataFrames, Dates
include("models/m50.jl")
function build_and_train_models(itpfuncs, norm_funcs,
norm_input, etnn_target, qnn_target,
train_x, train_y, train_timepoints,
exphydro_params, exphydro_init_states,
)
params = get_default_params_s(hidd_dims = 16)
et_nn, q_nn = build_M50_NNs(16)
m50_func_s, _ = build_M50(
itpfuncs, norm_funcs,
exphydro_params = Vector(exphydro_params),
exphydro_initstates = exphydro_init_states,
)
et_apply = (x, p) -> LuxCore.stateless_apply(et_nn, x, p)
q_apply = (x, p) -> LuxCore.stateless_apply(q_nn, x, p)
opt_en_ps, etnn_loss_recorder = train_NN(et_apply,
(norm_input[[1, 2, 3], :], etnn_target), params[:et])
opt_q_ps, qnn_loss_recorder = train_NN(q_apply,
(norm_input[[2, 4], :], qnn_target), params[:q])
m50_init_pas = ComponentArray(et = opt_en_ps, q = opt_q_ps)
train_arr = permutedims(train_x)[[2, 3, 1], :]
#* define loss functions
loss_func(obs, pred) = sum((pred .- obs) .^ 2) / sum((obs .- mean(obs)) .^ 2)
opt_m50_ps, m50_loss_recorder = train_NN(
(x, p) -> m50_func_s(x, p, train_timepoints),
(train_arr, train_y), m50_init_pas,
loss_func = loss_func, max_N_iter = 75, optmzr = ADAM(0.01),
)
return opt_en_ps, opt_q_ps, etnn_loss_recorder, qnn_loss_recorder, m50_func_s, opt_m50_ps, m50_loss_recorder
end
function main(basin_id, exphydro_params, exphydro_init_states)
# Check if m50 output directory exists for this basin
m50_dir = "data/m50/$(basin_id)"
if isdir(m50_dir)
@info "M50 output directory already exists for basin $(basin_id), skipping..."
return
end
camelsus_cache = load("data/camelsus/$(basin_id).jld2")
data_x, data_y, data_timepoints = camelsus_cache["data_x"], camelsus_cache["data_y"], camelsus_cache["data_timepoints"]
train_x, train_y, train_timepoints = camelsus_cache["train_x"], camelsus_cache["train_y"], camelsus_cache["train_timepoints"]
test_x, test_y, test_timepoints = camelsus_cache["test_x"], camelsus_cache["test_y"], camelsus_cache["test_timepoints"]
exphydro_df = CSV.read("data/exphydro/$(basin_id).csv", DataFrame)
M50_INPUT, EXPHYDRO_PARAMS, M50_PARAMS = M50_ATTR()
snowpack_vec = exphydro_df[!, "snowpack"]
soilwater_vec = exphydro_df[!, "soilwater"]
et_vec = exphydro_df[!, "et"]
qsim_vec = exphydro_df[!, "qsim"]
lday_vec = collect(data_x[:, 1])
prcp_vec = collect(data_x[:, 2])
t_vec = collect(data_x[:, 3])
#* prepare normalization
s0_mean, s0_std = mean(snowpack_vec), std(snowpack_vec)
s1_mean, s1_std = mean(soilwater_vec), std(soilwater_vec)
lday_mean, lday_std = mean(lday_vec), std(lday_vec)
p_mean, p_std = mean(prcp_vec), std(prcp_vec)
t_mean, t_std = mean(t_vec), std(t_vec)
#* define normalization functions
norm_S0(x) = (x .- s0_mean) ./ s0_std
norm_S1(x) = (x .- s1_mean) ./ s1_std
norm_LDAY(x) = (x .- lday_mean) ./ lday_std
norm_T(x) = (x .- t_mean) ./ t_std
norm_P(x) = (x .- p_mean) ./ p_std
#* define interpolation functions
itp_method = SteffenMonotonicInterpolation()
itp_Lday = interpolate(data_timepoints, lday_vec, itp_method)
itp_P = interpolate(data_timepoints, prcp_vec, itp_method)
itp_T = interpolate(data_timepoints, t_vec, itp_method)
itpfuncs = (itp_P, itp_T, itp_Lday)
#* prepare training samples
norm_input = permutedims(reduce(hcat, (
norm_S0.(snowpack_vec[1:length(train_timepoints)]),
norm_S1.(soilwater_vec[1:length(train_timepoints)]),
norm_T.(t_vec[1:length(train_timepoints)]),
norm_P.(prcp_vec[1:length(train_timepoints)]),
)))
etnn_target = reshape(log.(et_vec[1:length(train_timepoints)] ./ lday_vec[1:length(train_timepoints)]), 1, :)
qnn_target = reshape(log.(qsim_vec[1:length(train_timepoints)]), 1, :)
opt_en_ps, opt_q_ps, etnn_loss_recorder, qnn_loss_recorder, m50_func_s, opt_m50_ps, m50_loss_recorder = build_and_train_models(
itpfuncs,
(norm_S0, norm_S1, norm_T, norm_P),
norm_input, etnn_target, qnn_target,
train_x, train_y, train_timepoints,
exphydro_params,
exphydro_init_states,
)
mkpath("data/m50/$(basin_id)")
save("data/m50/$(basin_id)/pretrain_ckpts.jld2", "opt_en_ps", opt_en_ps, "opt_q_ps", opt_q_ps)
CSV.write("data/m50/$(basin_id)/etnn_loss_df.csv", etnn_loss_recorder)
CSV.write("data/m50/$(basin_id)/qnn_loss_df.csv", qnn_loss_recorder)
save("data/m50/$(basin_id)/train_records.jld2", "opt_m50_ps", opt_m50_ps)
CSV.write("data/m50/$(basin_id)/m50_loss_df.csv", m50_loss_recorder)
#* make prediction
total_arr = permutedims(data_x)[[2, 3, 1], :]
total_y_hat = m50_func_s(total_arr, opt_m50_ps, data_timepoints)
predicted_df = DataFrame((pred = total_y_hat, obs = data_y))
CSV.write("data/m50/$(basin_id)/predicted_df.csv", predicted_df)
GC.gc(true)
end
I will briefly explain the following code. This code builds a Neural-ODE problem and uses DifferentialEquations.jl to solve it. build_and_train_models
is the process of model building and training, and main
is a process of data processing and result storage. When the code is executed, I will call main
repeatedly with different basin_id and its parameters.
The way I am trying now is to use the shell to batch run new julia programs to directly avoid memory accumulation, but I still hope to use for-loop to avoid running environment activation.