diff --git a/uvq1p5_pytorch/utils/uvq1p5.py b/uvq1p5_pytorch/utils/uvq1p5.py index cd58717..a398fc1 100644 --- a/uvq1p5_pytorch/utils/uvq1p5.py +++ b/uvq1p5_pytorch/utils/uvq1p5.py @@ -109,6 +109,7 @@ def infer( fps: int = 1, orig_fps: float | None = None, ffmpeg_path: str = "ffmpeg", + device: str = "cpu", ) -> dict[str, Any]: """Runs UVQ 1.5 inference on a video file. @@ -119,6 +120,8 @@ def infer( fps: Frames per second to sample for inference. orig_fps: Original frames per second of the video, used for frame index calculation. + ffmpeg_path: Path to ffmpeg executable. + device: Device to run inference on (e.g., 'cpu' or 'cuda'). Returns: A dictionary containing the overall UVQ 1.5 score, per-frame scores, @@ -141,16 +144,16 @@ def infer( predictions = [] with torch.inference_mode(): for i in range(0, num_frames, batch_size): - batch = video_1080p[i : i + batch_size] + batch = video_1080p[i : i + batch_size].to(device) prediction_batch = self.uvq1p5_core(batch) predictions.append(prediction_batch) prediction = torch.cat(predictions, dim=0) else: with torch.inference_mode(): - prediction = self.uvq1p5_core(video_1080p) + prediction = self.uvq1p5_core(video_1080p.to(device)) video_score = torch.mean(prediction).item() - frame_scores = prediction.numpy().flatten().tolist() + frame_scores = prediction.cpu().numpy().flatten().tolist() if orig_fps: frame_indices = [ diff --git a/uvq_inference.py b/uvq_inference.py index d4f74f4..0031223 100644 --- a/uvq_inference.py +++ b/uvq_inference.py @@ -21,6 +21,7 @@ import json from typing import Any import tqdm +import torch from utils import probe @@ -96,6 +97,7 @@ def run_batch_inference(args): fps=fps_to_use, orig_fps=orig_fps, ffmpeg_path=args.ffmpeg_path, + device=args.device, ) score = results["uvq1p5_score"] elif args.model_version == "1.0": @@ -167,9 +169,12 @@ def run_single_inference(args): fps=fps, orig_fps=orig_fps, ffmpeg_path=args.ffmpeg_path, + device=args.device, ) elif args.model_version == "1.0": uvq_inference = uvq1p0.UVQ1p0() + if args.device == "cuda": + uvq_inference.cuda() # UVQ1.0 infer doesn't support fps or padding args. # It uses its own video reader, which has fixed 5 fps sampling. # If fps is passed for 1.0, it will be ignored by 1.0 infer method. @@ -201,6 +206,10 @@ def main(): parser = setup_parser() args = parser.parse_args() + if args.device == "cuda" and not torch.cuda.is_available(): + print("Error: CUDA is not available, please use --device cpu") + return + if args.input.endswith(".txt"): run_batch_inference(args) else: @@ -257,6 +266,7 @@ def setup_parser(): "--device", type=str, default="cpu", + choices=["cpu", "cuda"], help="Device to run inference on (e.g., 'cpu' or 'cuda').", ) parser.add_argument(