diff --git a/src/cdtools/models/fancy_ptycho.py b/src/cdtools/models/fancy_ptycho.py index 1be40c68..dba678ac 100644 --- a/src/cdtools/models/fancy_ptycho.py +++ b/src/cdtools/models/fancy_ptycho.py @@ -40,7 +40,10 @@ def __init__(self, exponentiate_obj=False, phase_only=False, dtype=t.float32, - obj_view_crop=0 + obj_view_crop=0, + near_field=False, + angular_spectrum_propagator=None, + inv_angular_spectrum_propagator=None, ): super(FancyPtycho, self).__init__() @@ -79,6 +82,25 @@ def __init__(self, self.register_buffer('phase_only', t.as_tensor(phase_only, dtype=bool)) + self.register_buffer('near_field', + t.as_tensor(near_field, dtype=bool)) + + if angular_spectrum_propagator is None: + self.angular_spectrum_propagator = None + else: + self.register_buffer( + 'angular_spectrum_propagator', + t.as_tensor(angular_spectrum_propagator, dtype=t.complex64) + ) + + if inv_angular_spectrum_propagator is None: + self.inv_angular_spectrum_propagator = None + else: + self.register_buffer( + 'inv_angular_spectrum_propagator', + t.as_tensor(inv_angular_spectrum_propagator, dtype=t.complex64) + ) + # Not sure how to make this a buffer... self.units = units @@ -217,6 +239,7 @@ def from_dataset(cls, phase_only=False, obj_view_crop=None, obj_padding=200, + near_field=False, ): wavelength = dataset.wavelength @@ -234,16 +257,86 @@ def from_dataset(cls, dataset.get_as(*get_as_args[0], **get_as_args[1]) - # Then, generate the probe geometry from the dataset - ewg = tools.initializers.exit_wave_geometry - obj_basis = ewg( - det_basis, - det_shape, - wavelength, - distance, - oversampling=oversampling, - ) + if not near_field: + # Then, generate the probe geometry from the dataset + ewg = tools.initializers.exit_wave_geometry + obj_basis = ewg( + det_basis, + det_shape, + wavelength, + distance, + oversampling=oversampling, + ) + probe = tools.initializers.SHARP_style_probe( + dataset, + propagation_distance=propagation_distance, + oversampling=oversampling, + ) + angular_spectrum_propagator=None + inv_angular_spectrum_propagator=None + + else: + if propagation_distance is None or propagation_distance==0: + # In this case, we assume that we're genuinely in a near + # field geometry, such that z_eff = z and there is no + # magnification + obj_basis = t.as_tensor(det_basis) / oversampling + angular_spectrum_propagator = \ + tools.propagators.generate_generalized_angular_spectrum_propagator( + [d*oversampling for d in det_shape], + obj_basis, + wavelength, + np.array([0,0,distance]), + ) + inv_angular_spectrum_propagator = \ + t.conj(angular_spectrum_propagator) + inv_angular_spectrum_propagator_init = t.conj( + tools.propagators.generate_generalized_angular_spectrum_propagator( + det_shape, + obj_basis, + wavelength, + np.array([0,0,distance]), + ) + ) + else: + # In this case, we assume that we're in a projection geometry + # with a z_eff based on propagation_distance and a nonzero + # magnification + M = (propagation_distance + distance) / propagation_distance + z_eff = distance / M + + obj_basis = t.as_tensor(det_basis) / (oversampling * M) + angular_spectrum_propagator = \ + tools.propagators.generate_generalized_angular_spectrum_propagator( + [d * oversampling for d in det_shape], + obj_basis, + wavelength, + np.array([0,0,z_eff]), + ) + inv_angular_spectrum_propagator = t.conj( + angular_spectrum_propagator) + inv_angular_spectrum_propagator_init = t.conj( + tools.propagators.generate_generalized_angular_spectrum_propagator( + det_shape, + obj_basis, + wavelength, + np.array([0,0,z_eff]), + ) + ) + + backward_propagator = lambda wavefields: \ + tools.propagators.near_field( + wavefields, + inv_angular_spectrum_propagator_init + ) + + probe = tools.initializers.SHARP_style_near_field_probe( + dataset, + backward_propagator=backward_propagator, + oversampling=oversampling, + ) + if hasattr(dataset, 'sample_info') and \ dataset.sample_info is not None and \ 'orientation' in dataset.sample_info: @@ -276,21 +369,6 @@ def from_dataset(cls, padding=obj_padding, ) - # Finally, initialize the probe and object using this information - if probe_shape is None: - probe = tools.initializers.SHARP_style_probe( - dataset, - propagation_distance=propagation_distance, - oversampling=oversampling, - ) - else: - probe = tools.initializers.gaussian_probe( - dataset, - obj_basis, - probe_shape, - propagation_distance=propagation_distance, - ) - if hasattr(dataset, 'background') and dataset.background is not None: background = t.sqrt(dataset.background) else: @@ -389,24 +467,35 @@ def from_dataset(cls, else: probe_support = None - return cls(wavelength, det_geo, obj_basis, probe, obj, - surface_normal=surface_normal, - min_translation=min_translation, - translation_offsets=translation_offsets, - weights=Ws, mask=mask, background=background, - translation_scale=translation_scale, - saturation=saturation, - probe_basis=probe_basis, - probe_support=probe_support, - fourier_probe=fourier_probe, - oversampling=oversampling, - loss=loss, units=units, - probe_fourier_shifts=probe_fourier_shifts, - simulate_probe_translation=simulate_probe_translation, - simulate_finite_pixels=simulate_finite_pixels, - phase_only=phase_only, - exponentiate_obj=exponentiate_obj, - obj_view_crop=obj_view_crop) + return cls( + wavelength, + det_geo, + obj_basis, + probe, + obj, + surface_normal=surface_normal, + min_translation=min_translation, + translation_offsets=translation_offsets, + weights=Ws, + mask=mask, + background=background, + translation_scale=translation_scale, + saturation=saturation, + probe_basis=probe_basis, + probe_support=probe_support, + fourier_probe=fourier_probe, + oversampling=oversampling, + loss=loss, units=units, + probe_fourier_shifts=probe_fourier_shifts, + simulate_probe_translation=simulate_probe_translation, + simulate_finite_pixels=simulate_finite_pixels, + phase_only=phase_only, + exponentiate_obj=exponentiate_obj, + obj_view_crop=obj_view_crop, + near_field=near_field, + angular_spectrum_propagator=angular_spectrum_propagator, + inv_angular_spectrum_propagator=inv_angular_spectrum_propagator, + ) def interaction(self, index, translations, *args): @@ -506,16 +595,26 @@ def interaction(self, index, translations, *args): probe_support=self.probe_support) return exit_waves - + def forward_propagator(self, wavefields): - return tools.propagators.far_field(wavefields) + if self.near_field: + return tools.propagators.near_field( + wavefields, self.angular_spectrum_propagator + ) + else: + return tools.propagators.far_field(wavefields) def backward_propagator(self, wavefields): - return tools.propagators.inverse_far_field(wavefields) - + if self.near_field: + return tools.propagators.near_field( + wavefields, self.inverse_angular_spectrum_propagator + ) + else: + return tools.propagators.inverse_far_field(wavefields) + def measurement(self, wavefields): return tools.measurements.quadratic_background( wavefields, @@ -735,7 +834,7 @@ def get_probes(idx): values=values, fig=fig, units=self.units, - basis=self.obj_basis, + basis=self.probe_basis, nanomap_colorbar_title='Total Probe Intensity', cmap=cmap, **kwargs), diff --git a/src/cdtools/tools/image_processing/image_processing.py b/src/cdtools/tools/image_processing/image_processing.py index 8a9eb4b3..b774daa7 100644 --- a/src/cdtools/tools/image_processing/image_processing.py +++ b/src/cdtools/tools/image_processing/image_processing.py @@ -321,19 +321,23 @@ def convolve_1d(image, kernel, dim=0, fftshift_kernel=True): return conv_im -def fourier_upsample(ims, preserve_mean=False): +def fourier_upsample(ims, upsample_factor=2, preserve_mean=False): # If preserve_mean is true, it preserves the mean pixel intensity # otherwise, it preserves the total summed intensity - upsampled = t.zeros(ims.shape[:-2]+(2*ims.shape[-2],2*ims.shape[-1]), + + upsampled = t.zeros(ims.shape[:-2]+(upsample_factor*ims.shape[-2], + upsample_factor*ims.shape[-1]), dtype=ims.dtype, device=ims.device) - left = [ims.shape[-2]//2,ims.shape[-1]//2] - right = [ims.shape[-2]//2+ims.shape[-2], - ims.shape[-1]//2+ims.shape[-1]] + + left = [((upsample_factor-1)*ims.shape[-2])//2, + ((upsample_factor-1)*ims.shape[-1])//2] + right = [left[0]+ims.shape[-2], + left[1]+ims.shape[-1]] upsampled[...,left[0]:right[0],left[1]:right[1]] = propagators.far_field(ims) if preserve_mean: - upsampled *= 2 + upsampled *= upsample_factor return propagators.inverse_far_field(upsampled) diff --git a/src/cdtools/tools/initializers/initializers.py b/src/cdtools/tools/initializers/initializers.py index bdb91b9a..c3aa47ea 100644 --- a/src/cdtools/tools/initializers/initializers.py +++ b/src/cdtools/tools/initializers/initializers.py @@ -16,10 +16,17 @@ import numpy as np from functools import * -__all__ = ['exit_wave_geometry', 'calc_object_setup', 'gaussian', - 'gaussian_probe', 'SHARP_style_probe', 'STEM_style_probe', - 'RPI_spectral_init', - 'generate_subdominant_modes'] +__all__ = [ + 'exit_wave_geometry', + 'calc_object_setup', + 'gaussian', + 'gaussian_probe', + 'SHARP_style_probe', + 'SHARP_style_near_field_probe', + 'STEM_style_probe', + 'RPI_spectral_init', + 'generate_subdominant_modes' + ] def exit_wave_geometry(det_basis, det_shape, wavelength, distance, oversampling=1): """Returns an exit wave basis and shape, as well as a detector slice for the given detector geometry @@ -317,10 +324,20 @@ def SHARP_style_probe(dataset, propagation_distance=None, oversampling=1): probe_fft = t.tensor(np.sqrt(intensities)).to(dtype=t.complex64) probe_guess = inverse_far_field(probe_fft) + # Finally, place this probe in a full-sized array if there is oversampling + full_shape = [oversampling * s for s in shape] + large_probe_guess = t.zeros(full_shape, dtype=probe_guess.dtype) + left = full_shape[0]//2 - shape[0] // 2 + top = full_shape[1]//2 - shape[1] // 2 + large_probe_guess[left : left + shape[0], + top : top + shape[1]] = probe_guess + + if propagation_distance is not None: # First generate the propagation array probe_shape = t.as_tensor(tuple(probe_guess.shape)) + large_probe_shape = t.as_tensor(tuple(large_probe_guess.shape)) # Start by recalculating the probe basis from the given information det_basis = t.as_tensor(dataset.detector_geometry['basis']) @@ -331,26 +348,83 @@ def SHARP_style_probe(dataset, propagation_distance=None, oversampling=1): # Then package everything as it's needed probe_spacing = t.norm(probe_basis,dim=0).numpy() - probe_shape = probe_shape.numpy().astype(np.int32) + large_probe_shape = large_probe_shape.numpy().astype(np.int32) # And generate the propagator AS_prop = generate_angular_spectrum_propagator( - probe_shape, + large_probe_shape, probe_spacing, dataset.wavelength, propagation_distance) - probe_guess = near_field(probe_guess,AS_prop) + large_probe_guess = near_field(large_probe_guess,AS_prop) - # Finally, place this probe in a full-sized array if there is oversampling - full_shape = [oversampling * s for s in shape] - final_probe = t.zeros(full_shape, dtype=t.complex64) - left = full_shape[0]//2 - shape[0] // 2 - top = full_shape[1]//2 - shape[1] // 2 - final_probe[left : left + shape[0], - top : top + shape[1]] = probe_guess - return final_probe + return large_probe_guess + +def SHARP_style_near_field_probe(dataset, backward_propagator, oversampling=1): + """Generates a SHARP style probe guess from a dataset + + What we call the "SHARP" style probe guess is to take a mean of all + the diffraction patterns and use that as an initial guess of the + Fourier space distribution of the probe. We set all the phases to + zero, which would for many simple beams (like a zone plate) generate + a first guess of the probe that is very close to the focal spot of + the probe beam. + + Parameters + ---------- + dataset : Ptycho_2D_Dataset + The dataset to work from + backward_propagator : function + A propagator (typically angular spectrum) used to map from the detector plane to the sample plane + oversampling : int + Default 1, the width of the region of pixels in the wavefield to bin into a single detector pixel + + Returns + ------- + torch.Tensor + The complex-style tensor storing the probe guess + """ + + # NOTE: I don't love the way np and torch are mixed here, I think this + # function deserves some love. + + shape = dataset.patterns.shape[-2:] + + # to use the mask or not? + intensities = np.zeros([dim for dim in shape]) + + # Eventually, do something with the recorded intensities, if they exist + factors = [1 for idx in range(len(dataset))] + + for params, im in dataset: + if hasattr(dataset,'mask') and dataset.mask is not None: + intensities += (dataset.mask.cpu().numpy() * im.cpu().numpy() + / factors[params[0]]) + else: + intensities += im.cpu().numpy() / params[factors[0]] + + intensities /= len(dataset) + + # Subtract off a known background if it's stored + if hasattr(dataset, 'background') and dataset.background is not None: + intensities = np.clip( + intensities - dataset.background.cpu().numpy(), + a_min=0, + a_max=None, + ) + + probe_guess_det_plane = t.tensor(np.sqrt(intensities)).to(dtype=t.complex64) + probe_guess = backward_propagator(probe_guess_det_plane) + + if oversampling != 1: + probe_guess = image_processing.fourier_upsample( + probe_guess, + upsample_factor=oversampling, preserve_mean=False + ) + + return probe_guess def STEM_style_probe(dataset, shape, det_slice, convergence_semiangle, propagation_distance=None, oversampling=1):