From 6740653e4c980f29aa6405238a288a6124e7d508 Mon Sep 17 00:00:00 2001 From: "William F. Broderick" Date: Thu, 20 Feb 2025 11:47:27 -0500 Subject: [PATCH] updates pyrshow error message, adds tests --- TESTS/unitTests.py | 26 +++++++++++ src/pyrtools/tools/display.py | 83 ++++++++++++++++++++++++++++------- 2 files changed, 94 insertions(+), 15 deletions(-) diff --git a/TESTS/unitTests.py b/TESTS/unitTests.py index 6a233d0..4a90787 100755 --- a/TESTS/unitTests.py +++ b/TESTS/unitTests.py @@ -1460,6 +1460,32 @@ def test_animshow_fail_n_frames(self): with self.assertRaises(Exception): fig = pt.animshow([vid1, vid2], as_html5=False)._fig + +class TestPyrshow(unittest.TestCase): + + def test_pyrshow_1d(self): + signal = np.random.rand(256) + pyr = pt.pyramids.GaussianPyramid(signal) + pt.pyrshow(pyr.pyr_coeffs) + + def test_pyrshow_1d_weird_shape(self): + # unlike 2d pyrshow, 1d pyrshow works with any shapes + signal = np.random.rand(255) + pyr = pt.pyramids.GaussianPyramid(signal) + pt.pyrshow(pyr.pyr_coeffs) + + def test_pyrshow_2d(self): + img = np.random.rand(256, 256) + pyr = pt.pyramids.GaussianPyramid(img) + pt.pyrshow(pyr.pyr_coeffs) + + def test_pyrshow_2d_shape_err(self): + img = np.random.rand(255, 255) + pyr = pt.pyramids.GaussianPyramid(img) + with self.assertRaises(ValueError): + pt.pyrshow(pyr.pyr_coeffs) + + def main(): unittest.main() diff --git a/src/pyrtools/tools/display.py b/src/pyrtools/tools/display.py index 4f03098..3fc5c88 100644 --- a/src/pyrtools/tools/display.py +++ b/src/pyrtools/tools/display.py @@ -324,8 +324,8 @@ def colormap_range(image, vrange='indep1', cmap=None): return vrange_list, cmap -def find_zooms(images, video=False): - """find the zooms necessary to display a list of images +def _check_shapes(images, video=False): + """Helper function to check whether images can be zoomed in appropriately. this convenience function takes a list of images and finds out if they can all be displayed at the same size. for this to be the case, there must be an integer for each image such that the @@ -341,21 +341,24 @@ def find_zooms(images, video=False): Returns ------- - zooms : `list` - list of integers showing how much each image needs to be zoomed max_shape : `tuple` 2-tuple of integers, showing the shape of the largest image in the list + Raises + ------ + ValueError : + If the images cannot be zoomed to the same. that is, if there is not an integer + for each image such that the image can be multiplied by that integer to be the + same size as the biggest image. """ def check_shape_1d(shapes): max_shape = np.max(shapes) for s in shapes: if not (max_shape % s) == 0: - raise Exception("All images must be able to be 'zoomed in' to the largest image." - "That is, the largest image must be a scalar multiple of all " - "images.") + raise ValueError("All images must be able to be 'zoomed in' to the largest image." + "That is, the largest image must be a scalar multiple of all " + "images.") return max_shape - if video: time_dim = 1 else: @@ -363,6 +366,43 @@ def check_shape_1d(shapes): max_shape = [] for i in range(2): max_shape.append(check_shape_1d([img.shape[i+time_dim] for img in images])) + return max_shape + + +def find_zooms(images, video=False): + """find the zooms necessary to display a list of images + + Arguments + --------- + images : `list` + list of numpy arrays to check the size of. In practice, these are 1d or 2d, but can in + principle be any number of dimensions + video: bool, optional (default False) + handling signals in both space and time or only space. + + Returns + ------- + zooms : `list` + list of integers showing how much each image needs to be zoomed + max_shape : `tuple` + 2-tuple of integers, showing the shape of the largest image in the list + + Raises + ------ + ValueError : + If the images cannot be zoomed to the same. that is, if there is not an integer + for each image such that the image can be multiplied by that integer to be the + same size as the biggest image. + ValueError : + If the two image dimensions require different levels of zoom (e.g., if the + height must be zoomed by 2 but the width must be zoomed by 3). + + """ + max_shape = _check_shapes(images, video) + if video: + time_dim = 1 + else: + time_dim = 0 zooms = [] for img in images: # this checks that there's only one unique value in the list @@ -373,8 +413,8 @@ def check_shape_1d(shapes): # the first two non-time dimensions (so we'll ignore the RGBA channel # if any image has that) if len(set([s // img.shape[i+time_dim] for i, s in enumerate(max_shape)])) > 1: - raise Exception("Both height and width must be multiplied by same amount but got " - "image shape {} and max_shape {}!".format(img.shape, max_shape)) + raise ValueError("Both height and width must be multiplied by same amount but got " + "image shape {} and max_shape {}!".format(img.shape, max_shape)) zooms.append(max_shape[0] // img.shape[0]) return zooms, max_shape @@ -839,9 +879,6 @@ def animate_video(t): def pyrshow(pyr_coeffs, is_complex=False, vrange='indep1', col_wrap=None, zoom=1, show_residuals=True, **kwargs): """Display the coefficients of the pyramid in an orderly fashion - NOTE: this currently only works for 2d signals. we still need to figure out how to handle 1D - signals. - Arguments --------- pyr_coeffs : `dict` @@ -894,8 +931,6 @@ def pyrshow(pyr_coeffs, is_complex=False, vrange='indep1', col_wrap=None, zoom=1 # pasting all coefficients into a giant array. # and the steerable pyramids have a num_orientations attribute - # TODO make list of different elements in each dim - # then only loop through those - see below line 655 num_scales = np.max(np.array([k for k in pyr_coeffs.keys() if isinstance(k, tuple)])[:,0]) + 1 num_orientations = np.max(np.array([k for k in pyr_coeffs.keys() if isinstance(k, tuple)])[:,1]) + 1 @@ -939,4 +974,22 @@ def pyrshow(pyr_coeffs, is_complex=False, vrange='indep1', col_wrap=None, zoom=1 ax.set_title(titles[i]) return fig else: + try: + _check_shapes(imgs) + except ValueError: + err_scales = num_scales + residual_err_msg = "" + shapes = [(imgs[0].shape[0]/ 2**i, imgs[0].shape[1] / 2**i) for i in range(err_scales)] + err_msg = [f"scale {i} shape: {sh}" for i, sh in enumerate(shapes)] + if show_residuals: + err_scales += 1 + residual_err_msg = ", plus 1 (for the residual lowpass)" + shape = (imgs[0].shape[0]/ int(2**err_scales), imgs[0].shape[1] / int(2**err_scales)) + err_msg.append(f"residual lowpass shape: {shape}") + err_msg = "\n".join(err_msg) + raise ValueError("In order to correctly display pyramid coefficients, the shape of" + f" the initial image must be evenly divisible by two {err_scales} " + "times, where this number is the height of the " + f"pyramid{residual_err_msg}. " + f"Instead, found:\n{err_msg}") return imshow(imgs, vrange=vrange, col_wrap=col_wrap, zoom=zoom, title=titles, **kwargs)