diff --git a/docs/changes/27.maintenance.rst b/docs/changes/27.maintenance.rst new file mode 100644 index 0000000..7adabb6 --- /dev/null +++ b/docs/changes/27.maintenance.rst @@ -0,0 +1 @@ +Removed cupy_to_torch and torch_to_cupy functions and use `array_namespace.from_dlpack` instead to convert arrays diff --git a/src/radioft/finufft/finufft.py b/src/radioft/finufft/finufft.py index 7f827f9..35829a4 100644 --- a/src/radioft/finufft/finufft.py +++ b/src/radioft/finufft/finufft.py @@ -14,16 +14,6 @@ } -def cupy_to_torch(x_cp): - """CuPy array -> Torch tensor on GPU (Zero-Copy via DLPack)""" - return torch.utils.dlpack.from_dlpack(x_cp.toDlpack()) - - -def torch_to_cupy(x_torch): - """Torch tensor -> CuPy array on GPU (Zero-Copy via DLPack)""" - return cp.fromDlpack(torch.utils.dlpack.to_dlpack(x_torch)) - - class CupyFinufft: """Wraper to use Finufft Type 3d3 for radio interferometry data.""" @@ -132,18 +122,18 @@ def nufft( and non-uniform target coordinates. """ # Sky coordinates (Image domain - lmn coordinates) - source_l = torch_to_cupy(l_coords / self.px_size).astype(cp.float64) - source_m = torch_to_cupy(m_coords / self.px_size).astype(cp.float64) - source_n = torch_to_cupy((n_coords - 1) / self.px_size).astype(cp.float64) + source_l = cp.from_dlpack(l_coords / self.px_size).astype(cp.float64) + source_m = cp.from_dlpack(m_coords / self.px_size).astype(cp.float64) + source_n = cp.from_dlpack((n_coords - 1) / self.px_size).astype(cp.float64) # Antenna coordinates (Fourier Domain - uvw coordinates) - target_u = torch_to_cupy(2 * pi * (u_coords.flatten() * self.px_size)).astype( + target_u = cp.from_dlpack(2 * pi * (u_coords.flatten() * self.px_size)).astype( cp.float64 ) - target_v = torch_to_cupy(2 * pi * (v_coords.flatten() * self.px_size)).astype( + target_v = cp.from_dlpack(2 * pi * (v_coords.flatten() * self.px_size)).astype( cp.float64 ) - target_w = torch_to_cupy(2 * pi * (w_coords.flatten() * self.px_size)).astype( + target_w = cp.from_dlpack(2 * pi * (w_coords.flatten() * self.px_size)).astype( cp.float64 ) @@ -156,14 +146,15 @@ def nufft( ) coord_outside = cp.where(cp.any(outside_bounds, axis=1))[0] if outside_bounds.any(): - warnings.warning( + warnings.warn( f"Some of the {', '.join(itemgetter(*coord_outside.get())(uvw_map))} " "coordinates lie outside the constructed image. This can lead to " - "cufinufft errors." + "cufinufft errors.", + stacklevel=2, ) # Values at source position (Source intensities) - c_values = torch_to_cupy(sky_values.flatten()).astype(cp.complex128) + c_values = cp.from_dlpack(sky_values.flatten()).astype(cp.complex128) result = self.ft( source_l, @@ -175,7 +166,7 @@ def nufft( target_w, ) - visibilities = cupy_to_torch(result) + visibilities = torch.from_dlpack(result) return visibilities @@ -193,13 +184,13 @@ def inufft( and non-uniform target coordinates. """ # Antenna coordinates (Fourier Domain - uvw coordinates) - source_u = torch_to_cupy(2 * pi * (u_coords.flatten() * self.px_size)).astype( + source_u = cp.from_dlpack(2 * pi * (u_coords.flatten() * self.px_size)).astype( cp.float64 ) - source_v = torch_to_cupy(2 * pi * (v_coords.flatten() * self.px_size)).astype( + source_v = cp.from_dlpack(2 * pi * (v_coords.flatten() * self.px_size)).astype( cp.float64 ) - source_w = torch_to_cupy(2 * pi * (w_coords.flatten() * self.px_size)).astype( + source_w = cp.from_dlpack(2 * pi * (w_coords.flatten() * self.px_size)).astype( cp.float64 ) @@ -222,23 +213,24 @@ def inufft( coord_outside = cp.where(cp.any(outside_bounds, axis=1))[0] if outside_bounds.any(): - warnings.warning( + warnings.warn( f"Some of the {', '.join(itemgetter(*coord_outside.get())(uvw_map))} " "coordinates lie outside the constructed image. This can lead to " - "cufinufft errors." + "cufinufft errors.", + stacklevel=2, ) # Fourier coeficients at antenna positions (Visibilities) - c_values = torch_to_cupy(visibilities.flatten()).astype(cp.complex128) + c_values = cp.from_dlpack(visibilities.flatten()).astype(cp.complex128) # Normalize visibility values by dividing by their bin counts # This means visibilities that fall into the same bin are averaged c_values_normalized = c_values / visibility_weights # Sky coordinates (Image domain - lmn coordinates) - target_l = torch_to_cupy(l_coords.flatten() / self.px_size).astype(cp.float64) - target_m = torch_to_cupy(m_coords.flatten() / self.px_size).astype(cp.float64) - target_n = torch_to_cupy((n_coords.flatten() - 1) / self.px_size).astype( + target_l = cp.from_dlpack(l_coords.flatten() / self.px_size).astype(cp.float64) + target_m = cp.from_dlpack(m_coords.flatten() / self.px_size).astype(cp.float64) + target_n = cp.from_dlpack((n_coords.flatten() - 1) / self.px_size).astype( cp.float64 ) @@ -255,6 +247,6 @@ def inufft( / self.px_scaling ) - sky_intensities = cupy_to_torch(result) + sky_intensities = torch.from_dlpack(result) return sky_intensities