Hi all, I posted earlier about minimising memory usage on gpu modelling a system here. Since then I’ve tidied up my code a bit and have managed to reduce memory usage, but it still runs out of memory when I try to scan over a range of parameters.
I’m setting ‘save_on = false’, which according to the documentation should stop it from saving intermediate states, and using a callback to save the solution to cpu at certain points in time. My understanding is that if it has enough memory to get through a few timesteps, it shouldn’t use any extra memory to go for an arbitrarily long time since it shouldn’t be saving intermediate timesteps.
What I find actually happens is for large grid sizes, it will run fine for a long time (100-200 seconds of simulated time), and then run out of memory and throw an error.
I’m basically wondering if there is another way to ensure that the solver isn’t saving any intermediate information, or if there’s a way to see what is taking up all of this memory so I can potentially use a callback to manually clear the gpu memory during solving?
using FFTW, CUDA, DifferentialEquations, LinearAlgebra, Plots function kfunc_opt!(dψ,ψ) mul!(dψ,Pf,ψ) dψ .*= k2 Pi!*dψ return nothing end function GPE!(dψ,ψ,var,t) # GPE Equation kfunc_opt!(dψ,ψ) @. dψ = -(im + γ)*(0.5*dψ + (V_0 + abs2(ψ) - 1)*ψ) end function GPU_Solve(save_array,EQ!, ψ, tspan) savepoints = tspan[2:end] condition(u, t, integrator) = t ∈ savepoints function affect!(integrator) push!(save_array, Array(integrator.u)) end push!(save_array, Array(ψ)) cb = DiscreteCallback(condition, affect!) i = 1 prob = ODEProblem(EQ!,ψ,(tspan,tspan[end])) solve(prob, callback=cb, tstops = savepoints, save_on=false) end L = 8 M = 60 x = LinRange(-L,L,M) |> cu; dx = x - x kx = fftfreq(M,2π/dx) |> collect |> cu; dkx = kx - kx k2 = kx.^2 .+ kx'.^2 .+ reshape(kx,(1,1,M)).^2; V_0 = 0.3*[i^2 + j^2 + k^2 for i in x, j in x, k in x] |> cu; const Pf = Float32(dx^3/(2π)^1.5)*plan_fft((cu(rand(M,M,M) + im*rand(M,M,M)))); const Pi! = Float32(M^3*dkx^3/(2π)^1.5)*plan_ifft!((cu(rand(M,M,M) + im*rand(M,M,M)))); γ = 0.05 tspan = LinRange(0.0,10,50); CUDA.memory_status() res_GS =  GPU_Solve(res_GS,GPE!,(cu(randn(M,M,M) + im*randn(M,M,M))),tspan); begin t = 3 # Change this to look at different times heatmap(abs2.(res_GS[t][:,:,30])) |> display end