Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions src/pyrtools/pyramids/SteerablePyramidFreq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class SteerablePyramidFreq(SteerablePyramidBase):
The squared radial functions tile the Fourier plane with a raised-cosine
falloff. Angular functions are cos(theta- k*pi/order+1)^(order).

Note that reconstruction will not be exact if the image has an odd shape (due to
boundary-handling issues) or if the pyramid is complex with order=0.

Notes
-----
Transform described in [1]_, filter kernel design described in [2]_.
Expand All @@ -30,7 +33,7 @@ class SteerablePyramidFreq(SteerablePyramidBase):
2d image upon which to construct to the pyramid.
height : 'auto' or `int`.
The height of the pyramid. If 'auto', will automatically determine based on the size of
`image`.
`image`. If an int, must be non-negative. When height=0, only returns the residuals.
order : `int`.
The Gaussian derivative order used for the steerable filters. Default value is 3.
Note that to achieve steerability the minimum number of orientation is `order` + 1,
Expand All @@ -52,7 +55,8 @@ class SteerablePyramidFreq(SteerablePyramidBase):
Human-readable string specifying the type of pyramid. For base class, is None.
pyr_coeffs : `dict`
Dictionary containing the coefficients of the pyramid. Keys are `(level, band)` tuples and
values are 1d or 2d numpy arrays (same number of dimensions as the input image)
values are 1d or 2d numpy arrays (same number of dimensions as the input image),
running from fine to coarse.
pyr_size : `dict`
Dictionary containing the sizes of the pyramid coefficients. Keys are `(level, band)`
tuples and values are tuples.
Expand All @@ -66,6 +70,7 @@ class SteerablePyramidFreq(SteerablePyramidBase):
Oct 1995.
.. [2] A Karasaridis and E P Simoncelli, "A Filter Design Technique for Steerable Pyramid
Image Transforms", ICASSP, Atlanta, GA, May 1996.

"""
def __init__(self, image, height='auto', order=3, twidth=1, is_complex=False):
# in the Fourier domain, there's only one choice for how do edge-handling: circular. to
Expand All @@ -78,24 +83,35 @@ def __init__(self, image, height='auto', order=3, twidth=1, is_complex=False):
self.filters = {}
self.order = int(order)

if (image.shape[0] % 2 != 0) or (image.shape[1] % 2 != 0):
warnings.warn("Reconstruction will not be perfect with odd-sized images")

if self.order == 0 and self.is_complex:
raise ValueError(
"Complex pyramid cannot have order=0! See "
"https://github.com/plenoptic-org/plenoptic/issues/326 "
"for an explanation."
)

# we can't use the base class's _set_num_scales method because the max height is calculated
# slightly differently
max_ht = np.floor(np.log2(min(self.image.shape))) - 2
if height == 'auto' or height is None:
self.num_scales = int(max_ht)
elif height > max_ht:
raise Exception("Cannot build pyramid higher than %d levels." % (max_ht))
raise ValueError("Cannot build pyramid higher than %d levels." % (max_ht))
elif height < 0:
raise ValueError("Height must be a non-negative int.")
else:
self.num_scales = int(height)

if self.order > 15 or self.order < 0:
raise Exception("order must be an integer in the range [0,15]. Truncating.")
raise ValueError("order must be an integer in the range [0,15].")

self.num_orientations = int(order + 1)

if twidth <= 0:
warnings.warn("twidth must be positive. Setting to 1.")
twidth = 1
raise ValueError("twidth must be positive.")
twidth = int(twidth)

dims = np.array(self.image.shape)
Expand Down Expand Up @@ -220,8 +236,7 @@ def recon_pyr(self, levels='all', bands='all', twidth=1):

"""
if twidth <= 0:
warnings.warn("twidth must be positive. Setting to 1.")
twidth = 1
raise ValueError("twidth must be positive.")

recon_keys = self._recon_keys(levels, bands)

Expand Down
2 changes: 1 addition & 1 deletion src/pyrtools/pyramids/pyramid.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _set_num_scales(self, filter_name, height, extra_height=0):
if height == 'auto':
self.num_scales = max_ht
elif height > max_ht:
raise Exception("Cannot build pyramid higher than %d levels." % (max_ht))
raise ValueError("Cannot build pyramid higher than %d levels." % (max_ht))
else:
self.num_scales = int(height)

Expand Down