3232
3333__global__ void NmDistanceKernel (const int pc0_n, const float *pc0_xyz, const int pc1_n, const float *pc1_xyz, float *result, int *result_i){
3434 int tid = blockIdx .x * blockDim .x + threadIdx .x ;
35-
36- if (tid >= pc0_n) return ;
37-
38- float x0 = pc0_xyz[tid * 3 + 0 ];
39- float y0 = pc0_xyz[tid * 3 + 1 ];
40- float z0 = pc0_xyz[tid * 3 + 2 ];
41-
42- __shared__ float shared_pc1[THREADS_PER_BLOCK * 3 ];
4335
36+ // Load point coordinates for valid threads
37+ float x0, y0, z0;
4438 int best_i = -1 ;
4539 float best = 1e20 ;
40+
41+ if (tid < pc0_n) {
42+ x0 = pc0_xyz[tid * 3 + 0 ];
43+ y0 = pc0_xyz[tid * 3 + 1 ];
44+ z0 = pc0_xyz[tid * 3 + 2 ];
45+ }
46+
47+ __shared__ float shared_pc1[THREADS_PER_BLOCK * 3 ];
4648
4749 for (int i = 0 ; i < pc1_n; i += THREADS_PER_BLOCK) {
48- // Copy a block of pc1 to shared memory
50+ // All threads cooperate to load shared memory
4951 int pc1_idx = i + threadIdx .x ;
5052 if (pc1_idx < pc1_n) {
5153 shared_pc1[threadIdx .x * 3 + 0 ] = pc1_xyz[pc1_idx * 3 + 0 ];
@@ -54,26 +56,30 @@ __global__ void NmDistanceKernel(const int pc0_n, const float *pc0_xyz, const in
5456 }
5557
5658 __syncthreads ();
57-
58- // Compute the distance between pc0[tid] and the points in shared_pc1
59- int num_elems = min (THREADS_PER_BLOCK, pc1_n - i);
60- for (int j = 0 ; j < num_elems; j++) {
61- float x1 = shared_pc1[j * 3 + 0 ];
62- float y1 = shared_pc1[j * 3 + 1 ];
63- float z1 = shared_pc1[j * 3 + 2 ];
64- float d = (x1 - x0) * (x1 - x0) + (y1 - y0) * (y1 - y0) + (z1 - z0) * (z1 - z0);
65- if (d < best) {
66- best = d;
67- best_i = j + i;
59+
60+ // Only valid threads compute distances
61+ if (tid < pc0_n) {
62+ int num_elems = min (THREADS_PER_BLOCK, pc1_n - i);
63+ for (int j = 0 ; j < num_elems; j++) {
64+ float x1 = shared_pc1[j * 3 + 0 ];
65+ float y1 = shared_pc1[j * 3 + 1 ];
66+ float z1 = shared_pc1[j * 3 + 2 ];
67+ float d = (x1 - x0) * (x1 - x0) + (y1 - y0) * (y1 - y0) + (z1 - z0) * (z1 - z0);
68+ if (d < best) {
69+ best = d;
70+ best_i = j + i;
71+ }
6872 }
6973 }
7074
7175 __syncthreads ();
7276 }
73-
74- // done with this thread in tid in pc_0, save the result to global memory
75- atomicExch (&result[tid], best);
76- atomicExch (&result_i[tid], best_i);
77+
78+ // Only valid threads write results
79+ if (tid < pc0_n) {
80+ result[tid] = best;
81+ result_i[tid] = best_i;
82+ }
7783}
7884
7985int chamfer_cuda_forward (const at::Tensor &pc0, const at::Tensor &pc1, at::Tensor &dist0, at::Tensor &dist1, at::Tensor &idx0, at::Tensor &idx1)
0 commit comments