Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions memery/streamlit_app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Builtins
from pathlib import Path
from PIL import Image
from io import StringIO
from io import StringIO, BytesIO
import sys
import argparse
from threading import current_thread
from contextlib import contextmanager
import zipfile

# Local
from memery.core import Memery
Expand Down Expand Up @@ -104,18 +105,58 @@ def search(root, text_query, negative_text_query, image_query, image_display_zon
with st_stdout('info'):
ranked = memery.query_flow(root, text_query, negative_text_query, image_query) # Modified line
ims_to_display = {}
full_paths = {} # Store full paths for download functionality
size = sizes[size_choice]
for o in ranked[:num_images]:
name = o.replace(path, '')
try:
ims_to_display[name] = Image.open(o).convert('RGB')
full_paths[name] = o # Store the full path
except Exception as e:
with skipped_files_box:
st.warning(f'Skipping bad file: {name}\ndue to {type(e)}')
pass
with image_display_zone:
# Add Download All button at the top if there are results
if ims_to_display:
# Create a zip file in memory with all the images
zip_buffer = BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
for name, img in ims_to_display.items():
original_file_path = full_paths[name]
# Use just the filename for the zip entry to avoid path issues
filename = Path(original_file_path).name
with open(original_file_path, 'rb') as f:
zip_file.writestr(filename, f.read())

zip_buffer.seek(0)
st.download_button(
label="⬇ Download All as ZIP",
data=zip_buffer.getvalue(),
file_name="memery_results.zip",
mime="application/zip",
key="download_all"
)

if captions_on:
st.image([o for o in ims_to_display.values()], width=size, channels='RGB', caption=[o for o in ims_to_display.keys()])
# Display images with captions and download buttons
cols = st.columns(min(3, len(ims_to_display))) # Create columns for layout
for idx, (name, img) in enumerate(ims_to_display.items()):
col_idx = idx % min(3, len(ims_to_display))
with cols[col_idx]:
st.image(img, width=size, channels='RGB', caption=name)
# Add download button for each image - use full path to get original file
original_file_path = full_paths[name]
with open(original_file_path, 'rb') as file:
file_bytes = file.read()
filename = Path(original_file_path).name
st.download_button(
label="⬇ Download",
data=file_bytes,
file_name=filename,
mime=f"image/{Path(original_file_path).suffix[1:]}",
key=f"download_{idx}"
)
else:
st.image([o for o in ims_to_display.values()], width=sizes[size_choice], channels='RGB')

Expand Down