diff --git a/deeplens/geolens.py b/deeplens/geolens.py index 6d2a7cb..3526745 100644 --- a/deeplens/geolens.py +++ b/deeplens/geolens.py @@ -1611,10 +1611,27 @@ def calc_foclen(self): # Trace a paraxial chief ray, shape [1, 1, num_rays, 3] paraxial_fov = 0.01 paraxial_fov_deg = float(np.rad2deg(paraxial_fov)) + + # 1. Trace on-axis parallel rays to find paraxial focus z (equivalent to infinite focus) + ray_axis = self.sample_parallel( + fov_x=0.0, fov_y=0.0, entrance_pupil=False, scale_pupil=0.2 + ) + ray_axis, _ = self.trace(ray_axis) + valid_axis = ray_axis.is_valid > 0 + t = -(ray_axis.d[valid_axis, 0] * ray_axis.o[valid_axis, 0] + + ray_axis.d[valid_axis, 1] * ray_axis.o[valid_axis, 1]) / ( + ray_axis.d[valid_axis, 0] ** 2 + ray_axis.d[valid_axis, 1] ** 2 + ) + focus_z = ray_axis.o[valid_axis, 2] + t * ray_axis.d[valid_axis, 2] + focus_z = focus_z[~torch.isnan(focus_z) & (focus_z > 0)] + paraxial_focus_z = float(torch.mean(focus_z)) + + # 2. Trace off-axis paraxial ray to paraxial focus, measure image height ray = self.sample_parallel( fov_x=0.0, fov_y=paraxial_fov_deg, entrance_pupil=False, scale_pupil=0.2 ) - ray = self.trace2sensor(ray) + ray, _ = self.trace(ray) + ray = ray.prop_to(paraxial_focus_z) # Compute the effective focal length paraxial_imgh = (ray.o[:, 1] * ray.is_valid).sum() / ray.is_valid.sum() @@ -1625,6 +1642,8 @@ def calc_foclen(self): # Compute the back focal length self.bfl = self.d_sensor.item() - self.surfaces[-1].d.item() + return eff_foclen + @torch.no_grad() def calc_numerical_aperture(self, n=1.0): """Compute numerical aperture (NA). @@ -1638,7 +1657,6 @@ def calc_numerical_aperture(self, n=1.0): Reference: [1] https://en.wikipedia.org/wiki/Numerical_aperture """ - breakpoint() return n * math.sin(math.atan(1 / 2 / self.fnum)) # return n / (2 * self.fnum) diff --git a/deeplens/lens.py b/deeplens/lens.py index d186ae4..f4ba587 100644 --- a/deeplens/lens.py +++ b/deeplens/lens.py @@ -504,6 +504,58 @@ def render_psf_map(self, img_obj, depth=DEPTH, psf_grid=7, psf_ks=PSF_KS): # ------------------------------------------- # Simulate 3D scene # ------------------------------------------- + def _sample_depth_layers(self, depth_min, depth_max, num_depth): + """Sample depth layers centered on the focal plane in disparity space. + + If the lens has a `calc_focal_plane` method, samples are split around the + focal plane so that it is always an explicit sample point. Otherwise falls + back to uniform disparity sampling. + + Args: + depth_min (float): Minimum (nearest) depth in mm (positive). + depth_max (float): Maximum (farthest) depth in mm (positive). + num_depth (int): Number of depth layers to sample. + + Returns: + tuple: (disp_ref, depths_ref) where disp_ref has shape (num_depth,) in + disparity space and depths_ref = -1/disp_ref (negative, for PSF). + """ + # Try to get focal depth from the lens + if hasattr(self, 'calc_focal_plane'): + focal_depth = abs(self.calc_focal_plane()) # positive mm + else: + focal_depth = None + + if focal_depth is not None: + # Extend range to include the focal depth + depth_min_ext = min(float(depth_min), focal_depth) + depth_max_ext = max(float(depth_max), focal_depth) + + disp_near = 1.0 / depth_min_ext # large disparity = near + disp_far = 1.0 / depth_max_ext # small disparity = far + focal_disp = 1.0 / focal_depth + + # Allocate samples proportionally to range on each side + near_range = disp_near - focal_disp + far_range = focal_disp - disp_far + total_range = near_range + far_range + + if total_range < 1e-10: + disp_ref = torch.full((num_depth,), focal_disp).to(self.device) + else: + n_far = max(1, round((num_depth - 1) * far_range / total_range)) + n_near = num_depth - 1 - n_far + + far_disps = torch.linspace(disp_far, focal_disp, n_far + 1) # includes focal + near_disps = torch.linspace(focal_disp, disp_near, n_near + 1)[1:] # exclude duplicate focal + disp_ref = torch.cat([far_disps, near_disps]).to(self.device) + else: + # Fallback: uniform disparity sampling + disp_ref = torch.linspace(1.0 / float(depth_max), 1.0 / float(depth_min), num_depth).to(self.device) + + depths_ref = -1.0 / disp_ref + return disp_ref, depths_ref + def render_rgbd(self, img_obj, depth_map, method="psf_patch", **kwargs): """Render RGBD image. @@ -538,8 +590,7 @@ def render_rgbd(self, img_obj, depth_map, method="psf_patch", **kwargs): interp_mode = kwargs.get("interp_mode", "disparity") # Calculate PSF at different depths, (num_depth, 3, ks, ks) - disp_ref = torch.linspace(1.0/depth_max, 1.0/depth_min, num_depth).to(self.device) - depths_ref = - 1.0 / disp_ref # Convert to negative for PSF calculation + disp_ref, depths_ref = self._sample_depth_layers(depth_min, depth_max, num_depth) points = torch.stack( [ @@ -552,7 +603,7 @@ def render_rgbd(self, img_obj, depth_map, method="psf_patch", **kwargs): psfs = self.psf_rgb(points=points, ks=psf_ks) # (num_depth, 3, ks, ks) # Image simulation - img_render = conv_psf_depth_interp(img_obj, depth_map, psfs, depths_ref, interp_mode=interp_mode) + img_render = conv_psf_depth_interp(img_obj, -depth_map, psfs, depths_ref, interp_mode=interp_mode) return img_render elif method == "psf_map": @@ -565,8 +616,7 @@ def render_rgbd(self, img_obj, depth_map, method="psf_patch", **kwargs): interp_mode = kwargs.get("interp_mode", "disparity") # Calculate PSF map at different depths (convert to negative for PSF calculation) - disp_ref = torch.linspace(1.0/depth_max, 1.0/depth_min, num_depth).to(self.device) - depths_ref = -1.0 / disp_ref # Convert to negative for PSF calculation + disp_ref, depths_ref = self._sample_depth_layers(depth_min, depth_max, num_depth) psf_maps = [] for depth in tqdm(depths_ref): @@ -578,7 +628,7 @@ def render_rgbd(self, img_obj, depth_map, method="psf_patch", **kwargs): # Image simulation img_render = conv_psf_map_depth_interp( - img_obj, depth_map, psf_map, depths_ref, interp_mode=interp_mode + img_obj, -depth_map, psf_map, depths_ref, interp_mode=interp_mode ) return img_render diff --git a/deeplens/optics/psf.py b/deeplens/optics/psf.py index 4594ad1..fb450d5 100644 --- a/deeplens/optics/psf.py +++ b/deeplens/optics/psf.py @@ -88,6 +88,9 @@ def conv_psf_map(img, psf_map): pad_w_right = ks // 2 img_pad = F.pad(img, (pad_w_left, pad_w_right, pad_h_left, pad_h_right), mode="reflect") + # Pre-flip entire PSF map once (instead of flipping each PSF inside the loop) + psf_map_flipped = torch.flip(psf_map, dims=(-2, -1)) + # Render image patch by patch img_render = torch.zeros_like(img) for i in range(grid_h): @@ -99,7 +102,7 @@ def conv_psf_map(img, psf_map): w_high = ((j + 1) * W) // grid_w # PSF, [C, 1, ks, ks] - psf = torch.flip(psf_map[i, j], dims=(-2, -1)).unsqueeze(1) + psf = psf_map_flipped[i, j].unsqueeze(1) # Consider overlap to avoid boundary artifacts img_pad_patch = img_pad[ @@ -121,39 +124,114 @@ def conv_psf_map_depth_interp(img, depth, psf_map, psf_depths, interp_mode="dept Args: img: (B, 3, H, W), [0, 1] - depth: (B, 1, H, W), [0, 1] + depth: (B, 1, H, W), (-inf, 0) psf_map: (grid_h, grid_w, num_depth, 3, ks, ks) - psf_depths: (num_depth). Used to interpolate psf_map. + psf_depths: (num_depth). (-inf, 0). Used to interpolate psf_map. interp_mode: "depth" or "disparity". If "disparity", weights are calculated based on disparity (1/depth). Returns: img_render: (B, 3, H, W), [0, 1] """ + assert interp_mode in ["depth", "disparity"], f"interp_mode must be 'depth' or 'disparity', got {interp_mode}" + assert depth.min() < 0 and depth.max() < 0, f"depth must be negative, got {depth.min()} and {depth.max()}" + assert psf_depths.min() < 0 and psf_depths.max() < 0, f"psf_depths must be negative, got {psf_depths.min()} and {psf_depths.max()}" + B, C, H, W = img.shape grid_h, grid_w, num_depths, C_psf, ks, _ = psf_map.shape assert C_psf == C, f"PSF map channels ({C_psf}) must match image channels ({C})." - - # Render image patch by patch, reusing conv_psf_depth_interp for each patch + + # Pad the full image once to avoid boundary artifacts at patch seams. + # Without this, each patch would be padded independently (reflecting within + # its own boundary), producing visible seams at grid boundaries. + pad_h_left = (ks - 1) // 2 + pad_h_right = ks // 2 + pad_w_left = (ks - 1) // 2 + pad_w_right = ks // 2 + img_pad = F.pad(img, (pad_w_left, pad_w_right, pad_h_left, pad_h_right), mode="reflect") + + # Pre-flip entire PSF map once: [grid_h, grid_w, num_depths, C, ks, ks] + psf_map_flipped = torch.flip(psf_map, dims=(-2, -1)) + + # Pre-compute depth interpolation weights (shared across all patches) + depth_flat = depth.flatten(1) # [B, H*W] + depth_flat = depth_flat.clamp(psf_depths[0] + DELTA, psf_depths[-1] - DELTA) + indices = torch.searchsorted(psf_depths, depth_flat, right=True) # [B, H*W] + indices = indices.clamp(1, num_depths - 1) + idx0 = indices - 1 + idx1 = indices + + d0 = psf_depths[idx0] # [B, H*W] + d1 = psf_depths[idx1] + + if interp_mode == "depth": + denom = d1 - d0 + denom[denom == 0] = 1e-6 + w1 = (depth_flat - d0) / denom + else: + disp_flat = 1.0 / depth_flat + disp0 = 1.0 / d0 + disp1 = 1.0 / d1 + denom = disp1 - disp0 + denom[denom == 0] = 1e-6 + w1 = (disp_flat - disp0) / denom + + w0 = 1 - w1 + + # Reshape weight indices to spatial layout for patch extraction + idx0_spatial = idx0.view(B, H, W) + idx1_spatial = idx1.view(B, H, W) + w0_spatial = w0.view(B, H, W) + w1_spatial = w1.view(B, H, W) + + # Render image patch by patch img_render = torch.zeros_like(img) for i in range(grid_h): h_low = (i * H) // grid_h h_high = ((i + 1) * H) // grid_h + patch_h = h_high - h_low - for j in range(grid_w): + for j in range(grid_w): w_low = (j * W) // grid_w w_high = ((j + 1) * W) // grid_w + patch_w = w_high - w_low + + # Extract overlapping patch from pre-padded image (no per-patch padding needed) + img_pad_patch = img_pad[ + :, :, + h_low : h_high + pad_h_left + pad_h_right, + w_low : w_high + pad_w_left + pad_w_right, + ] + + # Expand patch for all depths: [B, C, patch_h+pad, patch_w+pad] -> [B, num_depths*C, ...] + img_patch_expanded = img_pad_patch.unsqueeze(1).expand(B, num_depths, C, -1, -1).reshape( + B, num_depths * C, img_pad_patch.shape[2], img_pad_patch.shape[3] + ) + + # PSF kernels for this grid cell: [num_depths*C, 1, ks, ks] + psf_stacked = psf_map_flipped[i, j].reshape(num_depths * C, 1, ks, ks) + + # Grouped convolution -> [B, num_depths*C, patch_h, patch_w] + patch_blur = F.conv2d(img_patch_expanded, psf_stacked, groups=num_depths * C) + + # Reshape to [num_depths, B, C, patch_h, patch_w] + patch_blur = patch_blur.reshape(B, num_depths, C, patch_h, patch_w).permute(1, 0, 2, 3, 4) + + # Extract pre-computed weights for this patch + patch_idx0 = idx0_spatial[:, h_low:h_high, w_low:w_high].reshape(B, patch_h * patch_w) + patch_idx1 = idx1_spatial[:, h_low:h_high, w_low:w_high].reshape(B, patch_h * patch_w) + patch_w0 = w0_spatial[:, h_low:h_high, w_low:w_high].reshape(B, patch_h * patch_w) + patch_w1 = w1_spatial[:, h_low:h_high, w_low:w_high].reshape(B, patch_h * patch_w) + + # Build per-depth weight tensor for this patch + weights = torch.zeros(num_depths, B, patch_h * patch_w, device=img.device, dtype=img.dtype) + weights.scatter_add_(0, patch_idx0.unsqueeze(0).long(), patch_w0.unsqueeze(0)) + weights.scatter_add_(0, patch_idx1.unsqueeze(0).long(), patch_w1.unsqueeze(0)) + weights = weights.view(num_depths, B, 1, patch_h, patch_w) - # Extract image and depth patches - img_patch = img[:, :, h_low:h_high, w_low:w_high] - depth_patch = depth[:, :, h_low:h_high, w_low:w_high] - - # Extract PSF kernels for this patch at all depths - psf_kernels = psf_map[i, j, :, :, :, :] # [num_depths, C, ks, ks] - - # Use conv_psf_depth_interp for this patch - render_patch = conv_psf_depth_interp(img_patch, depth_patch, psf_kernels, psf_depths, interp_mode) + # Apply depth-interpolation weights + render_patch = torch.sum(patch_blur * weights, dim=0) img_render[:, :, h_low:h_high, w_low:w_high] = render_patch - + return img_render @@ -164,15 +242,17 @@ def conv_psf_depth_interp(img, depth, psf_kernels, psf_depths, interp_mode="dept Args: img: (B, 3, H, W), [0, 1] - depth: (B, 1, H, W), [0, 1] + depth: (B, 1, H, W), (-inf, 0) psf_kernels: (num_depth, 3, ks, ks) - psf_depths: (num_depth). Used to interpolate psf_kernels. + psf_depths: (num_depth). (-inf, 0). Used to interpolate psf_kernels. interp_mode: "depth" or "disparity". If "disparity", weights are calculated based on disparity (1/depth). Returns: img_blur: (B, 3, H, W), [0, 1] """ assert interp_mode in ["depth", "disparity"], f"interp_mode must be 'depth' or 'disparity', got {interp_mode}" + assert depth.min() < 0 and depth.max() < 0, f"depth must be negative, got {depth.min()} and {depth.max()}" + assert psf_depths.min() < 0 and psf_depths.max() < 0, f"psf_depths must be negative, got {psf_depths.min()} and {psf_depths.max()}" # assert img.device != torch.device("cpu"), "Image must be on GPU" num_depths, _, ks, _ = psf_kernels.shape @@ -182,19 +262,20 @@ def conv_psf_depth_interp(img, depth, psf_kernels, psf_depths, interp_mode="dept # ================================= B, C, H, W = img.shape - # Expand img: [B, C, H, W] -> [B, num_depths, C, H, W] -> [B, num_depths*C, H, W] - img_expanded = img.unsqueeze(1).expand(B, num_depths, C, H, W).reshape(B, num_depths * C, H, W) - # Prepare PSF kernel: [num_depths, C, ks, ks] -> [num_depths*C, 1, ks, ks] # Flip the PSF because F.conv2d uses cross-correlation psf_stacked = torch.flip(psf_kernels, [-2, -1]).reshape(num_depths * C, 1, ks, ks) - - # Padding (following conv_psf logic for even/odd kernel sizes) + + # Pad before expand: pad [B, C, H, W] first (C channels), then expand to num_depths*C + # This reduces padding work by a factor of num_depths pad_h_left = (ks - 1) // 2 pad_h_right = ks // 2 pad_w_left = (ks - 1) // 2 pad_w_right = ks // 2 - img_padded = F.pad(img_expanded, (pad_w_left, pad_w_right, pad_h_left, pad_h_right), mode="reflect") + img_padded_small = F.pad(img, (pad_w_left, pad_w_right, pad_h_left, pad_h_right), mode="reflect") + + # Expand padded img: [B, C, H+pad, W+pad] -> [B, num_depths*C, H+pad, W+pad] + img_padded = img_padded_small.unsqueeze(1).expand(B, num_depths, C, -1, -1).reshape(B, num_depths * C, img_padded_small.shape[2], img_padded_small.shape[3]) # Grouped convolution: each of the num_depths*C channels is convolved with its own kernel imgs_blur = F.conv2d(img_padded, psf_stacked, groups=num_depths * C) # [B, num_depths*C, H, W] @@ -207,7 +288,7 @@ def conv_psf_depth_interp(img, depth, psf_kernels, psf_depths, interp_mode="dept # ================================= B, _, H, W = depth.shape depth_flat = depth.flatten(1) # shape [B, H*W] - depth_flat = depth_flat.clamp(min(psf_depths) + DELTA, max(psf_depths) - DELTA) + depth_flat = depth_flat.clamp(psf_depths[0] + DELTA, psf_depths[-1] - DELTA) indices = torch.searchsorted(psf_depths, depth_flat, right=True) # shape [B, H*W] indices = indices.clamp(1, num_depths - 1) idx0 = indices - 1 @@ -432,7 +513,11 @@ def interp_psf_map(psf_map, grid_old, grid_new): ) # .reshape(grid_old, grid_old, C, ks, ks) elif len(psf_map.shape) == 5: # [grid_old, grid_old, C, ks, ks] - grid_old, grid_old, C, ks, ks = psf_map.shape + grid_h, grid_w, C, ks_h, ks_w = psf_map.shape + assert grid_h == grid_w, f"PSF map grid must be square, got {grid_h}x{grid_w}" + assert ks_h == ks_w, f"PSF kernel must be square, got {ks_h}x{ks_w}" + grid_old = grid_h + ks = ks_h psf_map_interp = psf_map else: raise ValueError( diff --git a/docs/api/geolens.rst b/docs/api/geolens.rst index b25f57e..a9b34d5 100644 --- a/docs/api/geolens.rst +++ b/docs/api/geolens.rst @@ -126,7 +126,7 @@ Ray Sampling Grid Sampling ~~~~~~~~~~~~~ -.. py:method:: GeoLens.sample_grid_rays(depth=float("inf"), num_grid=(11, 11), num_rays=16384, wvln=0.58756180, uniform_fov=True, sample_more_off_axis=False, scale_pupil=1.0) +.. py:method:: GeoLens.sample_grid_rays(depth=float("inf"), num_grid=(11, 11), num_rays=16384, wvln=0.587, uniform_fov=True, sample_more_off_axis=False, scale_pupil=1.0) Sample grid rays from object space for PSF map or spot diagram analysis. @@ -147,7 +147,7 @@ Grid Sampling :return: Ray object with shape [num_grid[1], num_grid[0], num_rays, 3] :rtype: Ray -.. py:method:: GeoLens.sample_radial_rays(num_field=5, depth=float("inf"), num_rays=2048, wvln=0.589) +.. py:method:: GeoLens.sample_radial_rays(num_field=5, depth=float("inf"), num_rays=16384, wvln=0.587) Sample radial (meridional, y-direction) rays at different field angles. @@ -165,7 +165,7 @@ Grid Sampling Point Source Sampling ~~~~~~~~~~~~~~~~~~~~~ -.. py:method:: GeoLens.sample_from_points(points=[[0.0, 0.0, -10000.0]], num_rays=2048, wvln=0.589, scale_pupil=1.0) +.. py:method:: GeoLens.sample_from_points(points=[[0.0, 0.0, -10000.0]], num_rays=16384, wvln=0.587, scale_pupil=1.0) Sample rays from point sources at absolute 3D coordinates. @@ -183,7 +183,7 @@ Point Source Sampling Parallel & Angular Sampling ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. py:method:: GeoLens.sample_parallel(fov_x=[0.0], fov_y=[0.0], num_rays=512, wvln=0.589, entrance_pupil=True, depth=-1.0, scale_pupil=1.0) +.. py:method:: GeoLens.sample_parallel(fov_x=[0.0], fov_y=[0.0], num_rays=1024, wvln=0.587, entrance_pupil=True, depth=-1.0, scale_pupil=1.0) Sample parallel rays at given field angles. @@ -204,7 +204,7 @@ Parallel & Angular Sampling :return: Ray object with shape [len(fov_y), len(fov_x), num_rays, 3] :rtype: Ray -.. py:method:: GeoLens.sample_point_source(fov_x=[0.0], fov_y=[0.0], depth=-10000.0, num_rays=2048, wvln=0.589, entrance_pupil=True, scale_pupil=1.0) +.. py:method:: GeoLens.sample_point_source(fov_x=[0.0], fov_y=[0.0], depth=-20000.0, num_rays=16384, wvln=0.587, entrance_pupil=True, scale_pupil=1.0) Sample point source rays at given field angles and depth. @@ -228,7 +228,7 @@ Parallel & Angular Sampling Sensor Sampling ~~~~~~~~~~~~~~~ -.. py:method:: GeoLens.sample_sensor(spp=64, wvln=0.589, sub_pixel=False) +.. py:method:: GeoLens.sample_sensor(spp=64, wvln=0.587, sub_pixel=False) Sample backward rays from sensor pixels for ray-tracing rendering. @@ -257,7 +257,7 @@ Helper Methods :return: Sampled points with shape [*shape, 3] :rtype: torch.Tensor -.. py:method:: GeoLens.sample_ring_arm_rays(num_ring=8, num_arm=8, spp=2048, depth=-10000.0, wvln=0.589, scale_pupil=1.0, sample_more_off_axis=True) +.. py:method:: GeoLens.sample_ring_arm_rays(num_ring=8, num_arm=8, spp=2048, depth=-20000.0, wvln=0.587, scale_pupil=1.0, sample_more_off_axis=True) Sample rays using ring-arm pattern for optimization (from ``GeoLensOptim``). @@ -346,7 +346,7 @@ Image Rendering Main Rendering ~~~~~~~~~~~~~~ -.. py:method:: GeoLens.render(img_obj, depth=-10000.0, method="ray_tracing", **kwargs) +.. py:method:: GeoLens.render(img_obj, depth=-20000.0, method="ray_tracing", **kwargs) Differentiable image simulation through the lens. @@ -366,7 +366,7 @@ Main Rendering Ray Tracing Rendering ~~~~~~~~~~~~~~~~~~~~~ -.. py:method:: GeoLens.render_raytracing(img, depth=-10000.0, spp=64, vignetting=False) +.. py:method:: GeoLens.render_raytracing(img, depth=-20000.0, spp=64, vignetting=False) Render RGB image using ray tracing. @@ -381,7 +381,7 @@ Ray Tracing Rendering :return: Rendered image [N, 3, H, W] :rtype: torch.Tensor -.. py:method:: GeoLens.render_raytracing_mono(img, wvln, depth=-10000.0, spp=64, vignetting=False) +.. py:method:: GeoLens.render_raytracing_mono(img, wvln, depth=-20000.0, spp=64, vignetting=False) Render monochrome image using ray tracing. @@ -418,7 +418,7 @@ Ray Tracing Rendering Post-Processing ~~~~~~~~~~~~~~~ -.. py:method:: GeoLens.unwarp(img, depth=-10000.0, num_grid=128, crop=True, flip=True) +.. py:method:: GeoLens.unwarp(img, depth=-20000.0, num_grid=128, crop=True, flip=True) Unwarp rendered images to correct distortion. @@ -435,7 +435,7 @@ Post-Processing :return: Unwarped image [N, C, H, W] :rtype: torch.Tensor -.. py:method:: GeoLens.analysis_rendering(img_org, save_name=None, depth=-10000.0, spp=64, unwarp=False, noise=0.0, method="ray_tracing", show=False) +.. py:method:: GeoLens.analysis_rendering(img_org, save_name=None, depth=-20000.0, spp=64, unwarp=False, noise=0.0, method="ray_tracing", show=False) Render image and compute PSNR/SSIM for analysis. @@ -461,7 +461,7 @@ Post-Processing PSF Calculation --------------- -.. py:method:: GeoLens.psf(points, ks=64, wvln=0.589, spp=None, recenter=True, model="geometric") +.. py:method:: GeoLens.psf(points, ks=64, wvln=0.587, spp=None, recenter=True, model="geometric") Calculate Point Spread Function (PSF) using different models. @@ -480,7 +480,7 @@ PSF Calculation :return: PSF tensor [ks, ks] or [N, ks, ks] :rtype: torch.Tensor -.. py:method:: GeoLens.psf_geometric(points, ks=64, wvln=0.589, spp=2048, recenter=True) +.. py:method:: GeoLens.psf_geometric(points, ks=64, wvln=0.587, spp=16384, recenter=True) Calculate incoherent geometric PSF using ray tracing. @@ -497,23 +497,23 @@ PSF Calculation :return: PSF tensor :rtype: torch.Tensor -.. py:method:: GeoLens.psf_coherent(points, ks=64, wvln=0.589, spp=1000000, recenter=True) +.. py:method:: GeoLens.psf_coherent(points, ks=64, wvln=0.587, spp=16777216, recenter=True) Calculate coherent PSF by propagating pupil field to sensor (Ray-Wave model). Alias for ``psf_pupil_prop``. -.. py:method:: GeoLens.psf_pupil_prop(points, ks=64, wvln=0.589, spp=1000000, recenter=True) +.. py:method:: GeoLens.psf_pupil_prop(points, ks=64, wvln=0.587, spp=16777216, recenter=True) Calculate coherent PSF by propagating pupil field to sensor using ASM. :param points: Point source positions :param ks: Kernel size :param wvln: Wavelength - :param spp: Sample rays (typically 1M) + :param spp: Sample rays (typically 16777216, ~16.8M, 2^24) :param recenter: Recenter PSF :return: PSF patch :rtype: torch.Tensor -.. py:method:: GeoLens.psf_huygens(points, ks=64, wvln=0.589, spp=1000000, recenter=True) +.. py:method:: GeoLens.psf_huygens(points, ks=64, wvln=0.587, spp=16777216, recenter=True) Calculate Huygens PSF by treating every exit-pupil ray as a secondary spherical wave source. @@ -525,7 +525,7 @@ PSF Calculation :return: Huygens PSF patch :rtype: torch.Tensor -.. py:method:: GeoLens.psf_map(depth=-10000.0, grid=(7, 7), ks=64, spp=2048, wvln=0.589, recenter=True) +.. py:method:: GeoLens.psf_map(depth=-20000.0, grid=(7, 7), ks=64, spp=16384, wvln=0.587, recenter=True) Calculate PSF map at different field positions. @@ -555,7 +555,7 @@ PSF Calculation :return: PSF centers [..., 2] :rtype: torch.Tensor -.. py:method:: GeoLens.psf_coherent(points, ks=64, wvln=0.589, spp=1000000, recenter=True) +.. py:method:: GeoLens.psf_coherent(points, ks=64, wvln=0.587, spp=16777216, recenter=True) Calculate coherent PSF using ray-wave model. Alias for ``psf_pupil_prop``. @@ -565,14 +565,14 @@ PSF Calculation :type ks: int :param wvln: Wavelength :type wvln: float - :param spp: Sample rays (>= 1M recommended) + :param spp: Sample rays (>= 16777216, ~16.8M, 2^24 recommended) :type spp: int :param recenter: Recenter PSF using chief ray :type recenter: bool :return: PSF patch [ks, ks] :rtype: torch.Tensor -.. py:method:: GeoLens.pupil_field(points, wvln=0.589, spp=1000000, recenter=True) +.. py:method:: GeoLens.pupil_field(points, wvln=0.587, spp=16777216, recenter=True) Calculate complex wavefront at exit pupil using coherent ray tracing. Only single-point input is supported. @@ -580,7 +580,7 @@ PSF Calculation :type points: torch.Tensor or list :param wvln: Wavelength in micrometers :type wvln: float - :param spp: Samples (>= 1M required) + :param spp: Samples (>= 16777216, ~16.8M, 2^24 required) :type spp: int :param recenter: Recenter PSF using chief ray :type recenter: bool @@ -593,7 +593,7 @@ Optical Analysis (GeoLensEval) Spot Diagrams ~~~~~~~~~~~~~ -.. py:method:: GeoLens.draw_spot_radial(save_name='./lens_spot_radial.png', num_fov=5, depth=float("inf"), num_rays=16384, wvln_list=[0.656, 0.588, 0.486], show=False) +.. py:method:: GeoLens.draw_spot_radial(save_name='./lens_spot_radial.png', num_fov=5, depth=float("inf"), num_rays=16384, wvln_list=[0.656, 0.587, 0.486], show=False) Draw spot diagrams along meridional direction. @@ -610,7 +610,7 @@ Spot Diagrams :param show: Display plot :type show: bool -.. py:method:: GeoLens.draw_spot_map(save_name='./lens_spot_map.png', num_grid=5, depth=-20000.0, num_rays=16384, wvln_list=[0.656, 0.588, 0.486], show=False) +.. py:method:: GeoLens.draw_spot_map(save_name='./lens_spot_map.png', num_grid=5, depth=-20000.0, num_rays=16384, wvln_list=[0.656, 0.587, 0.486], show=False) Draw spot diagram grid. @@ -641,7 +641,7 @@ Spot Diagrams RMS Error Maps ~~~~~~~~~~~~~~ -.. py:method:: GeoLens.rms_map_rgb(num_grid=32, depth=-10000.0) +.. py:method:: GeoLens.rms_map_rgb(num_grid=32, depth=-20000.0) Calculate RGB RMS spot error map. @@ -652,7 +652,7 @@ RMS Error Maps :return: RMS error map for each RGB channel :rtype: torch.Tensor -.. py:method:: GeoLens.rms_map(num_grid=32, depth=-10000.0, wvln=0.589) +.. py:method:: GeoLens.rms_map(num_grid=32, depth=-20000.0, wvln=0.587) Calculate RMS spot error map for single wavelength. @@ -668,7 +668,7 @@ RMS Error Maps Distortion ~~~~~~~~~~ -.. py:method:: GeoLens.calc_distortion_2D(rfov, wvln=0.58756180, plane='meridional', ray_aiming=True) +.. py:method:: GeoLens.calc_distortion_2D(rfov, wvln=0.587, plane='meridional', ray_aiming=True) Calculate distortion at a specific field angle. @@ -683,7 +683,7 @@ Distortion :return: Distortion at the specific field angle :rtype: float or numpy.ndarray -.. py:method:: GeoLens.draw_distortion_radial(rfov, save_name=None, num_points=21, wvln=0.58756180, plane='meridional', ray_aiming=True, show=False) +.. py:method:: GeoLens.draw_distortion_radial(rfov, save_name=None, num_points=21, wvln=0.587, plane='meridional', ray_aiming=True, show=False) Draw distortion curve vs field angle (Zemax-style). @@ -702,7 +702,7 @@ Distortion :param show: Display plot :type show: bool -.. py:method:: GeoLens.distortion_map(num_grid=16, depth=-10000.0) +.. py:method:: GeoLens.distortion_map(num_grid=16, depth=-20000.0) Compute distortion map for grid_sample. @@ -722,7 +722,7 @@ Distortion :return: Normalized distortion center positions [..., 2]. x, y in [-1, 1] :rtype: torch.Tensor -.. py:method:: GeoLens.draw_distortion(filename=None, num_grid=16, depth=-10000.0) +.. py:method:: GeoLens.draw_distortion(filename=None, num_grid=16, depth=-20000.0) Visualize distortion map. @@ -736,7 +736,7 @@ Distortion MTF Analysis ~~~~~~~~~~~~ -.. py:method:: GeoLens.mtf(fov, wvln=0.589) +.. py:method:: GeoLens.mtf(fov, wvln=0.587) Calculate Modulation Transfer Function at field of view. @@ -777,7 +777,7 @@ MTF Analysis Vignetting ~~~~~~~~~~ -.. py:method:: GeoLens.vignetting(depth=-10000.0, num_grid=64) +.. py:method:: GeoLens.vignetting(depth=-20000.0, num_grid=64) Compute vignetting map. @@ -788,7 +788,7 @@ Vignetting :return: Vignetting map :rtype: torch.Tensor -.. py:method:: GeoLens.draw_vignetting(filename=None, depth=-10000.0, resolution=512) +.. py:method:: GeoLens.draw_vignetting(filename=None, depth=-20000.0, resolution=512) Visualize vignetting effect. @@ -911,7 +911,7 @@ Field of View Focal & Sensor Planes ~~~~~~~~~~~~~~~~~~~~~ -.. py:method:: GeoLens.calc_focal_plane(wvln=0.589) +.. py:method:: GeoLens.calc_focal_plane(wvln=0.587) Calculate focus distance in object space by backward tracing. @@ -1308,7 +1308,7 @@ Visualization (GeoLensVis) 2D Ray Sampling ~~~~~~~~~~~~~~~ -.. py:method:: GeoLens.sample_parallel_2D(fov=0.0, num_rays=7, wvln=0.589, plane="meridional", entrance_pupil=True, depth=0.0) +.. py:method:: GeoLens.sample_parallel_2D(fov=0.0, num_rays=7, wvln=0.587, plane="meridional", entrance_pupil=True, depth=0.0) Sample 2D parallel rays for layout visualization. @@ -1327,7 +1327,7 @@ Visualization (GeoLensVis) :return: 2D ray object :rtype: Ray -.. py:method:: GeoLens.sample_point_source_2D(fov=0.0, num_rays=7, wvln=0.589, plane="meridional", depth=-10000.0) +.. py:method:: GeoLens.sample_point_source_2D(fov=0.0, num_rays=7, wvln=0.587, plane="meridional", depth=-20000.0) Sample 2D point source rays. @@ -1444,4 +1444,3 @@ References 1. Xinge Yang, Qiang Fu, and Wolfgang Heidrich, "Curriculum learning for ab initio deep learned refractive optics," Nature Communications 2024. 2. Jun Dai, Liqun Chen, Xinge Yang, Yuyao Hu, Jinwei Gu, Tianfan Xue, "Tolerance-Aware Deep Optics," arXiv:2502.04719, 2025. - diff --git a/docs/api/lens.rst b/docs/api/lens.rst index 581890a..ee661dd 100644 --- a/docs/api/lens.rst +++ b/docs/api/lens.rst @@ -6,14 +6,17 @@ This section documents the lens classes and their methods. Base Lens Class --------------- -.. py:class:: Lens(device=None, dtype=torch.float32) +.. py:class:: Lens(dtype=torch.float32, device=None) Base class for all lens systems in DeepLens. - :param device: Device to use ('cuda' or 'cpu') :param dtype: Data type for computations (default: torch.float32) + :param device: Device to use ('cuda' or 'cpu') + + .. note:: + Prefer keyword arguments when constructing lenses; positional arguments follow (dtype, device). - .. py:method:: psf(points, wvln=0.589, ks=64, **kwargs) + .. py:method:: psf(points, wvln=0.587, ks=64, **kwargs) Compute monochrome point PSF. This function should be differentiable. @@ -50,7 +53,7 @@ Base Lens Class :type depth_map: torch.Tensor :param method: Rendering method - 'psf_map', 'psf_patch', or 'psf_pixel' :type method: str - :param kwargs: Additional arguments (interp_mode, psf_grid, psf_ks, etc.) + :param kwargs: Additional arguments (interp_mode, depth_min, depth_max, num_depth, psf_grid, psf_ks, patch_center) :type kwargs: dict :return: Rendered image tensor [B, C, H, W] :rtype: torch.Tensor @@ -64,6 +67,13 @@ Base Lens Class :param sensor_res: Sensor resolution (W, H) in pixels :type sensor_res: tuple + .. py:method:: set_sensor_res(sensor_res) + + Set sensor resolution while keeping sensor radius unchanged. + + :param sensor_res: Sensor resolution (W, H) in pixels + :type sensor_res: tuple + .. py:method:: to(device) Move lens to specified device. @@ -71,12 +81,6 @@ Base Lens Class :param device: 'cuda' or 'cpu' :return: self - .. py:method:: parameters() - - Get optimizable parameters. - - :return: Iterator of torch.nn.Parameter - GeoLens ------- @@ -139,7 +143,7 @@ GeoLens :param ray: Input Ray object :return: Output Ray object - .. py:method:: sample_parallel_2D(fov=0.0, num_rays=7, wvln=0.589, plane='meridional', entrance_pupil=True, depth=0.0) + .. py:method:: sample_parallel_2D(fov=0.0, num_rays=7, wvln=0.587, plane='meridional', entrance_pupil=True, depth=0.0) Sample 2D parallel rays for layout visualization. @@ -158,7 +162,7 @@ GeoLens :return: Ray object with shape [num_rays, 3] :rtype: Ray - .. py:method:: sample_point_source(fov_x=[0.0], fov_y=[0.0], depth=-20000.0, num_rays=16384, wvln=0.589, entrance_pupil=True, scale_pupil=1.0) + .. py:method:: sample_point_source(fov_x=[0.0], fov_y=[0.0], depth=-20000.0, num_rays=16384, wvln=0.587, entrance_pupil=True, scale_pupil=1.0) Sample point source rays from object space with given field angles. @@ -179,7 +183,7 @@ GeoLens :return: Ray object with shape [len(fov_y), len(fov_x), num_rays, 3] :rtype: Ray - .. py:method:: sample_from_points(points=[[0.0, 0.0, -10000.0]], num_rays=16384, wvln=0.589, scale_pupil=1.0) + .. py:method:: sample_from_points(points=[[0.0, 0.0, -10000.0]], num_rays=16384, wvln=0.587, scale_pupil=1.0) Sample rays from point sources at absolute 3D coordinates. @@ -194,19 +198,19 @@ GeoLens :return: Sampled rays with shape [*points.shape[:-1], num_rays, 3] :rtype: Ray - .. py:method:: set_optimizer_params(params_dict) + .. py:method:: get_optimizer_params(lrs=[1e-4, 1e-4, 1e-2, 1e-4], decay=0.01, optim_mat=False, optim_surf_range=None) - Configure which parameters to optimize. + Get optimizer parameter groups for lens optimization. - :param params_dict: Dictionary with keys 'radius', 'thickness', 'conic', 'ai', 'material' + :param lrs: Learning rates for [thickness, curvature, conic, aspheric coefficients] + :param decay: Decay factor for higher-order aspheric coefficients + :param optim_mat: Whether to optimize material properties + :param optim_surf_range: Optional surface indices to optimize Example:: - lens.set_optimizer_params({ - 'radius': True, - 'thickness': True, - 'ai': True - }) + params = lens.get_optimizer_params(lrs=[1e-4, 1e-4, 1e-2, 1e-4]) + optimizer = torch.optim.Adam(params) .. py:method:: calc_foclen() @@ -522,10 +526,10 @@ Basic Usage ) # Calculate PSF - psf = lens.psf(depth=1000, spp=2048) + psf = lens.psf(points=[0.0, 0.0, -1000.0], spp=2048) # Render image - img_rendered = lens.render(img, depth=1000) + img_rendered = lens.render(img, depth=-1000) Lens Optimization ^^^^^^^^^^^^^^^^^ @@ -536,15 +540,18 @@ Lens Optimization from deeplens.optics import SpotLoss # Setup optimization - lens.set_optimizer_params({'radius': True, 'thickness': True}) - optimizer = optim.Adam(lens.parameters(), lr=0.01) + params = lens.get_optimizer_params(lrs=[1e-4, 1e-4, 1e-2, 1e-4]) + optimizer = optim.Adam(params) loss_fn = SpotLoss() # Optimization loop for i in range(1000): optimizer.zero_grad() - ray = lens.sample_point_source(depth=1e4, M=256) - ray_out = lens.trace(ray) + ray = lens.sample_point_source( + depth=-1e4, + num_rays=256, # 256 vs 16384 for faster iterations (lower sampling accuracy per step) + ) + ray_out, _ = lens.trace(ray) loss = loss_fn(ray_out) + lens.loss_constraint() loss.backward() optimizer.step() @@ -554,15 +561,15 @@ Fast PSF Prediction .. code-block:: python + import torch from deeplens import PSFNetLens # Load pre-trained model - lens = PSFNetLens( - ckpt_path='./ckpts/psfnet/PSFNet_ef50mm_f1.8_ps10um.pth' - ) + lens = PSFNetLens(lens_path='./datasets/lenses/camera/ef50mm_f1.8.json') + lens.load_net('./ckpts/psfnet/PSFNet_ef50mm_f1.8_ps10um.pth') # Fast PSF calculation - psf = lens.psf(depth=1000, field=[0, 0.5]) + psf = lens.psf_rgb(points=torch.tensor([[0.0, 0.5, -1000.0]]), ks=64).squeeze(0) # 100x faster than ray tracing! @@ -590,4 +597,3 @@ See Also * :doc:`../user_guide/lens_systems` - Detailed lens system guide * :doc:`../tutorials` - Step-by-step tutorials * :doc:`../examples/automated_lens_design` - Optimization examples - diff --git a/docs/contributing.rst b/docs/contributing.rst index 8b939bc..cf5d18b 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -162,7 +162,7 @@ Write tests for new features: def test_psf_calculation(): """Test PSF calculation.""" lens = GeoLens(filename='./datasets/lenses/camera/ef50mm_f1.8.json') - psf = lens.psf(depth=1000, spp=256) + psf = lens.psf(points=[0.0, 0.0, -1000.0], spp=256) assert psf.shape[0] == 1 # Single channel assert psf.sum() > 0 # Non-zero PSF @@ -436,4 +436,3 @@ See Also * :doc:`code_of_conduct` - Community guidelines * `GitHub Repository `_ * `Join Slack `_ - diff --git a/docs/examples/image_simulation.rst b/docs/examples/image_simulation.rst index 1928fc4..3ed382d 100644 --- a/docs/examples/image_simulation.rst +++ b/docs/examples/image_simulation.rst @@ -398,11 +398,12 @@ Light Field Rendering offset_x = (i - viewpoints//2) * 0.5 offset_y = (j - viewpoints//2) * 0.5 - # Render with offset - img_view = lens.render_with_offset( + # Render with offset (normalized patch center) + img_view = lens.render( img_tensor, - depth=1000, - offset=[offset_x, offset_y] + depth=-1000, + method="psf_patch", + patch_center=(offset_x, offset_y) ) light_field.append(img_view) @@ -420,13 +421,15 @@ For large images: .. code-block:: python + from deeplens.optics.psf import conv_psf + def render_tiled(img, depth, lens, tile_size=256, overlap=32): """Memory-efficient tile-based rendering.""" B, C, H, W = img.shape output = torch.zeros_like(img) # Calculate PSF once - psf = lens.psf(depth=depth, spp=1024) + psf = lens.psf_rgb(points=torch.tensor([[0.0, 0.0, -depth]]), spp=1024) for i in range(0, H, tile_size - overlap): for j in range(0, W, tile_size - overlap): @@ -436,7 +439,7 @@ For large images: tile = img[:, :, i1:i2, j1:j2] # Render tile - tile_rendered = lens.convolve_with_psf(tile, psf) + tile_rendered = conv_psf(tile, psf) # Blend into output output[:, :, i1:i2, j1:j2] = tile_rendered @@ -526,4 +529,3 @@ See Also * :doc:`../user_guide/lens_systems` - Lens system details * :doc:`../user_guide/sensors` - Sensor simulation * Example script: ``7_image_simulation.py`` - diff --git a/docs/user_guide/neural_networks.rst b/docs/user_guide/neural_networks.rst index f8589da..732a437 100644 --- a/docs/user_guide/neural_networks.rst +++ b/docs/user_guide/neural_networks.rst @@ -398,7 +398,10 @@ Create custom datasets: field = torch.rand(2) * 2 - 1 # [-1, 1] # Generate PSF - psf = self.lens.psf(depth=depth.item(), field=field.tolist()) + points = torch.tensor( + [[field[0].item(), field[1].item(), -depth.item()]] + ) + psf = self.lens.psf(points=points) return depth, field, psf @@ -410,6 +413,7 @@ Joint Lens-Network Optimization .. code-block:: python + import torch from deeplens import GeoLens from deeplens.network import UNet @@ -418,19 +422,18 @@ Joint Lens-Network Optimization network = UNet(in_channels=3, out_channels=3).cuda() # Enable lens optimization - lens.set_optimizer_params({'radius': True, 'thickness': True}) + lens_params = lens.get_optimizer_params(lrs=[1e-4, 1e-4, 1e-2, 1e-4]) # Combined optimizer - optimizer = torch.optim.Adam([ - {'params': lens.parameters(), 'lr': 1e-3}, - {'params': network.parameters(), 'lr': 1e-4} - ]) + optimizer = torch.optim.Adam( + lens_params + [{'params': network.parameters(), 'lr': 1e-4}] + ) # Training loop for epoch in range(100): for img_clean in dataloader: # Forward through lens - img_degraded = lens.render(img_clean, depth=1000) + img_degraded = lens.render(img_clean, depth=-1000) # Restore with network img_restored = network(img_degraded) @@ -647,4 +650,3 @@ Next Steps * Learn about :doc:`lens_systems` for optical system design * Check :doc:`../tutorials` for training workflows * Explore :doc:`../api/network` for detailed API reference -