diff --git a/src/pyrtools/pyramids/SteerablePyramidFreq.py b/src/pyrtools/pyramids/SteerablePyramidFreq.py index c05797a..eb6dffc 100644 --- a/src/pyrtools/pyramids/SteerablePyramidFreq.py +++ b/src/pyrtools/pyramids/SteerablePyramidFreq.py @@ -20,6 +20,9 @@ class SteerablePyramidFreq(SteerablePyramidBase): The squared radial functions tile the Fourier plane with a raised-cosine falloff. Angular functions are cos(theta- k*pi/order+1)^(order). + Note that reconstruction will not be exact if the image has an odd shape (due to + boundary-handling issues) or if the pyramid is complex with order=0. + Notes ----- Transform described in [1]_, filter kernel design described in [2]_. @@ -30,7 +33,7 @@ class SteerablePyramidFreq(SteerablePyramidBase): 2d image upon which to construct to the pyramid. height : 'auto' or `int`. The height of the pyramid. If 'auto', will automatically determine based on the size of - `image`. + `image`. If an int, must be non-negative. When height=0, only returns the residuals. order : `int`. The Gaussian derivative order used for the steerable filters. Default value is 3. Note that to achieve steerability the minimum number of orientation is `order` + 1, @@ -52,7 +55,8 @@ class SteerablePyramidFreq(SteerablePyramidBase): Human-readable string specifying the type of pyramid. For base class, is None. pyr_coeffs : `dict` Dictionary containing the coefficients of the pyramid. Keys are `(level, band)` tuples and - values are 1d or 2d numpy arrays (same number of dimensions as the input image) + values are 1d or 2d numpy arrays (same number of dimensions as the input image), + running from fine to coarse. pyr_size : `dict` Dictionary containing the sizes of the pyramid coefficients. Keys are `(level, band)` tuples and values are tuples. @@ -66,6 +70,7 @@ class SteerablePyramidFreq(SteerablePyramidBase): Oct 1995. .. [2] A Karasaridis and E P Simoncelli, "A Filter Design Technique for Steerable Pyramid Image Transforms", ICASSP, Atlanta, GA, May 1996. + """ def __init__(self, image, height='auto', order=3, twidth=1, is_complex=False): # in the Fourier domain, there's only one choice for how do edge-handling: circular. to @@ -78,24 +83,35 @@ def __init__(self, image, height='auto', order=3, twidth=1, is_complex=False): self.filters = {} self.order = int(order) + if (image.shape[0] % 2 != 0) or (image.shape[1] % 2 != 0): + warnings.warn("Reconstruction will not be perfect with odd-sized images") + + if self.order == 0 and self.is_complex: + raise ValueError( + "Complex pyramid cannot have order=0! See " + "https://github.com/plenoptic-org/plenoptic/issues/326 " + "for an explanation." + ) + # we can't use the base class's _set_num_scales method because the max height is calculated # slightly differently max_ht = np.floor(np.log2(min(self.image.shape))) - 2 if height == 'auto' or height is None: self.num_scales = int(max_ht) elif height > max_ht: - raise Exception("Cannot build pyramid higher than %d levels." % (max_ht)) + raise ValueError("Cannot build pyramid higher than %d levels." % (max_ht)) + elif height < 0: + raise ValueError("Height must be a non-negative int.") else: self.num_scales = int(height) if self.order > 15 or self.order < 0: - raise Exception("order must be an integer in the range [0,15]. Truncating.") + raise ValueError("order must be an integer in the range [0,15].") self.num_orientations = int(order + 1) if twidth <= 0: - warnings.warn("twidth must be positive. Setting to 1.") - twidth = 1 + raise ValueError("twidth must be positive.") twidth = int(twidth) dims = np.array(self.image.shape) @@ -220,8 +236,7 @@ def recon_pyr(self, levels='all', bands='all', twidth=1): """ if twidth <= 0: - warnings.warn("twidth must be positive. Setting to 1.") - twidth = 1 + raise ValueError("twidth must be positive.") recon_keys = self._recon_keys(levels, bands) diff --git a/src/pyrtools/pyramids/pyramid.py b/src/pyrtools/pyramids/pyramid.py index 9670582..5f216e6 100644 --- a/src/pyrtools/pyramids/pyramid.py +++ b/src/pyrtools/pyramids/pyramid.py @@ -96,7 +96,7 @@ def _set_num_scales(self, filter_name, height, extra_height=0): if height == 'auto': self.num_scales = max_ht elif height > max_ht: - raise Exception("Cannot build pyramid higher than %d levels." % (max_ht)) + raise ValueError("Cannot build pyramid higher than %d levels." % (max_ht)) else: self.num_scales = int(height)