Skip to content

Commit 9a154a4

Browse files
Kin-Zhangyanconglin
andcommitted
hotfix(chamfer): fix #22; prevent deadlock from early thread exit before __syncthreads()
--------- Co-authored-by: yanconglin <yanconglin@users.noreply.github.com>
1 parent 650383b commit 9a154a4

File tree

3 files changed

+34
-29
lines changed

3 files changed

+34
-29
lines changed

assets/cuda/chamfer3D/chamfer3D.cu

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,22 @@
3232

3333
__global__ void NmDistanceKernel(const int pc0_n, const float *pc0_xyz, const int pc1_n, const float *pc1_xyz, float *result, int *result_i){
3434
int tid = blockIdx.x * blockDim.x + threadIdx.x;
35-
36-
if (tid >= pc0_n) return;
37-
38-
float x0 = pc0_xyz[tid * 3 + 0];
39-
float y0 = pc0_xyz[tid * 3 + 1];
40-
float z0 = pc0_xyz[tid * 3 + 2];
41-
42-
__shared__ float shared_pc1[THREADS_PER_BLOCK * 3];
4335

36+
// Load point coordinates for valid threads
37+
float x0, y0, z0;
4438
int best_i = -1;
4539
float best = 1e20;
40+
41+
if (tid < pc0_n) {
42+
x0 = pc0_xyz[tid * 3 + 0];
43+
y0 = pc0_xyz[tid * 3 + 1];
44+
z0 = pc0_xyz[tid * 3 + 2];
45+
}
46+
47+
__shared__ float shared_pc1[THREADS_PER_BLOCK * 3];
4648

4749
for (int i = 0; i < pc1_n; i += THREADS_PER_BLOCK) {
48-
// Copy a block of pc1 to shared memory
50+
// All threads cooperate to load shared memory
4951
int pc1_idx = i + threadIdx.x;
5052
if (pc1_idx < pc1_n) {
5153
shared_pc1[threadIdx.x * 3 + 0] = pc1_xyz[pc1_idx * 3 + 0];
@@ -54,26 +56,30 @@ __global__ void NmDistanceKernel(const int pc0_n, const float *pc0_xyz, const in
5456
}
5557

5658
__syncthreads();
57-
58-
// Compute the distance between pc0[tid] and the points in shared_pc1
59-
int num_elems = min(THREADS_PER_BLOCK, pc1_n - i);
60-
for (int j = 0; j < num_elems; j++) {
61-
float x1 = shared_pc1[j * 3 + 0];
62-
float y1 = shared_pc1[j * 3 + 1];
63-
float z1 = shared_pc1[j * 3 + 2];
64-
float d = (x1 - x0) * (x1 - x0) + (y1 - y0) * (y1 - y0) + (z1 - z0) * (z1 - z0);
65-
if (d < best) {
66-
best = d;
67-
best_i = j + i;
59+
60+
// Only valid threads compute distances
61+
if (tid < pc0_n) {
62+
int num_elems = min(THREADS_PER_BLOCK, pc1_n - i);
63+
for (int j = 0; j < num_elems; j++) {
64+
float x1 = shared_pc1[j * 3 + 0];
65+
float y1 = shared_pc1[j * 3 + 1];
66+
float z1 = shared_pc1[j * 3 + 2];
67+
float d = (x1 - x0) * (x1 - x0) + (y1 - y0) * (y1 - y0) + (z1 - z0) * (z1 - z0);
68+
if (d < best) {
69+
best = d;
70+
best_i = j + i;
71+
}
6872
}
6973
}
7074

7175
__syncthreads();
7276
}
73-
74-
// done with this thread in tid in pc_0, save the result to global memory
75-
atomicExch(&result[tid], best);
76-
atomicExch(&result_i[tid], best_i);
77+
78+
// Only valid threads write results
79+
if (tid < pc0_n) {
80+
result[tid] = best;
81+
result_i[tid] = best_i;
82+
}
7783
}
7884

7985
int chamfer_cuda_forward(const at::Tensor &pc0, const at::Tensor &pc1, at::Tensor &dist0, at::Tensor &dist1, at::Tensor &idx0, at::Tensor &idx1)

assets/cuda/chamfer3D/setup.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
33

44
extra_compile_args = {
5-
'cxx': ['-DCCCL_IGNORE_DEPRECATED_CUDA_BELOW_12'],
6-
'nvcc': ['-DCCCL_IGNORE_DEPRECATED_CUDA_BELOW_12'],
5+
'cxx': ['-DCCCL_IGNORE_DEPRECATED_CUDA_BELOW_12', '-DTHRUST_IGNORE_CUB_VERSION_CHECK'],
6+
'nvcc': ['-DCCCL_IGNORE_DEPRECATED_CUDA_BELOW_12', '-DTHRUST_IGNORE_CUB_VERSION_CHECK'],
77
}
8-
98
setup(
109
name='chamfer3D',
1110
ext_modules=[
@@ -21,5 +20,5 @@
2120
cmdclass={
2221
'build_ext': BuildExtension
2322
},
24-
version='1.0.2'
23+
version='1.0.5'
2524
)

assets/tests/chamferdis_speed_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def faiss_knn(pc1, pc2):
7676
start_time = time.time()
7777
loss0, _ = my_chamfer_fn(pc0.unsqueeze(0), pc1.unsqueeze(0), truncate_dist=False)
7878

79-
print(f"Pytorch3d Chamfer Distance Cal time: {(time.time() - start_time)*1000:.3f} ms")
8079
print("loss: ", loss0)
80+
print(f"Pytorch3d Chamfer Distance Cal time: {(time.time() - start_time)*1000:.3f} ms")
8181
print()
8282

8383
if MMCV_TEST:

0 commit comments

Comments
 (0)