diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index eeceffa..9548391 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,99 +2,99 @@ name: CI/CD Pipeline on: push: - branches: [ main, master ] + branches: [main, master] pull_request: - branches: [ main, master ] + branches: [main, master] jobs: app-tests: runs-on: ubuntu-latest - + steps: - - uses: actions/checkout@v3 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: '3.11' - cache: 'pip' - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - # Install PyTorch CPU version first with the correct index - pip install torch==2.5.1+cpu torchvision==0.20.1+cpu --index-url https://download.pytorch.org/whl/cpu - # Install CI-specific requirements (CPU versions) - if [ -f app/requirements-ci.txt ]; then pip install -r app/requirements-ci.txt; fi - # Install test requirements - if [ -f app/tests/test_requirements.txt ]; then pip install -r app/tests/test_requirements.txt; fi - # Install SAM model (may be needed for imports, even though tests use mocks) - pip install git+https://github.com/facebookresearch/segment-anything.git - - - name: Set test mode environment variable - run: | - echo "SAT_ANNOTATOR_TEST_MODE=1" >> $GITHUB_ENV - - - name: Run tests - run: | - cd app/tests - python run_unittests.py + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + cache: 'pip' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + # Install PyTorch CPU version first with the correct index + pip install torch==2.5.1+cpu torchvision==0.20.1+cpu --index-url https://download.pytorch.org/whl/cpu + # Install CI-specific requirements (CPU versions) + if [ -f app/requirements-ci.txt ]; then pip install -r app/requirements-ci.txt; fi + # Install test requirements + if [ -f app/tests/test_requirements.txt ]; then pip install -r app/tests/test_requirements.txt; fi + # Install SAM model (may be needed for imports, even though tests use mocks) + pip install git+https://github.com/facebookresearch/segment-anything.git + + - name: Set test mode environment variable + run: | + echo "SAT_ANNOTATOR_TEST_MODE=1" >> $GITHUB_ENV + + - name: Run tests + run: | + cd app/tests + python run_unittests.py web-validation: runs-on: ubuntu-latest - + steps: - - uses: actions/checkout@v3 - - - name: Validate HTML files - run: | - echo "Validating web static files..." - # Check if main HTML files exist - test -f web/index.html || (echo "Missing index.html" && exit 1) - test -f web/styles.css || (echo "Missing styles.css" && exit 1) - test -d web/js || (echo "Missing js directory" && exit 1) - echo "Web files validation passed" - - - name: Check JavaScript syntax - run: | - echo "Checking JavaScript syntax..." - # Use Python to check basic JavaScript syntax if Node.js is available - if command -v node >/dev/null 2>&1; then - for js_file in web/js/*.js; do - if [ -f "$js_file" ]; then - echo "Checking $js_file..." - node -c "$js_file" || (echo "Syntax error in $js_file" && exit 1) - fi - done - else - echo "Node.js not available, skipping JS syntax check..." - fi - echo "JavaScript validation completed" + - uses: actions/checkout@v3 + + - name: Validate HTML files + run: | + echo "Validating web static files..." + # Check if main HTML files exist + test -f web/index.html || (echo "Missing index.html" && exit 1) + test -f web/styles.css || (echo "Missing styles.css" && exit 1) + test -d web/js || (echo "Missing js directory" && exit 1) + echo "Web files validation passed" + + - name: Check JavaScript syntax + run: | + echo "Checking JavaScript syntax..." + # Use Python to check basic JavaScript syntax if Node.js is available + if command -v node >/dev/null 2>&1; then + for js_file in web/js/*.js; do + if [ -f "$js_file" ]; then + echo "Checking $js_file..." + node -c "$js_file" || (echo "Syntax error in $js_file" && exit 1) + fi + done + else + echo "Node.js not available, skipping JS syntax check..." + fi + echo "JavaScript validation completed" docker-build: runs-on: ubuntu-latest needs: [app-tests, web-validation] if: github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master') - + steps: - - uses: actions/checkout@v3 - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v2 - - - name: Build and test Docker images - run: | - echo "Building Docker images..." - docker compose build - echo "Docker build completed successfully" - + - uses: actions/checkout@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + + - name: Build and test Docker images + run: | + echo "Building Docker images..." + docker compose build + echo "Docker build completed successfully" + # Uncomment and configure the following if you want to push to Docker Hub or another registry # - name: Login to Docker Hub # uses: docker/login-action@v2 # with: # username: ${{ secrets.DOCKERHUB_USERNAME }} # password: ${{ secrets.DOCKERHUB_TOKEN }} - # + # # - name: Push Docker images # run: | - # docker compose push \ No newline at end of file + # docker compose push diff --git a/.prettierrc b/.prettierrc new file mode 100644 index 0000000..b1c7fb5 --- /dev/null +++ b/.prettierrc @@ -0,0 +1,11 @@ +{ + "semi": true, + "trailingComma": "es5", + "singleQuote": true, + "printWidth": 80, + "tabWidth": 2, + "useTabs": false, + "bracketSpacing": true, + "arrowParens": "avoid", + "endOfLine": "lf" +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..b24297c --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,33 @@ +{ + "editor.formatOnSave": true, + "editor.defaultFormatter": "esbenp.prettier-vscode", + "editor.tabSize": 2, + "editor.insertSpaces": true, + "editor.detectIndentation": false, + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.tabSize": 4, + "editor.formatOnSave": true + }, + "[javascript]": { + "editor.defaultFormatter": "esbenp.prettier-vscode" + }, + "[typescript]": { + "editor.defaultFormatter": "esbenp.prettier-vscode" + }, + "[json]": { + "editor.defaultFormatter": "esbenp.prettier-vscode" + }, + "[html]": { + "editor.defaultFormatter": "esbenp.prettier-vscode" + }, + "[css]": { + "editor.defaultFormatter": "esbenp.prettier-vscode" + }, + "[yaml]": { + "editor.defaultFormatter": "esbenp.prettier-vscode" + }, + "[markdown]": { + "editor.defaultFormatter": "esbenp.prettier-vscode" + } +} diff --git a/README.md b/README.md index a1a9d35..9c128dc 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ This project is sponsored by the Egyptian Space Agency (EgSA). ## Features **Current:** + - RESTful API built with FastAPI - Session-based in-memory storage for image metadata (no database required) - File upload endpoint for satellite imagery @@ -31,6 +32,7 @@ This project is sponsored by the Egyptian Space Agency (EgSA). - Smart caching system for repeated segmentation operations **Planned:** + - Multiple prompt types (box, points, text) - Manual annotation tools with intuitive UI - Export annotations in additional formats (Shapefile) @@ -40,10 +42,10 @@ This project is sponsored by the Egyptian Space Agency (EgSA). - **Backend**: Python, FastAPI - **Storage**: Session-based in-memory storage -- **AI Models**: +- **AI Models**: - Segment Anything Model (SAM) - PyTorch with CUDA support -- **Image Processing**: +- **Image Processing**: - OpenCV - Pillow (PIL) - **Containerization**: Docker (optional) @@ -64,33 +66,39 @@ This project is sponsored by the Egyptian Space Agency (EgSA). ### Quick Start with Docker (Recommended) 1. Clone the repository: + ```bash git clone https://github.com/yourusername/sat-annotator.git cd sat-annotator ``` 2. Build and run with Docker: + ```bash docker-compose up --build ``` - *Note: The SAM model will be automatically downloaded during the Docker build process.* + +_Note: The SAM model will be automatically downloaded during the Docker build process._ 3. Access the application at `http://localhost:8000` ### Local Development Setup 1. Clone the repository: + ```bash git clone https://github.com/yourusername/sat-annotator.git cd sat-annotator ``` 2. Download the SAM model (required for local development): + - Download the SAM model checkpoint: [sam_vit_h_4b8939.pth](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) - Create a `models/` directory in the project root if it doesn't exist - Place the downloaded file in the `models/` directory 3. Set up Python environment: + ```bash # Create and activate a virtual environment python -m venv venv @@ -101,15 +109,17 @@ pip install -r app/requirements.txt pip install git+https://github.com/facebookresearch/segment-anything.git ``` - **Note on PyTorch versions:** - - `requirements.txt`: Contains CUDA version of PyTorch for local development with GPU acceleration - - `requirements-ci.txt`: Contains CPU version of PyTorch for CI/testing environments - - If you don't have CUDA support, install PyTorch CPU version first: - ```bash - pip install torch==2.5.1+cpu torchvision==0.20.1+cpu --index-url https://download.pytorch.org/whl/cpu - ``` +**Note on PyTorch versions:** + +- `requirements.txt`: Contains CUDA version of PyTorch for local development with GPU acceleration +- `requirements-ci.txt`: Contains CPU version of PyTorch for CI/testing environments +- If you don't have CUDA support, install PyTorch CPU version first: + ```bash + pip install torch==2.5.1+cpu torchvision==0.20.1+cpu --index-url https://download.pytorch.org/whl/cpu + ``` 4. Run the application: + ```bash uvicorn app.main:app --reload ``` @@ -121,6 +131,7 @@ uvicorn app.main:app --reload The frontend is served directly by the FastAPI backend as static files. No separate Node.js setup is required. **Docker vs Local Development:** + - **Docker**: SAM model downloads automatically during build process - **Local Development**: Manual model download required (as shown above) @@ -183,7 +194,7 @@ sat-annotator/ These directories are created automatically by the application: - **`uploads/`**: Stores user-uploaded satellite images -- **`annotations/`**: Stores AI-generated and manual annotation JSON files +- **`annotations/`**: Stores AI-generated and manual annotation JSON files - **`logs/`** & **`app/logs/`**: Application log files for debugging - **`models/`**: Contains the SAM AI model (auto-downloaded in Docker) @@ -209,11 +220,13 @@ These directories are created automatically by the application: The application provides a comprehensive REST API for programmatic access: #### Health Check + ```bash curl http://localhost:8000/health ``` Expected response: + ```json { "status": "healthy" @@ -223,16 +236,19 @@ Expected response: #### Session Management ##### Get Session Information + ```bash curl http://localhost:8000/api/session-info/ ``` ##### Clear Session Data + ```bash curl -X DELETE http://localhost:8000/api/session/ ``` ##### Export Session Data + ```bash curl -X POST http://localhost:8000/api/export-session/ ``` @@ -240,6 +256,7 @@ curl -X POST http://localhost:8000/api/export-session/ #### Image Management ##### Upload an Image + ```bash curl -X POST http://localhost:8000/api/upload-image/ \ -H "Content-Type: multipart/form-data" \ @@ -247,6 +264,7 @@ curl -X POST http://localhost:8000/api/upload-image/ \ ``` Expected response: + ```json { "success": true, @@ -264,16 +282,19 @@ Expected response: ``` ##### Retrieve All Images + ```bash curl http://localhost:8000/api/images/ ``` ##### Get Specific Image + ```bash curl http://localhost:8000/api/images/{image_id}/ ``` ##### Delete Image and Annotations + ```bash curl -X DELETE http://localhost:8000/api/images/{image_id} ``` @@ -281,6 +302,7 @@ curl -X DELETE http://localhost:8000/api/images/{image_id} #### AI Segmentation ##### Preprocess Image for Segmentation + ```bash curl -X POST http://localhost:8000/api/preprocess/ \ -H "Content-Type: application/json" \ @@ -288,6 +310,7 @@ curl -X POST http://localhost:8000/api/preprocess/ \ ``` ##### Generate Point-Based Segmentation + ```bash curl -X POST http://localhost:8000/api/segment/ \ -H "Content-Type: application/json" \ @@ -300,10 +323,16 @@ curl -X POST http://localhost:8000/api/segment/ \ ``` Expected response: + ```json { "success": true, - "polygon": [[0.1, 0.2], [0.3, 0.2], [0.3, 0.4], [0.1, 0.4]], + "polygon": [ + [0.1, 0.2], + [0.3, 0.2], + [0.3, 0.4], + [0.1, 0.4] + ], "annotation_id": "annotation-uuid", "label": "Building", "confidence": 0.92 @@ -313,6 +342,7 @@ Expected response: #### Annotation Management ##### Create Manual Annotation + ```bash curl -X POST http://localhost:8000/api/annotations/ \ -H "Content-Type: application/json" \ @@ -325,6 +355,7 @@ curl -X POST http://localhost:8000/api/annotations/ \ ``` ##### Update Annotation + ```bash curl -X PUT http://localhost:8000/api/annotations/{annotation_id} \ -H "Content-Type: application/json" \ @@ -335,11 +366,13 @@ curl -X PUT http://localhost:8000/api/annotations/{annotation_id} \ ``` ##### Delete Annotation + ```bash curl -X DELETE http://localhost:8000/api/annotations/{annotation_id} ``` ##### Get Image Annotations + ```bash curl http://localhost:8000/api/annotations/{image_id} ``` @@ -347,11 +380,13 @@ curl http://localhost:8000/api/annotations/{image_id} ### Testing the API #### Root Endpoint + ```bash curl http://localhost:8000/ ``` Expected response: + ```json { "message": "Welcome to the Satellite Image Annotation Tool" @@ -362,24 +397,25 @@ For comprehensive API examples, see the **API Reference** section above. ## API Documentation -| Endpoint | Method | Description | -|----------|--------|-------------| -| `/health` | GET | Health check for container orchestration | -| `/api/upload-image/` | POST | Upload satellite imagery (TIFF, PNG, JPG) | -| `/api/images/` | GET | Retrieve all uploaded images | -| `/api/images/{id}/` | GET | Get specific image by ID | -| `/api/images/{id}` | DELETE | Delete image and associated annotations | -| `/api/preprocess/` | POST | Prepare image for AI segmentation | -| `/api/segment/` | POST | Generate AI segmentation from point | -| `/api/annotations/` | POST | Create manual annotation | -| `/api/annotations/{id}` | PUT | Update existing annotation | -| `/api/annotations/{id}` | DELETE | Delete annotation | -| `/api/annotations/{image_id}` | GET | Get all annotations for image | -| `/api/session-info/` | GET | Get current session information | -| `/api/session/` | DELETE | Clear all session data | -| `/api/export-session/` | POST | Export session data as JSON | +| Endpoint | Method | Description | +| ----------------------------- | ------ | ----------------------------------------- | +| `/health` | GET | Health check for container orchestration | +| `/api/upload-image/` | POST | Upload satellite imagery (TIFF, PNG, JPG) | +| `/api/images/` | GET | Retrieve all uploaded images | +| `/api/images/{id}/` | GET | Get specific image by ID | +| `/api/images/{id}` | DELETE | Delete image and associated annotations | +| `/api/preprocess/` | POST | Prepare image for AI segmentation | +| `/api/segment/` | POST | Generate AI segmentation from point | +| `/api/annotations/` | POST | Create manual annotation | +| `/api/annotations/{id}` | PUT | Update existing annotation | +| `/api/annotations/{id}` | DELETE | Delete annotation | +| `/api/annotations/{image_id}` | GET | Get all annotations for image | +| `/api/session-info/` | GET | Get current session information | +| `/api/session/` | DELETE | Clear all session data | +| `/api/export-session/` | POST | Export session data as JSON | **Interactive Documentation:** + - Swagger UI: `http://localhost:8000/docs` - ReDoc: `http://localhost:8000/redoc` @@ -393,6 +429,7 @@ The application uses session-based in-memory storage for temporary data manageme - **Temporary Files**: Session data is cleared when the application restarts Benefits: + - No database setup required for quick deployment - Simplified development and testing - Stateless application design @@ -424,4 +461,4 @@ Created by ... --- -*Note: This project is under active development. Features and API endpoints are subject to change.* \ No newline at end of file +_Note: This project is under active development. Features and API endpoints are subject to change._ diff --git a/app/main.py b/app/main.py index 41d25c6..4019f4f 100644 --- a/app/main.py +++ b/app/main.py @@ -19,21 +19,24 @@ try: # For uvicorn from root directory from app.routers import session_images, session_segmentation + logger.info("Using app.routers imports") except ImportError: try: # For running directly from app directory from routers import session_images, session_segmentation + logger.info("Using direct routers imports") except ImportError as e: logger.error(f"Import error: {e}") # Final fallback - try with explicit path manipulation sys.path.insert(0, os.path.dirname(app_dir)) from app.routers import session_images, session_segmentation + logger.info("Using fallback app.routers imports") # Determine if we're running in Docker or locally -in_docker = os.path.exists('/.dockerenv') +in_docker = os.path.exists("/.dockerenv") # Set paths based on environment base_path = Path("/app") if in_docker else Path(".") @@ -60,11 +63,13 @@ allow_headers=["*"], ) + # Health check endpoint for Docker container orchestration @app.get("/health") def health_check(): return {"status": "healthy"} + # Include session-based routers app.include_router(session_images.router, prefix="/api", tags=["images"]) app.include_router(session_segmentation.router, prefix="/api", tags=["segmentation"]) @@ -76,10 +81,16 @@ def health_check(): if frontend_dir.exists() and (frontend_dir / "index.html").exists(): app.mount("/", StaticFiles(directory=str(frontend_dir), html=True), name="frontend") else: + @app.get("/") def read_root(): - return {"message": "Welcome to the Satellite Image Annotation Tool (API Only Mode)"} - + return { + "message": "Welcome to the Satellite Image Annotation Tool (API Only Mode)" + } + @app.get("/frontend-status") def frontend_status(): - return {"status": "not_mounted", "message": "Frontend not found. Make sure the web directory contains index.html."} \ No newline at end of file + return { + "status": "not_mounted", + "message": "Frontend not found. Make sure the web directory contains index.html.", + } diff --git a/app/routers/session_images.py b/app/routers/session_images.py index add1bb1..6eaae39 100644 --- a/app/routers/session_images.py +++ b/app/routers/session_images.py @@ -9,32 +9,32 @@ router = APIRouter() + @router.post("/upload-image/", response_model=UploadResponse) async def upload_image( file: UploadFile = File(...), - session_manager: SessionManager = Depends(get_session_manager) + session_manager: SessionManager = Depends(get_session_manager), ): """ Upload a satellite image file for annotation. - + Supports JPG, PNG, TIFF and GeoTIFF formats. - """ # Validate file type + """ # Validate file type if not file: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="No file provided" + status_code=status.HTTP_400_BAD_REQUEST, detail="No file provided" ) - + if not validate_image_file(file): raise HTTPException( status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, - detail="File type not supported. Please upload JPG, PNG, TIFF or GeoTIFF" + detail="File type not supported. Please upload JPG, PNG, TIFF or GeoTIFF", ) - + # Process and save the file try: file_info = await save_upload_file(file) - + # Save to session store session_id = session_manager.session_id session_image = session_store.add_image( @@ -42,9 +42,9 @@ async def upload_image( file_name=file_info["original_filename"], file_path=file_info["path"], resolution=file_info["resolution"], - source="user_upload" + source="user_upload", ) - # Convert SessionImage to the expected Image pydantic model format + # Convert SessionImage to the expected Image pydantic model format # Create an Image Pydantic model directly from the SessionImage attributes image = Image( image_id=session_image.image_id, @@ -53,37 +53,38 @@ async def upload_image( resolution=session_image.resolution, source=session_image.source, capture_date=session_image.capture_date, - created_at=session_image.created_at - ) # Image uploaded successfully - ready for immediate preprocessing + created_at=session_image.created_at, + ) # Image uploaded successfully - ready for immediate preprocessing logging.info(f"✓ Image uploaded successfully: {file_info['original_filename']}") return UploadResponse( success=True, message="File uploaded successfully. Ready for annotation.", - image=image + image=image, ) - + except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error uploading file: {str(e)}" + detail=f"Error uploading file: {str(e)}", ) + @router.get("/images/", response_model=List[Image]) def get_images( - skip: int = 0, + skip: int = 0, limit: int = 100, - session_manager: SessionManager = Depends(get_session_manager) + session_manager: SessionManager = Depends(get_session_manager), ): """Get list of uploaded images in the current session""" session_id = session_manager.session_id images = session_store.get_images(session_id, skip=skip, limit=limit) return images + @router.get("/images/{image_id}/", response_model=Image) def get_image( - image_id: str, - session_manager: SessionManager = Depends(get_session_manager) + image_id: str, session_manager: SessionManager = Depends(get_session_manager) ): """ Retrieve a specific image by its ID. @@ -96,115 +97,117 @@ def get_image( if not image: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Image with ID {image_id} not found" + detail=f"Image with ID {image_id} not found", ) return image + @router.delete("/images/{image_id}", response_model=dict) async def delete_image( - image_id: str, - session_manager: SessionManager = Depends(get_session_manager) + image_id: str, session_manager: SessionManager = Depends(get_session_manager) ): """ Delete an image and its associated annotations. """ session_id = session_manager.session_id - + # Get the image from session store image = session_store.get_image(session_id, image_id) if not image: raise HTTPException(status_code=404, detail="Image not found") - + try: # Delete the image file if it exists if os.path.exists(image.file_path): os.remove(image.file_path) - + # Delete all annotations associated with this image annotations = session_store.get_annotations(session_id, image_id) for annotation in annotations: if os.path.exists(annotation.file_path): os.remove(annotation.file_path) session_store.remove_annotation(session_id, annotation.annotation_id) - + # Remove image from session store success = session_store.remove_image(session_id, image_id) - + if success: - return {"success": True, "message": "Image and associated annotations deleted successfully"} + return { + "success": True, + "message": "Image and associated annotations deleted successfully", + } else: - raise HTTPException(status_code=500, detail="Failed to remove image from session") - + raise HTTPException( + status_code=500, detail="Failed to remove image from session" + ) + except Exception as e: raise HTTPException(status_code=500, detail=f"Error deleting image: {str(e)}") + @router.get("/session-info/") -def get_session_info( - session_manager: SessionManager = Depends(get_session_manager) -): +def get_session_info(session_manager: SessionManager = Depends(get_session_manager)): """Get information about the current session""" session_id = session_manager.session_id session_data = session_store.get_session(session_id) - + if session_data: return { "session_id": session_id, "images_count": len(session_data.get("images", {})), "annotations_count": len(session_data.get("annotations", {})), - "created_at": session_data.get("created_at") + "created_at": session_data.get("created_at"), } - + return { "session_id": session_id, "images_count": 0, "annotations_count": 0, - "created_at": None + "created_at": None, } + @router.delete("/session/") -def clear_session( - session_manager: SessionManager = Depends(get_session_manager) -): +def clear_session(session_manager: SessionManager = Depends(get_session_manager)): """Clear the current session data""" session_id = session_manager.session_id - + # Remove session from store if session_id in session_store.sessions: del session_store.sessions[session_id] - + # Clear the session cookie session_manager.clear_session() - + return {"message": "Session cleared successfully"} + @router.post("/export-session/") -def export_session( - session_manager: SessionManager = Depends(get_session_manager) -): +def export_session(session_manager: SessionManager = Depends(get_session_manager)): """Export the current session data""" session_id = session_manager.session_id session_data = session_store.export_session(session_id) - + if not session_data: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="No session data found" + status_code=status.HTTP_404_NOT_FOUND, detail="No session data found" ) - + # Convert session data to a format suitable for export export_data = { "session_id": session_id, "created_at": session_data["created_at"].isoformat(), "images": [img.dict() for img in session_data["images"].values()], - "annotations": [ann.dict() for ann in session_data["annotations"].values()] + "annotations": [ann.dict() for ann in session_data["annotations"].values()], } - + return export_data + @router.get("/session-id/") def get_session_id_endpoint( - session_manager: SessionManager = Depends(get_session_manager) + session_manager: SessionManager = Depends(get_session_manager), ): """Get the current session ID for API calls""" session_id = session_manager.session_id diff --git a/app/routers/session_segmentation.py b/app/routers/session_segmentation.py index 0c61f3e..4a0478c 100644 --- a/app/routers/session_segmentation.py +++ b/app/routers/session_segmentation.py @@ -11,7 +11,11 @@ import logging from datetime import datetime import cv2 -from app.schemas.session_schemas import ManualAnnotationCreate, ManualAnnotationUpdate, AnnotationResponse +from app.schemas.session_schemas import ( + ManualAnnotationCreate, + ManualAnnotationUpdate, + AnnotationResponse, +) # Set up logging log_dir = Path("logs") @@ -19,22 +23,20 @@ log_filename = f"debug_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" logging.basicConfig( level=logging.INFO, # Changed from DEBUG to INFO - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler(log_dir / log_filename), - logging.StreamHandler() - ] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.FileHandler(log_dir / log_filename), logging.StreamHandler()], ) logger = logging.getLogger("segmentation_router") router = APIRouter() segmenter = SAMSegmenter() + def construct_image_path(stored_path): """Construct consistent image path for both preprocessing and segmentation""" # Determine if running in Docker - in_docker = os.path.exists('/.dockerenv') - + in_docker = os.path.exists("/.dockerenv") + if not os.path.isabs(stored_path): if in_docker: # Docker environment @@ -45,20 +47,24 @@ def construct_image_path(stored_path): else: # Local environment if stored_path.startswith("uploads/"): - base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + base_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) image_path = os.path.join(base_dir, stored_path) else: image_path = stored_path else: image_path = stored_path - + return image_path + class PointPrompt(BaseModel): image_id: str x: float y: float + class SegmentationResponse(BaseModel): success: bool polygon: List[List[float]] @@ -66,114 +72,129 @@ class SegmentationResponse(BaseModel): cached: bool = False processing_time: Optional[float] = None + class PreprocessRequest(BaseModel): image_id: str + class PreprocessResponse(BaseModel): success: bool message: str + @router.post("/segment/", response_model=SegmentationResponse) async def segment_from_point( - prompt: PointPrompt, - session_manager: SessionManager = Depends(get_session_manager) + prompt: PointPrompt, session_manager: SessionManager = Depends(get_session_manager) ): """Generate segmentation from a point click with timeout handling""" import asyncio import concurrent.futures - + session_id = session_manager.session_id - + # Debug: log the received coordinates - logger.debug(f"Received segmentation request: image_id={prompt.image_id}, x={prompt.x}, y={prompt.y}") - + logger.debug( + f"Received segmentation request: image_id={prompt.image_id}, x={prompt.x}, y={prompt.y}" + ) + image = session_store.get_image(session_id, prompt.image_id) if not image: raise HTTPException(status_code=404, detail="Image not found") - + try: import time + timings = {} op_start = time.time() # Determine if running in Docker - in_docker = os.path.exists('/.dockerenv') - timings['env_check'] = time.time() - op_start + in_docker = os.path.exists("/.dockerenv") + timings["env_check"] = time.time() - op_start # Define annotation directory t0 = time.time() if in_docker: annotation_dir = Path("/app/annotations") else: - base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + base_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) annotation_dir = Path(os.path.join(base_dir, "annotations")) annotation_dir.mkdir(exist_ok=True) - timings['annotation_dir'] = time.time() - t0 + timings["annotation_dir"] = time.time() - t0 # Get the image path from the session store t1 = time.time() stored_path = image.file_path image_path = construct_image_path(stored_path) - timings['construct_image_path'] = time.time() - t1 + timings["construct_image_path"] = time.time() - t1 # Check if file exists t2 = time.time() if not os.path.exists(image_path): raise FileNotFoundError(f"Image file not found at {image_path}") - timings['file_exists'] = time.time() - t2 + timings["file_exists"] = time.time() - t2 # Check if this is a new image or one we've already processed t3 = time.time() - is_cached = (image_path == segmenter.current_image_path and image_path in segmenter.cache) - timings['cache_check'] = time.time() - t3 + is_cached = ( + image_path == segmenter.current_image_path and image_path in segmenter.cache + ) + timings["cache_check"] = time.time() - t3 - logger.info(f"Processing image: {image_path}, cached: {is_cached}, current: {segmenter.current_image_path}") + logger.info( + f"Processing image: {image_path}, cached: {is_cached}, current: {segmenter.current_image_path}" + ) def run_segmentation(): op_times = {} - op_times['start'] = time.time() + op_times["start"] = time.time() # OPTIMIZED: Only set image if it's not already the current image if segmenter.current_image_path != image_path: logger.info(f"Setting new image in SAM: {image_path}") t_set = time.time() height, width = segmenter.set_image(image_path) - op_times['set_image'] = time.time() - t_set + op_times["set_image"] = time.time() - t_set else: logger.info(f"Using already-set image (instant segmentation!)") t_cache = time.time() if image_path in segmenter.cache: - height, width = segmenter.cache[image_path]['image_size'] + height, width = segmenter.cache[image_path]["image_size"] else: logger.warning(f"Image not in cache, falling back to set_image") height, width = segmenter.set_image(image_path) - op_times['cache_lookup'] = time.time() - t_cache + op_times["cache_lookup"] = time.time() - t_cache pixel_x = int(prompt.x * width) pixel_y = int(prompt.y * height) - logger.info(f"Click at coordinates: ({pixel_x}, {pixel_y}) for image size: {width}x{height}") + logger.info( + f"Click at coordinates: ({pixel_x}, {pixel_y}) for image size: {width}x{height}" + ) # Get mask from point (GPU accelerated) t_mask = time.time() mask = segmenter.predict_from_point([pixel_x, pixel_y]) - op_times['mask_generation'] = time.time() - t_mask + op_times["mask_generation"] = time.time() - t_mask logger.info(f"Mask generation time: {op_times['mask_generation']:.3f}s") # Convert mask to polygon immediately t_poly = time.time() polygon = segmenter.mask_to_polygon(mask) - op_times['polygon_conversion'] = time.time() - t_poly - logger.info(f"Polygon conversion time: {op_times['polygon_conversion']:.3f}s") + op_times["polygon_conversion"] = time.time() - t_poly + logger.info( + f"Polygon conversion time: {op_times['polygon_conversion']:.3f}s" + ) if not polygon: raise ValueError("Could not generate polygon from mask") - op_times['total'] = time.time() - op_times['start'] + op_times["total"] = time.time() - op_times["start"] logger.info(f"Segmentation operation timings: {op_times}") # Find the slowest step (excluding 'start' and 'total') slowest_step = None slowest_time = 0.0 for k, v in op_times.items(): - if k not in ('start', 'total') and v > slowest_time: + if k not in ("start", "total") and v > slowest_time: slowest_step = k slowest_time = v if slowest_step: @@ -185,30 +206,30 @@ def run_segmentation(): future = executor.submit(run_segmentation) try: polygon, is_cached, seg_timings = await asyncio.wait_for( - asyncio.wrap_future(future), - timeout=30.0 + asyncio.wrap_future(future), timeout=30.0 ) except asyncio.TimeoutError: raise HTTPException( status_code=408, - detail="Segmentation timeout - image may still be processing..." + detail="Segmentation timeout - image may still be processing...", ) # Save JSON t_save = time.time() - annotation_path = annotation_dir / f"annotation_{session_id}_{image.image_id}_{len(polygon)}.json" + annotation_path = ( + annotation_dir + / f"annotation_{session_id}_{image.image_id}_{len(polygon)}.json" + ) with open(annotation_path, "w") as f: - json.dump({ - "type": "Feature", - "geometry": { - "type": "Polygon", - "coordinates": [polygon] + json.dump( + { + "type": "Feature", + "geometry": {"type": "Polygon", "coordinates": [polygon]}, + "properties": {"cached": is_cached}, }, - "properties": { - "cached": is_cached - } - }, f) - timings['save_json'] = time.time() - t_save + f, + ) + timings["save_json"] = time.time() - t_save # Add annotation to session store t_ann = time.time() @@ -216,11 +237,13 @@ def run_segmentation(): session_id=session_id, image_id=image.image_id, file_path=str(annotation_path), - auto_generated=True + auto_generated=True, ) - timings['add_annotation'] = time.time() - t_ann + timings["add_annotation"] = time.time() - t_ann - logger.info(f"Generated segmentation with {len(polygon)} points, cached: {is_cached}") + logger.info( + f"Generated segmentation with {len(polygon)} points, cached: {is_cached}" + ) total_processing_time = time.time() - op_start logger.info(f"Total segmentation processing time: {total_processing_time:.3f}s") logger.info(f"Step timings: {timings}") @@ -232,7 +255,7 @@ def run_segmentation(): annotation_id=annotation.annotation_id if annotation else None, cached=is_cached, processing_time=total_processing_time, - timings={**timings, **seg_timings} + timings={**timings, **seg_timings}, ) except HTTPException: @@ -241,98 +264,110 @@ def run_segmentation(): logger.error(f"Error in segmentation: {str(e)}", exc_info=True) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error generating segmentation: {str(e)}" + detail=f"Error generating segmentation: {str(e)}", ) + @router.get("/masks/{session_id}/{image_id}/{mask_type}") async def get_mask_image(session_id: str, image_id: str, mask_type: str): """Serve mask or overlay image for download or display.""" if mask_type not in ["mask", "overlay"]: - raise HTTPException(status_code=400, detail="Invalid mask_type. Use 'mask' or 'overlay'.") + raise HTTPException( + status_code=400, detail="Invalid mask_type. Use 'mask' or 'overlay'." + ) mask_path = Path(f"/app/annotations/{mask_type}_{session_id}_{image_id}.png") if not mask_path.exists(): raise HTTPException(status_code=404, detail=f"{mask_type} image not found") - return FileResponse(mask_path, media_type="image/png", filename=f"{mask_type}_{session_id}_{image_id}.png") + return FileResponse( + mask_path, + media_type="image/png", + filename=f"{mask_type}_{session_id}_{image_id}.png", + ) + @router.get("/annotations/{image_id}") async def get_image_annotations( - image_id: str, - session_manager: SessionManager = Depends(get_session_manager) + image_id: str, session_manager: SessionManager = Depends(get_session_manager) ): """Get all annotations for a specific image""" session_id = session_manager.session_id image = session_store.get_image(session_id, image_id) if not image: raise HTTPException(status_code=404, detail="Image not found") - + annotations = session_store.get_annotations(session_id, image_id) - + result = [] for ann in annotations: try: file_path = ann.file_path if os.path.exists(file_path): - with open(file_path, 'r') as f: - json_data = json.load(f) # DEBUG: Log what we're loading - if json_data.get('features') and len(json_data['features']) > 0: - feature = json_data['features'][0] - if feature.get('geometry', {}).get('coordinates'): - coords = feature['geometry']['coordinates'] - - result.append({ - "annotation_id": ann.annotation_id, - "created_at": ann.created_at, - "auto_generated": ann.auto_generated, - "data": json_data - }) + with open(file_path, "r") as f: + json_data = json.load(f) # DEBUG: Log what we're loading + if json_data.get("features") and len(json_data["features"]) > 0: + feature = json_data["features"][0] + if feature.get("geometry", {}).get("coordinates"): + coords = feature["geometry"]["coordinates"] + + result.append( + { + "annotation_id": ann.annotation_id, + "created_at": ann.created_at, + "auto_generated": ann.auto_generated, + "data": json_data, + } + ) except Exception as e: logger.error(f"Error loading annotation {ann.annotation_id}: {e}") continue - + return result + @router.post("/clear-cache/{image_id}") async def clear_image_cache( - image_id: str, - session_manager: SessionManager = Depends(get_session_manager) + image_id: str, session_manager: SessionManager = Depends(get_session_manager) ): """Clear the segmentation cache for a specific image""" session_id = session_manager.session_id image = session_store.get_image(session_id, image_id) - + if not image: raise HTTPException(status_code=404, detail="Image not found") - + # Use unified path construction image_path = construct_image_path(image.file_path) - + segmenter.clear_cache(image_path) - + return {"success": True, "message": f"Cache cleared for image {image_id}"} + @router.post("/annotations/", response_model=AnnotationResponse) async def save_manual_annotation( annotation_data: ManualAnnotationCreate, - session_manager: SessionManager = Depends(get_session_manager) + session_manager: SessionManager = Depends(get_session_manager), ): """Save a manual annotation""" session_id = session_manager.session_id - + # Verify the image exists image = session_store.get_image(session_id, annotation_data.image_id) if not image: raise HTTPException(status_code=404, detail="Image not found") - + try: # Define annotation directory - in_docker = os.path.exists('/.dockerenv') + in_docker = os.path.exists("/.dockerenv") if in_docker: annotation_dir = Path("/app/annotations") else: - base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + base_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) annotation_dir = Path(os.path.join(base_dir, "annotations")) annotation_dir.mkdir(exist_ok=True) - + # Create JSON format for the annotation json_data = { "type": "FeatureCollection", @@ -343,140 +378,156 @@ async def save_manual_annotation( "label": annotation_data.label, "type": annotation_data.type, "source": annotation_data.source, - "created": datetime.now().isoformat() - }, "geometry": { + "created": datetime.now().isoformat(), + }, + "geometry": { "type": "Polygon", - "coordinates": [annotation_data.polygon] # Fix: Direct polygon array, not nested - } + "coordinates": [ + annotation_data.polygon + ], # Fix: Direct polygon array, not nested + }, } - ] + ], } - # Save annotation file - annotation_path = annotation_dir / f"manual_{session_id}_{annotation_data.image_id}_{annotation_data.id}.json" + # Save annotation file + annotation_path = ( + annotation_dir + / f"manual_{session_id}_{annotation_data.image_id}_{annotation_data.id}.json" + ) with open(annotation_path, "w") as f: json.dump(json_data, f, indent=2) - + # Add annotation to session store annotation = session_store.add_annotation( session_id, annotation_data.image_id, annotation_id=annotation_data.id, file_path=str(annotation_path), - auto_generated=False + auto_generated=False, ) - + return AnnotationResponse( success=True, message="Annotation saved successfully", - annotation_id=annotation.annotation_id if annotation else annotation_data.id + annotation_id=( + annotation.annotation_id if annotation else annotation_data.id + ), ) - + except Exception as e: logger.error(f"Error saving annotation: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error saving annotation: {str(e)}" + detail=f"Error saving annotation: {str(e)}", ) + @router.put("/annotations/{annotation_id}", response_model=AnnotationResponse) async def update_annotation( annotation_id: str, update_data: ManualAnnotationUpdate, - session_manager: SessionManager = Depends(get_session_manager) + session_manager: SessionManager = Depends(get_session_manager), ): """Update an existing annotation""" session_id = session_manager.session_id - + # Get the annotation from session store annotation = session_store.get_annotation(session_id, annotation_id) if not annotation: raise HTTPException(status_code=404, detail="Annotation not found") - - try: # Load existing annotation data + + try: # Load existing annotation data if not os.path.exists(annotation.file_path): raise HTTPException(status_code=404, detail="Annotation file not found") - - with open(annotation.file_path, 'r') as f: + + with open(annotation.file_path, "r") as f: json_data = json.load(f) - + # Update the data if json_data.get("features"): feature = json_data["features"][0] - + if update_data.polygon: - feature["geometry"]["coordinates"] = [[[point[0], point[1]] for point in update_data.polygon]] - + feature["geometry"]["coordinates"] = [ + [[point[0], point[1]] for point in update_data.polygon] + ] + if update_data.label: feature["properties"]["label"] = update_data.label - + # Update modified timestamp feature["properties"]["modified"] = datetime.now().isoformat() - # Save updated annotation + # Save updated annotation with open(annotation.file_path, "w") as f: json.dump(json_data, f, indent=2) - + return AnnotationResponse( success=True, message="Annotation updated successfully", - annotation_id=annotation_id + annotation_id=annotation_id, ) - + except Exception as e: logger.error(f"Error updating annotation: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error updating annotation: {str(e)}" + detail=f"Error updating annotation: {str(e)}", ) + @router.delete("/annotations/{annotation_id}", response_model=AnnotationResponse) async def delete_annotation( - annotation_id: str, - session_manager: SessionManager = Depends(get_session_manager) + annotation_id: str, session_manager: SessionManager = Depends(get_session_manager) ): """Delete an annotation""" session_id = session_manager.session_id - - logger.info(f"Attempting to delete annotation {annotation_id} from session {session_id}") - + + logger.info( + f"Attempting to delete annotation {annotation_id} from session {session_id}" + ) + # List all annotations in session for debugging all_annotations = session_store.get_annotations(session_id) - logger.info(f"Available annotations in session: {[ann.annotation_id for ann in all_annotations]}") - + logger.info( + f"Available annotations in session: {[ann.annotation_id for ann in all_annotations]}" + ) + # Get the annotation from session store annotation = session_store.get_annotation(session_id, annotation_id) if not annotation: logger.error(f"Annotation {annotation_id} not found in session {session_id}") raise HTTPException(status_code=404, detail="Annotation not found") - + logger.info(f"Found annotation to delete: {annotation.annotation_id}") - + try: # Delete the annotation file if it exists if os.path.exists(annotation.file_path): os.remove(annotation.file_path) logger.info(f"Deleted annotation file: {annotation.file_path}") - + # Remove from session store success = session_store.remove_annotation(session_id, annotation_id) logger.info(f"Removed annotation from session store: {success}") - + return AnnotationResponse( success=True, message="Annotation deleted successfully", - annotation_id=annotation_id + annotation_id=annotation_id, ) - + except Exception as e: logger.error(f"Error deleting annotation: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error deleting annotation: {str(e)}" + detail=f"Error deleting annotation: {str(e)}", ) + @router.post("/preprocess/", response_model=PreprocessResponse) async def preprocess_image( request: PreprocessRequest, - session_manager: SessionManager = Depends(get_session_manager) + session_manager: SessionManager = Depends(get_session_manager), ): """Pre-generate embeddings for faster segmentation""" try: @@ -484,41 +535,40 @@ async def preprocess_image( session_data = session_store.get_session(session_manager.session_id) if not session_data: raise HTTPException( - status_code=404, - detail=f"Session {session_manager.session_id} not found. Please refresh the page to create a new session." + status_code=404, + detail=f"Session {session_manager.session_id} not found. Please refresh the page to create a new session.", ) - + # Get image from session image = session_store.get_image(session_manager.session_id, request.image_id) if not image: raise HTTPException( - status_code=404, - detail=f"Image {request.image_id} not found in session {session_manager.session_id}" - ) # Handle both absolute and relative paths for image_path + status_code=404, + detail=f"Image {request.image_id} not found in session {session_manager.session_id}", + ) # Handle both absolute and relative paths for image_path image_path = construct_image_path(image.file_path) - + # Check if file exists if not os.path.exists(image_path): logger.error(f"File does not exist at: {image_path}") raise FileNotFoundError(f"Image file not found at {image_path}") - + success = segmenter.preprocess_image(image_path) - + if success: logger.info(f"Successfully preprocessed image {request.image_id}") return PreprocessResponse( - success=True, - message="Image preprocessed successfully" + success=True, message="Image preprocessed successfully" ) else: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to preprocess image" + detail="Failed to preprocess image", ) - + except Exception as e: logger.error(f"Error preprocessing image {request.image_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error preprocessing image: {str(e)}" - ) \ No newline at end of file + detail=f"Error preprocessing image: {str(e)}", + ) diff --git a/app/schemas/session_schemas.py b/app/schemas/session_schemas.py index c0b4349..74ec155 100644 --- a/app/schemas/session_schemas.py +++ b/app/schemas/session_schemas.py @@ -2,43 +2,52 @@ from typing import Optional, List, Any from datetime import datetime + class ImageBase(BaseModel): file_name: str file_path: str resolution: Optional[str] = None source: Optional[str] = None + class ImageCreate(ImageBase): pass + class Image(ImageBase): image_id: str # Now using UUID string instead of int capture_date: datetime created_at: datetime + class UploadResponse(BaseModel): success: bool message: str image: Optional[Image] = None + class AnnotationBase(BaseModel): file_path: str auto_generated: bool = False - + + class AnnotationCreate(AnnotationBase): image_id: str - + + class Annotation(AnnotationBase): annotation_id: str image_id: str created_at: datetime - + + class SessionInfo(BaseModel): session_id: str image_count: int annotation_count: int created_at: datetime + # Manual annotation schemas for frontend requests class ManualAnnotationCreate(BaseModel): image_id: str @@ -48,10 +57,12 @@ class ManualAnnotationCreate(BaseModel): label: str source: str = "manual" + class ManualAnnotationUpdate(BaseModel): polygon: Optional[List[List[float]]] = None label: Optional[str] = None + class AnnotationResponse(BaseModel): success: bool message: str diff --git a/app/storage/session_manager.py b/app/storage/session_manager.py index cd35ac1..b3acbdc 100644 --- a/app/storage/session_manager.py +++ b/app/storage/session_manager.py @@ -8,22 +8,25 @@ # Session timeout (7 days) SESSION_TIMEOUT_DAYS = 7 + def generate_session_id() -> str: """Generate a unique session ID""" return str(uuid.uuid4()) + async def get_session_id(request: Request) -> str: """ Get the session ID from the request cookie or create a new one. This function should be used as a dependency in FastAPI routes. """ session_id = request.cookies.get(SESSION_COOKIE_NAME) - + if not session_id: session_id = generate_session_id() - + return session_id + def set_session_cookie(response: Response, session_id: str) -> None: """ Set the session cookie in the response. @@ -35,21 +38,22 @@ def set_session_cookie(response: Response, session_id: str) -> None: value=session_id, httponly=True, # Prevent JavaScript access for enhanced security against XSS samesite="lax", - path="/" + path="/", # No expires parameter = session cookie (cleared on browser close/reload) ) + class SessionManager: """ A manager for handling session-related operations in route handlers. Simplifies session management in API endpoints. """ - + def __init__(self, request: Request, response: Response): self.request = request self.response = response self._session_id: Optional[str] = None - + @property def session_id(self) -> str: """Get or initialize the session ID""" @@ -59,13 +63,13 @@ def session_id(self) -> str: self._session_id = generate_session_id() set_session_cookie(self.response, self._session_id) return self._session_id - + def clear_session(self) -> None: """Remove the session cookie""" self.response.delete_cookie( key=SESSION_COOKIE_NAME, httponly=True, # Match the httponly setting from set_session_cookie - path="/" + path="/", ) diff --git a/app/storage/session_store.py b/app/storage/session_store.py index 708e97d..daeea03 100644 --- a/app/storage/session_store.py +++ b/app/storage/session_store.py @@ -28,135 +28,159 @@ class SessionStore: In-memory session-based storage for images and annotations. This replaces the database for storing metadata. """ + def __init__(self): # Dictionary to store active sessions self.sessions: Dict[str, Dict] = {} - + def create_session(self, session_id: str) -> None: """Create a new session if it doesn't exist""" if session_id not in self.sessions: self.sessions[session_id] = { "images": {}, "annotations": {}, - "created_at": datetime.now() + "created_at": datetime.now(), } - + def get_session(self, session_id: str) -> Optional[Dict]: """Get session data by ID""" return self.sessions.get(session_id) - - def add_image(self, session_id: str, file_name: str, file_path: str, - resolution: Optional[str] = None, source: Optional[str] = None) -> SessionImage: + + def add_image( + self, + session_id: str, + file_name: str, + file_path: str, + resolution: Optional[str] = None, + source: Optional[str] = None, + ) -> SessionImage: """Add image to session and return the created image object""" self.create_session(session_id) - + image_id = str(uuid.uuid4()) image = SessionImage( image_id=image_id, file_name=file_name, file_path=file_path, resolution=resolution, - source=source or "user_upload" ) - + source=source or "user_upload", + ) + self.sessions[session_id]["images"][image_id] = image return image - - def get_images(self, session_id: str, skip: int = 0, limit: int = 100) -> List[SessionImage]: + + def get_images( + self, session_id: str, skip: int = 0, limit: int = 100 + ) -> List[SessionImage]: """Get all images in a session with pagination""" if session_id not in self.sessions: return [] - + images = list(self.sessions[session_id]["images"].values()) # Apply pagination - return images[skip:skip + limit] - + return images[skip : skip + limit] + def get_image(self, session_id: str, image_id: str) -> Optional[SessionImage]: """Get specific image by ID""" if session_id not in self.sessions: return None return self.sessions[session_id]["images"].get(image_id) - - def add_annotation(self, session_id: str, image_id: str, file_path: str, - auto_generated: bool = False, model_id: Optional[str] = None, - annotation_id: Optional[str] = None) -> Optional[SessionAnnotation]: + + def add_annotation( + self, + session_id: str, + image_id: str, + file_path: str, + auto_generated: bool = False, + model_id: Optional[str] = None, + annotation_id: Optional[str] = None, + ) -> Optional[SessionAnnotation]: """Add annotation to session and return the created annotation object""" - if session_id not in self.sessions or image_id not in self.sessions[session_id]["images"]: + if ( + session_id not in self.sessions + or image_id not in self.sessions[session_id]["images"] + ): return None - + # Use provided annotation_id or generate a new one if not annotation_id: annotation_id = str(uuid.uuid4()) - + annotation = SessionAnnotation( annotation_id=annotation_id, image_id=image_id, file_path=file_path, auto_generated=auto_generated, - model_id=model_id + model_id=model_id, ) - + self.sessions[session_id]["annotations"][annotation_id] = annotation return annotation - - def get_annotations(self, session_id: str, image_id: Optional[str] = None) -> List[SessionAnnotation]: + + def get_annotations( + self, session_id: str, image_id: Optional[str] = None + ) -> List[SessionAnnotation]: """Get annotations, optionally filtered by image_id""" if session_id not in self.sessions: return [] - + annotations = list(self.sessions[session_id]["annotations"].values()) - + # Filter by image_id if provided if image_id: annotations = [a for a in annotations if a.image_id == image_id] - + return annotations - - def get_annotation(self, session_id: str, annotation_id: str) -> Optional[SessionAnnotation]: + + def get_annotation( + self, session_id: str, annotation_id: str + ) -> Optional[SessionAnnotation]: """Get specific annotation by ID""" if session_id not in self.sessions: return None - + return self.sessions[session_id]["annotations"].get(annotation_id) - + def remove_annotation(self, session_id: str, annotation_id: str) -> bool: """Remove annotation from session and return True if successful""" if session_id not in self.sessions: return False - + if annotation_id in self.sessions[session_id]["annotations"]: del self.sessions[session_id]["annotations"][annotation_id] return True return False - + def remove_image(self, session_id: str, image_id: str) -> bool: """Remove image from session and return True if successful""" if session_id not in self.sessions: return False - + if image_id in self.sessions[session_id]["images"]: # Remove the image del self.sessions[session_id]["images"][image_id] - + # Remove any annotations associated with this image self.sessions[session_id]["annotations"] = { - k: v for k, v in self.sessions[session_id]["annotations"].items() + k: v + for k, v in self.sessions[session_id]["annotations"].items() if v.image_id != image_id } - + return True return False - + def delete_session(self, session_id: str) -> bool: """Delete a session and return True if successful""" if session_id in self.sessions: del self.sessions[session_id] return True return False - + def export_session(self, session_id: str) -> Optional[Dict]: """Export session data as a dictionary""" return self.get_session(session_id) - + def import_session(self, session_id: str, data: Dict) -> bool: """Import session data from a dictionary""" if "images" in data and "annotations" in data: diff --git a/app/tests/generate_test_requirements.py b/app/tests/generate_test_requirements.py index d0666b3..7755261 100644 --- a/app/tests/generate_test_requirements.py +++ b/app/tests/generate_test_requirements.py @@ -13,29 +13,29 @@ # Testing frameworks "httpx==0.28.1", "coverage==7.8.0", - # FastAPI related "fastapi>=0.115.8", "starlette>=0.45.3", "uvicorn>=0.27.0", "python-multipart>=0.0.7", - # Image processing - "pillow==11.2.1" + "pillow==11.2.1", ] + def main(): """Generate requirements file.""" # Get the directory of this script script_dir = Path(__file__).resolve().parent requirements_path = script_dir / "test_requirements.txt" - + with open(requirements_path, "w") as f: f.write("# Test requirements for sat-annotator backend\n") for req in REQUIREMENTS: f.write(f"{req}\n") - + print(f"Created {requirements_path}") + if __name__ == "__main__": main() diff --git a/app/tests/mocks.py b/app/tests/mocks.py index e715398..bbe826a 100644 --- a/app/tests/mocks.py +++ b/app/tests/mocks.py @@ -1,6 +1,7 @@ """ Mocks for tests to avoid loading the actual SAM model """ + import os import sys from pathlib import Path @@ -12,24 +13,28 @@ mock_torch = MagicMock() mock_cv2 = MagicMock() + # Setup mock for SAM model registry class MockSAMModel: def __init__(self, checkpoint=None): self.checkpoint = checkpoint - + def to(self, device=None): return self + def mock_model_constructor(checkpoint): return MockSAMModel(checkpoint) + sam_registry = { - 'vit_h': mock_model_constructor, - 'vit_l': mock_model_constructor, - 'vit_b': mock_model_constructor, + "vit_h": mock_model_constructor, + "vit_l": mock_model_constructor, + "vit_b": mock_model_constructor, } mock_segment_anything.sam_model_registry = sam_registry + # Setup mock for SamPredictor class MockSamPredictor: def __init__(self, model): @@ -37,11 +42,11 @@ def __init__(self, model): self.image = None self.current_image_path = None self.cache = {} - + def set_image(self, image): self.image = image - return getattr(image, 'shape', (768, 1024))[:2] - + return getattr(image, "shape", (768, 1024))[:2] + def predict(self, point_coords, point_labels, multimask_output=True): h, w = 768, 1024 masks = np.zeros((3, h, w), dtype=bool) @@ -50,6 +55,7 @@ def predict(self, point_coords, point_labels, multimask_output=True): logits = np.zeros((3, h, w)) return masks, scores, logits + mock_segment_anything.SamPredictor = MockSamPredictor # Setup mock for torch @@ -64,8 +70,12 @@ def predict(self, point_coords, point_labels, multimask_output=True): mock_cv2.imread = lambda path: mock_image mock_cv2.cvtColor = lambda img, code: img mock_cv2.findContours = lambda mask, mode, method: ( - [np.array([[[400, 300]], [[600, 300]], [[600, 500]], [[400, 500]]], dtype=np.int32)], - None + [ + np.array( + [[[400, 300]], [[600, 300]], [[600, 500]], [[400, 500]]], dtype=np.int32 + ) + ], + None, ) mock_cv2.RETR_EXTERNAL = 0 mock_cv2.CHAIN_APPROX_SIMPLE = 1 @@ -80,12 +90,13 @@ def predict(self, point_coords, point_labels, multimask_output=True): mock_pil.Image.open.return_value.__enter__.return_value.width = 1024 mock_pil.Image.open.return_value.__enter__.return_value.height = 768 + def apply_mocks(): """Apply the mocks to sys.modules""" - sys.modules['segment_anything'] = mock_segment_anything - sys.modules['torch'] = mock_torch - sys.modules['cv2'] = mock_cv2 - sys.modules['PIL'] = mock_pil - + sys.modules["segment_anything"] = mock_segment_anything + sys.modules["torch"] = mock_torch + sys.modules["cv2"] = mock_cv2 + sys.modules["PIL"] = mock_pil + # Set environment variable to indicate we're in test mode os.environ["SAT_ANNOTATOR_TEST_MODE"] = "1" diff --git a/app/tests/run_unittests.py b/app/tests/run_unittests.py index e8a79d6..7f79900 100644 --- a/app/tests/run_unittests.py +++ b/app/tests/run_unittests.py @@ -16,38 +16,38 @@ def main(): """Run all unittest-based tests.""" # Get the directory of this script script_dir = Path(__file__).resolve().parent - + # Get the app directory (parent of tests directory) app_dir = script_dir.parent - + # Add the app directory to sys.path if str(app_dir) not in sys.path: sys.path.insert(0, str(app_dir)) - + # Set environment variable to indicate we're in test mode os.environ["SAT_ANNOTATOR_TEST_MODE"] = "1" - + # Get a list of unittest files unittest_files = [ "unittest_session_store.py", - "unittest_session_manager.py", + "unittest_session_manager.py", "unittest_image_processing.py", "unittest_main_api.py", "unittest_sam_segmenter.py", "unittest_session_images_api.py", - "unittest_segmentation_api.py" + "unittest_segmentation_api.py", ] - + # Import and run each unittest file separately results = [] all_tests = 0 failures = 0 errors = 0 - + for test_file in unittest_files: print(f"\n===== Running {test_file} =====") test_path = script_dir / test_file - + # Run the test file as a subprocess with timeout try: # Run with real-time output for better debugging @@ -56,92 +56,100 @@ def main(): [sys.executable, str(test_path)], capture_output=True, text=True, - timeout=30 # Add timeout to catch hanging tests + timeout=30, # Add timeout to catch hanging tests ) # Print output if result.stdout: print(result.stdout) else: print("(No output)") - + if result.stderr: print("ERRORS:") print(result.stderr) - + # Determine test status stdout_content = result.stdout or "" stderr_content = result.stderr or "" - + if "OK" in stdout_content: status = "PASS" elif "FAILED" in stdout_content or "ERROR" in stdout_content: status = "FAIL" else: status = "UNKNOWN" - + except subprocess.TimeoutExpired: print("ERROR: Test timed out after 30 seconds") status = "TIMEOUT" errors += 1 result = None - # Extract test counts if we have results + # Extract test counts if we have results if result: try: test_count = 0 fail_count = 0 error_count = 0 - + # Get safe content to work with stdout_content = result.stdout or "" stderr_content = result.stderr or "" - + # Process test counts with multiple approaches for reliability - + # First check for explicit "Ran X tests" in stdout (standard unittest output) - ran_match = re.search(r'Ran (\d+) test', stdout_content) + ran_match = re.search(r"Ran (\d+) test", stdout_content) if ran_match: test_count = int(ran_match.group(1)) # Also check output for "Tests run: X" (our custom output) else: - tests_run_match = re.search(r'Tests run: (\d+)', stdout_content) + tests_run_match = re.search(r"Tests run: (\d+)", stdout_content) if tests_run_match: test_count = int(tests_run_match.group(1)) else: # Count lines ending with "... ok" which typically appear for each test - ok_matches = re.findall(r'\.+ ok$', stdout_content, re.MULTILINE) + ok_matches = re.findall( + r"\.+ ok$", stdout_content, re.MULTILINE + ) if ok_matches: test_count = len(ok_matches) else: # Look for pattern "test_name (...) ... ok" - test_ok_matches = re.findall(r'test_\w+\s+\([^)]+\).*ok', stdout_content) + test_ok_matches = re.findall( + r"test_\w+\s+\([^)]+\).*ok", stdout_content + ) if test_ok_matches: test_count = len(test_ok_matches) else: # Look for test_* methods in output as last resort - method_matches = re.findall(r'test_\w+', stdout_content) + method_matches = re.findall(r"test_\w+", stdout_content) if method_matches: - test_count = len(set(method_matches)) # Use set to remove duplicates - + test_count = len( + set(method_matches) + ) # Use set to remove duplicates + # Check for failures - fail_matches = re.search(r'[Ff]ailures[=:]?\s*(\d+)', stdout_content) + fail_matches = re.search(r"[Ff]ailures[=:]?\s*(\d+)", stdout_content) if fail_matches: fail_count = int(fail_matches.group(1)) - + # Check for errors - error_matches = re.search(r'[Ee]rrors[=:]?\s*(\d+)', stdout_content) + error_matches = re.search(r"[Ee]rrors[=:]?\s*(\d+)", stdout_content) if error_matches: error_count = int(error_matches.group(1)) - + # Update overall counts all_tests += test_count failures += fail_count errors += error_count - + # Add to results results.append((test_file, status, test_count)) - - print(f"Detected {test_count} tests, {fail_count} failures, {error_count} errors") - + + print( + f"Detected {test_count} tests, {fail_count} failures, {error_count} errors" + ) + except Exception as e: print(f"Error parsing test results: {e}") # Still add the file to results with unknown status @@ -149,7 +157,7 @@ def main(): else: # Add the timed out file results.append((test_file, "TIMEOUT", 0)) - + # Print summary print("\n\n===== TEST SUMMARY =====") print(f"{'Test File':<32} {'Status':<10} {'Tests':<10}") @@ -160,7 +168,7 @@ def main(): print(f"Total tests: {all_tests}") print(f"Failures: {failures}") print(f"Errors: {errors}") - + # Overall status if all_tests == 0: print("\nWARNING: No test results were detected!") @@ -168,7 +176,7 @@ def main(): print(f"\nALL {all_tests} TESTS PASSED!") else: print(f"\nTESTS FAILED: {failures + errors} issues found in {all_tests} tests.") - + # Return success if all tests passed return 0 if failures == 0 and errors == 0 else 1 diff --git a/app/tests/unittest_image_processing.py b/app/tests/unittest_image_processing.py index 9956237..3c8b7ea 100644 --- a/app/tests/unittest_image_processing.py +++ b/app/tests/unittest_image_processing.py @@ -21,81 +21,75 @@ os.environ["SAT_ANNOTATOR_TEST_MODE"] = "1" # Mock PIL before importing our app code -sys.modules['PIL'] = MagicMock() +sys.modules["PIL"] = MagicMock() # Import application code from utils.image_processing import validate_image_file + class MockUploadFile: """Mock class for FastAPI's UploadFile""" + def __init__(self, filename, content_type, content=None): self.filename = filename self.content_type = content_type self.file = io.BytesIO(content or b"mock content") - + async def read(self): self.file.seek(0) return self.file.read() + class TestImageProcessing(unittest.TestCase): """Tests for image processing utilities""" - + def test_validate_image_file(self): """Test validating image file types""" # Test valid image types - valid_types = [ - "image/jpeg", - "image/png", - "image/tiff", - "image/geotiff" - ] - + valid_types = ["image/jpeg", "image/png", "image/tiff", "image/geotiff"] + for content_type in valid_types: mock_file = MockUploadFile("test.jpg", content_type) self.assertTrue( validate_image_file(mock_file), - f"validate_image_file should return True for {content_type}" + f"validate_image_file should return True for {content_type}", ) - + # Test invalid image types - invalid_types = [ - "application/pdf", - "text/plain", - "application/octet-stream" - ] - + invalid_types = ["application/pdf", "text/plain", "application/octet-stream"] + for content_type in invalid_types: mock_file = MockUploadFile("test.pdf", content_type) self.assertFalse( validate_image_file(mock_file), - f"validate_image_file should return False for {content_type}" + f"validate_image_file should return False for {content_type}", ) - - @patch('builtins.open', MagicMock()) - @patch('os.path.getsize', MagicMock(return_value=1024)) - @patch('pathlib.Path.mkdir', MagicMock()) - @patch('PIL.Image.open') + + @patch("builtins.open", MagicMock()) + @patch("os.path.getsize", MagicMock(return_value=1024)) + @patch("pathlib.Path.mkdir", MagicMock()) + @patch("PIL.Image.open") async def test_save_upload_file(self, mock_pil_open): """Test saving an uploaded file""" # Import here to avoid circular import issues with the mocks from utils.image_processing import save_upload_file - + # Mock PIL image dimensions mock_img = MagicMock() mock_img.width = 1024 mock_img.height = 768 mock_pil_open.return_value.__enter__.return_value = mock_img - + # Create a mock file mock_file = MockUploadFile( "test.jpg", "image/jpeg", - content=b"\xff\xd8\xff\xe0JFIF" # JPEG file signature + content=b"\xff\xd8\xff\xe0JFIF", # JPEG file signature ) - + # Call the function file_info = await save_upload_file(mock_file) - + # Verify the result self.assertIn("filename", file_info) self.assertIn("original_filename", file_info) @@ -106,28 +100,34 @@ async def test_save_upload_file(self, mock_pil_open): self.assertEqual(file_info["size"], 1024) self.assertEqual(file_info["content_type"], "image/jpeg") + if __name__ == "__main__": # Run synchronous tests directly print("Running image processing tests...") print(f"App path: {app_path}") print(f"sys.path: {sys.path}") - + sync_suite = unittest.TestSuite() - sync_methods = [m for m in dir(TestImageProcessing) if m.startswith('test_') and not m.startswith('test_save')] - + sync_methods = [ + m + for m in dir(TestImageProcessing) + if m.startswith("test_") and not m.startswith("test_save") + ] + for method in sync_methods: print(f"Adding test method: {method}") sync_suite.addTest(TestImageProcessing(method)) - + print("Running synchronous tests...") runner = unittest.TextTestRunner(verbosity=2) runner.run(sync_suite) - + # Run async tests manually print("\nRunning asynchronous tests...") import asyncio + test_case = TestImageProcessing() - + # Run test_save_upload_file print("\ntest_save_upload_file:") try: @@ -136,4 +136,5 @@ async def test_save_upload_file(self, mock_pil_open): except Exception as e: print(f"ERROR: {e}") import traceback + traceback.print_exc() diff --git a/app/tests/unittest_main_api.py b/app/tests/unittest_main_api.py index e882713..9f633da 100644 --- a/app/tests/unittest_main_api.py +++ b/app/tests/unittest_main_api.py @@ -21,61 +21,68 @@ # Import mocks before importing any app code from mocks import apply_mocks + apply_mocks() # We'll create a FastAPI app for testing from fastapi import FastAPI from fastapi.testclient import TestClient + class TestMainAPI(unittest.TestCase): """Tests for the main FastAPI application""" - + def setUp(self): """Set up test environment before each test""" # Create a simple FastAPI app for testing self.app = FastAPI(title="Satellite Image Annotation Tool") - + # Add a health check endpoint @self.app.get("/health") def health_check(): return {"status": "healthy"} + # Add root endpoint (when frontend is not mounted) @self.app.get("/") def read_root(): - return {"message": "Welcome to the Satellite Image Annotation Tool (API Only Mode)"} - + return { + "message": "Welcome to the Satellite Image Annotation Tool (API Only Mode)" + } + # Add frontend status endpoint @self.app.get("/frontend-status") def frontend_status(): return { - "status": "not_mounted", - "message": "Frontend not mounted. Static files should be served from the web directory." + "status": "not_mounted", + "message": "Frontend not mounted. Static files should be served from the web directory.", } - + # Create a test client self.client = TestClient(self.app) - + def test_app_creation(self): """Test that the FastAPI app is created successfully""" # Ensure the app exists and has the expected attributes self.assertIsNotNone(self.app) self.assertEqual(self.app.title, "Satellite Image Annotation Tool") - + def test_health_endpoint(self): """Test the health check endpoint""" response = self.client.get("/health") self.assertEqual(response.status_code, 200) self.assertEqual(response.json(), {"status": "healthy"}) - + def test_root_endpoint(self): """Test the root endpoint when frontend is not mounted""" response = self.client.get("/") self.assertEqual(response.status_code, 200) self.assertEqual( - response.json(), - {"message": "Welcome to the Satellite Image Annotation Tool (API Only Mode)"} + response.json(), + { + "message": "Welcome to the Satellite Image Annotation Tool (API Only Mode)" + }, ) - + def test_frontend_status_endpoint(self): """Test the frontend status endpoint""" response = self.client.get("/frontend-status") @@ -83,25 +90,26 @@ def test_frontend_status_endpoint(self): self.assertEqual( response.json(), { - "status": "not_mounted", - "message": "Frontend not mounted. Static files should be served from the web directory." - } + "status": "not_mounted", + "message": "Frontend not mounted. Static files should be served from the web directory.", + }, ) + if __name__ == "__main__": print("Running main API tests...") print(f"App path: {app_path}") print(f"sys.path: {sys.path}") print("Test methods:") for method in dir(TestMainAPI): - if method.startswith('test_'): + if method.startswith("test_"): print(f" - {method}") - + # Create test suite explicitly suite = unittest.TestSuite() for method in dir(TestMainAPI): - if method.startswith('test_'): + if method.startswith("test_"): suite.addTest(TestMainAPI(method)) - + runner = unittest.TextTestRunner(verbosity=2) runner.run(suite) diff --git a/app/tests/unittest_sam_segmenter.py b/app/tests/unittest_sam_segmenter.py index bd538a6..39dc9f5 100644 --- a/app/tests/unittest_sam_segmenter.py +++ b/app/tests/unittest_sam_segmenter.py @@ -22,29 +22,31 @@ # Import mocks before importing any app code from mocks import apply_mocks + apply_mocks() # Now import the segmenter code from utils.sam_model import SAMSegmenter + class TestSAMSegmenter(unittest.TestCase): """Tests for SAM segmentation model wrapper""" - + def setUp(self): """Set up the test environment""" # Create a mock segmenter instance self.segmenter = SAMSegmenter() - + # Create test data self.test_image_path = "/fake/path/image.jpg" self.test_point = [500, 400] - + def test_segmenter_initialization(self): """Test that the segmenter is initialized correctly""" self.assertIsNotNone(self.segmenter) self.assertEqual(self.segmenter.cache, {}) self.assertIsNone(self.segmenter.current_image_path) - + @patch("cv2.imread") @patch("cv2.cvtColor") def test_set_image(self, mock_cvtcolor, mock_imread): @@ -53,16 +55,18 @@ def test_set_image(self, mock_cvtcolor, mock_imread): mock_img = np.zeros((768, 1024, 3), dtype=np.uint8) mock_imread.return_value = mock_img mock_cvtcolor.return_value = mock_img - + # Call the method result = self.segmenter.set_image(self.test_image_path) - + # Check the results self.assertEqual(result, (768, 1024)) self.assertEqual(self.segmenter.current_image_path, self.test_image_path) self.assertIn(self.test_image_path, self.segmenter.cache) - self.assertEqual(self.segmenter.cache[self.test_image_path]['image_size'], (768, 1024)) - + self.assertEqual( + self.segmenter.cache[self.test_image_path]["image_size"], (768, 1024) + ) + @patch("cv2.imread") @patch("cv2.cvtColor") def test_predict_from_point(self, mock_cvtcolor, mock_imread): @@ -71,96 +75,106 @@ def test_predict_from_point(self, mock_cvtcolor, mock_imread): mock_img = np.zeros((768, 1024, 3), dtype=np.uint8) mock_imread.return_value = mock_img mock_cvtcolor.return_value = mock_img - + # Set up the predictor mock self.segmenter.predictor.predict = MagicMock() - + # Create a mock mask mock_mask = np.zeros((768, 1024), dtype=bool) mock_mask[300:500, 400:600] = True # Add a rectangle - + # Set up the predict return value - masks = np.array([mock_mask, np.zeros_like(mock_mask), np.zeros_like(mock_mask)]) + masks = np.array( + [mock_mask, np.zeros_like(mock_mask), np.zeros_like(mock_mask)] + ) scores = np.array([0.95, 0.5, 0.3]) self.segmenter.predictor.predict.return_value = (masks, scores, None) - + # Set the image first self.segmenter.set_image(self.test_image_path) - + # Call the method result = self.segmenter.predict_from_point(self.test_point) - + # Check the results self.assertEqual(result.shape, (768, 1024)) - self.assertTrue(np.array_equal(result[300:500, 400:600], np.ones((200, 200), dtype=np.uint8) * 255)) - self.assertTrue(np.array_equal(result[0:300, 0:400], np.zeros((300, 400), dtype=np.uint8))) - + self.assertTrue( + np.array_equal( + result[300:500, 400:600], np.ones((200, 200), dtype=np.uint8) * 255 + ) + ) + self.assertTrue( + np.array_equal(result[0:300, 0:400], np.zeros((300, 400), dtype=np.uint8)) + ) + # Check that the result was cached point_key = tuple(self.test_point) - self.assertIn(point_key, self.segmenter.cache[self.test_image_path]['masks']) - + self.assertIn(point_key, self.segmenter.cache[self.test_image_path]["masks"]) + @patch("cv2.findContours") def test_mask_to_polygon(self, mock_findcontours): """Test converting a mask to polygon coordinates""" # Create a test mask mask = np.zeros((768, 1024), dtype=np.uint8) mask[300:500, 400:600] = 255 - + # Create a mock contour contours = [np.array([[[400, 300]], [[600, 300]], [[600, 500]], [[400, 500]]])] mock_findcontours.return_value = (contours, None) - + # Call the method result = self.segmenter.mask_to_polygon(mask) - + # Check the result self.assertEqual(len(result), 4) # Should have 4 points for a rectangle self.assertEqual(result[0], [400, 300]) self.assertEqual(result[1], [600, 300]) self.assertEqual(result[2], [600, 500]) self.assertEqual(result[3], [400, 500]) - + def test_clear_cache(self): """Test clearing the segmenter cache""" # Set up test data in the cache self.segmenter.cache = { "image1.jpg": {"image_size": (100, 100), "masks": {}}, - "image2.jpg": {"image_size": (200, 200), "masks": {}} + "image2.jpg": {"image_size": (200, 200), "masks": {}}, } self.segmenter.current_image_path = "image1.jpg" - + # Clear one specific image self.segmenter.clear_cache("image1.jpg") - + # Check the result self.assertNotIn("image1.jpg", self.segmenter.cache) self.assertIn("image2.jpg", self.segmenter.cache) self.assertIsNone(self.segmenter.current_image_path) - + # Clear all cache self.segmenter.clear_cache() - + # Check the result self.assertEqual(self.segmenter.cache, {}) + if __name__ == "__main__": print("Running SAM segmenter tests...") print(f"App path: {app_path}") print(f"sys.path: {sys.path}") - + try: # Create test suite explicitly suite = unittest.TestSuite() for method in dir(TestSAMSegmenter): - if method.startswith('test_'): + if method.startswith("test_"): print(f"Adding test method: {method}") suite.addTest(TestSAMSegmenter(method)) - + # Run tests with a time limit runner = unittest.TextTestRunner(verbosity=2) runner.run(suite) - + except Exception as e: print(f"Error in test execution: {e}") import traceback + traceback.print_exc() diff --git a/app/tests/unittest_segmentation_api.py b/app/tests/unittest_segmentation_api.py index ce69874..42ce5db 100644 --- a/app/tests/unittest_segmentation_api.py +++ b/app/tests/unittest_segmentation_api.py @@ -23,6 +23,7 @@ # Import mocks before importing any app code from mocks import apply_mocks + apply_mocks() # Import FastAPI testing components @@ -34,190 +35,199 @@ from storage.session_store import SessionStore, session_store from storage.session_manager import SESSION_COOKIE_NAME + class TestSegmentationAPI(unittest.TestCase): """Tests for segmentation API endpoints""" - + def setUp(self): """Set up test environment before each test""" # Reset session store session_store.sessions = {} - + # Create a test session ID self.test_session_id = str(uuid.uuid4()) session_store.create_session(self.test_session_id) - + # Add a test image to the session self.test_image = session_store.add_image( session_id=self.test_session_id, file_name="test.jpg", file_path="uploads/test.jpg", - resolution="1024x768" + resolution="1024x768", ) - + # Create a FastAPI app self.app = FastAPI() - + # Mock SAM segmenter self.mock_segmenter = MagicMock() self.mock_segmenter.set_image.return_value = (768, 1024) - + # Create a mock mask mock_mask = np.zeros((768, 1024), dtype=np.uint8) mock_mask[300:500, 400:600] = 255 self.mock_segmenter.predict_from_point.return_value = mock_mask - + # Mock polygon output self.mock_segmenter.mask_to_polygon.return_value = [ - [400, 300], [600, 300], [600, 500], [400, 500] + [400, 300], + [600, 300], + [600, 500], + [400, 500], ] - + # Add test route for point-based segmentation @self.app.post("/api/segment/point") def segment_from_point(request_data: dict): image_id = request_data.get("image_id") point = request_data.get("point") - + # Check if image exists if image_id not in session_store.sessions[self.test_session_id]["images"]: - return JSONResponse(status_code=404, content={"success": False, "detail": "Image not found"}) - + return JSONResponse( + status_code=404, + content={"success": False, "detail": "Image not found"}, + ) + # Get the image image = session_store.sessions[self.test_session_id]["images"][image_id] - + # Mock segmentation using our mock segmenter try: self.mock_segmenter.set_image(image.file_path) mask = self.mock_segmenter.predict_from_point(point) polygon = self.mock_segmenter.mask_to_polygon(mask) - + # Return successful response return { "success": True, "mask_url": f"/api/segment/mask/{image_id}", - "polygon": polygon + "polygon": polygon, } except Exception as e: return JSONResponse( status_code=500, - content={"success": False, "detail": f"Segmentation failed: {str(e)}"} + content={ + "success": False, + "detail": f"Segmentation failed: {str(e)}", + }, ) - + # Add test route for retrieving a mask @self.app.get("/api/segment/mask/{image_id}") def get_mask(image_id: str): # Just return a success response for testing purposes return {"success": True, "image_id": image_id} - + # Create test client self.client = TestClient(self.app) - + def test_segment_from_point(self): """Test segmenting an image from a point prompt""" # Prepare test data - test_data = { - "image_id": self.test_image.image_id, - "point": [512, 384] - } - + test_data = {"image_id": self.test_image.image_id, "point": [512, 384]} + # Call the endpoint response = self.client.post( "/api/segment/point", json=test_data, - cookies={SESSION_COOKIE_NAME: self.test_session_id} + cookies={SESSION_COOKIE_NAME: self.test_session_id}, ) - + # Check response self.assertEqual(response.status_code, 200) data = response.json() self.assertTrue(data["success"]) self.assertIn("mask_url", data) self.assertIn("polygon", data) - + # Check mask URL format expected_mask_url = f"/api/segment/mask/{self.test_image.image_id}" self.assertEqual(data["mask_url"], expected_mask_url) - + # Check polygon data self.assertEqual(len(data["polygon"]), 4) # Should be a rectangle with 4 points self.assertEqual(data["polygon"][0], [400, 300]) self.assertEqual(data["polygon"][1], [600, 300]) self.assertEqual(data["polygon"][2], [600, 500]) self.assertEqual(data["polygon"][3], [400, 500]) - + # Verify the segmenter was called correctly self.mock_segmenter.set_image.assert_called_once_with(self.test_image.file_path) - self.mock_segmenter.predict_from_point.assert_called_once_with(test_data["point"]) + self.mock_segmenter.predict_from_point.assert_called_once_with( + test_data["point"] + ) self.mock_segmenter.mask_to_polygon.assert_called_once() - + def test_segment_invalid_image(self): """Test segmenting an image that doesn't exist""" # Prepare test data with a non-existent image ID - test_data = { - "image_id": str(uuid.uuid4()), - "point": [512, 384] - } - + test_data = {"image_id": str(uuid.uuid4()), "point": [512, 384]} + # Call the endpoint response = self.client.post( "/api/segment/point", json=test_data, - cookies={SESSION_COOKIE_NAME: self.test_session_id} + cookies={SESSION_COOKIE_NAME: self.test_session_id}, ) - + # Check response self.assertEqual(response.status_code, 404) data = response.json() self.assertFalse(data["success"]) self.assertIn("detail", data) self.assertIn("Image not found", data["detail"]) - + def test_get_mask(self): """Test retrieving a mask""" # Call the endpoint response = self.client.get( f"/api/segment/mask/{self.test_image.image_id}", - cookies={SESSION_COOKIE_NAME: self.test_session_id} + cookies={SESSION_COOKIE_NAME: self.test_session_id}, ) - + # Check response self.assertEqual(response.status_code, 200) data = response.json() self.assertTrue(data["success"]) self.assertEqual(data["image_id"], self.test_image.image_id) + if __name__ == "__main__": print("Running segmentation API tests...") print(f"App path: {app_path}") print(f"sys.path: {sys.path}") - + try: # Create test suite explicitly - test_methods = [m for m in dir(TestSegmentationAPI) if m.startswith('test_')] + test_methods = [m for m in dir(TestSegmentationAPI) if m.startswith("test_")] print(f"Found {len(test_methods)} test methods:") - + suite = unittest.TestSuite() for method in test_methods: print(f" - {method}") suite.addTest(TestSegmentationAPI(method)) - + # Run tests with clear output runner = unittest.TextTestRunner(verbosity=2) results = runner.run(suite) - + # Print explicit summary print(f"\nTests run: {results.testsRun}") print(f"Failures: {len(results.failures)}") print(f"Errors: {len(results.errors)}") print("DONE") - + # Force flush output to ensure it's captured import sys + sys.stdout.flush() - + except Exception as e: print(f"Error in test execution: {e}") import traceback + traceback.print_exc() - + print("Test execution completed") sys.stdout.flush() diff --git a/app/tests/unittest_session_images_api.py b/app/tests/unittest_session_images_api.py index ffde1c1..c3e1ff3 100644 --- a/app/tests/unittest_session_images_api.py +++ b/app/tests/unittest_session_images_api.py @@ -23,6 +23,7 @@ # Import mocks before importing any app code from mocks import apply_mocks + apply_mocks() # Import FastAPI testing components @@ -34,45 +35,50 @@ from storage.session_store import SessionStore, session_store from storage.session_manager import SessionManager, SESSION_COOKIE_NAME + class MockUploadFile: """Mock for FastAPI's UploadFile""" + def __init__(self, filename, content_type="image/jpeg", content=None): self.filename = filename self.content_type = content_type self.file = io.BytesIO(content or b"mock image content") - + async def read(self): self.file.seek(0) return self.file.read() - + def __enter__(self): return self - + def __exit__(self, exc_type, exc_val, exc_tb): pass + class TestSessionImagesAPI(unittest.TestCase): """Tests for session_images API endpoints""" - + def setUp(self): """Set up test environment before each test""" # Reset session store session_store.sessions = {} - + # Create a test session ID self.test_session_id = str(uuid.uuid4()) session_store.create_session(self.test_session_id) - + # Create a FastAPI app self.app = FastAPI() - + # Add test route for uploading images @self.app.post("/api/images/upload") async def upload_image(file: UploadFile = File(...)): # Mock implementation of the upload endpoint - if not file.content_type.startswith('image/'): - return JSONResponse(status_code=400, content={"detail": "File must be an image"}) - + if not file.content_type.startswith("image/"): + return JSONResponse( + status_code=400, content={"detail": "File must be an image"} + ) + # Create mock file info file_info = { "filename": f"{uuid.uuid4()}.jpg", @@ -80,23 +86,20 @@ async def upload_image(file: UploadFile = File(...)): "path": f"uploads/{uuid.uuid4()}.jpg", "resolution": "1024x768", "size": 1024, - "content_type": file.content_type + "content_type": file.content_type, } - + # Add image to session image = session_store.add_image( session_id=self.test_session_id, file_name=file_info["original_filename"], file_path=file_info["path"], - resolution=file_info["resolution"] + resolution=file_info["resolution"], ) - - return { - "success": True, - "image_id": image.image_id, - "file_info": file_info - } - # Add test route for getting images + + return {"success": True, "image_id": image.image_id, "file_info": file_info} + + # Add test route for getting images @self.app.get("/api/images") def get_images(): images = session_store.get_images(self.test_session_id) @@ -107,49 +110,59 @@ def get_images(): except AttributeError: # Fall back to old method for compatibility return {"images": [img.dict() for img in images]} - # Add test route for deleting an image + + # Add test route for deleting an image @self.app.delete("/api/images/{image_id}") def delete_image(image_id: str): # Implement removal directly since SessionStore doesn't have remove_image if self.test_session_id not in session_store.sessions: - return JSONResponse(status_code=404, content={"success": False, "detail": "Session not found"}) - + return JSONResponse( + status_code=404, + content={"success": False, "detail": "Session not found"}, + ) + if "images" not in session_store.sessions[self.test_session_id]: - return JSONResponse(status_code=404, content={"success": False, "detail": "No images in session"}) - + return JSONResponse( + status_code=404, + content={"success": False, "detail": "No images in session"}, + ) + if image_id in session_store.sessions[self.test_session_id]["images"]: # Remove the image from the session del session_store.sessions[self.test_session_id]["images"][image_id] return {"success": True} else: - return JSONResponse(status_code=404, content={"success": False, "detail": "Image not found"}) - + return JSONResponse( + status_code=404, + content={"success": False, "detail": "Image not found"}, + ) + # Create test client self.client = TestClient(self.app) - + @patch("builtins.open", MagicMock()) @patch("os.path.getsize", MagicMock(return_value=1024)) @patch("pathlib.Path.mkdir", MagicMock()) def test_upload_image(self): """Test uploading an image""" # Create a mock file - file_content = b'\xff\xd8\xff' + b'\x00' * 100 # JPEG signature + padding + file_content = b"\xff\xd8\xff" + b"\x00" * 100 # JPEG signature + padding mock_file = MockUploadFile(filename="test.jpg", content=file_content) - + # Mock PIL image - with patch('PIL.Image.open') as mock_pil_open: + with patch("PIL.Image.open") as mock_pil_open: mock_img = MagicMock() mock_img.width = 1024 mock_img.height = 768 mock_pil_open.return_value.__enter__.return_value = mock_img - + # Call the endpoint response = self.client.post( "/api/images/upload", files={"file": ("test.jpg", file_content, "image/jpeg")}, - cookies={SESSION_COOKIE_NAME: self.test_session_id} + cookies={SESSION_COOKIE_NAME: self.test_session_id}, ) - + # Check response self.assertEqual(response.status_code, 200) data = response.json() @@ -157,44 +170,43 @@ def test_upload_image(self): self.assertIn("image_id", data) self.assertIn("file_info", data) self.assertEqual(data["file_info"]["original_filename"], "test.jpg") - + # Verify image was added to session store images = session_store.get_images(self.test_session_id) self.assertEqual(len(images), 1) self.assertEqual(images[0].image_id, data["image_id"]) - + def test_upload_invalid_image(self): """Test uploading an invalid file type""" # Create a text file instead of an image - file_content = b'This is not an image' - + file_content = b"This is not an image" + # Call the endpoint response = self.client.post( "/api/images/upload", files={"file": ("test.txt", file_content, "text/plain")}, - cookies={SESSION_COOKIE_NAME: self.test_session_id} + cookies={SESSION_COOKIE_NAME: self.test_session_id}, ) - + # Check response self.assertEqual(response.status_code, 400) self.assertIn("detail", response.json()) self.assertIn("image", response.json()["detail"]) - + # Verify nothing was added to session store images = session_store.get_images(self.test_session_id) self.assertEqual(len(images), 0) - + def test_get_images_empty(self): """Test getting images when none exist""" response = self.client.get( - "/api/images", - cookies={SESSION_COOKIE_NAME: self.test_session_id} + "/api/images", cookies={SESSION_COOKIE_NAME: self.test_session_id} ) - + self.assertEqual(response.status_code, 200) data = response.json() self.assertEqual(data["images"], []) - + def test_get_images(self): """Test getting images from session""" # Add some test images to the session @@ -202,32 +214,31 @@ def test_get_images(self): session_id=self.test_session_id, file_name="test1.jpg", file_path="uploads/test1.jpg", - resolution="1024x768" + resolution="1024x768", ) - + image2 = session_store.add_image( session_id=self.test_session_id, file_name="test2.jpg", file_path="uploads/test2.jpg", - resolution="1920x1080" + resolution="1920x1080", ) - + # Call the endpoint response = self.client.get( - "/api/images", - cookies={SESSION_COOKIE_NAME: self.test_session_id} + "/api/images", cookies={SESSION_COOKIE_NAME: self.test_session_id} ) - + # Check response self.assertEqual(response.status_code, 200) data = response.json() self.assertEqual(len(data["images"]), 2) - + # Verify image data image_ids = [img["image_id"] for img in data["images"]] self.assertIn(image1.image_id, image_ids) self.assertIn(image2.image_id, image_ids) - + def test_delete_image(self): """Test deleting an image""" # Add a test image to the session @@ -235,66 +246,68 @@ def test_delete_image(self): session_id=self.test_session_id, file_name="test.jpg", file_path="uploads/test.jpg", - resolution="1024x768" + resolution="1024x768", ) - + # Call delete endpoint response = self.client.delete( f"/api/images/{image.image_id}", - cookies={SESSION_COOKIE_NAME: self.test_session_id} + cookies={SESSION_COOKIE_NAME: self.test_session_id}, ) - + # Check response self.assertEqual(response.status_code, 200) data = response.json() self.assertTrue(data["success"]) - + # Verify image was removed from session store images = session_store.get_images(self.test_session_id) self.assertEqual(len(images), 0) - + def test_delete_nonexistent_image(self): """Test deleting an image that doesn't exist""" # Call delete endpoint with random UUID response = self.client.delete( f"/api/images/{uuid.uuid4()}", - cookies={SESSION_COOKIE_NAME: self.test_session_id} + cookies={SESSION_COOKIE_NAME: self.test_session_id}, ) - + # Check response self.assertEqual(response.status_code, 404) data = response.json() self.assertFalse(data["success"]) self.assertIn("detail", data) + if __name__ == "__main__": print("Running session images API tests...") print(f"App path: {app_path}") print(f"sys.path: {sys.path}") - + try: # Create test suite explicitly - test_methods = [m for m in dir(TestSessionImagesAPI) if m.startswith('test_')] + test_methods = [m for m in dir(TestSessionImagesAPI) if m.startswith("test_")] print(f"Found {len(test_methods)} test methods:") - + suite = unittest.TestSuite() for method in test_methods: print(f" - {method}") suite.addTest(TestSessionImagesAPI(method)) - + # Run tests with clear output runner = unittest.TextTestRunner(verbosity=2) results = runner.run(suite) - + # Print explicit summary print(f"\nTests run: {results.testsRun}") print(f"Failures: {len(results.failures)}") print(f"Errors: {len(results.errors)}") print("DONE") - + except Exception as e: print(f"Error in test execution: {e}") import traceback + traceback.print_exc() - + print("Test execution completed") diff --git a/app/tests/unittest_session_manager.py b/app/tests/unittest_session_manager.py index 342a9cc..c97314f 100644 --- a/app/tests/unittest_session_manager.py +++ b/app/tests/unittest_session_manager.py @@ -21,43 +21,44 @@ # Import application code from storage.session_manager import ( - SessionManager, - generate_session_id, + SessionManager, + generate_session_id, get_session_id, - set_session_cookie, - SESSION_COOKIE_NAME + set_session_cookie, + SESSION_COOKIE_NAME, ) + # Mock classes for FastAPI request/response class MockRequest: def __init__(self, cookies=None): self.cookies = cookies or {} + class MockResponse: def __init__(self): self.cookies = {} self._deleted_cookies = {} - def set_cookie(self, key, value, expires=None, httponly=False, samesite=None, path=None): + def set_cookie( + self, key, value, expires=None, httponly=False, samesite=None, path=None + ): self.cookies[key] = { "key": key, "value": value, "expires": expires, "httponly": httponly, "samesite": samesite, - "path": path + "path": path, } - + def delete_cookie(self, key, httponly=False, path=None): - self._deleted_cookies[key] = { - "key": key, - "httponly": httponly, - "path": path - } + self._deleted_cookies[key] = {"key": key, "httponly": httponly, "path": path} + class TestSessionManager(unittest.TestCase): """Tests for SessionManager functionality""" - + def test_generate_session_id(self): """Test generating a session ID produces a valid UUID""" session_id = generate_session_id() @@ -67,19 +68,19 @@ def test_generate_session_id(self): self.assertEqual(str(uuid_obj), session_id) except ValueError: self.fail("generate_session_id did not produce a valid UUID") - + async def test_get_session_id_existing(self): """Test retrieving an existing session ID from cookies""" test_session_id = str(uuid.uuid4()) request = MockRequest(cookies={SESSION_COOKIE_NAME: test_session_id}) - + result = await get_session_id(request) self.assertEqual(result, test_session_id) - + async def test_get_session_id_new(self): """Test that a new session ID is generated if none exists""" request = MockRequest() - + result = await get_session_id(request) # Verify it's a valid UUID try: @@ -87,38 +88,42 @@ async def test_get_session_id_new(self): self.assertEqual(str(uuid_obj), result) except ValueError: self.fail("get_session_id did not produce a valid UUID") - + def test_set_session_cookie(self): """Test that session cookie is properly set in response""" response = MockResponse() test_session_id = str(uuid.uuid4()) - + # Call the function set_session_cookie(response, test_session_id) # Verify cookie was set self.assertIn(SESSION_COOKIE_NAME, response.cookies) - self.assertEqual(response.cookies[SESSION_COOKIE_NAME]["value"], test_session_id) - self.assertTrue(response.cookies[SESSION_COOKIE_NAME]["httponly"]) # True to prevent JavaScript access for enhanced security against XSS + self.assertEqual( + response.cookies[SESSION_COOKIE_NAME]["value"], test_session_id + ) + self.assertTrue( + response.cookies[SESSION_COOKIE_NAME]["httponly"] + ) # True to prevent JavaScript access for enhanced security against XSS self.assertEqual(response.cookies[SESSION_COOKIE_NAME]["samesite"], "lax") self.assertEqual(response.cookies[SESSION_COOKIE_NAME]["path"], "/") - + def test_session_manager(self): """Test SessionManager gets or generates a session ID""" # Test with existing session test_session_id = str(uuid.uuid4()) request = MockRequest(cookies={SESSION_COOKIE_NAME: test_session_id}) response = MockResponse() - + manager = SessionManager(request, response) self.assertEqual(manager.session_id, test_session_id) - + # Test with new session request = MockRequest() response = MockResponse() - + manager = SessionManager(request, response) session_id = manager.session_id - + # Verify ID was generated and cookie was set self.assertIsNotNone(session_id) try: @@ -126,46 +131,53 @@ def test_session_manager(self): self.assertEqual(str(uuid_obj), session_id) except ValueError: self.fail("SessionManager did not produce a valid UUID") - + self.assertIn(SESSION_COOKIE_NAME, response.cookies) self.assertEqual(response.cookies[SESSION_COOKIE_NAME]["value"], session_id) - + def test_session_manager_clear_session(self): """Test that SessionManager.clear_session removes the cookie""" request = MockRequest() response = MockResponse() - + manager = SessionManager(request, response) manager.clear_session() - + # Verify cookie was deleted self.assertIn(SESSION_COOKIE_NAME, response._deleted_cookies) - self.assertTrue(response._deleted_cookies[SESSION_COOKIE_NAME]["httponly"]) # True to match set_session_cookie + self.assertTrue( + response._deleted_cookies[SESSION_COOKIE_NAME]["httponly"] + ) # True to match set_session_cookie self.assertEqual(response._deleted_cookies[SESSION_COOKIE_NAME]["path"], "/") + if __name__ == "__main__": import asyncio - + print("Running tests for session_manager.py") print(f"App path: {app_path}") print(f"sys.path: {sys.path}") - + # Create a test suite with only synchronous tests sync_suite = unittest.TestSuite() - sync_methods = [m for m in dir(TestSessionManager) if m.startswith('test_') and not m.startswith('test_get_session_id')] + sync_methods = [ + m + for m in dir(TestSessionManager) + if m.startswith("test_") and not m.startswith("test_get_session_id") + ] for method in sync_methods: print(f"Adding test method: {method}") sync_suite.addTest(TestSessionManager(method)) - + # Run synchronous tests print("Running synchronous tests...") runner = unittest.TextTestRunner(verbosity=2) runner.run(sync_suite) - + # Run async tests manually print("\nRunning asynchronous tests...") test_case = TestSessionManager() - + # Run test_get_session_id_existing print("\ntest_get_session_id_existing:") try: @@ -174,8 +186,9 @@ def test_session_manager_clear_session(self): except Exception as e: print(f"ERROR: {e}") import traceback + traceback.print_exc() - + # Run test_get_session_id_new print("\ntest_get_session_id_new:") try: @@ -184,4 +197,5 @@ def test_session_manager_clear_session(self): except Exception as e: print(f"ERROR: {e}") import traceback + traceback.print_exc() diff --git a/app/tests/unittest_session_store.py b/app/tests/unittest_session_store.py index 9e877e6..636b4cb 100644 --- a/app/tests/unittest_session_store.py +++ b/app/tests/unittest_session_store.py @@ -23,64 +23,69 @@ # Import application code from storage.session_store import SessionStore, SessionImage, SessionAnnotation + class TestSessionStore(unittest.TestCase): """Tests for SessionStore functionality""" - + def setUp(self): """Set up before each test""" self.store = SessionStore() self.session_id = str(uuid.uuid4()) - + def test_session_creation(self): """Test creating a session""" self.store.create_session(self.session_id) self.assertIn(self.session_id, self.store.sessions) self.assertIn("images", self.store.sessions[self.session_id]) self.assertIn("annotations", self.store.sessions[self.session_id]) - + def test_add_image(self): """Test adding an image to a session""" image = self.store.add_image( session_id=self.session_id, file_name="test.jpg", file_path="uploads/test.jpg", - resolution="1024x768" + resolution="1024x768", ) - + self.assertIn(self.session_id, self.store.sessions) self.assertIn(image.image_id, self.store.sessions[self.session_id]["images"]) - + # Test getting images images = self.store.get_images(self.session_id) self.assertEqual(len(images), 1) self.assertEqual(images[0].file_name, "test.jpg") - + def test_add_annotation(self): """Test adding an annotation to a session""" # Add an image first image = self.store.add_image( session_id=self.session_id, file_name="test.jpg", - file_path="uploads/test.jpg" + file_path="uploads/test.jpg", ) - + # Add an annotation annotation = self.store.add_annotation( session_id=self.session_id, image_id=image.image_id, - file_path="annotations/test.json" + file_path="annotations/test.json", ) - + self.assertIsNotNone(annotation) - self.assertIn(annotation.annotation_id, self.store.sessions[self.session_id]["annotations"]) - + self.assertIn( + annotation.annotation_id, + self.store.sessions[self.session_id]["annotations"], + ) + # Test getting annotations annotations = self.store.get_annotations(self.session_id) self.assertEqual(len(annotations), 1) self.assertEqual(annotations[0].file_path, "annotations/test.json") + if __name__ == "__main__": print("Running tests...") print(f"App path: {app_path}") print(f"sys.path: {sys.path}") - unittest.main(argv=['first-arg', '-v']) + unittest.main(argv=["first-arg", "-v"]) diff --git a/app/utils/image_processing.py b/app/utils/image_processing.py index 96e378a..8cca334 100644 --- a/app/utils/image_processing.py +++ b/app/utils/image_processing.py @@ -9,57 +9,62 @@ logger = logging.getLogger(__name__) # Determine if we're running in Docker or locally -in_docker = os.path.exists('/.dockerenv') +in_docker = os.path.exists("/.dockerenv") # Create upload directory with the appropriate path if in_docker: UPLOAD_DIR = Path("/app/uploads") else: # For local development, use a path relative to the project root - UPLOAD_DIR = Path(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "uploads")) + UPLOAD_DIR = Path( + os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "uploads" + ) + ) UPLOAD_DIR.mkdir(exist_ok=True) + async def save_upload_file(file: UploadFile) -> dict: """Save an uploaded file to the upload directory and convert TIFF to PNG if needed.""" # Generate unique filename file_extension = os.path.splitext(file.filename)[1].lower() temp_filename = f"{uuid.uuid4()}{file_extension}" - + # Create temporary file path temp_file_path = UPLOAD_DIR / temp_filename - + # Save the original file temporarily contents = await file.read() with open(temp_file_path, "wb") as f: f.write(contents) - + # Check if we need to convert TIFF to PNG for browser compatibility final_file_path = temp_file_path final_filename = temp_filename - - if file_extension in ['.tif', '.tiff']: + + if file_extension in [".tif", ".tiff"]: # Convert TIFF to PNG for browser compatibility png_filename = f"{uuid.uuid4()}.png" png_file_path = UPLOAD_DIR / png_filename - + try: with Image.open(temp_file_path) as img: # Convert to RGB if necessary (some TIFFs might be in different color modes) - if img.mode not in ('RGB', 'RGBA'): - img = img.convert('RGB') + if img.mode not in ("RGB", "RGBA"): + img = img.convert("RGB") # Save as PNG - img.save(png_file_path, 'PNG') - + img.save(png_file_path, "PNG") + # Remove the temporary TIFF file os.remove(temp_file_path) - # Use the PNG file as the final file + # Use the PNG file as the final file final_file_path = png_file_path final_filename = png_filename except Exception as e: # If conversion fails, keep the original TIFF file logger.warning(f"Failed to convert TIFF to PNG: {e}") - + # Get image dimensions and resolution resolution = None try: @@ -67,10 +72,10 @@ async def save_upload_file(file: UploadFile) -> dict: resolution = f"{img.width}x{img.height}" except Exception: # Not a valid image or PIL cannot read it - pass + pass # Get file size file_size = os.path.getsize(final_file_path) - + # Store only the filename for the path to make it work in both Docker and local environments # This will be served from the /uploads/ route return { @@ -79,11 +84,12 @@ async def save_upload_file(file: UploadFile) -> dict: "size": file_size, "content_type": file.content_type, "path": f"uploads/{final_filename}", # Use relative path for consistent access - "resolution": resolution + "resolution": resolution, } + def validate_image_file(file: UploadFile) -> bool: """Validate if the file is a supported image format.""" content_type = file.content_type valid_types = ["image/jpeg", "image/png", "image/tiff", "image/geotiff"] - return content_type in valid_types \ No newline at end of file + return content_type in valid_types diff --git a/app/utils/sam_model.py b/app/utils/sam_model.py index 201003e..ff0face 100644 --- a/app/utils/sam_model.py +++ b/app/utils/sam_model.py @@ -11,12 +11,15 @@ # Set up logging for SAM model logger = logging.getLogger(__name__) + class SAMSegmenter: - def __init__(self): # Enhanced GPU detection and setup + def __init__(self): # Enhanced GPU detection and setup if torch.cuda.is_available(): - self.device = torch.device('cuda') + self.device = torch.device("cuda") logger.info(f"CUDA available! Using GPU: {torch.cuda.get_device_name(0)}") - logger.info(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") + logger.info( + f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB" + ) # Set optimal GPU settings for SAM real-time performance torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -24,36 +27,40 @@ def __init__(self): # Enhanced GPU detection and setup # Enable memory optimization torch.cuda.empty_cache() else: - self.device = torch.device('cpu') + self.device = torch.device("cpu") logger.warning("CUDA not available, using CPU (will be slower)") # CPU optimizations torch.set_num_threads(4) # Limit CPU threads for better responsiveness - + logger.info(f"SAM Model will run on: {self.device}") # Check if running in Docker or locally - in_docker = os.path.exists('/.dockerenv') + in_docker = os.path.exists("/.dockerenv") base_path = Path("/app") if in_docker else Path(".") - + self.sam_checkpoint = str(base_path / "models/sam_vit_h_4b8939.pth") - self.model_type = "vit_h" + self.model_type = "vit_h" if not Path(self.sam_checkpoint).exists(): - raise FileNotFoundError(f"SAM checkpoint not found at {self.sam_checkpoint}. Please download it from https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth") - + raise FileNotFoundError( + f"SAM checkpoint not found at {self.sam_checkpoint}. Please download it from https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" + ) + logger.info(f"Loading SAM model from: {self.sam_checkpoint}") self.sam = sam_model_registry[self.model_type](checkpoint=self.sam_checkpoint) - + # Move to device and optimize for inference self.sam.to(device=self.device) # Note: Removed half-precision to avoid dtype mismatch issues # if self.device.type == 'cuda': # self.sam = self.sam.half() # logger.info("Using half-precision (FP16) for faster GPU inference") - + self.predictor = SamPredictor(self.sam) logger.info("SAM model loaded successfully") - + # Cache for storing image embeddings and masks self.cache: Dict[str, Dict] = {} - self.current_image_path = None # Add thread lock to prevent concurrent access issues + self.current_image_path = ( + None # Add thread lock to prevent concurrent access issues + ) self._lock = threading.Lock() logger.info("Thread synchronization enabled for multi-image processing") @@ -64,39 +71,45 @@ def set_image(self, image_path): image = cv2.imread(image_path) if image is None: raise ValueError(f"Could not load image from {image_path}") - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Check if we have cached embeddings for this image + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # Check if we have cached embeddings for this image if image_path in self.cache: logger.debug(f"Found cached embeddings for {Path(image_path).name}") - + # Only re-set the image if it's different from the current one if self.current_image_path != image_path: - logger.debug(f"Re-setting predictor to cached image: {Path(image_path).name}") + logger.debug( + f"Re-setting predictor to cached image: {Path(image_path).name}" + ) with torch.no_grad(): self.predictor.set_image(image) self.current_image_path = image_path self._last_set_image = image_path else: - logger.debug(f"Image already loaded - instant segmentation ready!") - - return self.cache[image_path]['image_size'] - + logger.debug( + f"Image already loaded in predictor - embeddings cached" + ) + + return self.cache[image_path]["image_size"] + logger.info(f"Loading and processing new image: {Path(image_path).name}") logger.debug(f"Image size: {image.shape[1]}x{image.shape[0]} pixels") - # Generate embeddings on GPU (this is the heavy computation) + # Generate embeddings on GPU (this is the heavy computation) with torch.no_grad(): # Disable gradients for faster inference self.predictor.set_image(image) - + logger.debug(f"Image embeddings generated on {self.device}") - + # Store in cache self.cache[image_path] = { - 'image_size': image.shape[:2], # (height, width) - 'masks': {}, # Will store generated masks - 'embeddings': None # This is implicitly stored in the predictor + "image_size": image.shape[:2], # (height, width) + "masks": {}, # Will store generated masks + "embeddings": None, # This is implicitly stored in the predictor } self.current_image_path = image_path self._last_set_image = image_path - + return image.shape[:2] # Return height, width def preprocess_image(self, image_path): @@ -104,32 +117,38 @@ def preprocess_image(self, image_path): with self._lock: # Check if we've already processed this image if image_path in self.cache: - logger.debug(f"Image {Path(image_path).name} already has cached embeddings") + logger.debug( + f"Image {Path(image_path).name} already has cached embeddings" + ) return True - + try: - logger.info(f"Pre-processing image for faster segmentation: {Path(image_path).name}") - + logger.info( + f"Pre-processing image for faster segmentation: {Path(image_path).name}" + ) + # Load and process the image image = cv2.imread(image_path) if image is None: raise ValueError(f"Could not load image from {image_path}") - + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - + # Generate embeddings on GPU using the main predictor with torch.no_grad(): - self.predictor.set_image(image) # Store in cache and set as current image + self.predictor.set_image( + image + ) # Store in cache and set as current image self.cache[image_path] = { - 'image_size': image.shape[:2], # (height, width) - 'masks': {}, # Will store generated masks - 'embeddings': None # This is implicitly stored in the predictor + "image_size": image.shape[:2], # (height, width) + "masks": {}, # Will store generated masks + "embeddings": None, # This is implicitly stored in the predictor } self.current_image_path = image_path - self._last_set_image = image_path + self._last_set_image = image_path logger.info(f"Pre-processing complete for {Path(image_path).name}") return True - + except Exception as e: logger.error(f"Error pre-processing image {Path(image_path).name}: {e}") return False @@ -138,64 +157,83 @@ def predict_from_point(self, point_coords, point_labels=None): """Generate mask from a point prompt, using cache if available with thread safety""" # First check cache without lock for performance point_key = tuple(point_coords) - if (self.current_image_path and - self.current_image_path in self.cache and - point_key in self.cache[self.current_image_path]['masks']): - return self.cache[self.current_image_path]['masks'][point_key] - + if ( + self.current_image_path + and self.current_image_path in self.cache + and point_key in self.cache[self.current_image_path]["masks"] + ): + return self.cache[self.current_image_path]["masks"][point_key] + with self._lock: if self.current_image_path is None: - raise ValueError("No image set for segmentation. Call set_image() first.") - + raise ValueError( + "No image set for segmentation. Call set_image() first." + ) + # Ensure we have the correct image set in the predictor if self.current_image_path not in self.cache: - raise ValueError(f"Image {self.current_image_path} not found in cache. Call set_image() first.") - + raise ValueError( + f"Image {self.current_image_path} not found in cache. Call set_image() first." + ) + # Re-set the image if it's not the current one (safety check) # This ensures the predictor has the correct embeddings - if hasattr(self, '_last_set_image') and self._last_set_image != self.current_image_path: - logger.debug(f"Re-setting image in predictor for thread safety: {Path(self.current_image_path).name}") + if ( + hasattr(self, "_last_set_image") + and self._last_set_image != self.current_image_path + ): + logger.debug( + f"Re-setting image in predictor for thread safety: {Path(self.current_image_path).name}" + ) image = cv2.imread(self.current_image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) with torch.no_grad(): self.predictor.set_image(image) - + self._last_set_image = self.current_image_path - # Double-check cache (in case another thread added it) - if point_key in self.cache[self.current_image_path]['masks']: - logger.debug(f"Using cached mask for point {point_coords} (added by another thread)") - return self.cache[self.current_image_path]['masks'][point_key] - - logger.debug(f"Generating new mask for point {point_coords} on {self.device}") - + # Double-check cache (in case another thread added it) + if point_key in self.cache[self.current_image_path]["masks"]: + logger.debug( + f"Using cached mask for point {point_coords} (added by another thread)" + ) + return self.cache[self.current_image_path]["masks"][point_key] + + logger.debug( + f"Generating new mask for point {point_coords} on {self.device}" + ) + try: # Generate new mask with performance optimizations point_coords_array = np.array([point_coords]) if point_labels is None: point_labels = np.array([1]) # 1 indicates a foreground point - + # Ensure inputs are the right data type for GPU point_coords_array = point_coords_array.astype(np.float32) point_labels = point_labels.astype(np.int32) - + # Use GPU optimization if available with torch.no_grad(): # Disable gradient computation for faster inference masks, scores, _ = self.predictor.predict( point_coords=point_coords_array, point_labels=point_labels, - multimask_output=True + multimask_output=True, ) - + best_mask_idx = np.argmax(scores) - mask = masks[best_mask_idx].astype(np.uint8) * 255 # Convert to 8-bit mask - - logger.debug(f"Mask generated successfully (confidence: {scores[best_mask_idx]:.3f})") - + mask = ( + masks[best_mask_idx].astype(np.uint8) * 255 + ) # Convert to 8-bit mask + + logger.debug( + f"Mask generated successfully (confidence: {scores[best_mask_idx]:.3f})" + ) + # Cache the result - self.cache[self.current_image_path]['masks'][point_key] = mask - + self.cache[self.current_image_path]["masks"][point_key] = mask + return mask - + except Exception as e: logger.error(f"Error generating mask: {e}") # Clear CUDA cache if error occurs @@ -208,17 +246,17 @@ def mask_to_polygon(self, mask): contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: return None - + largest_contour = max(contours, key=cv2.contourArea) polygon = largest_contour.squeeze().tolist() - + if not isinstance(polygon[0], list): polygon = [polygon] - + # Get image dimensions for normalization if self.current_image_path and self.current_image_path in self.cache: - height, width = self.cache[self.current_image_path]['image_size'] - + height, width = self.cache[self.current_image_path]["image_size"] + # Normalize coordinates to 0-1 range normalized_polygon = [] for point in polygon: @@ -226,9 +264,9 @@ def mask_to_polygon(self, mask): normalized_x = point[0] / width normalized_y = point[1] / height normalized_polygon.append([normalized_x, normalized_y]) - + return normalized_polygon - + return polygon def clear_cache(self, image_path=None): diff --git a/docker-compose.yml b/docker-compose.yml index 766f04a..62dd5ea 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,7 +4,7 @@ services: context: . dockerfile: Dockerfile.app ports: - - "8000:8000" + - '8000:8000' volumes: - ./app:/app/app - ./uploads:/app/uploads @@ -12,13 +12,22 @@ services: environment: - ENVIRONMENT=development - SESSION_SECRET=your_session_secret_key_here - command: ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"] + command: + [ + 'uvicorn', + 'app.main:app', + '--host', + '0.0.0.0', + '--port', + '8000', + '--reload', + ] frontend: build: context: . dockerfile: Dockerfile.web ports: - - "8080:8080" + - '8080:8080' volumes: - ./web:/web working_dir: /web diff --git a/docs/sat-annotator-desseration.pdf b/docs/sat-annotator-desseration.pdf new file mode 100644 index 0000000..60ef2c2 Binary files /dev/null and b/docs/sat-annotator-desseration.pdf differ diff --git a/docs/sat-annotator.pdf b/docs/sat-annotator.pdf new file mode 100644 index 0000000..3651e7b Binary files /dev/null and b/docs/sat-annotator.pdf differ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..578ce40 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,5 @@ +[tool.black] +line-length = 88 +target-version = ['py38'] +include = '\.pyi?$' +skip-string-normalization = false diff --git a/web/index.html b/web/index.html index 5f78a8b..3d27478 100644 --- a/web/index.html +++ b/web/index.html @@ -1,248 +1,408 @@ - + -
- + + +