Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changes/27.maintenance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Removed cupy_to_torch and torch_to_cupy functions and use `array_namespace.from_dlpack` instead to convert arrays
52 changes: 22 additions & 30 deletions src/radioft/finufft/finufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,6 @@
}


def cupy_to_torch(x_cp):
"""CuPy array -> Torch tensor on GPU (Zero-Copy via DLPack)"""
return torch.utils.dlpack.from_dlpack(x_cp.toDlpack())


def torch_to_cupy(x_torch):
"""Torch tensor -> CuPy array on GPU (Zero-Copy via DLPack)"""
return cp.fromDlpack(torch.utils.dlpack.to_dlpack(x_torch))


class CupyFinufft:
"""Wraper to use Finufft Type 3d3 for radio interferometry data."""

Expand Down Expand Up @@ -132,18 +122,18 @@ def nufft(
and non-uniform target coordinates.
"""
# Sky coordinates (Image domain - lmn coordinates)
source_l = torch_to_cupy(l_coords / self.px_size).astype(cp.float64)
source_m = torch_to_cupy(m_coords / self.px_size).astype(cp.float64)
source_n = torch_to_cupy((n_coords - 1) / self.px_size).astype(cp.float64)
source_l = cp.from_dlpack(l_coords / self.px_size).astype(cp.float64)
source_m = cp.from_dlpack(m_coords / self.px_size).astype(cp.float64)
source_n = cp.from_dlpack((n_coords - 1) / self.px_size).astype(cp.float64)

# Antenna coordinates (Fourier Domain - uvw coordinates)
target_u = torch_to_cupy(2 * pi * (u_coords.flatten() * self.px_size)).astype(
target_u = cp.from_dlpack(2 * pi * (u_coords.flatten() * self.px_size)).astype(
cp.float64
)
target_v = torch_to_cupy(2 * pi * (v_coords.flatten() * self.px_size)).astype(
target_v = cp.from_dlpack(2 * pi * (v_coords.flatten() * self.px_size)).astype(
cp.float64
)
target_w = torch_to_cupy(2 * pi * (w_coords.flatten() * self.px_size)).astype(
target_w = cp.from_dlpack(2 * pi * (w_coords.flatten() * self.px_size)).astype(
cp.float64
)

Expand All @@ -156,14 +146,15 @@ def nufft(
)
coord_outside = cp.where(cp.any(outside_bounds, axis=1))[0]
if outside_bounds.any():
warnings.warning(
warnings.warn(
f"Some of the {', '.join(itemgetter(*coord_outside.get())(uvw_map))} "
"coordinates lie outside the constructed image. This can lead to "
"cufinufft errors."
"cufinufft errors.",
stacklevel=2,
)

# Values at source position (Source intensities)
c_values = torch_to_cupy(sky_values.flatten()).astype(cp.complex128)
c_values = cp.from_dlpack(sky_values.flatten()).astype(cp.complex128)

result = self.ft(
source_l,
Expand All @@ -175,7 +166,7 @@ def nufft(
target_w,
)

visibilities = cupy_to_torch(result)
visibilities = torch.from_dlpack(result)

return visibilities

Expand All @@ -193,13 +184,13 @@ def inufft(
and non-uniform target coordinates.
"""
# Antenna coordinates (Fourier Domain - uvw coordinates)
source_u = torch_to_cupy(2 * pi * (u_coords.flatten() * self.px_size)).astype(
source_u = cp.from_dlpack(2 * pi * (u_coords.flatten() * self.px_size)).astype(
cp.float64
)
source_v = torch_to_cupy(2 * pi * (v_coords.flatten() * self.px_size)).astype(
source_v = cp.from_dlpack(2 * pi * (v_coords.flatten() * self.px_size)).astype(
cp.float64
)
source_w = torch_to_cupy(2 * pi * (w_coords.flatten() * self.px_size)).astype(
source_w = cp.from_dlpack(2 * pi * (w_coords.flatten() * self.px_size)).astype(
cp.float64
)

Expand All @@ -222,23 +213,24 @@ def inufft(
coord_outside = cp.where(cp.any(outside_bounds, axis=1))[0]

if outside_bounds.any():
warnings.warning(
warnings.warn(
f"Some of the {', '.join(itemgetter(*coord_outside.get())(uvw_map))} "
"coordinates lie outside the constructed image. This can lead to "
"cufinufft errors."
"cufinufft errors.",
stacklevel=2,
)

# Fourier coeficients at antenna positions (Visibilities)
c_values = torch_to_cupy(visibilities.flatten()).astype(cp.complex128)
c_values = cp.from_dlpack(visibilities.flatten()).astype(cp.complex128)

# Normalize visibility values by dividing by their bin counts
# This means visibilities that fall into the same bin are averaged
c_values_normalized = c_values / visibility_weights

# Sky coordinates (Image domain - lmn coordinates)
target_l = torch_to_cupy(l_coords.flatten() / self.px_size).astype(cp.float64)
target_m = torch_to_cupy(m_coords.flatten() / self.px_size).astype(cp.float64)
target_n = torch_to_cupy((n_coords.flatten() - 1) / self.px_size).astype(
target_l = cp.from_dlpack(l_coords.flatten() / self.px_size).astype(cp.float64)
target_m = cp.from_dlpack(m_coords.flatten() / self.px_size).astype(cp.float64)
target_n = cp.from_dlpack((n_coords.flatten() - 1) / self.px_size).astype(
cp.float64
)

Expand All @@ -255,6 +247,6 @@ def inufft(
/ self.px_scaling
)

sky_intensities = cupy_to_torch(result)
sky_intensities = torch.from_dlpack(result)

return sky_intensities