I don’t know how to describe the problem in my project, but I got the following two simple examples after a series of tests.
First of all, the first test looks like this:
#test1.jl
using Integrals
using SciMLSensitivity
using Zygote
using SphericalHarmonics
function test1(t::Float64,D::Int64)
buf = zeros(D,D,D,D)
T = Zygote.Buffer(buf, D,D,D,D)
T[:,:,:,:]=buf[:,:,:,:]
for i in 1:D
for j in 1:D
for k in 1:D
for l in 1:D
l1,m1 = identity(i),0
l2,m2 = identity(j),0
l3,m3 = identity(k),0
l4,m4 = identity(l),0
T[i,j,k,l] = 2*pi * solve(IntegralProblem((x,t)->SphericalHarmonics.sphericalharmonic(x, 0, l1,m1)*SphericalHarmonics.sphericalharmonic(x, 0, l2,m2)*SphericalHarmonics.sphericalharmonic(x, 0, l3,m3)*SphericalHarmonics.sphericalharmonic(x, 0, l4,m4)*sin(x)*exp(t*cos(x)),0, pi,t),QuadGKJL(),abstol=1e-16)[1]
end
end
end
end
T=copy(T)
return sum(T)
end
function main(D::Int64)
t0 = 1.5
f(t) = test1(t,D)
@time g = gradient(t -> f(t), t0)[1]
return nothing
end
@time main(64)
This will not throw an error, but when I run test2.jl , it will throw an error in an hour or so:
#test2.jl
using Integrals
using SciMLSensitivity
using Zygote
using SphericalHarmonics
function test2(t::Float64,D::Int64)
buf = zeros(D,D,D,D)
T = Zygote.Buffer(buf, D,D,D,D)
T[:,:,:,:]=buf[:,:,:,:]
for i in 1:D
for j in 1:D
for k in 1:D
for l in 1:D
l1,m1 = floor(Int64,sqrt(i-1)),0
l2,m2 = floor(Int64,sqrt(j-1)),0
l3,m3 = floor(Int64,sqrt(k-1)),0
l4,m4 = floor(Int64,sqrt(l-1)),0
T[i,j,k,l] = 2*pi * solve(IntegralProblem((x,t)->SphericalHarmonics.sphericalharmonic(x, 0, l1,m1)*SphericalHarmonics.sphericalharmonic(x, 0, l2,m2)*SphericalHarmonics.sphericalharmonic(x, 0, l3,m3)*SphericalHarmonics.sphericalharmonic(x, 0, l4,m4)*sin(x)*exp(t*cos(x)),0, pi,t),QuadGKJL(),abstol=1e-16)[1]
end
end
end
end
T=copy(T)
return sum(T)
end
function main(D::Int64)
t0 = 1.5
f(t) = test2(t,D)
@time g = gradient(t -> f(t), t0)[1]
return nothing
end
@time main(64)
throw:
=>> PBS: job killed: vmem 58051637248 exceeded limit 57982058496
I’ve been running on multiple HPC platforms throwing similar out-of-memory errors.
There are only four lines of code between test1 and test2:
l1,m1 = identity(i),0 l1,m1 = floor(Int64,sqrt(i-1)),0
l2,m2 = identity(j),0 l2,m2 = floor(Int64,sqrt(j-1)),0
l3,m3 = identity(k),0 l3,m3 = floor(Int64,sqrt(k-1)),0
l4,m4 = identity(l),0 l4,m4 = floor(Int64,sqrt(l-1)),0
I don’t know what the problem is. I think the memory cost of test1 and test2 should be about the same. Do you have any good suggestions?