DifferentialEquations.jl EnsembleThreads: crash w/increase in problem size

DifferentialEquations.jl has facilities for parallel processing of the solver w/different parameters, see docs. In further testing of prior code, I added the option to run parallel threads using EmsembleThreads(), as documented on the above docs page. The following code crashes as the size of the problem increases, for reasons that I cannot find.

To run the code:

import FrenchFlagThreads as fft

Then for the following, the first four work, the next two crash with the first lines of crash reporting shown. Note also that in REPL w/Revise, sometimes one of the first four commands fails the first time, but then running it again immediately works, suggesting some sort of compilation issue as an additional problem ??

fft.main(S=3, B=256, threads=false, ensemble=false)
fft.main(S=3, B=256, threads=false, ensemble=true)
fft.main(S=3, B=256, threads=true, ensemble=true)
fft.main(S=7, B=256, threads=false, ensemble=false)

Those first four work (apart from need to rerun sometimes because of REPL or compile issue)

fft.main(S=7, B=256, threads=false, ensemble=true)

fails w/output:

0.25135885013308756
0.2627287236602456
0.2510416786361418
0.25157018945790177
ERROR: BoundsError: attempt to access 13-element Vector{Vector{Float64}} at index [14]
Stacktrace:
  [1] getindex
    @ ./array.jl:805 [inlined]
  [2] (::DiffEqSensitivity.var"#253#259"{Vector{Matrix{Float64}}, Vector{Vector{Float64}}, Vector{Float64}})(i::Int64)
    @ DiffEqSensitivity /opt/julia/packages/DiffEqSensitivity/uakCr/src/concrete_solve.jl:527
  [3] _mapreduce(f::DiffEqSensitivity.var"#253#259"{Vector{Matrix{Float64}}, Vector{Vector{Float64}}, Vector{Float64}}, op::typeof(Base.add_sum), #unused#::IndexLinear, A::Base.OneTo{Int64})
    @ Base ./reduce.jl:411


fft.main(S=7, B=256, threads=true, ensemble=true)

fails with output:

0.25135885013308756
0.2627287236602456
0.2510416786361418
0.25157018945790177
ERROR: TaskFailedException
Stacktrace:
  [1] wait
    @ ./task.jl:322 [inlined]
  [2] threading_run(func::Function)
    @ Base.Threads ./threadingconstructs.jl:34
  [3] macro expansion
    @ ./threadingconstructs.jl:93 [inlined]
  [4] tmap(::Function, ::Vector{typeof(∂(λ))}, ::Vararg{Any, N} where N)
    @ SciMLBase /opt/julia/packages/SciMLBase/x3z0g/src/ensemble/basic_ensemble_solve.jl:220

Here is the code:

module FrenchFlagThreads

using Random: seed!
using RandomNumbers.Xorshifts
using DifferentialEquations, DiffEqFlux
using DiffEqSensitivity, OrdinaryDiffEq, ForwardDiff
using Statistics

r = Xoroshiro128Plus(0x1234567890abcdef)

const final_t		= 3.0			# final time
const tspan			= (0.0,final_t)	# range of times
atol				= 1e-6			# abs tol for diffeq
rtol				= 1e-4			# rel tol for diffeq
####################################################################

phi(x) = 1.0 ./ (exp.(-x) .+ 1.0)

function ode!(ds, s, w, t, input)
	ds .= phi(w*s) .- s .+ input
end

function newbatch(B, S; stochastic=false)
	initial = 0.1 .* ones(B,S)
	if stochastic
		inputvec = 2.0 .* rand(r,B)
	else
		inputvec = [range(0,2,length=B);]	# need semicolon to expand range
	end
	# make inputs [B,S] with first column as input to y0 at t=0, other cols as zeros
	input = hcat(inputvec, zeros(B,S-1))
	output = [(inputvec[i] >= 0.5 && inputvec[i] <= 1.5) ? 1.0 : 0.0 for i in 1:B]
	# map fails because of bug in autodiff
	# output = map(x -> convert(Float64,x >= 0.5 && x <= 1.5), inputvec)
	return input, initial, output
end

function loss(p, B, S, stoch_batch; threads=false, use_ensemble=false)
	input, initial, output = newbatch(B, S, stochastic=stoch_batch)
	if use_ensemble
		prob = ODEProblem((ds, s, p, t) -> ode!(ds, s, w(p,S), t, input[1,:]),
							initial[1,:], tspan, p)
		function prob_func(prob, i, repeat)
			remake(prob, f = (ds, s, p, t) -> ode!(ds, s, w(p,S), t, input[i,:]),
							u0 = initial[i,:])
		end

		ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
		ensemble = threads ? EnsembleThreads() : EnsembleSerial()
		soln = solve(ensemble_prob, Tsit5(), ensemble, saveon = false, trajectories = B,
					reltol=rtol,abstol=atol)
		# soln is vector of length B, each soln has final time at [end]
		# store each batch element as a column because Julia is column major and so faster
		result = soln[1][end]
		for i in 2:B
			result = hcat(result, soln[i][end])
		end
	else
		for i in 1:B
			prob = ODEProblem((ds, s, p, t) -> ode!(ds, s, w(p,S), t, input[i,:]),
							initial[i,:], tspan, p)
			soln = solve(prob, Tsit5(), saveon = false, reltol=rtol,abstol=atol)
			# store each batch element as a column because Julia is column major and so faster
			result = (i == 1) ? soln[end] : hcat(result, soln[end])
		end
	end
  	# assume the last column (last species) of each replicate is the phenotype for that
  	# replicate, and then compare the phenotype to the expected output
	loss_val = mean(abs2.(a(p) .* result[end,:] .- output))
  # return loss, additional ret vals used as input for callback
	return loss_val
end

function callback(p, l)
	display(l)
	return false
end

w(p, S) = reshape(p[1:end-1],S,S)
a(p) = p[end]
	
# parameters p:
# a = p[end] = 1.0
# w = reshape(p[1:end-1],S,S)
function main(;B::Int=4, S::Int=3, stoch_batch::Bool=false,
			rnd_entropy::Bool=false, threads::Bool=false, ensemble::Bool=false)
	rnd_entropy ? seed!(r) : seed!(r, 0x1234567437ab88ef)
	p = ones(S*S+1)
	p[1:end-1] = 0.1 .* randn(r, S*S)
	# try without ADAM arg, should then be ADAM -> BFGS by default
	# or do sequence with ADAM then BFGS
	opt_result = DiffEqFlux.sciml_train(p -> loss(p, B, S, stoch_batch; 
			 threads=threads, use_ensemble=ensemble),
			 p, ADAM(0.1); cb = callback, maxiters = 50)
	finalp = opt_result.u
	
	println()
	print(opt_result)
	println()
	print("Final param = ", typeof(finalp))
	println()
	println("Loss calculated from final param = ", 
				loss(finalp, B, S, stoch_batch; threads=threads)[1])
	return w(opt_result.u,S)
end
main(S=3, B=256, threads=false, ensemble=true)

pretty reliably fails for me on 1.7.0-rc3 with

ERROR: LoadError: BoundsError: attempt to access 12-element Vector{Vector{Float64}} at index [13]
main(S=7, B=256, threads=false, ensemble=false)

seems to work.

Open an issue.

1 Like

Posted as issue in GitHub: SciML/DifferentialEquations.jl, link here