Construct discrete Bloch sum using GPU

I’m taking my first plunge into the world of GPU computing. What I’m trying to do is to construct a bloch sum of a wavefunction which is defined over 3 unit cells. The idea is that for each point in space, I take the original wavefunction (points in the following code), shift it to the neighbouring unit cells (since it is considered to be periodic), and add the value of the shifted function to the original one, multiplying it with the correct phase factor.

Ofcourse I’m not actually shifting the wavefunction, I’m finding what the indices are of the overlapping sections. This done by for each shift finding the starting indices (done in the find_start function) of the overlapping sections for both the original and shifted wavefunction. Then I run through all the points until one of the two indices goes out of bounds.

I feel like this would not be a bad thing to do on a GPU, however I can’t seem to figure out the best way to do it. Could anyone point me to good resources to implement something like this, or give me an idea of how to do this?

Relevant code:

struct Point3D{T} <: AbstractFloat
  x::T
  y::T
  z::T
end

function construct_bloch_sum(points::Complex{T}, cell::Matrix{T},k::Array{T}) where T<:AbstractFloat
  out = zeros(Complex{T},size(wfc.points))
  dim_a = size(points)[1]
  dim_b = size(points)[2]
  dim_c = size(points)[3]
  R::Point3D{T} = Point3D(0.0)
  for R1=-1:1,R2=-1:1,R3=-1:1
    R= R1*wfc.cell[1]+R2*wfc.cell[2]+R3*wfc.cell[3]
    c = exp(dot(-2*pi*k,[R1,R2,R3])*1im)
    ind1,ind2 = find_start(wfc,R,Int(dim_a/3))
    i3 = ind1[3]
    j3 = ind2[3]
    while i3 <= dim_c && j3 <=dim_c
      i2 = ind1[2]
      j2 = ind2[2]
      while i2 <= dim_b && j2 <= dim_b
        i1 = ind1[1]
        j1 = ind2[1]
        while i1 <=dim_a && j1 <= dim_a
          out[i1,i2,i3]+=c*points[j1,j2,j3]
          i1+=1
          j1+=1
        end
        i2+=1
        j2+=1
      end
      i3+=1
      j3+=1
    end
  end
  return out 
end

function find_start(points,R,partitions)::Tuple{Tuple{Int64,Int64,Int64},Tuple{Int64,Int64,Int64}}
  part_1D = partitions^(1/3)
  dim_a = size(points)[1]
  dim_b = size(points)[2]
  dim_c = size(points)[3]
  stride_a::Int64 = dim_a/part_1D
  stride_b::Int64 = dim_b/part_1D
  stride_c::Int64 = dim_c/part_1D
  anchors = [points[a,b,c].p for a=1:stride_a:dim_a,b=1:stride_a:dim_b,c=1:stride_c:dim_c]
  shifted_anchors = [points[a,b,c].p-R for a=1:stride_a:dim_a,b=1:stride_a:dim_b,c=1:stride_c:dim_c]
  for i in eachindex(anchors)
    for j in eachindex(shifted_anchors)
      if norm(anchors[i]-shifted_anchors[j])<0.00001
        tmp1 = ind2sub(anchors,i)
        tmp2 = ind2sub(shifted_anchors,j)
        ind1 = ((tmp1[1]-1)*stride_a+1,(tmp1[2]-1)*stride_b+1,(tmp1[3]-1)*stride_c+1)
        ind2 = ((tmp2[1]-1)*stride_a+1,(tmp2[2]-1)*stride_b+1,(tmp2[3]-1)*stride_c+1)
        return ind1,ind2
      end
    end
  end
end

Any help at all is much appreciated!

Hi Louis,

Not sure what to do with the GPUs, but your code looks like a slow Fourier transform to me. If I’m not mistaken, you can figure out the Fourier coefficients of your WF easily from your input data. Then you just do a global FFT. I use this in some code for Wannier functions, although your setup might be different?

Yes this is kind of exactly what I’m trying to do. I didn’t know how to do a global fourier, because although figuring out the components is easy, how do i figure out which parts of the wavefunction overlap with which. Could you point me to the code you’re using the global fft in?

Will contact you in PM (your original question is still valid)