Julia runs slower after many loop iterations (solve ODE problem)

Hello, I need to perform optimization of many groups of ODE problems. The optimization time required for the first few groups is about 20 minutes, but after completing multiple optimizations, the time required gradually increases to more than 1 hour. I have tried to use GC.gc(true) to clear some caches, but it does not seem to achieve the expectation of maintaining computational efficiency.
My code is as follows:

using Lux
using DifferentialEquations
using SciMLSensitivity
using StableRNGs
using Zygote
using Statistics
using ComponentArrays
using BenchmarkTools
using Interpolations
using ProgressMeter
using Optimization, OptimizationOptimisers
using JLD2
using CSV, DataFrames, Dates

include("models/m50.jl")

function build_and_train_models(itpfuncs, norm_funcs,
	norm_input, etnn_target, qnn_target,
	train_x, train_y, train_timepoints,
	exphydro_params, exphydro_init_states,
)
	params = get_default_params_s(hidd_dims = 16)
	et_nn, q_nn = build_M50_NNs(16)
	m50_func_s, _ = build_M50(
		itpfuncs, norm_funcs,
		exphydro_params = Vector(exphydro_params),
		exphydro_initstates = exphydro_init_states,
	)

	et_apply = (x, p) -> LuxCore.stateless_apply(et_nn, x, p)
	q_apply = (x, p) -> LuxCore.stateless_apply(q_nn, x, p)

	opt_en_ps, etnn_loss_recorder = train_NN(et_apply,
		(norm_input[[1, 2, 3], :], etnn_target), params[:et])
	opt_q_ps, qnn_loss_recorder = train_NN(q_apply,
		(norm_input[[2, 4], :], qnn_target), params[:q])


	m50_init_pas = ComponentArray(et = opt_en_ps, q = opt_q_ps)
	train_arr = permutedims(train_x)[[2, 3, 1], :]
	
	#* define loss functions
	loss_func(obs, pred) = sum((pred .- obs) .^ 2) / sum((obs .- mean(obs)) .^ 2)

	opt_m50_ps, m50_loss_recorder = train_NN(
		(x, p) -> m50_func_s(x, p, train_timepoints),
		(train_arr, train_y), m50_init_pas,
		loss_func = loss_func, max_N_iter = 75, optmzr = ADAM(0.01),
	)

	return opt_en_ps, opt_q_ps, etnn_loss_recorder, qnn_loss_recorder, m50_func_s, opt_m50_ps, m50_loss_recorder
end

function main(basin_id, exphydro_params, exphydro_init_states)

	# Check if m50 output directory exists for this basin
	m50_dir = "data/m50/$(basin_id)"
	if isdir(m50_dir)
		@info "M50 output directory already exists for basin $(basin_id), skipping..."
		return
	end

	camelsus_cache = load("data/camelsus/$(basin_id).jld2")
	data_x, data_y, data_timepoints = camelsus_cache["data_x"], camelsus_cache["data_y"], camelsus_cache["data_timepoints"]
	train_x, train_y, train_timepoints = camelsus_cache["train_x"], camelsus_cache["train_y"], camelsus_cache["train_timepoints"]
	test_x, test_y, test_timepoints = camelsus_cache["test_x"], camelsus_cache["test_y"], camelsus_cache["test_timepoints"]

	exphydro_df = CSV.read("data/exphydro/$(basin_id).csv", DataFrame)

	M50_INPUT, EXPHYDRO_PARAMS, M50_PARAMS = M50_ATTR()

	snowpack_vec = exphydro_df[!, "snowpack"]
	soilwater_vec = exphydro_df[!, "soilwater"]
	et_vec = exphydro_df[!, "et"]
	qsim_vec = exphydro_df[!, "qsim"]

	lday_vec = collect(data_x[:, 1])
	prcp_vec = collect(data_x[:, 2])
	t_vec = collect(data_x[:, 3])

	#* prepare normalization
	s0_mean, s0_std = mean(snowpack_vec), std(snowpack_vec)
	s1_mean, s1_std = mean(soilwater_vec), std(soilwater_vec)

	lday_mean, lday_std = mean(lday_vec), std(lday_vec)
	p_mean, p_std = mean(prcp_vec), std(prcp_vec)
	t_mean, t_std = mean(t_vec), std(t_vec)

	#* define normalization functions
	norm_S0(x) = (x .- s0_mean) ./ s0_std
	norm_S1(x) = (x .- s1_mean) ./ s1_std
	norm_LDAY(x) = (x .- lday_mean) ./ lday_std
	norm_T(x) = (x .- t_mean) ./ t_std
	norm_P(x) = (x .- p_mean) ./ p_std

	#* define interpolation functions
	itp_method = SteffenMonotonicInterpolation()
	itp_Lday = interpolate(data_timepoints, lday_vec, itp_method)
	itp_P = interpolate(data_timepoints, prcp_vec, itp_method)
	itp_T = interpolate(data_timepoints, t_vec, itp_method)
	itpfuncs = (itp_P, itp_T, itp_Lday)

	#* prepare training samples
	norm_input = permutedims(reduce(hcat, (
		norm_S0.(snowpack_vec[1:length(train_timepoints)]),
		norm_S1.(soilwater_vec[1:length(train_timepoints)]),
		norm_T.(t_vec[1:length(train_timepoints)]),
		norm_P.(prcp_vec[1:length(train_timepoints)]),
	)))
	etnn_target = reshape(log.(et_vec[1:length(train_timepoints)] ./ lday_vec[1:length(train_timepoints)]), 1, :)
	qnn_target = reshape(log.(qsim_vec[1:length(train_timepoints)]), 1, :)

	opt_en_ps, opt_q_ps, etnn_loss_recorder, qnn_loss_recorder, m50_func_s, opt_m50_ps, m50_loss_recorder = build_and_train_models(
		itpfuncs,
		(norm_S0, norm_S1, norm_T, norm_P),
		norm_input, etnn_target, qnn_target,
		train_x, train_y, train_timepoints,
		exphydro_params,
		exphydro_init_states,
	)

	mkpath("data/m50/$(basin_id)")
	save("data/m50/$(basin_id)/pretrain_ckpts.jld2", "opt_en_ps", opt_en_ps, "opt_q_ps", opt_q_ps)
	CSV.write("data/m50/$(basin_id)/etnn_loss_df.csv", etnn_loss_recorder)
	CSV.write("data/m50/$(basin_id)/qnn_loss_df.csv", qnn_loss_recorder)
	save("data/m50/$(basin_id)/train_records.jld2", "opt_m50_ps", opt_m50_ps)
	CSV.write("data/m50/$(basin_id)/m50_loss_df.csv", m50_loss_recorder)

	#* make prediction
	total_arr = permutedims(data_x)[[2, 3, 1], :]
	total_y_hat = m50_func_s(total_arr, opt_m50_ps, data_timepoints)
	predicted_df = DataFrame((pred = total_y_hat, obs = data_y))
	CSV.write("data/m50/$(basin_id)/predicted_df.csv", predicted_df)

	GC.gc(true)
end

I will briefly explain the following code. This code builds a Neural-ODE problem and uses DifferentialEquations.jl to solve it. build_and_train_models is the process of model building and training, and main is a process of data processing and result storage. When the code is executed, I will call main repeatedly with different basin_id and its parameters.
The way I am trying now is to use the shell to batch run new julia programs to directly avoid memory accumulation, but I still hope to use for-loop to avoid running environment activation.

Are you sure it’s code performance and not the training? If it’s different data, are the number of steps required in the solver increasing?

The length of each input data is basically the same. I compared the running efficiency under a certain basin_id after multiple cycles with the running efficiency of basin_id alone, and found that the former took more than an hour to optimize 75 times, while the latter only took 20 minutes. For a better understanding, I gave the method of building and training this model:

function M50_ATTR()
	M50_INPUT = [:prcp, :temp, :lday]
	EXPHYDRO_PARAMS = [:f, :Smax, :Qmax, :Df, :Tmax, :Tmin]
	M50_PARAMS = [:exphydro, :et, :q, :est]
	return M50_INPUT, EXPHYDRO_PARAMS, M50_PARAMS
end

# smooting step function
step_fct(x) = (tanh(5.0 * x) + 1.0) * 0.5
# snow precipitation
Ps(P, T, Tmin) = step_fct(Tmin - T) * P
# rain precipitation
Pr(P, T, Tmin) = step_fct(T - Tmin) * P
# snow melt
M(S0, T, Df, Tmax) = step_fct(T - Tmax) * step_fct(S0) * minimum([S0, Df * (T - Tmax)])
# evapotranspiration
PET(T, Lday) = 29.8 * Lday * 0.611 * exp((17.3 * T) / (T + 237.3)) / (T + 273.2)
ET(S1, T, Lday, Smax) = step_fct(S1) * step_fct(S1 - Smax) * PET(T, Lday) + step_fct(S1) * step_fct(Smax - S1) * PET(T, Lday) * (S1 / Smax)
# base flow
Qb(S1, f, Smax, Qmax) = step_fct(S1) * step_fct(S1 - Smax) * Qmax + step_fct(S1) * step_fct(Smax - S1) * Qmax * exp(-f * (Smax - S1))
# peak flow
Qs(S1, Smax) = step_fct(S1) * step_fct(S1 - Smax) * (S1 - Smax)


function LSTMCompact(in_dims, hidden_dims, out_dims)
	lstm_cell = LSTMCell(in_dims => hidden_dims)
	classifier = Dense(hidden_dims => out_dims, sigmoid)
	return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
		x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
		y, carry = lstm_cell(x_init)
		for x in x_rest
			y, carry = lstm_cell((x, carry))
		end
		@return classifier(y)
	end
end

function build_initstate_estimator(hidd_dims = 16)
	lstm_model = LSTMCompact(3, 16, 2)
end

function build_M50_NNs(hidd_dims = 16)
	#* input is norm_snowpack, norm_soilwater, norm_temp, output is evap
	etnn = Lux.Chain(
		Lux.Dense(3, hidd_dims, tanh),
		Lux.Dense(hidd_dims, hidd_dims, leakyrelu),
		Lux.Dense(hidd_dims, 1, leakyrelu),
	)
	#* input is norm_soilwater, norm_prcp
	qnn = Lux.Chain(
		Lux.Dense(2, hidd_dims, tanh),
		Lux.Dense(hidd_dims, hidd_dims, leakyrelu),
		Lux.Dense(hidd_dims, 1, leakyrelu),
	)
	return etnn, qnn
end

function get_default_params_s(; hidd_dims = 16, rng = StableRNG(42))
	states_estimator = build_initstate_estimator(hidd_dims)
	ann_ET, ann_Q = build_M50_NNs(hidd_dims)
	et_ps, _ = Lux.setup(rng, ann_ET)
	q_ps, _ = Lux.setup(rng, ann_Q)
	est_ps, _ = Lux.setup(rng, states_estimator)
	return ComponentVector(
		exphydro = [0.01674478, 1709.461015, 18.46996175, 2.674548848, 0.175739196, -2.092959084],
		et = et_ps,
		q = q_ps,
	)
end

function get_default_params_m(; hidd_dims = 16, rng = StableRNG(42))
	states_estimator = build_initstate_estimator(hidd_dims)
	ann_ET, ann_Q = build_M50_NNs(hidd_dims)
	et_ps, _ = Lux.setup(rng, ann_ET)
	q_ps, _ = Lux.setup(rng, ann_Q)
	est_ps, _ = Lux.setup(rng, states_estimator)
	return ComponentVector(
		exphydro = [0.01674478, 1709.461015, 18.46996175, 2.674548848, 0.175739196, -2.092959084],
		et = et_ps,
		q = q_ps,
		est = est_ps,
	)
end

function build_M50(itpfuncs, normfuncs;
	hidd_dims = 16, rng = StableRNG(42),
	exphydro_params = ComponentVector(),
	exphydro_initstates = [],
	est_norms = []
)
	p_itp, t_itp, l_itp = itpfuncs
	norm_S0, norm_S1, norm_T, norm_P = normfuncs
	states_estimator = build_initstate_estimator(hidd_dims)
	ann_ET, ann_Q = build_M50_NNs(hidd_dims)
	_, est_st = Lux.setup(rng, states_estimator)
	states_est_func(x, p) = Lux.apply(states_estimator, x, p, est_st)[1]
	etnn_func(x, p) = LuxCore.stateless_apply(ann_ET, x, p)
	qnn_func(x, p) = LuxCore.stateless_apply(ann_Q, x, p)

	function M50_ODE_core!(dS, S, p, t)
		@views Tmin, Tmax, Df = exphydro_params[6], exphydro_params[5], exphydro_params[4]

		Lday, P, T = l_itp(t), p_itp(t), t_itp(t)

		@views S0, S1 = S[1], S[2]
		norm_s0, norm_s1 = norm_S0(S0), norm_S1(S1)

		g_ET = etnn_func([norm_s0, norm_s1, norm_T(T)], view(p, :et))
		g_Q = qnn_func([norm_s1, norm_P(P)], view(p, :q))

		melting = M.(S0, T, Df, Tmax)

		dS[1] = Ps(P, T, Tmin) - melting
		dS[2] = Pr(P, T, Tmin) + melting - step_fct(S1) * Lday * exp(g_ET[1]) - step_fct(S1) * exp(g_Q[1])
	end

	function solve_prob(initstates, params, timesteps)
		prob = ODEProblem(
			M50_ODE_core!,
			initstates,
			Float64.((timesteps[1], timesteps[end])),
			params,
		)
		sol = solve(
			prob,
			BS3(),
			saveat = 1.0,
			reltol = 1e-3,
			abstol = 1e-3,
			sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP()),
		)
		Array(sol)
	end

	function solve_ensemble_prob(initstates, params, timesteps, batchsize = 10)
		prob = ODEProblem(
			M50_ODE_core!,
			initstates[:, 1],
			Float64.((timesteps[1][1], maximum(timesteps[1]))),
			params,
		)

		function prob_func(prob, i, repeat)
			remake(prob, u0 = initstates[:, i], tspan = Float64.((timesteps[i][1], maximum(timesteps[i]))))
		end

		ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
		sol = solve(ensemble_prob, BS3(), EnsembleThreads(),
			saveat = 1.0,
			reltol = 1e-3,
			abstol = 1e-3,
			trajectories = batchsize,
			sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP()),
		)
		sol_arr = Array(sol)
		sol_arr
	end

	function m50_model_s(input, params, timesteps)
		sol_arr = solve_prob(exphydro_initstates, params, timesteps)
		norm_prcp_vec = norm_P.(view(input, 1, :))
		norm_s1_vec = norm_S1.(view(sol_arr, 2, :))
		Qout_ = exp.(view(qnn_func(permutedims([norm_s1_vec norm_prcp_vec]), params[:q]), 1, :))
		return Qout_
	end

	function m50_model_m(input, params, timesteps; history_data = ones(3, 20, 10))
		s_init_prop = states_est_func(history_data, params[:est])
		s_init_pred = s_init_prop * (est_norms[2] .- est_norms[1]) .+ est_norms[1]
		sol_arr = solve_ensemble_prob(s_init_pred, params, timesteps)
		return sol_arr
	end

	return m50_model_s, m50_model_m
end

function train_NN(nn_func, data, params; loss_func = (y, y_hat) -> mean((y .- y_hat) .^ 2), optmzr = ADAM(0.01), max_N_iter = 1000)
	x, y = data
	progress = Progress(max_N_iter, desc = "Training...")
	recorder = []

	function objective(u, p)
		y_hat = nn_func(x, u)
		return loss_func(y, y_hat)
	end

	function callback(state, l)
		push!(recorder, (iter = state.iter, loss = l, time = now()))
		next!(progress)
		return false
	end

	optf = Optimization.OptimizationFunction(objective, Optimization.AutoZygote())
	optprob = Optimization.OptimizationProblem(optf, params)
	sol = Optimization.solve(optprob, optmzr, maxiters = max_N_iter, callback = callback)
	recorder_df = DataFrame(recorder)
	return sol.u, recorder_df
end

function batch_train_NN(nn_func, dataloaders, params; loss_func = (y, y_hat) -> mean((y .- y_hat) .^ 2), optmzr = ADAM(0.01), max_N_iter = 1000)
	train_dataloader, val_dataloader = dataloaders
	progress = Progress(max_N_iter, desc = "Training...")
	recorder = []

	function objective(u, data)
		x, y = data
		y_hat = nn_func(x, u)
		return loss_func(y, y_hat)
	end

	early_stop_iter = -1
	initial_val_loss = -Inf

	function callback(state, l)
		val_mean_loss = sum(map(val_dataloader) do data
			x, y = data
			y_hat = nn_func(x, state.u)
			loss_func(y, y_hat)
		end)
		if val_mean_loss > initial_val_loss
			early_stop_iter += 1
			initial_val_loss = val_mean_loss
		end
		if early_stop_iter > 10
			return true
		end
		push!(recorder, (iter = state.iter, train_loss = l, val_loss = val_mean_loss, time = now()))
		next!(progress)
		return false
	end

	optf = Optimization.OptimizationFunction(objective, Optimization.AutoZygote())
	optprob = Optimization.OptimizationProblem(optf, params, train_dataloader)
	sol = Optimization.solve(optprob, optmzr, maxiters = max_N_iter, callback = callback)
	recorder_df = DataFrame(recorder)
	return sol.u, recorder_df
end

function batch_train_M50(nn_func, dataloaders, params; loss_func = (y, y_hat) -> mean((y .- y_hat) .^ 2), optmzr = ADAM(0.01), max_N_iter = 1000)
	train_dataloader, val_dataloader = dataloaders
	progress = Progress(max_N_iter, desc = "Training...")
	recorder = []

	function objective(u, data)
		x, y = data
		y_hat = nn_func(x, u)
		return loss_func(y, y_hat)
	end

	early_stop_iter = -1
	initial_val_loss = -Inf

	function callback(state, l)
		val_mean_loss = sum(map(val_dataloader) do data
			x, y = data
			y_hat = nn_func(x, state.u)
			loss_func(y, y_hat)
		end)
		if val_mean_loss > initial_val_loss
			early_stop_iter += 1
			initial_val_loss = val_mean_loss
		end
		if early_stop_iter > 10
			return true
		end
		push!(recorder, (iter = state.iter, train_loss = l, val_loss = val_mean_loss, time = now()))
		next!(progress)
		return false
	end

	optf = Optimization.OptimizationFunction(objective, Optimization.AutoZygote())
	optprob = Optimization.OptimizationProblem(optf, params, train_dataloader)
	sol = Optimization.solve(optprob, optmzr, maxiters = max_N_iter, callback = callback)
	recorder_df = DataFrame(recorder)
	return sol.u, recorder_df
end

I’m not sure if the step size is consistent each time the ODE is solved, but I did keep the same settings for reltol and abstol.

That’s not what matters, it’s the number of steps. sol.stats.