diff --git a/README.md b/README.md index f4e73bb..8b5be79 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,7 @@ If you provide a path to a `.txt` file instead of a video file to the `input` ar The `.txt` file should contain one video path per line. In batch mode, you must specify an output file using `--output`, which will be populated with `video_basename,score` for each video. The `--output_all_stats` flag is ignored in batch mode. +If you need all statistics in batch mode, use `--batch_json_output` to write the results as json array with the complete statistics and a `video_name` key to identify the source video. For example, if `video_list.txt` contains: ``` @@ -105,12 +106,31 @@ This will create `batch_results.txt` with content like: Gaming_1080P-0ce6_orig.mp4,3.880362033843994 ``` +To obtain all statistics in JSON format, use the `--batch_json_output` flag: +```bash +python uvq_inference.py video_list.txt --model_version 1.5 --batch_json_output --output batch_results.txt +``` + +This will create `batch_results.txt` with content like: +```json +[ + { + "uvq1p5_score": 3.880362033843994, + "per_frame_scores": [4.021927833557129, 4.013788223266602, 4.110747814178467, 4.142043113708496, 4.1536993980407715, 4.147506237030029, 4.149798393249512, 4.149064064025879, 4.149083137512207, 4.133814811706543, 3.5636682510375977, 3.8045108318328857, 3.630220413208008, 3.6495614051818848, 3.6260201930999756, 3.6136975288391113, 3.5050578117370605, 3.7031033039093018, 3.676196575164795, 3.663726806640625], + "frame_indices": [0, 30, 60, 90, 120, 150, 180, 210, 240, 270, 300, 330, 360, 390, 420, 450, 480, 510, 540, 570], + "video_name": "Gaming_1080P-0ce6_orig.mp4" + } +] +``` + #### Optional Arguments * `--transpose`: Transpose the video before processing (e.g., for portrait videos). * `--output OUTPUT`: Path to save the output scores to a file. Scores will be saved in JSON format. * `--device DEVICE`: Device to run inference on (e.g., `cpu` or `cuda`). * `--fps FPS`: (UVQ 1.5 only) Frames per second to sample. Default is 1. Use -1 to sample all frames. +* `--chunk_size_frames FRAMES`: (UVQ 1.5 only) Frames to process at once during inference. If you run out of memory reduce this number. Default is 16. +* `--batch_json_output`: If specified, outputs batch results in JSON format including per frame scores instead of just overall mean score. * `--output_all_stats`: If specified, print all stats in JSON format to stdout. * `--ffmpeg_path`: Path to FFmpeg executable (default: `ffmpeg`). * `--ffprobe_path`: Path to FFprobe executable (default: `ffprobe`). diff --git a/utils/probe.py b/utils/probe.py index ec7b485..80991fd 100644 --- a/utils/probe.py +++ b/utils/probe.py @@ -39,7 +39,7 @@ def get_dimensions( result = subprocess.run( cmd, stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, + stderr=subprocess.PIPE, check=True, text=True, ) @@ -72,7 +72,7 @@ def get_nb_frames(video_path, ffprobe_path="ffprobe") -> int | None: result = subprocess.run( cmd, stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, + stderr=subprocess.PIPE, check=True, text=True, ) @@ -103,7 +103,7 @@ def get_r_frame_rate(video_path, ffprobe_path="ffprobe") -> int | None: result = subprocess.run( cmd, stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, + stderr=subprocess.PIPE, check=True, text=True, ) @@ -137,7 +137,7 @@ def get_video_duration(video_path, ffprobe_path="ffprobe") -> float | None: ] try: result = subprocess.run( - cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, check=True + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True ) duration = float(result.stdout) return duration diff --git a/utils/video_reader.py b/utils/video_reader.py index 50ed956..33e89d8 100644 --- a/utils/video_reader.py +++ b/utils/video_reader.py @@ -157,7 +157,7 @@ def load_video_1p0( return video, video_resized -def load_video_1p5( +def yield_video_1p5_chunks( filepath: str, video_length: int, transpose: bool = False, @@ -165,8 +165,9 @@ def load_video_1p5( video_height: int = 1080, video_width: int = 1920, ffmpeg_path: str = "ffmpeg", -) -> tuple[np.ndarray, int]: - """Load input video for UVQ 1.5. + chunk_size_frames: int = 16, +): + """Yields chunks of the video as numpy arrays. Args: filepath: Path to the video file. @@ -175,10 +176,12 @@ def load_video_1p5( video_fps: Frames per second to sample for inference. video_height: Height of the video to resize to. video_width: Width of the video to resize to. - - Returns: - A tuple containing the loaded video as a numpy array and the number of - real frames. + chunk_size_frames: Number of frames to yield per chunk. + + Yields: + A tuple containing: + - A chunk of the loaded video as a numpy array (batch, 1, h, w, c). + - The number of real frames in the entire video (only available once determined). """ video_channel = 3 # Rotate video if requested @@ -208,41 +211,114 @@ def load_video_1p5( raise error # For video, the entire video is divided into 1s chunks in 5 fps - with open(temp_filename, "rb") as rgb_file: - single_frame_size = video_width * video_height * video_channel - full_decode_size = video_length * video_fps * single_frame_size - rgb_file.seek(0, 2) - rgb_file_size = rgb_file.tell() - rgb_file.seek(0) - num_real_frames = rgb_file_size // single_frame_size - assert rgb_file_size >= single_frame_size, ( - f"Decoding failed to output a single frame: {rgb_file_size} <" - f" {single_frame_size}" - ) - if rgb_file_size < full_decode_size: - logging.warning( - "Decoding may be truncated: %d bytes (%d frames) < %d bytes (%d" - " frames), or video length (%ds) may be too incorrect", - rgb_file_size, - rgb_file_size / single_frame_size, - full_decode_size, - full_decode_size / single_frame_size, - video_length, + try: + with open(temp_filename, "rb") as rgb_file: + single_frame_size = video_width * video_height * video_channel + full_decode_size = video_length * video_fps * single_frame_size + rgb_file.seek(0, 2) + rgb_file_size = rgb_file.tell() + rgb_file.seek(0) + num_real_frames = rgb_file_size // single_frame_size + assert rgb_file_size >= single_frame_size, ( + f"Decoding failed to output a single frame: {rgb_file_size} <" + f" {single_frame_size}" ) - - rgb = _extend_array(bytearray(rgb_file.read()), full_decode_size) - video = ( - np.reshape( - np.frombuffer(rgb, "uint8"), - (video_length, int(video_fps), video_height, video_width, 3), + + if rgb_file_size < full_decode_size: + logging.warning( + "Decoding may be truncated: %d bytes (%d frames) < %d bytes (%d" + " frames), or video length (%ds) may be too incorrect", + rgb_file_size, + rgb_file_size / single_frame_size, + full_decode_size, + full_decode_size / single_frame_size, + video_length, ) - / 255.0 - - 0.5 - ) * 2 - # Delete temp files - os.close(fd) - os.remove(temp_filename) - logging.info("Load %s done successfully.", filepath) + chunk_size_bytes = chunk_size_frames * single_frame_size + + # Read and yield chunks + read_frames = 0 + while read_frames < num_real_frames: + chunk_bytes = rgb_file.read(chunk_size_bytes) + if not chunk_bytes: + break + + # Handle partial chunks (e.g. end of file) + # We read len(chunk_bytes). We simply divide by single_frame_size. + # If there are leftovers < single_frame_size (partial frame), we ignore them. + current_chunk_frames = len(chunk_bytes) // single_frame_size + + if current_chunk_frames == 0: + break + + # Truncate to valid frames bytes + valid_bytes = current_chunk_frames * single_frame_size + if len(chunk_bytes) > valid_bytes: + logging.warning("Read partial frame at end of file, truncating.") + chunk_bytes = chunk_bytes[:valid_bytes] + + if current_chunk_frames == 0: + break + + video_chunk = ( + np.reshape( + np.frombuffer(chunk_bytes, "uint8"), + (current_chunk_frames, 1, video_height, video_width, 3), + ).astype(np.float32) + / 255.0 + - 0.5 + ) * 2 + + yield video_chunk, num_real_frames + read_frames += current_chunk_frames + + finally: + # Delete temp files + os.close(fd) + if os.path.exists(temp_filename): + os.remove(temp_filename) + logging.info("Load %s done successfully.", filepath) + + +def load_video_1p5( + filepath: str, + video_length: int, + transpose: bool = False, + video_fps: int = 1, + video_height: int = 1080, + video_width: int = 1920, + ffmpeg_path: str = "ffmpeg", +) -> tuple[np.ndarray, int]: + """Load input video for UVQ 1.5. + + Note: This loads the entire video into memory. Use yield_video_1p5_chunks for large videos. + + Args: + filepath: Path to the video file. + video_length: Length of the video in seconds. + transpose: Whether to transpose the video. + video_fps: Frames per second to sample for inference. + video_height: Height of the video to resize to. + video_width: Width of the video to resize to. + + Returns: + A tuple containing the loaded video as a numpy array and the number of + real frames. + """ + chunks = [] + num_real_frames = 0 + for chunk, n_frames in yield_video_1p5_chunks( + filepath, video_length, transpose, video_fps, video_height, video_width, ffmpeg_path + ): + chunks.append(chunk) + num_real_frames = n_frames + + if not chunks: + return np.array([]), 0 + # Reconstruct the full video array. + # Shape will be (TotalFrames, 1, H, W, 3). + video = np.concatenate(chunks, axis=0) + return video, num_real_frames diff --git a/uvq1p5_pytorch/utils/uvq1p5.py b/uvq1p5_pytorch/utils/uvq1p5.py index cd58717..1813cf8 100644 --- a/uvq1p5_pytorch/utils/uvq1p5.py +++ b/uvq1p5_pytorch/utils/uvq1p5.py @@ -109,6 +109,7 @@ def infer( fps: int = 1, orig_fps: float | None = None, ffmpeg_path: str = "ffmpeg", + chunk_size_frames: int = 16, ) -> dict[str, Any]: """Runs UVQ 1.5 inference on a video file. @@ -119,38 +120,40 @@ def infer( fps: Frames per second to sample for inference. orig_fps: Original frames per second of the video, used for frame index calculation. + chunk_size_frames: Number of frames to process in each chunk during inference. Returns: A dictionary containing the overall UVQ 1.5 score, per-frame scores, and frame indices. """ - video_1080p, _ = self.load_video( + + predictions = [] + + # Use generator to process video in chunks + for video_chunk, _ in video_reader.yield_video_1p5_chunks( video_filename, video_length, transpose, - fps=fps, + video_fps=fps, ffmpeg_path=ffmpeg_path, - ) - num_seconds, read_fps, c, h, w = video_1080p.shape - # reshape to (num_seconds * fps, 1, 3, h, w) to process all frames - num_frames = num_seconds * read_fps - video_1080p = video_1080p.reshape(num_frames, 1, c, h, w) - - batch_size = 24 - if num_frames > batch_size: # if video is longer than batch size, run inference in batches to avoid OOM - predictions = [] - with torch.inference_mode(): - for i in range(0, num_frames, batch_size): - batch = video_1080p[i : i + batch_size] - prediction_batch = self.uvq1p5_core(batch) - predictions.append(prediction_batch) - prediction = torch.cat(predictions, dim=0) - else: - with torch.inference_mode(): - prediction = self.uvq1p5_core(video_1080p) + chunk_size_frames=chunk_size_frames + ): + + video_chunk_torch = torch.from_numpy(video_chunk).float() + video_chunk_torch = video_chunk_torch.permute(0, 1, 4, 2, 3) + + with torch.inference_mode(): + batch = video_chunk_torch.to(next(self.parameters()).device) + prediction_batch = self.uvq1p5_core(batch) + predictions.append(prediction_batch) + + if not predictions: + raise ValueError(f"No frames were read from {video_filename}") + + prediction = torch.cat(predictions, dim=0) video_score = torch.mean(prediction).item() - frame_scores = prediction.numpy().flatten().tolist() + frame_scores = prediction.cpu().numpy().flatten().tolist() if orig_fps: frame_indices = [ diff --git a/uvq_inference.py b/uvq_inference.py index d4f74f4..06474da 100644 --- a/uvq_inference.py +++ b/uvq_inference.py @@ -96,6 +96,7 @@ def run_batch_inference(args): fps=fps_to_use, orig_fps=orig_fps, ffmpeg_path=args.ffmpeg_path, + chunk_size_frames=args.chunk_size_frames, ) score = results["uvq1p5_score"] elif args.model_version == "1.0": @@ -105,7 +106,12 @@ def run_batch_inference(args): transpose_flag, ) score = float(results["compression_content_distortion"]) - results_to_write.append(f"{os.path.basename(video_path)},{score}") + + if args.batch_json_output: + results["video_name"] = os.path.basename(video_path) + results_to_write.append(results) + else: + results_to_write.append(f"{os.path.basename(video_path)},{score}") except Exception as e: print(f"Error processing {video_path}: {e}") @@ -114,9 +120,13 @@ def run_batch_inference(args): output_dir = os.path.dirname(args.output) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) - with open(args.output, "w") as f_out: - for line in results_to_write: - f_out.write(line + "\n") + + if args.batch_json_output: + write_dict_to_file(results_to_write, args.output) + else: + with open(args.output, "w") as f_out: + for line in results_to_write: + f_out.write(line + "\n") print(f"Batch inference complete. Results saved to {args.output}") except IOError as e: print(f"Error writing to output file {args.output}: {e}") @@ -167,6 +177,7 @@ def run_single_inference(args): fps=fps, orig_fps=orig_fps, ffmpeg_path=args.ffmpeg_path, + chunk_size_frames=args.chunk_size_frames, ) elif args.model_version == "1.0": uvq_inference = uvq1p0.UVQ1p0() @@ -266,6 +277,18 @@ def setup_parser(): help="Frames per second to sample for UVQ1.5. -1 to sample all frames." " Ignored for UVQ1.0.", ) + parser.add_argument( + "--chunk_size_frames", + type=int, + default=16, + help="Number of frames to process in each chunk during inference.", + ) + parser.add_argument( + "--batch_json_output", + action="store_true", + help="If specified, outputs batch results in JSON format including per " \ + "frame scores instead of just overall mean score.", + ) parser.add_argument( "--output_all_stats", action="store_true",