Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions TESTS/unitTests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,6 +1460,32 @@ def test_animshow_fail_n_frames(self):
with self.assertRaises(Exception):
fig = pt.animshow([vid1, vid2], as_html5=False)._fig


class TestPyrshow(unittest.TestCase):

def test_pyrshow_1d(self):
signal = np.random.rand(256)
pyr = pt.pyramids.GaussianPyramid(signal)
pt.pyrshow(pyr.pyr_coeffs)

def test_pyrshow_1d_weird_shape(self):
# unlike 2d pyrshow, 1d pyrshow works with any shapes
signal = np.random.rand(255)
pyr = pt.pyramids.GaussianPyramid(signal)
pt.pyrshow(pyr.pyr_coeffs)

def test_pyrshow_2d(self):
img = np.random.rand(256, 256)
pyr = pt.pyramids.GaussianPyramid(img)
pt.pyrshow(pyr.pyr_coeffs)

def test_pyrshow_2d_shape_err(self):
img = np.random.rand(255, 255)
pyr = pt.pyramids.GaussianPyramid(img)
with self.assertRaises(ValueError):
pt.pyrshow(pyr.pyr_coeffs)


def main():
unittest.main()

Expand Down
83 changes: 68 additions & 15 deletions src/pyrtools/tools/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,8 @@ def colormap_range(image, vrange='indep1', cmap=None):
return vrange_list, cmap


def find_zooms(images, video=False):
"""find the zooms necessary to display a list of images
def _check_shapes(images, video=False):
"""Helper function to check whether images can be zoomed in appropriately.

this convenience function takes a list of images and finds out if they can all be displayed at
the same size. for this to be the case, there must be an integer for each image such that the
Expand All @@ -341,28 +341,68 @@ def find_zooms(images, video=False):

Returns
-------
zooms : `list`
list of integers showing how much each image needs to be zoomed
max_shape : `tuple`
2-tuple of integers, showing the shape of the largest image in the list

Raises
------
ValueError :
If the images cannot be zoomed to the same. that is, if there is not an integer
for each image such that the image can be multiplied by that integer to be the
same size as the biggest image.
"""
def check_shape_1d(shapes):
max_shape = np.max(shapes)
for s in shapes:
if not (max_shape % s) == 0:
raise Exception("All images must be able to be 'zoomed in' to the largest image."
"That is, the largest image must be a scalar multiple of all "
"images.")
raise ValueError("All images must be able to be 'zoomed in' to the largest image."
"That is, the largest image must be a scalar multiple of all "
"images.")
return max_shape

if video:
time_dim = 1
else:
time_dim = 0
max_shape = []
for i in range(2):
max_shape.append(check_shape_1d([img.shape[i+time_dim] for img in images]))
return max_shape


def find_zooms(images, video=False):
"""find the zooms necessary to display a list of images

Arguments
---------
images : `list`
list of numpy arrays to check the size of. In practice, these are 1d or 2d, but can in
principle be any number of dimensions
video: bool, optional (default False)
handling signals in both space and time or only space.

Returns
-------
zooms : `list`
list of integers showing how much each image needs to be zoomed
max_shape : `tuple`
2-tuple of integers, showing the shape of the largest image in the list

Raises
------
ValueError :
If the images cannot be zoomed to the same. that is, if there is not an integer
for each image such that the image can be multiplied by that integer to be the
same size as the biggest image.
ValueError :
If the two image dimensions require different levels of zoom (e.g., if the
height must be zoomed by 2 but the width must be zoomed by 3).

"""
max_shape = _check_shapes(images, video)
if video:
time_dim = 1
else:
time_dim = 0
zooms = []
for img in images:
# this checks that there's only one unique value in the list
Expand All @@ -373,8 +413,8 @@ def check_shape_1d(shapes):
# the first two non-time dimensions (so we'll ignore the RGBA channel
# if any image has that)
if len(set([s // img.shape[i+time_dim] for i, s in enumerate(max_shape)])) > 1:
raise Exception("Both height and width must be multiplied by same amount but got "
"image shape {} and max_shape {}!".format(img.shape, max_shape))
raise ValueError("Both height and width must be multiplied by same amount but got "
"image shape {} and max_shape {}!".format(img.shape, max_shape))
zooms.append(max_shape[0] // img.shape[0])
return zooms, max_shape

Expand Down Expand Up @@ -839,9 +879,6 @@ def animate_video(t):
def pyrshow(pyr_coeffs, is_complex=False, vrange='indep1', col_wrap=None, zoom=1, show_residuals=True, **kwargs):
"""Display the coefficients of the pyramid in an orderly fashion

NOTE: this currently only works for 2d signals. we still need to figure out how to handle 1D
signals.

Arguments
---------
pyr_coeffs : `dict`
Expand Down Expand Up @@ -894,8 +931,6 @@ def pyrshow(pyr_coeffs, is_complex=False, vrange='indep1', col_wrap=None, zoom=1
# pasting all coefficients into a giant array.
# and the steerable pyramids have a num_orientations attribute

# TODO make list of different elements in each dim
# then only loop through those - see below line 655
num_scales = np.max(np.array([k for k in pyr_coeffs.keys() if isinstance(k, tuple)])[:,0]) + 1
num_orientations = np.max(np.array([k for k in pyr_coeffs.keys() if isinstance(k, tuple)])[:,1]) + 1

Expand Down Expand Up @@ -939,4 +974,22 @@ def pyrshow(pyr_coeffs, is_complex=False, vrange='indep1', col_wrap=None, zoom=1
ax.set_title(titles[i])
return fig
else:
try:
_check_shapes(imgs)
except ValueError:
err_scales = num_scales
residual_err_msg = ""
shapes = [(imgs[0].shape[0]/ 2**i, imgs[0].shape[1] / 2**i) for i in range(err_scales)]
err_msg = [f"scale {i} shape: {sh}" for i, sh in enumerate(shapes)]
if show_residuals:
err_scales += 1
residual_err_msg = ", plus 1 (for the residual lowpass)"
shape = (imgs[0].shape[0]/ int(2**err_scales), imgs[0].shape[1] / int(2**err_scales))
err_msg.append(f"residual lowpass shape: {shape}")
err_msg = "\n".join(err_msg)
raise ValueError("In order to correctly display pyramid coefficients, the shape of"
f" the initial image must be evenly divisible by two {err_scales} "
"times, where this number is the height of the "
f"pyramid{residual_err_msg}. "
f"Instead, found:\n{err_msg}")
return imshow(imgs, vrange=vrange, col_wrap=col_wrap, zoom=zoom, title=titles, **kwargs)
Loading