Thanks Tim, yeah, that’s a pretty important thing to have missed. Just benchmarking memory throughput I guess.
I actually used that toy kernel to test a customization I had made to the cuda macro to support selecting 32 bit indexing. I wanted to ensure that the custom switch was overridden to ensure 64 bit indexing occurred whenever the number of elements > 2^32.
I used it as an example because the generated code is tiny.
A better example may be @omlins diffusion3D example (slightly amended):
function diffusion3D_step!(T2, T, Ci, lam, dt, _dx, _dy, _dz)
ix = (blockIdx().x-1i32) * blockDim().x + threadIdx().x
iy = (blockIdx().y-1i32) * blockDim().y + threadIdx().y
T_ix_iy_izm1 = 0.0f0
T_ix_iy_iz = 0.0f0
T_ix_iy_izp1 = T[ix,iy,1i32]
nx, ny, nz = size(T2)
for iz = 1i32:nz
T_ix_iy_izm1 = T_ix_iy_iz
T_ix_iy_iz = T_ix_iy_izp1
T_ix_iy_izp1 = iz<nz ? T[ix,iy,iz+1i32] : 0.0f0
if (ix>1i32 && ix<nx && iy>1i32 && iy<ny && iz>1i32 && iz<nz)
T2[ix,iy,iz] = T_ix_iy_iz + dt*(Ci[ix,iy,iz]*(
- ((-lam*(T[ix+1i32,iy,iz] - T_ix_iy_iz)*_dx) - (-lam*(T_ix_iy_iz - T[ix-1i32,iy,iz])*_dx))*_dx
- ((-lam*(T[ix,iy+1i32,iz] - T_ix_iy_iz)*_dy) - (-lam*(T_ix_iy_iz - T[ix,iy-1i32,iz])*_dy))*_dy
- ((-lam*(T_ix_iy_izp1 - T_ix_iy_iz)*_dz) - (-lam*(T_ix_iy_iz - T_ix_iy_izm1)*_dz))*_dz
))
end
end
return
end
So, at least we have array indexing varying over the z axis.
This is also going to be memory bound. NSight shows something like 12-14% compute and a possible 14% improvement if eg stalls can be reduced. OK, we can try to improve the code. But isn’t the question whether or not 32 bit indexing can improve performance by itself?
Any 32 bit indexing example is, by definition, going to access arrays. In the above example, the T matrix is accessed by the 32 bit indices ix, iy and iz.
We have a flat 64 bit address space, so all accesses will be at address &T + n*(ix+(iy+iz*ny)*nx+c), where n is the element size in bytes and c is a constant.
In the above example we access flattened arrays with offsets:
- n*(ix+(iy+iz*ny)*nx)
- n*(ix+(iy+iz*ny)*nx-1)
- n*(ix+(iy+iz*ny)*nx+1)
The compute requirement for this seems minimal and is likely swamped by the latency of global array accesses.
If I use 32 bit indexing for this example, the llvm code always shows 32 bit operations.
define ptx_kernel void @_Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0_({ i64, i32 } %state, { i8 addrspace(1)*, i64, [3 x i32], i64 } %0, { i8 addrspace(1)*, i64, [3 x i32], i64 } %1, { i8 addrspace(1)*, i64, [3 x i32], i64 } %2, float %3, float %4, float %5, float %6, float %7) local_unnamed_addr {
conversion:
%.fca.2.0.extract37 = extractvalue { i8 addrspace(1)*, i64, [3 x i32], i64 } %0, 2, 0
%.fca.2.1.extract38 = extractvalue { i8 addrspace(1)*, i64, [3 x i32], i64 } %0, 2, 1
%.fca.2.2.extract39 = extractvalue { i8 addrspace(1)*, i64, [3 x i32], i64 } %0, 2, 2
%.fca.0.extract11 = extractvalue { i8 addrspace(1)*, i64, [3 x i32], i64 } %1, 0
%.fca.2.0.extract13 = extractvalue { i8 addrspace(1)*, i64, [3 x i32], i64 } %1, 2, 0
%.fca.2.1.extract14 = extractvalue { i8 addrspace(1)*, i64, [3 x i32], i64 } %1, 2, 1
%.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [3 x i32], i64 } %2, 2, 0
%.fca.2.1.extract = extractvalue { i8 addrspace(1)*, i64, [3 x i32], i64 } %2, 2, 1
%8 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%9 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%10 = mul i32 %8, %9
%11 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%12 = add i32 %10, %11
%13 = add i32 %12, 1
%14 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
%15 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
%16 = mul nuw nsw i32 %14, %15
%17 = call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
%18 = add nuw nsw i32 %16, %17
%19 = add nuw nsw i32 %18, 1
%20 = bitcast i8 addrspace(1)* %.fca.0.extract11 to float addrspace(1)*
%value_phi = call i32 @llvm.smax.i32(i32 %.fca.2.2.extract39, i32 0)
%21 = icmp slt i32 %.fca.2.2.extract39, 1
br i1 %21, label %L477, label %L109.preheader
L109.preheader: ; preds = %conversion
%22 = mul i32 %.fca.2.0.extract13, %18
%23 = add i32 %12, %22
%24 = getelementptr inbounds float, float addrspace(1)* %20, i32 %23
%25 = load float, float addrspace(1)* %24, align 4
%.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [3 x i32], i64 } %2, 0
%.fca.0.extract35 = extractvalue { i8 addrspace(1)*, i64, [3 x i32], i64 } %0, 0
%26 = icmp slt i32 %13, 2
%27 = icmp sge i32 %13, %.fca.2.0.extract37
%28 = icmp eq i32 %18, 0
%29 = icmp sge i32 %19, %.fca.2.1.extract38
%30 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
%31 = fneg float %3
%32 = add i32 %12, -1
%33 = add nsw i32 %18, -1
%34 = bitcast i8 addrspace(1)* %.fca.0.extract35 to float addrspace(1)*
%35 = select i1 %26, i1 true, i1 %27
%brmerge = select i1 %35, i1 true, i1 %28
%36 = select i1 %brmerge, i1 true, i1 %29
br label %L109
L109: ; preds = %L466, %L109.preheader
%value_phi4 = phi i32 [ %93, %L466 ], [ 1, %L109.preheader ]
%value_phi6 = phi float [ %value_phi8, %L466 ], [ %25, %L109.preheader ]
%value_phi7 = phi float [ %value_phi6, %L466 ], [ 0.000000e+00, %L109.preheader ]
%.not = icmp sge i32 %value_phi4, %.fca.2.2.extract39
br i1 %.not, label %L160, label %L115
L115: ; preds = %L109
%37 = mul i32 %value_phi4, %.fca.2.1.extract14
%reass.add = add i32 %18, %37
%reass.mul = mul i32 %reass.add, %.fca.2.0.extract13
%38 = add i32 %12, %reass.mul
%39 = getelementptr inbounds float, float addrspace(1)* %20, i32 %38
%40 = load float, float addrspace(1)* %39, align 4
br label %L160
L160: ; preds = %L115, %L109
%value_phi8 = phi float [ %40, %L115 ], [ 0.000000e+00, %L109 ]
%41 = icmp ult i32 %value_phi4, 2
%or.cond57 = select i1 %36, i1 true, i1 %41
%brmerge58 = or i1 %or.cond57, %.not
br i1 %brmerge58, label %L466, label %L173
L173: ; preds = %L160
%42 = add nsw i32 %value_phi4, -1
%43 = mul i32 %42, %.fca.2.1.extract
%reass.add43 = add i32 %18, %43
%reass.mul44 = mul i32 %reass.add43, %.fca.2.0.extract
%44 = add i32 %12, %reass.mul44
%45 = getelementptr inbounds float, float addrspace(1)* %30, i32 %44
%46 = load float, float addrspace(1)* %45, align 4
%47 = mul i32 %42, %.fca.2.1.extract14
%reass.add45 = add i32 %18, %47
%reass.mul46 = mul i32 %reass.add45, %.fca.2.0.extract13
%48 = add i32 %reass.mul46, %13
%49 = getelementptr inbounds float, float addrspace(1)* %20, i32 %48
%50 = load float, float addrspace(1)* %49, align 4
%51 = fsub float %50, %value_phi6
%52 = fmul float %51, %31
%53 = fmul float %52, %5
%54 = add i32 %32, %reass.mul46
%55 = getelementptr inbounds float, float addrspace(1)* %20, i32 %54
%56 = load float, float addrspace(1)* %55, align 4
%57 = fsub float %value_phi6, %56
%58 = fmul float %57, %31
%59 = fmul float %58, %5
%60 = fsub float %53, %59
%61 = fneg float %60
%62 = fmul float %61, %5
%reass.add49 = add i32 %47, %19
%reass.mul50 = mul i32 %reass.add49, %.fca.2.0.extract13
%63 = add i32 %12, %reass.mul50
%64 = getelementptr inbounds float, float addrspace(1)* %20, i32 %63
%65 = load float, float addrspace(1)* %64, align 4
%66 = fsub float %65, %value_phi6
%67 = fmul float %66, %31
%68 = fmul float %67, %6
%reass.add51 = add i32 %33, %47
%reass.mul52 = mul i32 %reass.add51, %.fca.2.0.extract13
%69 = add i32 %12, %reass.mul52
%70 = getelementptr inbounds float, float addrspace(1)* %20, i32 %69
%71 = load float, float addrspace(1)* %70, align 4
%72 = fsub float %value_phi6, %71
%73 = fmul float %72, %31
%74 = fmul float %73, %6
%75 = fsub float %68, %74
%76 = fmul float %75, %6
%77 = fsub float %62, %76
%78 = fsub float %value_phi8, %value_phi6
%79 = fmul float %78, %31
%80 = fmul float %79, %7
%81 = fsub float %value_phi6, %value_phi7
%82 = fmul float %81, %31
%83 = fmul float %82, %7
%84 = fsub float %80, %83
%85 = fmul float %84, %7
%86 = fsub float %77, %85
%87 = fmul float %46, %86
%88 = fmul float %87, %4
%89 = fadd float %value_phi6, %88
%90 = mul i32 %.fca.2.1.extract38, %42
%reass.add53 = add i32 %18, %90
%reass.mul54 = mul i32 %reass.add53, %.fca.2.0.extract37
%91 = add i32 %12, %reass.mul54
%92 = getelementptr inbounds float, float addrspace(1)* %34, i32 %91
store float %89, float addrspace(1)* %92, align 4
br label %L466
L466: ; preds = %L173, %L160
%.not42.not = icmp eq i32 %value_phi4, %value_phi
%93 = add nuw i32 %value_phi4, 1
br i1 %.not42.not, label %L477, label %L109
L477: ; preds = %L466, %conversion
ret void
}
The ptx code does also, except at the point where array addresses are required, which need to be 64 bit.
//
// Generated by LLVM NVPTX Back-End
//
.version 8.5
.target sm_89
.address_size 64
// .globl _Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0_ // -- Begin function _Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0_
// @_Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0_
.visible .entry _Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0_(
.param .align 8 .b8 _Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_0[16],
.param .align 8 .b8 _Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_1[40],
.param .align 8 .b8 _Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_2[40],
.param .align 8 .b8 _Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_3[40],
.param .f32 _Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_4,
.param .f32 _Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_5,
.param .f32 _Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_6,
.param .f32 _Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_7,
.param .f32 _Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_8
)
{
.reg .pred %p<15>;
.reg .b32 %r<71>;
.reg .f32 %f<51>;
.reg .b64 %rd<26>;
// %bb.0: // %conversion
ld.param.u32 %r42, [_Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_1+24];
setp.lt.s32 %p2, %r42, 1;
@%p2 bra $L__BB0_7;
// %bb.1: // %L109.preheader
ld.param.v2.u32 {%r43, %r44}, [_Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_3+16];
ld.param.u64 %rd8, [_Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_3];
ld.param.v2.u32 {%r40, %r41}, [_Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_1+16];
ld.param.u64 %rd5, [_Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_1];
ld.param.f32 %f11, [_Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_8];
ld.param.f32 %f10, [_Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_7];
ld.param.f32 %f9, [_Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_6];
ld.param.f32 %f8, [_Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_5];
ld.param.f32 %f7, [_Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_4];
ld.param.u64 %rd1, [_Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_2];
ld.param.v2.u32 {%r4, %r5}, [_Z17diffusion3D_step_13CuDeviceArrayI7Float32Li3ELi1E5Int32ES2_S2_S0_S0_S0_S0_S0__param_2+16];
mov.u32 %r46, %ctaid.x;
mov.u32 %r47, %ntid.x;
mul.lo.s32 %r8, %r46, %r47;
mov.u32 %r9, %tid.x;
add.s32 %r10, %r8, %r9;
add.s32 %r11, %r10, 1;
mov.u32 %r48, %ctaid.y;
mov.u32 %r49, %ntid.y;
mul.lo.s32 %r12, %r48, %r49;
mov.u32 %r13, %tid.y;
add.s32 %r14, %r12, %r13;
add.s32 %r15, %r14, 1;
max.s32 %r16, %r42, 0;
mul.lo.s32 %r51, %r4, %r14;
add.s32 %r52, %r10, %r51;
mul.wide.s32 %rd11, %r52, 4;
add.s64 %rd12, %rd1, %rd11;
ld.global.f32 %f48, [%rd12];
setp.lt.s32 %p3, %r11, 2;
setp.ge.s32 %p4, %r11, %r40;
setp.eq.s32 %p5, %r14, 0;
setp.ge.s32 %p6, %r15, %r41;
neg.f32 %f2, %f7;
or.pred %p7, %p3, %p4;
or.pred %p8, %p7, %p5;
or.pred %p1, %p8, %p6;
mad.lo.s32 %r69, %r40, %r14, %r9;
mul.lo.s32 %r18, %r41, %r40;
mad.lo.s32 %r68, %r43, %r14, %r9;
mul.lo.s32 %r20, %r44, %r43;
add.s32 %r53, %r14, -1;
mad.lo.s32 %r67, %r4, %r53, %r9;
mul.lo.s32 %r22, %r5, %r4;
add.s32 %r66, %r9, %r51;
mad.lo.s32 %r65, %r4, %r15, %r9;
add.s32 %r54, %r5, %r13;
add.s32 %r55, %r54, %r12;
mad.lo.s32 %r64, %r4, %r55, %r9;
mov.f32 %f12, 0f00000000;
mov.u32 %r70, 0;
mov.f32 %f49, %f12;
bra.uni $L__BB0_2;
$L__BB0_6: // %L466
// in Loop: Header=BB0_2 Depth=1
add.s32 %r69, %r69, %r18;
add.s32 %r68, %r68, %r20;
add.s32 %r67, %r67, %r22;
add.s32 %r66, %r66, %r22;
add.s32 %r65, %r65, %r22;
add.s32 %r64, %r64, %r22;
setp.ne.s32 %p14, %r16, %r70;
mov.f32 %f49, %f3;
@%p14 bra $L__BB0_2;
bra.uni $L__BB0_7;
$L__BB0_2: // %L109
// =>This Inner Loop Header: Depth=1
mov.f32 %f3, %f48;
add.s32 %r70, %r70, 1;
setp.ge.s32 %p9, %r70, %r42;
mov.f32 %f48, %f12;
@%p9 bra $L__BB0_4;
// %bb.3: // %L115
// in Loop: Header=BB0_2 Depth=1
add.s32 %r56, %r8, %r64;
mul.wide.s32 %rd13, %r56, 4;
add.s64 %rd4, %rd1, %rd13;
ld.global.f32 %f48, [%rd4];
$L__BB0_4: // %L160
// in Loop: Header=BB0_2 Depth=1
setp.lt.u32 %p11, %r70, 2;
or.pred %p12, %p1, %p11;
or.pred %p13, %p12, %p9;
@%p13 bra $L__BB0_6;
// %bb.5: // %L173
// in Loop: Header=BB0_2 Depth=1
add.s32 %r57, %r8, %r68;
mul.wide.s32 %rd14, %r57, 4;
add.s64 %rd15, %rd8, %rd14;
ld.global.f32 %f14, [%rd15];
add.s32 %r58, %r8, %r66;
add.s32 %r59, %r58, 1;
mul.wide.s32 %rd16, %r59, 4;
add.s64 %rd17, %rd1, %rd16;
ld.global.f32 %f15, [%rd17];
sub.f32 %f16, %f15, %f3;
mul.f32 %f17, %f16, %f2;
mul.f32 %f18, %f17, %f9;
add.s32 %r60, %r58, -1;
mul.wide.s32 %rd18, %r60, 4;
add.s64 %rd19, %rd1, %rd18;
ld.global.f32 %f19, [%rd19];
sub.f32 %f20, %f3, %f19;
mul.f32 %f21, %f20, %f2;
mul.f32 %f22, %f21, %f9;
sub.f32 %f23, %f18, %f22;
neg.f32 %f24, %f23;
mul.f32 %f25, %f24, %f9;
add.s32 %r61, %r8, %r65;
mul.wide.s32 %rd20, %r61, 4;
add.s64 %rd21, %rd1, %rd20;
ld.global.f32 %f26, [%rd21];
sub.f32 %f27, %f26, %f3;
mul.f32 %f28, %f27, %f2;
mul.f32 %f29, %f28, %f10;
add.s32 %r62, %r8, %r67;
mul.wide.s32 %rd22, %r62, 4;
add.s64 %rd23, %rd1, %rd22;
ld.global.f32 %f30, [%rd23];
sub.f32 %f31, %f3, %f30;
mul.f32 %f32, %f31, %f2;
mul.f32 %f33, %f32, %f10;
sub.f32 %f34, %f29, %f33;
mul.f32 %f35, %f34, %f10;
sub.f32 %f36, %f25, %f35;
sub.f32 %f37, %f48, %f3;
mul.f32 %f38, %f37, %f2;
mul.f32 %f39, %f38, %f11;
sub.f32 %f40, %f3, %f49;
mul.f32 %f41, %f40, %f2;
mul.f32 %f42, %f41, %f11;
sub.f32 %f43, %f39, %f42;
mul.f32 %f44, %f43, %f11;
sub.f32 %f45, %f36, %f44;
mul.f32 %f46, %f14, %f45;
fma.rn.f32 %f47, %f46, %f8, %f3;
add.s32 %r63, %r8, %r69;
mul.wide.s32 %rd24, %r63, 4;
add.s64 %rd25, %rd5, %rd24;
st.global.f32 [%rd25], %f47;
bra.uni $L__BB0_6;
$L__BB0_7: // %L477
ret;
// -- End function
}
The SASS code shows 29 registers being used in the 32 bit indexing case versus 38 registers for 64 bit indexing. So at least register usage is lower. SASS code for the latter uses the uniform register file and LEA calculations, versus IMAD instructions for the former. Run times using @omlins setup don’t appear to differ by much at all between 32 bit and 64 bit indexing
Here is what NSight Compute shows at the summary level. ids 4-8 are for 64 bit indexing, ids 13-17 for 32 bit indexing:
I’m struggling to see any obvious benefit from using 32 bit indexing, but I’m happy to try out on other kernels if people can supply them.