diff --git a/engine.py b/engine.py index 6686176..1998780 100644 --- a/engine.py +++ b/engine.py @@ -399,6 +399,53 @@ def synthesize( return None, None +def unload_model() -> bool: + """ + Unloads the current model and releases all GPU memory. + Does NOT reload the model - use reload_model() for that. + + Returns: + bool: True if the model was unloaded successfully, False otherwise. + """ + global chatterbox_model, MODEL_LOADED, model_device, loaded_model_type, loaded_model_class_name + + logger.info("Initiating model unload sequence...") + + # 1. Unload existing model + if chatterbox_model is not None: + logger.info("Unloading TTS model from memory...") + del chatterbox_model + chatterbox_model = None + + # 2. Reset state flags + MODEL_LOADED = False + model_device = None + loaded_model_type = None + loaded_model_class_name = None + + # 3. Force Python Garbage Collection + gc.collect() + logger.info("Python garbage collection completed.") + + # 4. Clear GPU Cache (CUDA) + if torch.cuda.is_available(): + logger.info("Clearing CUDA cache...") + torch.cuda.empty_cache() + + # 5. Clear GPU Cache (MPS - Apple Silicon) + if torch.backends.mps.is_available(): + try: + torch.mps.empty_cache() + logger.info("Cleared MPS cache.") + except AttributeError: + logger.debug( + "torch.mps.empty_cache() not available in this PyTorch version." + ) + + logger.info("Model unloaded and GPU memory released.") + return True + + def reload_model() -> bool: """ Unloads the current model, clears GPU memory, and reloads the model diff --git a/server.py b/server.py index 761b512..15ed4ed 100644 --- a/server.py +++ b/server.py @@ -585,6 +585,35 @@ async def restart_server_endpoint(): ) +@app.post("/api/unload", tags=["Configuration"]) +async def unload_model_endpoint(): + """ + Unloads the TTS model and releases all CUDA/GPU memory. + The model will need to be reloaded (via /restart_server) before TTS requests can be processed. + """ + logger.info("Request received for /api/unload (Model Unload).") + + try: + success = engine.unload_model() + + if success: + logger.info("Model successfully unloaded and GPU memory released.") + return {"status": "unloaded"} + else: + error_msg = "Model unload failed. Check logs for details." + logger.error(error_msg) + raise HTTPException(status_code=500, detail=error_msg) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Critical error during model unload: {e}", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"Internal server error during model unload: {str(e)}", + ) + + # --- UI Helper API Endpoints --- @app.get("/get_reference_files", response_model=List[str], tags=["UI Helpers"]) async def get_reference_files_api():