-
-
Notifications
You must be signed in to change notification settings - Fork 94
[Fix] Image simulation bugs #103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
5f56ef9
a22d365
d9fc82d
c2a9559
d9c44fe
b88fe7c
607d3c1
8243ce8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -504,6 +504,58 @@ def render_psf_map(self, img_obj, depth=DEPTH, psf_grid=7, psf_ks=PSF_KS): | |||||||||||||||||||||||||||||||||||||||||
| # ------------------------------------------- | ||||||||||||||||||||||||||||||||||||||||||
| # Simulate 3D scene | ||||||||||||||||||||||||||||||||||||||||||
| # ------------------------------------------- | ||||||||||||||||||||||||||||||||||||||||||
| def _sample_depth_layers(self, depth_min, depth_max, num_depth): | ||||||||||||||||||||||||||||||||||||||||||
| """Sample depth layers centered on the focal plane in disparity space. | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| If the lens has a `calc_focal_plane` method, samples are split around the | ||||||||||||||||||||||||||||||||||||||||||
| focal plane so that it is always an explicit sample point. Otherwise falls | ||||||||||||||||||||||||||||||||||||||||||
| back to uniform disparity sampling. | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||
| depth_min (float): Minimum (nearest) depth in mm (positive). | ||||||||||||||||||||||||||||||||||||||||||
| depth_max (float): Maximum (farthest) depth in mm (positive). | ||||||||||||||||||||||||||||||||||||||||||
| num_depth (int): Number of depth layers to sample. | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||
| tuple: (disp_ref, depths_ref) where disp_ref has shape (num_depth,) in | ||||||||||||||||||||||||||||||||||||||||||
| disparity space and depths_ref = -1/disp_ref (negative, for PSF). | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
| # Try to get focal depth from the lens | ||||||||||||||||||||||||||||||||||||||||||
| if hasattr(self, 'calc_focal_plane'): | ||||||||||||||||||||||||||||||||||||||||||
| focal_depth = abs(self.calc_focal_plane()) # positive mm | ||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||
| focal_depth = None | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if focal_depth is not None: | ||||||||||||||||||||||||||||||||||||||||||
| # Extend range to include the focal depth | ||||||||||||||||||||||||||||||||||||||||||
| depth_min_ext = min(float(depth_min), focal_depth) | ||||||||||||||||||||||||||||||||||||||||||
| depth_max_ext = max(float(depth_max), focal_depth) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| disp_near = 1.0 / depth_min_ext # large disparity = near | ||||||||||||||||||||||||||||||||||||||||||
| disp_far = 1.0 / depth_max_ext # small disparity = far | ||||||||||||||||||||||||||||||||||||||||||
| focal_disp = 1.0 / focal_depth | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Allocate samples proportionally to range on each side | ||||||||||||||||||||||||||||||||||||||||||
| near_range = disp_near - focal_disp | ||||||||||||||||||||||||||||||||||||||||||
| far_range = focal_disp - disp_far | ||||||||||||||||||||||||||||||||||||||||||
| total_range = near_range + far_range | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if total_range < 1e-10: | ||||||||||||||||||||||||||||||||||||||||||
| disp_ref = torch.full((num_depth,), focal_disp).to(self.device) | ||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||
| n_far = max(1, round((num_depth - 1) * far_range / total_range)) | ||||||||||||||||||||||||||||||||||||||||||
| n_near = num_depth - 1 - n_far | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| far_disps = torch.linspace(disp_far, focal_disp, n_far + 1) # includes focal | ||||||||||||||||||||||||||||||||||||||||||
| near_disps = torch.linspace(focal_disp, disp_near, n_near + 1)[1:] # exclude duplicate focal | ||||||||||||||||||||||||||||||||||||||||||
| disp_ref = torch.cat([far_disps, near_disps]).to(self.device) | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+544
to
+551
|
||||||||||||||||||||||||||||||||||||||||||
| disp_ref = torch.full((num_depth,), focal_disp).to(self.device) | |
| else: | |
| n_far = max(1, round((num_depth - 1) * far_range / total_range)) | |
| n_near = num_depth - 1 - n_far | |
| far_disps = torch.linspace(disp_far, focal_disp, n_far + 1) # includes focal | |
| near_disps = torch.linspace(focal_disp, disp_near, n_near + 1)[1:] # exclude duplicate focal | |
| disp_ref = torch.cat([far_disps, near_disps]).to(self.device) | |
| disp_ref = torch.full((num_depth,), focal_disp, dtype=self.dtype, device=self.device) | |
| else: | |
| # Special case: with only one depth sample, place it at the focal plane. | |
| if num_depth == 1: | |
| disp_ref = torch.full((1,), focal_disp, dtype=self.dtype, device=self.device) | |
| else: | |
| n_far = max(1, round((num_depth - 1) * far_range / total_range)) | |
| n_near = num_depth - 1 - n_far | |
| far_disps = torch.linspace(disp_far, focal_disp, n_far + 1) # includes focal | |
| near_disps = torch.linspace(focal_disp, disp_near, n_near + 1)[1:] # exclude duplicate focal | |
| disp_ref = torch.cat([far_disps, near_disps]).to(self.device) |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -88,6 +88,9 @@ def conv_psf_map(img, psf_map): | |||||||||||||||||||||||||||
| pad_w_right = ks // 2 | ||||||||||||||||||||||||||||
| img_pad = F.pad(img, (pad_w_left, pad_w_right, pad_h_left, pad_h_right), mode="reflect") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Pre-flip entire PSF map once (instead of flipping each PSF inside the loop) | ||||||||||||||||||||||||||||
| psf_map_flipped = torch.flip(psf_map, dims=(-2, -1)) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Render image patch by patch | ||||||||||||||||||||||||||||
| img_render = torch.zeros_like(img) | ||||||||||||||||||||||||||||
| for i in range(grid_h): | ||||||||||||||||||||||||||||
|
|
@@ -99,7 +102,7 @@ def conv_psf_map(img, psf_map): | |||||||||||||||||||||||||||
| w_high = ((j + 1) * W) // grid_w | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # PSF, [C, 1, ks, ks] | ||||||||||||||||||||||||||||
| psf = torch.flip(psf_map[i, j], dims=(-2, -1)).unsqueeze(1) | ||||||||||||||||||||||||||||
| psf = psf_map_flipped[i, j].unsqueeze(1) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Consider overlap to avoid boundary artifacts | ||||||||||||||||||||||||||||
| img_pad_patch = img_pad[ | ||||||||||||||||||||||||||||
|
|
@@ -121,39 +124,114 @@ def conv_psf_map_depth_interp(img, depth, psf_map, psf_depths, interp_mode="dept | |||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||
| img: (B, 3, H, W), [0, 1] | ||||||||||||||||||||||||||||
| depth: (B, 1, H, W), [0, 1] | ||||||||||||||||||||||||||||
| depth: (B, 1, H, W), (-inf, 0) | ||||||||||||||||||||||||||||
| psf_map: (grid_h, grid_w, num_depth, 3, ks, ks) | ||||||||||||||||||||||||||||
| psf_depths: (num_depth). Used to interpolate psf_map. | ||||||||||||||||||||||||||||
| psf_depths: (num_depth). (-inf, 0). Used to interpolate psf_map. | ||||||||||||||||||||||||||||
| interp_mode: "depth" or "disparity". If "disparity", weights are calculated based on disparity (1/depth). | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||
| img_render: (B, 3, H, W), [0, 1] | ||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||
| assert interp_mode in ["depth", "disparity"], f"interp_mode must be 'depth' or 'disparity', got {interp_mode}" | ||||||||||||||||||||||||||||
| assert depth.min() < 0 and depth.max() < 0, f"depth must be negative, got {depth.min()} and {depth.max()}" | ||||||||||||||||||||||||||||
| assert psf_depths.min() < 0 and psf_depths.max() < 0, f"psf_depths must be negative, got {psf_depths.min()} and {psf_depths.max()}" | ||||||||||||||||||||||||||||
|
Comment on lines
+136
to
+137
|
||||||||||||||||||||||||||||
| assert depth.min() < 0 and depth.max() < 0, f"depth must be negative, got {depth.min()} and {depth.max()}" | |
| assert psf_depths.min() < 0 and psf_depths.max() < 0, f"psf_depths must be negative, got {psf_depths.min()} and {psf_depths.max()}" | |
| depth_min = depth.min().item() | |
| depth_max = depth.max().item() | |
| psf_depths_min = psf_depths.min().item() | |
| psf_depths_max = psf_depths.max().item() | |
| assert (depth_min >= 0 and depth_max >= 0) or (depth_min <= 0 and depth_max <= 0), ( | |
| f"depth values must be consistently signed (all <= 0 or all >= 0), got min={depth_min}, max={depth_max}" | |
| ) | |
| assert (psf_depths_min >= 0 and psf_depths_max >= 0) or (psf_depths_min <= 0 and psf_depths_max <= 0), ( | |
| f"psf_depths values must be consistently signed (all <= 0 or all >= 0), " | |
| f"got min={psf_depths_min}, max={psf_depths_max}" | |
| ) |
Copilot
AI
Feb 6, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The clamping logic assumes psf_depths is sorted in ascending order (most negative to least negative). However, the function doesn't validate this assumption. If psf_depths is passed in descending order or unsorted, the clamping and searchsorted operations will produce incorrect results. Consider adding an assertion to verify that psf_depths is sorted, or use min() and max() instead of indexing to be order-agnostic.
| depth_flat = depth.flatten(1) # [B, H*W] | |
| depth_flat = depth.flatten(1) # [B, H*W] | |
| # Validate psf_depths ordering: required for clamp and searchsorted to behave correctly. | |
| if psf_depths.ndim != 1: | |
| raise ValueError("psf_depths must be a 1D tensor of sorted depths.") | |
| if not torch.all(psf_depths[1:] >= psf_depths[:-1]): | |
| raise ValueError( | |
| "psf_depths must be sorted in ascending order (most negative to least negative)." | |
| ) |
Copilot
AI
Feb 6, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The depth clamping at line 157 uses DELTA (1e-6) offset from the psf_depths boundaries. However, if psf_depths values are very close together (e.g., closer than 2*DELTA), this clamping could potentially clamp depth values outside the actual range of psf_depths. While unlikely in practice, consider validating that psf_depths has adequate spacing or using a relative epsilon instead of a fixed DELTA.
| depth_flat = depth_flat.clamp(psf_depths[0] + DELTA, psf_depths[-1] - DELTA) | |
| # Use a spacing-aware epsilon for clamping instead of a fixed DELTA. | |
| # This avoids pushing values outside the actual sampled psf_depths range | |
| # when depths are very closely spaced. | |
| if num_depths > 1: | |
| spacings = psf_depths[1:] - psf_depths[:-1] | |
| min_spacing = torch.min(spacings) | |
| eps = 0.5 * torch.clamp(min_spacing, min=0.0) | |
| else: | |
| eps = torch.zeros((), dtype=psf_depths.dtype, device=psf_depths.device) | |
| depth_flat = depth_flat.clamp(psf_depths[0] + eps, psf_depths[-1] - eps) |
Copilot
AI
Feb 6, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new assertions requiring negative depth values will break existing tests. The tests in test/test_psf.py (lines 128-181) use positive depth values (e.g., torch.linspace(0, 1, 5) at line 138, torch.linspace(1.0, 2.0, 5) at line 166). These tests will fail with the new assertions. Consider either updating the tests to use negative depth values or relaxing the assertions to allow both positive and negative depths with appropriate internal conversion.
Copilot
AI
Feb 6, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The clamping logic assumes psf_depths is sorted in ascending order (most negative to least negative). However, the function doesn't validate this assumption. If psf_depths is passed in descending order or unsorted, the clamping and searchsorted operations will produce incorrect results. Consider adding an assertion to verify that psf_depths is sorted, or use min() and max() instead of indexing to be order-agnostic.
| depth_flat = depth.flatten(1) # shape [B, H*W] | |
| depth_flat = depth.flatten(1) # shape [B, H*W] | |
| # Ensure psf_depths is a 1D tensor sorted in ascending order, as required by | |
| # the clamping logic and torch.searchsorted below. | |
| assert psf_depths.dim() == 1, f"psf_depths must be 1D, got shape {tuple(psf_depths.shape)}" | |
| assert torch.all(psf_depths[1:] >= psf_depths[:-1]), "psf_depths must be sorted in ascending order" |
Copilot
AI
Feb 6, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The depth clamping at line 291 uses DELTA (1e-6) offset from the psf_depths boundaries. However, if psf_depths values are very close together (e.g., closer than 2*DELTA), this clamping could potentially clamp depth values outside the actual range of psf_depths. While unlikely in practice, consider validating that psf_depths has adequate spacing or using a relative epsilon instead of a fixed DELTA.
| depth_flat = depth_flat.clamp(psf_depths[0] + DELTA, psf_depths[-1] - DELTA) | |
| # Use an epsilon that respects the actual spacing of psf_depths. | |
| # This ensures we don't clamp outside the effective range when depths are very close. | |
| eps = DELTA | |
| if psf_depths.numel() > 1: | |
| diffs = psf_depths[1:] - psf_depths[:-1] | |
| # Use at most half the minimum positive spacing, capped by DELTA. | |
| min_diff = torch.min(torch.abs(diffs)) | |
| max_eps = (min_diff * 0.5).item() | |
| if max_eps > 0: | |
| eps = min(DELTA, max_eps) | |
| depth_flat = depth_flat.clamp(psf_depths[0] + eps, psf_depths[-1] - eps) |
Copilot
AI
Feb 6, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new assertions at lines 517-518 enforce that PSF map grids and kernels must be square. While this is reasonable for most cases, the error messages could be more helpful by suggesting what the user should do. Consider adding guidance like 'PSF map grid must be square. If you have a non-square grid, please crop or pad it to square dimensions before calling this function.'
| assert grid_h == grid_w, f"PSF map grid must be square, got {grid_h}x{grid_w}" | |
| assert ks_h == ks_w, f"PSF kernel must be square, got {ks_h}x{ks_w}" | |
| assert grid_h == grid_w, ( | |
| f"PSF map grid must be square, got {grid_h}x{grid_w}. " | |
| "If you have a non-square grid, please crop or pad it to square " | |
| "dimensions before calling interp_psf_map()." | |
| ) | |
| assert ks_h == ks_w, ( | |
| f"PSF kernel must be square, got {ks_h}x{ks_w}. " | |
| "If you have a non-square kernel, please crop or pad it to square " | |
| "dimensions before calling interp_psf_map()." | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tensors created on lines 549-550 (far_disps and near_disps) are not explicitly moved to the correct device before concatenation. While line 551 does call
.to(self.device)on the concatenated result, it would be more efficient and clearer to create the tensors on the correct device initially. Consider usingtorch.linspace(...).to(self.device)on lines 549-550, or pass a device parameter to linspace if supported.