Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions deeplens/geolens.py
Original file line number Diff line number Diff line change
Expand Up @@ -1611,10 +1611,27 @@ def calc_foclen(self):
# Trace a paraxial chief ray, shape [1, 1, num_rays, 3]
paraxial_fov = 0.01
paraxial_fov_deg = float(np.rad2deg(paraxial_fov))

# 1. Trace on-axis parallel rays to find paraxial focus z (equivalent to infinite focus)
ray_axis = self.sample_parallel(
fov_x=0.0, fov_y=0.0, entrance_pupil=False, scale_pupil=0.2
)
ray_axis, _ = self.trace(ray_axis)
valid_axis = ray_axis.is_valid > 0
t = -(ray_axis.d[valid_axis, 0] * ray_axis.o[valid_axis, 0]
+ ray_axis.d[valid_axis, 1] * ray_axis.o[valid_axis, 1]) / (
ray_axis.d[valid_axis, 0] ** 2 + ray_axis.d[valid_axis, 1] ** 2
)
focus_z = ray_axis.o[valid_axis, 2] + t * ray_axis.d[valid_axis, 2]
focus_z = focus_z[~torch.isnan(focus_z) & (focus_z > 0)]
paraxial_focus_z = float(torch.mean(focus_z))

# 2. Trace off-axis paraxial ray to paraxial focus, measure image height
ray = self.sample_parallel(
fov_x=0.0, fov_y=paraxial_fov_deg, entrance_pupil=False, scale_pupil=0.2
)
ray = self.trace2sensor(ray)
ray, _ = self.trace(ray)
ray = ray.prop_to(paraxial_focus_z)

# Compute the effective focal length
paraxial_imgh = (ray.o[:, 1] * ray.is_valid).sum() / ray.is_valid.sum()
Expand All @@ -1625,6 +1642,8 @@ def calc_foclen(self):
# Compute the back focal length
self.bfl = self.d_sensor.item() - self.surfaces[-1].d.item()

return eff_foclen

@torch.no_grad()
def calc_numerical_aperture(self, n=1.0):
"""Compute numerical aperture (NA).
Expand All @@ -1638,7 +1657,6 @@ def calc_numerical_aperture(self, n=1.0):
Reference:
[1] https://en.wikipedia.org/wiki/Numerical_aperture
"""
breakpoint()
return n * math.sin(math.atan(1 / 2 / self.fnum))
# return n / (2 * self.fnum)

Expand Down
62 changes: 56 additions & 6 deletions deeplens/lens.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,58 @@ def render_psf_map(self, img_obj, depth=DEPTH, psf_grid=7, psf_ks=PSF_KS):
# -------------------------------------------
# Simulate 3D scene
# -------------------------------------------
def _sample_depth_layers(self, depth_min, depth_max, num_depth):
"""Sample depth layers centered on the focal plane in disparity space.

If the lens has a `calc_focal_plane` method, samples are split around the
focal plane so that it is always an explicit sample point. Otherwise falls
back to uniform disparity sampling.

Args:
depth_min (float): Minimum (nearest) depth in mm (positive).
depth_max (float): Maximum (farthest) depth in mm (positive).
num_depth (int): Number of depth layers to sample.

Returns:
tuple: (disp_ref, depths_ref) where disp_ref has shape (num_depth,) in
disparity space and depths_ref = -1/disp_ref (negative, for PSF).
"""
# Try to get focal depth from the lens
if hasattr(self, 'calc_focal_plane'):
focal_depth = abs(self.calc_focal_plane()) # positive mm
else:
focal_depth = None

if focal_depth is not None:
# Extend range to include the focal depth
depth_min_ext = min(float(depth_min), focal_depth)
depth_max_ext = max(float(depth_max), focal_depth)

disp_near = 1.0 / depth_min_ext # large disparity = near
disp_far = 1.0 / depth_max_ext # small disparity = far
focal_disp = 1.0 / focal_depth

# Allocate samples proportionally to range on each side
near_range = disp_near - focal_disp
far_range = focal_disp - disp_far
total_range = near_range + far_range

if total_range < 1e-10:
disp_ref = torch.full((num_depth,), focal_disp).to(self.device)
else:
n_far = max(1, round((num_depth - 1) * far_range / total_range))
n_near = num_depth - 1 - n_far

far_disps = torch.linspace(disp_far, focal_disp, n_far + 1) # includes focal
near_disps = torch.linspace(focal_disp, disp_near, n_near + 1)[1:] # exclude duplicate focal
Comment on lines +549 to +550
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tensors created on lines 549-550 (far_disps and near_disps) are not explicitly moved to the correct device before concatenation. While line 551 does call .to(self.device) on the concatenated result, it would be more efficient and clearer to create the tensors on the correct device initially. Consider using torch.linspace(...).to(self.device) on lines 549-550, or pass a device parameter to linspace if supported.

Suggested change
far_disps = torch.linspace(disp_far, focal_disp, n_far + 1) # includes focal
near_disps = torch.linspace(focal_disp, disp_near, n_near + 1)[1:] # exclude duplicate focal
far_disps = torch.linspace(disp_far, focal_disp, n_far + 1, device=self.device) # includes focal
near_disps = torch.linspace(focal_disp, disp_near, n_near + 1, device=self.device)[1:] # exclude duplicate focal

Copilot uses AI. Check for mistakes.
disp_ref = torch.cat([far_disps, near_disps]).to(self.device)
Comment on lines +544 to +551
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When num_depth is 1, the allocation logic at lines 546-547 will result in n_far=1 and n_near=0. Line 549 will create a single-element tensor with focal_disp, and line 550 will create an empty tensor (linspace with 1 element, then [1:] slices it away). This will result in disp_ref having only 1 element, which is correct, but the logic could be clearer. Consider handling the num_depth==1 case explicitly for clarity.

Suggested change
disp_ref = torch.full((num_depth,), focal_disp).to(self.device)
else:
n_far = max(1, round((num_depth - 1) * far_range / total_range))
n_near = num_depth - 1 - n_far
far_disps = torch.linspace(disp_far, focal_disp, n_far + 1) # includes focal
near_disps = torch.linspace(focal_disp, disp_near, n_near + 1)[1:] # exclude duplicate focal
disp_ref = torch.cat([far_disps, near_disps]).to(self.device)
disp_ref = torch.full((num_depth,), focal_disp, dtype=self.dtype, device=self.device)
else:
# Special case: with only one depth sample, place it at the focal plane.
if num_depth == 1:
disp_ref = torch.full((1,), focal_disp, dtype=self.dtype, device=self.device)
else:
n_far = max(1, round((num_depth - 1) * far_range / total_range))
n_near = num_depth - 1 - n_far
far_disps = torch.linspace(disp_far, focal_disp, n_far + 1) # includes focal
near_disps = torch.linspace(focal_disp, disp_near, n_near + 1)[1:] # exclude duplicate focal
disp_ref = torch.cat([far_disps, near_disps]).to(self.device)

Copilot uses AI. Check for mistakes.
else:
# Fallback: uniform disparity sampling
disp_ref = torch.linspace(1.0 / float(depth_max), 1.0 / float(depth_min), num_depth).to(self.device)

depths_ref = -1.0 / disp_ref
return disp_ref, depths_ref

def render_rgbd(self, img_obj, depth_map, method="psf_patch", **kwargs):
"""Render RGBD image.

Expand Down Expand Up @@ -538,8 +590,7 @@ def render_rgbd(self, img_obj, depth_map, method="psf_patch", **kwargs):
interp_mode = kwargs.get("interp_mode", "disparity")

# Calculate PSF at different depths, (num_depth, 3, ks, ks)
disp_ref = torch.linspace(1.0/depth_max, 1.0/depth_min, num_depth).to(self.device)
depths_ref = - 1.0 / disp_ref # Convert to negative for PSF calculation
disp_ref, depths_ref = self._sample_depth_layers(depth_min, depth_max, num_depth)

points = torch.stack(
[
Expand All @@ -552,7 +603,7 @@ def render_rgbd(self, img_obj, depth_map, method="psf_patch", **kwargs):
psfs = self.psf_rgb(points=points, ks=psf_ks) # (num_depth, 3, ks, ks)

# Image simulation
img_render = conv_psf_depth_interp(img_obj, depth_map, psfs, depths_ref, interp_mode=interp_mode)
img_render = conv_psf_depth_interp(img_obj, -depth_map, psfs, depths_ref, interp_mode=interp_mode)
return img_render

elif method == "psf_map":
Expand All @@ -565,8 +616,7 @@ def render_rgbd(self, img_obj, depth_map, method="psf_patch", **kwargs):
interp_mode = kwargs.get("interp_mode", "disparity")

# Calculate PSF map at different depths (convert to negative for PSF calculation)
disp_ref = torch.linspace(1.0/depth_max, 1.0/depth_min, num_depth).to(self.device)
depths_ref = -1.0 / disp_ref # Convert to negative for PSF calculation
disp_ref, depths_ref = self._sample_depth_layers(depth_min, depth_max, num_depth)

psf_maps = []
for depth in tqdm(depths_ref):
Expand All @@ -578,7 +628,7 @@ def render_rgbd(self, img_obj, depth_map, method="psf_patch", **kwargs):

# Image simulation
img_render = conv_psf_map_depth_interp(
img_obj, depth_map, psf_map, depths_ref, interp_mode=interp_mode
img_obj, -depth_map, psf_map, depths_ref, interp_mode=interp_mode
)
return img_render

Expand Down
137 changes: 111 additions & 26 deletions deeplens/optics/psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def conv_psf_map(img, psf_map):
pad_w_right = ks // 2
img_pad = F.pad(img, (pad_w_left, pad_w_right, pad_h_left, pad_h_right), mode="reflect")

# Pre-flip entire PSF map once (instead of flipping each PSF inside the loop)
psf_map_flipped = torch.flip(psf_map, dims=(-2, -1))

# Render image patch by patch
img_render = torch.zeros_like(img)
for i in range(grid_h):
Expand All @@ -99,7 +102,7 @@ def conv_psf_map(img, psf_map):
w_high = ((j + 1) * W) // grid_w

# PSF, [C, 1, ks, ks]
psf = torch.flip(psf_map[i, j], dims=(-2, -1)).unsqueeze(1)
psf = psf_map_flipped[i, j].unsqueeze(1)

# Consider overlap to avoid boundary artifacts
img_pad_patch = img_pad[
Expand All @@ -121,39 +124,114 @@ def conv_psf_map_depth_interp(img, depth, psf_map, psf_depths, interp_mode="dept

Args:
img: (B, 3, H, W), [0, 1]
depth: (B, 1, H, W), [0, 1]
depth: (B, 1, H, W), (-inf, 0)
psf_map: (grid_h, grid_w, num_depth, 3, ks, ks)
psf_depths: (num_depth). Used to interpolate psf_map.
psf_depths: (num_depth). (-inf, 0). Used to interpolate psf_map.
interp_mode: "depth" or "disparity". If "disparity", weights are calculated based on disparity (1/depth).

Returns:
img_render: (B, 3, H, W), [0, 1]
"""
assert interp_mode in ["depth", "disparity"], f"interp_mode must be 'depth' or 'disparity', got {interp_mode}"
assert depth.min() < 0 and depth.max() < 0, f"depth must be negative, got {depth.min()} and {depth.max()}"
assert psf_depths.min() < 0 and psf_depths.max() < 0, f"psf_depths must be negative, got {psf_depths.min()} and {psf_depths.max()}"
Comment on lines +136 to +137
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new assertions requiring negative depth values will break existing tests. The tests in test/test_psf.py (lines 187-213) use positive depth values (e.g., torch.linspace(0, 1, 5) at line 195, torch.linspace(1.0, 2.0, 5) at line 208). These tests will fail with the new assertions. Consider either updating the tests to use negative depth values or relaxing the assertions to allow both positive and negative depths with appropriate internal conversion.

Suggested change
assert depth.min() < 0 and depth.max() < 0, f"depth must be negative, got {depth.min()} and {depth.max()}"
assert psf_depths.min() < 0 and psf_depths.max() < 0, f"psf_depths must be negative, got {psf_depths.min()} and {psf_depths.max()}"
depth_min = depth.min().item()
depth_max = depth.max().item()
psf_depths_min = psf_depths.min().item()
psf_depths_max = psf_depths.max().item()
assert (depth_min >= 0 and depth_max >= 0) or (depth_min <= 0 and depth_max <= 0), (
f"depth values must be consistently signed (all <= 0 or all >= 0), got min={depth_min}, max={depth_max}"
)
assert (psf_depths_min >= 0 and psf_depths_max >= 0) or (psf_depths_min <= 0 and psf_depths_max <= 0), (
f"psf_depths values must be consistently signed (all <= 0 or all >= 0), "
f"got min={psf_depths_min}, max={psf_depths_max}"
)

Copilot uses AI. Check for mistakes.

B, C, H, W = img.shape
grid_h, grid_w, num_depths, C_psf, ks, _ = psf_map.shape
assert C_psf == C, f"PSF map channels ({C_psf}) must match image channels ({C})."

# Render image patch by patch, reusing conv_psf_depth_interp for each patch

# Pad the full image once to avoid boundary artifacts at patch seams.
# Without this, each patch would be padded independently (reflecting within
# its own boundary), producing visible seams at grid boundaries.
pad_h_left = (ks - 1) // 2
pad_h_right = ks // 2
pad_w_left = (ks - 1) // 2
pad_w_right = ks // 2
img_pad = F.pad(img, (pad_w_left, pad_w_right, pad_h_left, pad_h_right), mode="reflect")

# Pre-flip entire PSF map once: [grid_h, grid_w, num_depths, C, ks, ks]
psf_map_flipped = torch.flip(psf_map, dims=(-2, -1))

# Pre-compute depth interpolation weights (shared across all patches)
depth_flat = depth.flatten(1) # [B, H*W]
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The clamping logic assumes psf_depths is sorted in ascending order (most negative to least negative). However, the function doesn't validate this assumption. If psf_depths is passed in descending order or unsorted, the clamping and searchsorted operations will produce incorrect results. Consider adding an assertion to verify that psf_depths is sorted, or use min() and max() instead of indexing to be order-agnostic.

Suggested change
depth_flat = depth.flatten(1) # [B, H*W]
depth_flat = depth.flatten(1) # [B, H*W]
# Validate psf_depths ordering: required for clamp and searchsorted to behave correctly.
if psf_depths.ndim != 1:
raise ValueError("psf_depths must be a 1D tensor of sorted depths.")
if not torch.all(psf_depths[1:] >= psf_depths[:-1]):
raise ValueError(
"psf_depths must be sorted in ascending order (most negative to least negative)."
)

Copilot uses AI. Check for mistakes.
depth_flat = depth_flat.clamp(psf_depths[0] + DELTA, psf_depths[-1] - DELTA)
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The depth clamping at line 157 uses DELTA (1e-6) offset from the psf_depths boundaries. However, if psf_depths values are very close together (e.g., closer than 2*DELTA), this clamping could potentially clamp depth values outside the actual range of psf_depths. While unlikely in practice, consider validating that psf_depths has adequate spacing or using a relative epsilon instead of a fixed DELTA.

Suggested change
depth_flat = depth_flat.clamp(psf_depths[0] + DELTA, psf_depths[-1] - DELTA)
# Use a spacing-aware epsilon for clamping instead of a fixed DELTA.
# This avoids pushing values outside the actual sampled psf_depths range
# when depths are very closely spaced.
if num_depths > 1:
spacings = psf_depths[1:] - psf_depths[:-1]
min_spacing = torch.min(spacings)
eps = 0.5 * torch.clamp(min_spacing, min=0.0)
else:
eps = torch.zeros((), dtype=psf_depths.dtype, device=psf_depths.device)
depth_flat = depth_flat.clamp(psf_depths[0] + eps, psf_depths[-1] - eps)

Copilot uses AI. Check for mistakes.
indices = torch.searchsorted(psf_depths, depth_flat, right=True) # [B, H*W]
indices = indices.clamp(1, num_depths - 1)
idx0 = indices - 1
idx1 = indices

d0 = psf_depths[idx0] # [B, H*W]
d1 = psf_depths[idx1]

if interp_mode == "depth":
denom = d1 - d0
denom[denom == 0] = 1e-6
w1 = (depth_flat - d0) / denom
else:
disp_flat = 1.0 / depth_flat
disp0 = 1.0 / d0
disp1 = 1.0 / d1
denom = disp1 - disp0
denom[denom == 0] = 1e-6
w1 = (disp_flat - disp0) / denom

w0 = 1 - w1

# Reshape weight indices to spatial layout for patch extraction
idx0_spatial = idx0.view(B, H, W)
idx1_spatial = idx1.view(B, H, W)
w0_spatial = w0.view(B, H, W)
w1_spatial = w1.view(B, H, W)

# Render image patch by patch
img_render = torch.zeros_like(img)
for i in range(grid_h):
h_low = (i * H) // grid_h
h_high = ((i + 1) * H) // grid_h
patch_h = h_high - h_low

for j in range(grid_w):
for j in range(grid_w):
w_low = (j * W) // grid_w
w_high = ((j + 1) * W) // grid_w
patch_w = w_high - w_low

# Extract overlapping patch from pre-padded image (no per-patch padding needed)
img_pad_patch = img_pad[
:, :,
h_low : h_high + pad_h_left + pad_h_right,
w_low : w_high + pad_w_left + pad_w_right,
]

# Expand patch for all depths: [B, C, patch_h+pad, patch_w+pad] -> [B, num_depths*C, ...]
img_patch_expanded = img_pad_patch.unsqueeze(1).expand(B, num_depths, C, -1, -1).reshape(
B, num_depths * C, img_pad_patch.shape[2], img_pad_patch.shape[3]
)

# PSF kernels for this grid cell: [num_depths*C, 1, ks, ks]
psf_stacked = psf_map_flipped[i, j].reshape(num_depths * C, 1, ks, ks)

# Grouped convolution -> [B, num_depths*C, patch_h, patch_w]
patch_blur = F.conv2d(img_patch_expanded, psf_stacked, groups=num_depths * C)

# Reshape to [num_depths, B, C, patch_h, patch_w]
patch_blur = patch_blur.reshape(B, num_depths, C, patch_h, patch_w).permute(1, 0, 2, 3, 4)

# Extract pre-computed weights for this patch
patch_idx0 = idx0_spatial[:, h_low:h_high, w_low:w_high].reshape(B, patch_h * patch_w)
patch_idx1 = idx1_spatial[:, h_low:h_high, w_low:w_high].reshape(B, patch_h * patch_w)
patch_w0 = w0_spatial[:, h_low:h_high, w_low:w_high].reshape(B, patch_h * patch_w)
patch_w1 = w1_spatial[:, h_low:h_high, w_low:w_high].reshape(B, patch_h * patch_w)

# Build per-depth weight tensor for this patch
weights = torch.zeros(num_depths, B, patch_h * patch_w, device=img.device, dtype=img.dtype)
weights.scatter_add_(0, patch_idx0.unsqueeze(0).long(), patch_w0.unsqueeze(0))
weights.scatter_add_(0, patch_idx1.unsqueeze(0).long(), patch_w1.unsqueeze(0))
weights = weights.view(num_depths, B, 1, patch_h, patch_w)

# Extract image and depth patches
img_patch = img[:, :, h_low:h_high, w_low:w_high]
depth_patch = depth[:, :, h_low:h_high, w_low:w_high]

# Extract PSF kernels for this patch at all depths
psf_kernels = psf_map[i, j, :, :, :, :] # [num_depths, C, ks, ks]

# Use conv_psf_depth_interp for this patch
render_patch = conv_psf_depth_interp(img_patch, depth_patch, psf_kernels, psf_depths, interp_mode)
# Apply depth-interpolation weights
render_patch = torch.sum(patch_blur * weights, dim=0)
img_render[:, :, h_low:h_high, w_low:w_high] = render_patch

return img_render


Expand All @@ -164,15 +242,17 @@ def conv_psf_depth_interp(img, depth, psf_kernels, psf_depths, interp_mode="dept

Args:
img: (B, 3, H, W), [0, 1]
depth: (B, 1, H, W), [0, 1]
depth: (B, 1, H, W), (-inf, 0)
psf_kernels: (num_depth, 3, ks, ks)
psf_depths: (num_depth). Used to interpolate psf_kernels.
psf_depths: (num_depth). (-inf, 0). Used to interpolate psf_kernels.
interp_mode: "depth" or "disparity". If "disparity", weights are calculated based on disparity (1/depth).

Returns:
img_blur: (B, 3, H, W), [0, 1]
"""
assert interp_mode in ["depth", "disparity"], f"interp_mode must be 'depth' or 'disparity', got {interp_mode}"
assert depth.min() < 0 and depth.max() < 0, f"depth must be negative, got {depth.min()} and {depth.max()}"
assert psf_depths.min() < 0 and psf_depths.max() < 0, f"psf_depths must be negative, got {psf_depths.min()} and {psf_depths.max()}"
Comment on lines +254 to +255
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new assertions requiring negative depth values will break existing tests. The tests in test/test_psf.py (lines 128-181) use positive depth values (e.g., torch.linspace(0, 1, 5) at line 138, torch.linspace(1.0, 2.0, 5) at line 166). These tests will fail with the new assertions. Consider either updating the tests to use negative depth values or relaxing the assertions to allow both positive and negative depths with appropriate internal conversion.

Copilot uses AI. Check for mistakes.

# assert img.device != torch.device("cpu"), "Image must be on GPU"
num_depths, _, ks, _ = psf_kernels.shape
Expand All @@ -182,19 +262,20 @@ def conv_psf_depth_interp(img, depth, psf_kernels, psf_depths, interp_mode="dept
# =================================
B, C, H, W = img.shape

# Expand img: [B, C, H, W] -> [B, num_depths, C, H, W] -> [B, num_depths*C, H, W]
img_expanded = img.unsqueeze(1).expand(B, num_depths, C, H, W).reshape(B, num_depths * C, H, W)

# Prepare PSF kernel: [num_depths, C, ks, ks] -> [num_depths*C, 1, ks, ks]
# Flip the PSF because F.conv2d uses cross-correlation
psf_stacked = torch.flip(psf_kernels, [-2, -1]).reshape(num_depths * C, 1, ks, ks)

# Padding (following conv_psf logic for even/odd kernel sizes)

# Pad before expand: pad [B, C, H, W] first (C channels), then expand to num_depths*C
# This reduces padding work by a factor of num_depths
pad_h_left = (ks - 1) // 2
pad_h_right = ks // 2
pad_w_left = (ks - 1) // 2
pad_w_right = ks // 2
img_padded = F.pad(img_expanded, (pad_w_left, pad_w_right, pad_h_left, pad_h_right), mode="reflect")
img_padded_small = F.pad(img, (pad_w_left, pad_w_right, pad_h_left, pad_h_right), mode="reflect")

# Expand padded img: [B, C, H+pad, W+pad] -> [B, num_depths*C, H+pad, W+pad]
img_padded = img_padded_small.unsqueeze(1).expand(B, num_depths, C, -1, -1).reshape(B, num_depths * C, img_padded_small.shape[2], img_padded_small.shape[3])

# Grouped convolution: each of the num_depths*C channels is convolved with its own kernel
imgs_blur = F.conv2d(img_padded, psf_stacked, groups=num_depths * C) # [B, num_depths*C, H, W]
Expand All @@ -207,7 +288,7 @@ def conv_psf_depth_interp(img, depth, psf_kernels, psf_depths, interp_mode="dept
# =================================
B, _, H, W = depth.shape
depth_flat = depth.flatten(1) # shape [B, H*W]
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The clamping logic assumes psf_depths is sorted in ascending order (most negative to least negative). However, the function doesn't validate this assumption. If psf_depths is passed in descending order or unsorted, the clamping and searchsorted operations will produce incorrect results. Consider adding an assertion to verify that psf_depths is sorted, or use min() and max() instead of indexing to be order-agnostic.

Suggested change
depth_flat = depth.flatten(1) # shape [B, H*W]
depth_flat = depth.flatten(1) # shape [B, H*W]
# Ensure psf_depths is a 1D tensor sorted in ascending order, as required by
# the clamping logic and torch.searchsorted below.
assert psf_depths.dim() == 1, f"psf_depths must be 1D, got shape {tuple(psf_depths.shape)}"
assert torch.all(psf_depths[1:] >= psf_depths[:-1]), "psf_depths must be sorted in ascending order"

Copilot uses AI. Check for mistakes.
depth_flat = depth_flat.clamp(min(psf_depths) + DELTA, max(psf_depths) - DELTA)
depth_flat = depth_flat.clamp(psf_depths[0] + DELTA, psf_depths[-1] - DELTA)
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The depth clamping at line 291 uses DELTA (1e-6) offset from the psf_depths boundaries. However, if psf_depths values are very close together (e.g., closer than 2*DELTA), this clamping could potentially clamp depth values outside the actual range of psf_depths. While unlikely in practice, consider validating that psf_depths has adequate spacing or using a relative epsilon instead of a fixed DELTA.

Suggested change
depth_flat = depth_flat.clamp(psf_depths[0] + DELTA, psf_depths[-1] - DELTA)
# Use an epsilon that respects the actual spacing of psf_depths.
# This ensures we don't clamp outside the effective range when depths are very close.
eps = DELTA
if psf_depths.numel() > 1:
diffs = psf_depths[1:] - psf_depths[:-1]
# Use at most half the minimum positive spacing, capped by DELTA.
min_diff = torch.min(torch.abs(diffs))
max_eps = (min_diff * 0.5).item()
if max_eps > 0:
eps = min(DELTA, max_eps)
depth_flat = depth_flat.clamp(psf_depths[0] + eps, psf_depths[-1] - eps)

Copilot uses AI. Check for mistakes.
indices = torch.searchsorted(psf_depths, depth_flat, right=True) # shape [B, H*W]
indices = indices.clamp(1, num_depths - 1)
idx0 = indices - 1
Expand Down Expand Up @@ -432,7 +513,11 @@ def interp_psf_map(psf_map, grid_old, grid_new):
) # .reshape(grid_old, grid_old, C, ks, ks)
elif len(psf_map.shape) == 5:
# [grid_old, grid_old, C, ks, ks]
grid_old, grid_old, C, ks, ks = psf_map.shape
grid_h, grid_w, C, ks_h, ks_w = psf_map.shape
assert grid_h == grid_w, f"PSF map grid must be square, got {grid_h}x{grid_w}"
assert ks_h == ks_w, f"PSF kernel must be square, got {ks_h}x{ks_w}"
Comment on lines +517 to +518
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new assertions at lines 517-518 enforce that PSF map grids and kernels must be square. While this is reasonable for most cases, the error messages could be more helpful by suggesting what the user should do. Consider adding guidance like 'PSF map grid must be square. If you have a non-square grid, please crop or pad it to square dimensions before calling this function.'

Suggested change
assert grid_h == grid_w, f"PSF map grid must be square, got {grid_h}x{grid_w}"
assert ks_h == ks_w, f"PSF kernel must be square, got {ks_h}x{ks_w}"
assert grid_h == grid_w, (
f"PSF map grid must be square, got {grid_h}x{grid_w}. "
"If you have a non-square grid, please crop or pad it to square "
"dimensions before calling interp_psf_map()."
)
assert ks_h == ks_w, (
f"PSF kernel must be square, got {ks_h}x{ks_w}. "
"If you have a non-square kernel, please crop or pad it to square "
"dimensions before calling interp_psf_map()."
)

Copilot uses AI. Check for mistakes.
grid_old = grid_h
ks = ks_h
psf_map_interp = psf_map
else:
raise ValueError(
Expand Down
Loading