diff --git a/memery/streamlit_app.py b/memery/streamlit_app.py index 9e4fbc2..a12f986 100644 --- a/memery/streamlit_app.py +++ b/memery/streamlit_app.py @@ -1,11 +1,12 @@ # Builtins from pathlib import Path from PIL import Image -from io import StringIO +from io import StringIO, BytesIO import sys import argparse from threading import current_thread from contextlib import contextmanager +import zipfile # Local from memery.core import Memery @@ -104,18 +105,58 @@ def search(root, text_query, negative_text_query, image_query, image_display_zon with st_stdout('info'): ranked = memery.query_flow(root, text_query, negative_text_query, image_query) # Modified line ims_to_display = {} + full_paths = {} # Store full paths for download functionality size = sizes[size_choice] for o in ranked[:num_images]: name = o.replace(path, '') try: ims_to_display[name] = Image.open(o).convert('RGB') + full_paths[name] = o # Store the full path except Exception as e: with skipped_files_box: st.warning(f'Skipping bad file: {name}\ndue to {type(e)}') pass with image_display_zone: + # Add Download All button at the top if there are results + if ims_to_display: + # Create a zip file in memory with all the images + zip_buffer = BytesIO() + with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: + for name, img in ims_to_display.items(): + original_file_path = full_paths[name] + # Use just the filename for the zip entry to avoid path issues + filename = Path(original_file_path).name + with open(original_file_path, 'rb') as f: + zip_file.writestr(filename, f.read()) + + zip_buffer.seek(0) + st.download_button( + label="⬇ Download All as ZIP", + data=zip_buffer.getvalue(), + file_name="memery_results.zip", + mime="application/zip", + key="download_all" + ) + if captions_on: - st.image([o for o in ims_to_display.values()], width=size, channels='RGB', caption=[o for o in ims_to_display.keys()]) + # Display images with captions and download buttons + cols = st.columns(min(3, len(ims_to_display))) # Create columns for layout + for idx, (name, img) in enumerate(ims_to_display.items()): + col_idx = idx % min(3, len(ims_to_display)) + with cols[col_idx]: + st.image(img, width=size, channels='RGB', caption=name) + # Add download button for each image - use full path to get original file + original_file_path = full_paths[name] + with open(original_file_path, 'rb') as file: + file_bytes = file.read() + filename = Path(original_file_path).name + st.download_button( + label="⬇ Download", + data=file_bytes, + file_name=filename, + mime=f"image/{Path(original_file_path).suffix[1:]}", + key=f"download_{idx}" + ) else: st.image([o for o in ims_to_display.values()], width=sizes[size_choice], channels='RGB')