Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 147 additions & 48 deletions src/cdtools/models/fancy_ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def __init__(self,
exponentiate_obj=False,
phase_only=False,
dtype=t.float32,
obj_view_crop=0
obj_view_crop=0,
near_field=False,
angular_spectrum_propagator=None,
inv_angular_spectrum_propagator=None,
):

super(FancyPtycho, self).__init__()
Expand Down Expand Up @@ -79,6 +82,25 @@ def __init__(self,
self.register_buffer('phase_only',
t.as_tensor(phase_only, dtype=bool))

self.register_buffer('near_field',
t.as_tensor(near_field, dtype=bool))

if angular_spectrum_propagator is None:
self.angular_spectrum_propagator = None
else:
self.register_buffer(
'angular_spectrum_propagator',
t.as_tensor(angular_spectrum_propagator, dtype=t.complex64)
)

if inv_angular_spectrum_propagator is None:
self.inv_angular_spectrum_propagator = None
else:
self.register_buffer(
'inv_angular_spectrum_propagator',
t.as_tensor(inv_angular_spectrum_propagator, dtype=t.complex64)
)

# Not sure how to make this a buffer...
self.units = units

Expand Down Expand Up @@ -217,6 +239,7 @@ def from_dataset(cls,
phase_only=False,
obj_view_crop=None,
obj_padding=200,
near_field=False,
):

wavelength = dataset.wavelength
Expand All @@ -234,16 +257,86 @@ def from_dataset(cls,

dataset.get_as(*get_as_args[0], **get_as_args[1])

# Then, generate the probe geometry from the dataset
ewg = tools.initializers.exit_wave_geometry
obj_basis = ewg(
det_basis,
det_shape,
wavelength,
distance,
oversampling=oversampling,
)
if not near_field:
# Then, generate the probe geometry from the dataset
ewg = tools.initializers.exit_wave_geometry
obj_basis = ewg(
det_basis,
det_shape,
wavelength,
distance,
oversampling=oversampling,
)

probe = tools.initializers.SHARP_style_probe(
dataset,
propagation_distance=propagation_distance,
oversampling=oversampling,
)
angular_spectrum_propagator=None
inv_angular_spectrum_propagator=None

else:
if propagation_distance is None or propagation_distance==0:
# In this case, we assume that we're genuinely in a near
# field geometry, such that z_eff = z and there is no
# magnification
obj_basis = t.as_tensor(det_basis) / oversampling
angular_spectrum_propagator = \
tools.propagators.generate_generalized_angular_spectrum_propagator(
[d*oversampling for d in det_shape],
obj_basis,
wavelength,
np.array([0,0,distance]),
)
inv_angular_spectrum_propagator = \
t.conj(angular_spectrum_propagator)
inv_angular_spectrum_propagator_init = t.conj(
tools.propagators.generate_generalized_angular_spectrum_propagator(
det_shape,
obj_basis,
wavelength,
np.array([0,0,distance]),
)
)
else:
# In this case, we assume that we're in a projection geometry
# with a z_eff based on propagation_distance and a nonzero
# magnification
M = (propagation_distance + distance) / propagation_distance
z_eff = distance / M

obj_basis = t.as_tensor(det_basis) / (oversampling * M)
angular_spectrum_propagator = \
tools.propagators.generate_generalized_angular_spectrum_propagator(
[d * oversampling for d in det_shape],
obj_basis,
wavelength,
np.array([0,0,z_eff]),
)
inv_angular_spectrum_propagator = t.conj(
angular_spectrum_propagator)
inv_angular_spectrum_propagator_init = t.conj(
tools.propagators.generate_generalized_angular_spectrum_propagator(
det_shape,
obj_basis,
wavelength,
np.array([0,0,z_eff]),
)
)

backward_propagator = lambda wavefields: \
tools.propagators.near_field(
wavefields,
inv_angular_spectrum_propagator_init
)

probe = tools.initializers.SHARP_style_near_field_probe(
dataset,
backward_propagator=backward_propagator,
oversampling=oversampling,
)

if hasattr(dataset, 'sample_info') and \
dataset.sample_info is not None and \
'orientation' in dataset.sample_info:
Expand Down Expand Up @@ -276,21 +369,6 @@ def from_dataset(cls,
padding=obj_padding,
)

# Finally, initialize the probe and object using this information
if probe_shape is None:
probe = tools.initializers.SHARP_style_probe(
dataset,
propagation_distance=propagation_distance,
oversampling=oversampling,
)
else:
probe = tools.initializers.gaussian_probe(
dataset,
obj_basis,
probe_shape,
propagation_distance=propagation_distance,
)

if hasattr(dataset, 'background') and dataset.background is not None:
background = t.sqrt(dataset.background)
else:
Expand Down Expand Up @@ -389,24 +467,35 @@ def from_dataset(cls,
else:
probe_support = None

return cls(wavelength, det_geo, obj_basis, probe, obj,
surface_normal=surface_normal,
min_translation=min_translation,
translation_offsets=translation_offsets,
weights=Ws, mask=mask, background=background,
translation_scale=translation_scale,
saturation=saturation,
probe_basis=probe_basis,
probe_support=probe_support,
fourier_probe=fourier_probe,
oversampling=oversampling,
loss=loss, units=units,
probe_fourier_shifts=probe_fourier_shifts,
simulate_probe_translation=simulate_probe_translation,
simulate_finite_pixels=simulate_finite_pixels,
phase_only=phase_only,
exponentiate_obj=exponentiate_obj,
obj_view_crop=obj_view_crop)
return cls(
wavelength,
det_geo,
obj_basis,
probe,
obj,
surface_normal=surface_normal,
min_translation=min_translation,
translation_offsets=translation_offsets,
weights=Ws,
mask=mask,
background=background,
translation_scale=translation_scale,
saturation=saturation,
probe_basis=probe_basis,
probe_support=probe_support,
fourier_probe=fourier_probe,
oversampling=oversampling,
loss=loss, units=units,
probe_fourier_shifts=probe_fourier_shifts,
simulate_probe_translation=simulate_probe_translation,
simulate_finite_pixels=simulate_finite_pixels,
phase_only=phase_only,
exponentiate_obj=exponentiate_obj,
obj_view_crop=obj_view_crop,
near_field=near_field,
angular_spectrum_propagator=angular_spectrum_propagator,
inv_angular_spectrum_propagator=inv_angular_spectrum_propagator,
)


def interaction(self, index, translations, *args):
Expand Down Expand Up @@ -506,16 +595,26 @@ def interaction(self, index, translations, *args):
probe_support=self.probe_support)

return exit_waves


def forward_propagator(self, wavefields):
return tools.propagators.far_field(wavefields)
if self.near_field:
return tools.propagators.near_field(
wavefields, self.angular_spectrum_propagator
)
else:
return tools.propagators.far_field(wavefields)


def backward_propagator(self, wavefields):
return tools.propagators.inverse_far_field(wavefields)

if self.near_field:
return tools.propagators.near_field(
wavefields, self.inverse_angular_spectrum_propagator
)
else:
return tools.propagators.inverse_far_field(wavefields)


def measurement(self, wavefields):
return tools.measurements.quadratic_background(
wavefields,
Expand Down Expand Up @@ -735,7 +834,7 @@ def get_probes(idx):
values=values,
fig=fig,
units=self.units,
basis=self.obj_basis,
basis=self.probe_basis,
nanomap_colorbar_title='Total Probe Intensity',
cmap=cmap,
**kwargs),
Expand Down
16 changes: 10 additions & 6 deletions src/cdtools/tools/image_processing/image_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,19 +321,23 @@ def convolve_1d(image, kernel, dim=0, fftshift_kernel=True):
return conv_im


def fourier_upsample(ims, preserve_mean=False):
def fourier_upsample(ims, upsample_factor=2, preserve_mean=False):
# If preserve_mean is true, it preserves the mean pixel intensity
# otherwise, it preserves the total summed intensity
upsampled = t.zeros(ims.shape[:-2]+(2*ims.shape[-2],2*ims.shape[-1]),

upsampled = t.zeros(ims.shape[:-2]+(upsample_factor*ims.shape[-2],
upsample_factor*ims.shape[-1]),
dtype=ims.dtype,
device=ims.device)
left = [ims.shape[-2]//2,ims.shape[-1]//2]
right = [ims.shape[-2]//2+ims.shape[-2],
ims.shape[-1]//2+ims.shape[-1]]

left = [((upsample_factor-1)*ims.shape[-2])//2,
((upsample_factor-1)*ims.shape[-1])//2]
right = [left[0]+ims.shape[-2],
left[1]+ims.shape[-1]]

upsampled[...,left[0]:right[0],left[1]:right[1]] = propagators.far_field(ims)
if preserve_mean:
upsampled *= 2
upsampled *= upsample_factor
return propagators.inverse_far_field(upsampled)


Expand Down
Loading