From 2f865a15336153b6f1162e6d13325cf1d72bf118 Mon Sep 17 00:00:00 2001 From: rka97 Date: Fri, 12 Dec 2025 01:47:03 +0000 Subject: [PATCH 1/5] ImageNet caching for faster dataset access PyTorch --- .../imagenet_pytorch/workload.py | 111 +++++++++++++++++- algoperf/workloads/ogbg/workload.py | 2 +- scoring/performance_profile.py | 2 +- 3 files changed, 109 insertions(+), 6 deletions(-) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index d5366c60d..07f84975e 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -3,10 +3,13 @@ import contextlib import functools import itertools +import json import math import os import random -from typing import Dict, Iterator, Optional, Tuple +import time +from pathlib import Path +from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union import numpy as np import torch @@ -14,7 +17,11 @@ import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP from torchvision import transforms -from torchvision.datasets.folder import ImageFolder +from torchvision.datasets.folder import ( + IMG_EXTENSIONS, + ImageFolder, + default_loader, +) import algoperf.random_utils as prng from algoperf import data_utils, param_utils, pytorch_utils, spec @@ -28,6 +35,100 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() +class CachedImageFolder(ImageFolder): + """ImageFolder that caches the file listing to avoid repeated filesystem scans.""" + + def __init__( + self, + root: Union[str, Path], + cache_file: Optional[Union[str, Path]] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + loader: Callable[[str], Any] = default_loader, + is_valid_file: Optional[Callable[[str], bool]] = None, + allow_empty: bool = False, + rebuild_cache: bool = False, + cache_build_timeout_minutes: int = 30, + ): + self.root = os.path.expanduser(root) + self.transform = transform + self.target_transform = target_transform + self.loader = loader + self.extensions = IMG_EXTENSIONS if is_valid_file is None else None + + # Default cache location: .cache_index.json in the root directory + if cache_file is None: + cache_file = os.path.join(self.root, '.cache_index.json') + self.cache_file = cache_file + + is_distributed = dist.is_available() and dist.is_initialized() + rank = dist.get_rank() if is_distributed else 0 + + cache_exists = os.path.exists(self.cache_file) + needs_rebuild = rebuild_cache or not cache_exists + + if needs_rebuild: + # We only want one process to build the cache + # and others to wait for it to finish. + if rank == 0: + self._build_and_save_cache(is_valid_file, allow_empty) + if is_distributed: + self._wait_for_cache(timeout_minutes=cache_build_timeout_minutes) + dist.barrier() + + self._load_from_cache() + + self.targets = [s[1] for s in self.samples] + self.imgs = self.samples + + def _wait_for_cache(self, timeout_minutes: int): + """Poll for cache file to exist.""" + timeout_seconds = timeout_minutes * 60 + poll_interval = 5 + elapsed = 0 + + while not os.path.exists(self.cache_file): + if elapsed >= timeout_seconds: + raise TimeoutError( + f'Timed out waiting for cache file after {timeout_minutes} minutes: {self.cache_file}' + ) + time.sleep(poll_interval) + elapsed += poll_interval + + def _load_from_cache(self): + """Load classes and samples from cache file.""" + with open(os.path.abspath(self.cache_file), 'r') as f: + cache = json.load(f) + self.classes = cache['classes'] + self.class_to_idx = cache['class_to_idx'] + # Convert relative paths back to absolute + self.samples = [ + (os.path.join(self.root, rel_path), idx) + for rel_path, idx in cache['samples'] + ] + + def _build_and_save_cache(self, is_valid_file, allow_empty): + """Scan filesystem, build index, and save to cache.""" + self.classes, self.class_to_idx = self.find_classes(self.root) + self.samples = self.make_dataset( + self.root, + class_to_idx=self.class_to_idx, + extensions=self.extensions, + is_valid_file=is_valid_file, + allow_empty=allow_empty, + ) + + cache = { + 'classes': self.classes, + 'class_to_idx': self.class_to_idx, + 'samples': [ + (os.path.relpath(path, self.root), idx) for path, idx in self.samples + ], + } + with open(os.path.abspath(self.cache_file), 'w') as f: + json.dump(cache, f) + + def imagenet_v2_to_torch( batch: Dict[str, spec.Tensor], ) -> Dict[str, spec.Tensor]: @@ -119,8 +220,10 @@ def _build_dataset( ) folder = 'train' if 'train' in split else 'val' - dataset = ImageFolder( - os.path.join(data_dir, folder), transform=transform_config + dataset = CachedImageFolder( + os.path.join(data_dir, folder), + transform=transform_config, + cache_file='.imagenet_cache_index.json', ) if split == 'eval_train': diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 002576268..771b103a0 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -92,7 +92,7 @@ def max_allowed_runtime_sec(self) -> int: @property def eval_period_time_sec(self) -> int: - return 452 # approx 25 evals + return 452 # approx 25 evals def _build_input_queue( self, diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index b200c6865..043a65791 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -71,7 +71,7 @@ 'wer', 'l1_loss', 'loss', - 'ppl' + 'ppl', ] MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu'] From c2f443bff5d4ffe1a52db63f726cdb0f5e1e3382 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 18 Dec 2025 18:38:04 +0000 Subject: [PATCH 2/5] Refactor ImageNet input pipeline and augmentations --- .../imagenet_resnet/custom_tf_addons.py | 456 +++++++++++++++ .../imagenet_resnet/imagenet_jax/workload.py | 7 +- .../imagenet_pytorch/randaugment.py | 189 ------ .../imagenet_pytorch/workload.py | 221 ++----- .../workloads/imagenet_resnet/imagenet_v2.py | 2 +- .../imagenet_resnet/input_pipeline.py | 455 +++++++++++++++ .../workloads/imagenet_resnet/randaugment.py | 548 ++++++++++++++++++ .../pytorch_nadamw_full_budget.py | 18 +- 8 files changed, 1506 insertions(+), 390 deletions(-) create mode 100644 algoperf/workloads/imagenet_resnet/custom_tf_addons.py delete mode 100644 algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py create mode 100644 algoperf/workloads/imagenet_resnet/input_pipeline.py create mode 100644 algoperf/workloads/imagenet_resnet/randaugment.py diff --git a/algoperf/workloads/imagenet_resnet/custom_tf_addons.py b/algoperf/workloads/imagenet_resnet/custom_tf_addons.py new file mode 100644 index 000000000..53368b384 --- /dev/null +++ b/algoperf/workloads/imagenet_resnet/custom_tf_addons.py @@ -0,0 +1,456 @@ +""" +Note: +The following code is adapted from: +https://github.com/tensorflow/addons/tree/master/tensorflow_addons/image + + +""" + +from typing import List, Optional, Union + +import numpy as np +import tensorflow as tf + +_IMAGE_DTYPES = { + tf.dtypes.uint8, + tf.dtypes.int32, + tf.dtypes.int64, + tf.dtypes.float16, + tf.dtypes.float32, + tf.dtypes.float64, +} + +Number = Union[ + float, + int, + np.float16, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, +] + +TensorLike = Union[ + List[Union[Number, list]], + tuple, + Number, + np.ndarray, + tf.Tensor, + tf.SparseTensor, + tf.Variable, +] + + +def get_ndims(image): + return image.get_shape().ndims or tf.rank(image) + + +def to_4d_image(image): + """Convert 2/3/4D image to 4D image. + + Args: + image: 2/3/4D `Tensor`. + + Returns: + 4D `Tensor` with the same type. + """ + with tf.control_dependencies( + [ + tf.debugging.assert_rank_in( + image, [2, 3, 4], message='`image` must be 2/3/4D tensor' + ) + ] + ): + ndims = image.get_shape().ndims + if ndims is None: + return _dynamic_to_4d_image(image) + elif ndims == 2: + return image[None, :, :, None] + elif ndims == 3: + return image[None, :, :, :] + else: + return image + + +def _dynamic_to_4d_image(image): + shape = tf.shape(image) + original_rank = tf.rank(image) + # 4D image => [N, H, W, C] or [N, C, H, W] + # 3D image => [1, H, W, C] or [1, C, H, W] + # 2D image => [1, H, W, 1] + left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) + right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) + new_shape = tf.concat( + [ + tf.ones(shape=left_pad, dtype=tf.int32), + shape, + tf.ones(shape=right_pad, dtype=tf.int32), + ], + axis=0, + ) + return tf.reshape(image, new_shape) + + +def from_4d_image(image, ndims): + """Convert back to an image with `ndims` rank. + + Args: + image: 4D `Tensor`. + ndims: The original rank of the image. + + Returns: + `ndims`-D `Tensor` with the same type. + """ + with tf.control_dependencies( + [tf.debugging.assert_rank(image, 4, message='`image` must be 4D tensor')] + ): + if isinstance(ndims, tf.Tensor): + return _dynamic_from_4d_image(image, ndims) + elif ndims == 2: + return tf.squeeze(image, [0, 3]) + elif ndims == 3: + return tf.squeeze(image, [0]) + else: + return image + + +def _dynamic_from_4d_image(image, original_rank): + shape = tf.shape(image) + # 4D image <= [N, H, W, C] or [N, C, H, W] + # 3D image <= [1, H, W, C] or [1, C, H, W] + # 2D image <= [1, H, W, 1] + begin = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) + end = 4 - tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) + new_shape = shape[begin:end] + return tf.reshape(image, new_shape) + + +def transform( + images: TensorLike, + transforms: TensorLike, + interpolation: str = 'nearest', + fill_mode: str = 'constant', + output_shape: Optional[list] = None, + name: Optional[str] = None, + fill_value: TensorLike = 0.0, +) -> tf.Tensor: + """Applies the given transform(s) to the image(s). + + Args: + images: A tensor of shape (num_images, num_rows, num_columns, + num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or + (num_rows, num_columns) (HW). + transforms: Projective transform matrix/matrices. A vector of length 8 or + tensor of size N x 8. If one row of transforms is + [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point + `(x, y)` to a transformed *input* point + `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, + where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to + the transform mapping input points to output points. Note that + gradients are not backpropagated into transformation parameters. + interpolation: Interpolation mode. + Supported values: "nearest", "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + output_shape: Output dimesion after the transform, [height, width]. + If None, output is the same size as input image. + + name: The name of the op. + + Returns: + Image(s) with the same type and shape as `images`, with the given + transform(s) applied. Transformed coordinates outside of the input image + will be filled with zeros. + + Raises: + TypeError: If `image` is an invalid type. + ValueError: If output shape is not 1-D int32 Tensor. + """ + with tf.name_scope(name or 'transform'): + image_or_images = tf.convert_to_tensor(images, name='images') + transform_or_transforms = tf.convert_to_tensor( + transforms, name='transforms', dtype=tf.dtypes.float32 + ) + if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: + raise TypeError('Invalid dtype %s.' % image_or_images.dtype) + images = to_4d_image(image_or_images) + original_ndims = get_ndims(image_or_images) + + if output_shape is None: + output_shape = tf.shape(images)[1:3] + + output_shape = tf.convert_to_tensor( + output_shape, tf.dtypes.int32, name='output_shape' + ) + + if not output_shape.get_shape().is_compatible_with([2]): + raise ValueError( + 'output_shape must be a 1-D Tensor of 2 elements: new_height, new_width' + ) + + if len(transform_or_transforms.get_shape()) == 1: + transforms = transform_or_transforms[None] + elif transform_or_transforms.get_shape().ndims is None: + raise ValueError('transforms rank must be statically known') + elif len(transform_or_transforms.get_shape()) == 2: + transforms = transform_or_transforms + else: + transforms = transform_or_transforms + raise ValueError( + 'transforms should have rank 1 or 2, but got rank %d' + % len(transforms.get_shape()) + ) + + fill_value = tf.convert_to_tensor( + fill_value, dtype=tf.float32, name='fill_value' + ) + output = tf.raw_ops.ImageProjectiveTransformV3( + images=images, + transforms=transforms, + output_shape=output_shape, + interpolation=interpolation.upper(), + fill_mode=fill_mode.upper(), + fill_value=fill_value, + ) + return from_4d_image(output, original_ndims) + + +def angles_to_projective_transforms( + angles: TensorLike, + image_height: TensorLike, + image_width: TensorLike, + name: Optional[str] = None, +) -> tf.Tensor: + """Returns projective transform(s) for the given angle(s). + + Args: + angles: A scalar angle to rotate all images by, or (for batches of + images) a vector with an angle to rotate each image in the batch. The + rank must be statically known (the shape is not `TensorShape(None)`. + image_height: Height of the image(s) to be transformed. + image_width: Width of the image(s) to be transformed. + + Returns: + A tensor of shape (num_images, 8). Projective transforms which can be + given to `transform` op. + """ + with tf.name_scope(name or 'angles_to_projective_transforms'): + angle_or_angles = tf.convert_to_tensor( + angles, name='angles', dtype=tf.dtypes.float32 + ) + + if len(angle_or_angles.get_shape()) not in (0, 1): + raise ValueError('angles should have rank 0 or 1.') + + if len(angle_or_angles.get_shape()) == 0: + angles = angle_or_angles[None] + else: + angles = angle_or_angles + + cos_angles = tf.math.cos(angles) + sin_angles = tf.math.sin(angles) + x_offset = ( + (image_width - 1) + - (cos_angles * (image_width - 1) - sin_angles * (image_height - 1)) + ) / 2.0 + y_offset = ( + (image_height - 1) + - (sin_angles * (image_width - 1) + cos_angles * (image_height - 1)) + ) / 2.0 + num_angles = tf.shape(angles)[0] + return tf.concat( + values=[ + cos_angles[:, None], + -sin_angles[:, None], + x_offset[:, None], + sin_angles[:, None], + cos_angles[:, None], + y_offset[:, None], + tf.zeros((num_angles, 2), tf.dtypes.float32), + ], + axis=1, + ) + + +def rotate_img( + images: TensorLike, + angles: TensorLike, + interpolation: str = 'nearest', + fill_mode: str = 'constant', + name: Optional[str] = None, + fill_value: TensorLike = 0.0, +) -> tf.Tensor: + """Rotate image(s) counterclockwise by the passed angle(s) in radians. + + Args: + images: A tensor of shape + `(num_images, num_rows, num_columns, num_channels)` + (NHWC), `(num_rows, num_columns, num_channels)` (HWC), or + `(num_rows, num_columns)` (HW). + angles: A scalar angle to rotate all images by (if `images` has rank 4) + a vector of length num_images, with an angle for each image in the + batch. + interpolation: Interpolation mode. Supported values: "nearest", + "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + name: The name of the op. + + Returns: + Image(s) with the same type and shape as `images`, rotated by the given + angle(s). Empty space due to the rotation will be filled with zeros. + + Raises: + TypeError: If `images` is an invalid type. + """ + with tf.name_scope(name or 'rotate'): + image_or_images = tf.convert_to_tensor(images) + if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: + raise TypeError('Invalid dtype %s.' % image_or_images.dtype) + images = to_4d_image(image_or_images) + original_ndims = get_ndims(image_or_images) + + image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None] + image_width = tf.cast(tf.shape(images)[2], tf.dtypes.float32)[None] + output = transform( + images, + angles_to_projective_transforms(angles, image_height, image_width), + interpolation=interpolation, + fill_mode=fill_mode, + fill_value=fill_value, + ) + return from_4d_image(output, original_ndims) + + +def translations_to_projective_transforms( + translations: TensorLike, name: Optional[str] = None +) -> tf.Tensor: + """Returns projective transform(s) for the given translation(s). + + Args: + translations: A 2-element list representing `[dx, dy]` or a matrix of + 2-element lists representing `[dx, dy]` to translate for each image + (for a batch of images). The rank must be statically known + (the shape is not `TensorShape(None)`). + name: The name of the op. + Returns: + A tensor of shape `(num_images, 8)` projective transforms which can be + given to `tfa.image.transform`. + """ + with tf.name_scope(name or 'translations_to_projective_transforms'): + translation_or_translations = tf.convert_to_tensor( + translations, name='translations', dtype=tf.dtypes.float32 + ) + if translation_or_translations.get_shape().ndims is None: + raise TypeError( + 'translation_or_translations rank must be statically known' + ) + + if len(translation_or_translations.get_shape()) not in (1, 2): + raise TypeError('Translations should have rank 1 or 2.') + + if len(translation_or_translations.get_shape()) == 1: + translations = translation_or_translations[None] + else: + translations = translation_or_translations + + num_translations = tf.shape(translations)[0] + # The translation matrix looks like: + # [[1 0 -dx] + # [0 1 -dy] + # [0 0 1]] + # where the last entry is implicit. + # Translation matrices are always float32. + return tf.concat( + values=[ + tf.ones((num_translations, 1), tf.dtypes.float32), + tf.zeros((num_translations, 1), tf.dtypes.float32), + -translations[:, 0, None], + tf.zeros((num_translations, 1), tf.dtypes.float32), + tf.ones((num_translations, 1), tf.dtypes.float32), + -translations[:, 1, None], + tf.zeros((num_translations, 2), tf.dtypes.float32), + ], + axis=1, + ) + + +@tf.function +def translate( + images: TensorLike, + translations: TensorLike, + interpolation: str = 'nearest', + fill_mode: str = 'constant', + name: Optional[str] = None, + fill_value: TensorLike = 0.0, +) -> tf.Tensor: + """Translate image(s) by the passed vectors(s). + + Args: + images: A tensor of shape + `(num_images, num_rows, num_columns, num_channels)` (NHWC), + `(num_rows, num_columns, num_channels)` (HWC), or + `(num_rows, num_columns)` (HW). The rank must be statically known (the + shape is not `TensorShape(None)`). + translations: A vector representing `[dx, dy]` or (if `images` has rank 4) + a matrix of length num_images, with a `[dx, dy]` vector for each image + in the batch. + interpolation: Interpolation mode. Supported values: "nearest", + "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + name: The name of the op. + Returns: + Image(s) with the same type and shape as `images`, translated by the + given vector(s). Empty space due to the translation will be filled with + zeros. + Raises: + TypeError: If `images` is an invalid type. + """ + with tf.name_scope(name or 'translate'): + return transform( + images, + translations_to_projective_transforms(translations), + interpolation=interpolation, + fill_mode=fill_mode, + fill_value=fill_value, + ) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index f73a1b26e..60a76e366 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -18,11 +18,8 @@ from algoperf import jax_sharding_utils, param_utils, spec from algoperf import random_utils as prng -from algoperf.workloads.imagenet_resnet import imagenet_v2 -from algoperf.workloads.imagenet_resnet.imagenet_jax import ( - input_pipeline, - models, -) +from algoperf.workloads.imagenet_resnet import imagenet_v2, input_pipeline +from algoperf.workloads.imagenet_resnet.imagenet_jax import models from algoperf.workloads.imagenet_resnet.workload import ( BaseImagenetResNetWorkload, ) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py deleted file mode 100644 index 1c5c0d952..000000000 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py +++ /dev/null @@ -1,189 +0,0 @@ -"""PyTorch implementation of RandAugmentation. - -Adapted from: -https://pytorch.org/vision/stable/_modules/torchvision/transforms/autoaugment.html. -""" - -import math -from typing import Dict, List, Optional, Tuple - -import numpy as np -import PIL -import torch -from torch import Tensor -from torchvision.transforms import InterpolationMode -from torchvision.transforms import functional as F - -from algoperf import spec - - -def cutout(img: spec.Tensor, pad_size: int) -> spec.Tensor: - image_width, image_height = img.size - x0 = np.random.uniform(image_width) - y0 = np.random.uniform(image_height) - - # Double the pad size to match Jax implementation. - pad_size = pad_size * 2 - x0 = int(max(0, x0 - pad_size / 2.0)) - y0 = int(max(0, y0 - pad_size / 2.0)) - x1 = int(min(image_width, x0 + pad_size)) - y1 = int(min(image_height, y0 + pad_size)) - xy = (x0, y0, x1, y1) - img = img.copy() - PIL.ImageDraw.Draw(img).rectangle(xy, (128, 128, 128)) - return img - - -def solarize(img: spec.Tensor, threshold: float) -> spec.Tensor: - img = np.array(img) - new_img = np.where(img < threshold, img, 255.0 - img) - return PIL.Image.fromarray(new_img.astype(np.uint8)) - - -def solarize_add(img: spec.Tensor, addition: int = 0) -> spec.Tensor: - threshold = 128 - img = np.array(img) - added_img = img.astype(np.int64) + addition - added_img = np.clip(added_img, 0, 255).astype(np.uint8) - new_img = np.where(img < threshold, added_img, img) - return PIL.Image.fromarray(new_img) - - -def _apply_op( - img: spec.Tensor, - op_name: str, - magnitude: float, - interpolation: InterpolationMode, - fill: Optional[List[float]], -) -> spec.Tensor: - if op_name == 'ShearX': - # Magnitude should be arctan(magnitude). - img = F.affine( - img, - angle=0.0, - translate=[0, 0], - scale=1.0, - shear=[math.degrees(math.atan(magnitude)), 0.0], - interpolation=interpolation, - fill=fill, - center=[0, 0], - ) - elif op_name == 'ShearY': - # Magnitude should be arctan(magnitude). - img = F.affine( - img, - angle=0.0, - translate=[0, 0], - scale=1.0, - shear=[0.0, math.degrees(math.atan(magnitude))], - interpolation=interpolation, - fill=fill, - center=[0, 0], - ) - elif op_name == 'TranslateX': - img = F.affine( - img, - angle=0.0, - translate=[int(magnitude), 0], - scale=1.0, - interpolation=interpolation, - shear=[0.0, 0.0], - fill=fill, - ) - elif op_name == 'TranslateY': - img = F.affine( - img, - angle=0.0, - translate=[0, int(magnitude)], - scale=1.0, - interpolation=interpolation, - shear=[0.0, 0.0], - fill=fill, - ) - elif op_name == 'Rotate': - img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill) - elif op_name == 'Brightness': - img = F.adjust_brightness(img, magnitude) - elif op_name == 'Color': - img = F.adjust_saturation(img, magnitude) - elif op_name == 'Contrast': - img = F.adjust_contrast(img, magnitude) - elif op_name == 'Sharpness': - img = F.adjust_sharpness(img, magnitude) - elif op_name == 'Posterize': - img = F.posterize(img, int(magnitude)) - elif op_name == 'Cutout': - img = cutout(img, int(magnitude)) - elif op_name == 'SolarizeAdd': - img = solarize_add(img, int(magnitude)) - elif op_name == 'Solarize': - img = solarize(img, magnitude) - elif op_name == 'AutoContrast': - img = F.autocontrast(img) - elif op_name == 'Equalize': - img = F.equalize(img) - elif op_name == 'Invert': - img = F.invert(img) - elif op_name == 'Identity': - pass - else: - raise ValueError(f'The provided operator {op_name} is not recognized.') - return img - - -def ops_space() -> Dict[str, Tuple[spec.Tensor, bool]]: - return { - # op_name: (magnitudes, signed) - 'ShearX': (torch.tensor(0.3), True), - 'ShearY': (torch.tensor(0.3), True), - 'TranslateX': (torch.tensor(100), True), - 'TranslateY': (torch.tensor(100), True), - 'Rotate': (torch.tensor(30), True), - 'Brightness': (torch.tensor(1.9), False), - 'Color': (torch.tensor(1.9), False), - 'Contrast': (torch.tensor(1.9), False), - 'Sharpness': (torch.tensor(1.9), False), - 'Posterize': (torch.tensor(4), False), - 'Solarize': (torch.tensor(256), False), - 'SolarizeAdd': (torch.tensor(110), False), - 'AutoContrast': (torch.tensor(0.0), False), - 'Equalize': (torch.tensor(0.0), False), - 'Invert': (torch.tensor(0.0), False), - 'Cutout': (torch.tensor(40.0), False), - } - - -class RandAugment(torch.nn.Module): - def __init__( - self, - num_ops: int = 2, - interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[List[float]] = None, - ) -> None: - super().__init__() - self.num_ops = num_ops - self.interpolation = interpolation - self.fill = fill - - def forward(self, img: spec.Tensor) -> spec.Tensor: - fill = self.fill if self.fill is not None else 128 - channels, _, _ = F.get_dimensions(img) - if isinstance(img, Tensor): - if isinstance(fill, (int, float)): - fill = [float(fill)] * channels - elif fill is not None: - fill = [float(f) for f in fill] - - op_meta = ops_space() - for _ in range(self.num_ops): - op_index = int(torch.randint(len(op_meta), (1,)).item()) - op_name = list(op_meta.keys())[op_index] - magnitude, signed = op_meta[op_name] - magnitude = float(magnitude) - if signed and torch.randint(2, (1,)): - # With 50% prob turn the magnitude negative. - magnitude *= -1.0 - img = _apply_op( - img, op_name, magnitude, interpolation=self.interpolation, fill=fill - ) - return img diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 07f84975e..541d82165 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -3,30 +3,20 @@ import contextlib import functools import itertools -import json import math -import os -import random -import time -from pathlib import Path -from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union +from typing import Dict, Iterator, Optional, Tuple +import jax import numpy as np +import tensorflow_datasets as tfds import torch import torch.distributed as dist import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP -from torchvision import transforms -from torchvision.datasets.folder import ( - IMG_EXTENSIONS, - ImageFolder, - default_loader, -) import algoperf.random_utils as prng -from algoperf import data_utils, param_utils, pytorch_utils, spec -from algoperf.workloads.imagenet_resnet import imagenet_v2 -from algoperf.workloads.imagenet_resnet.imagenet_pytorch import randaugment +from algoperf import param_utils, pytorch_utils, spec +from algoperf.workloads.imagenet_resnet import imagenet_v2, input_pipeline from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import resnet50 from algoperf.workloads.imagenet_resnet.workload import ( BaseImagenetResNetWorkload, @@ -35,100 +25,6 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() -class CachedImageFolder(ImageFolder): - """ImageFolder that caches the file listing to avoid repeated filesystem scans.""" - - def __init__( - self, - root: Union[str, Path], - cache_file: Optional[Union[str, Path]] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - loader: Callable[[str], Any] = default_loader, - is_valid_file: Optional[Callable[[str], bool]] = None, - allow_empty: bool = False, - rebuild_cache: bool = False, - cache_build_timeout_minutes: int = 30, - ): - self.root = os.path.expanduser(root) - self.transform = transform - self.target_transform = target_transform - self.loader = loader - self.extensions = IMG_EXTENSIONS if is_valid_file is None else None - - # Default cache location: .cache_index.json in the root directory - if cache_file is None: - cache_file = os.path.join(self.root, '.cache_index.json') - self.cache_file = cache_file - - is_distributed = dist.is_available() and dist.is_initialized() - rank = dist.get_rank() if is_distributed else 0 - - cache_exists = os.path.exists(self.cache_file) - needs_rebuild = rebuild_cache or not cache_exists - - if needs_rebuild: - # We only want one process to build the cache - # and others to wait for it to finish. - if rank == 0: - self._build_and_save_cache(is_valid_file, allow_empty) - if is_distributed: - self._wait_for_cache(timeout_minutes=cache_build_timeout_minutes) - dist.barrier() - - self._load_from_cache() - - self.targets = [s[1] for s in self.samples] - self.imgs = self.samples - - def _wait_for_cache(self, timeout_minutes: int): - """Poll for cache file to exist.""" - timeout_seconds = timeout_minutes * 60 - poll_interval = 5 - elapsed = 0 - - while not os.path.exists(self.cache_file): - if elapsed >= timeout_seconds: - raise TimeoutError( - f'Timed out waiting for cache file after {timeout_minutes} minutes: {self.cache_file}' - ) - time.sleep(poll_interval) - elapsed += poll_interval - - def _load_from_cache(self): - """Load classes and samples from cache file.""" - with open(os.path.abspath(self.cache_file), 'r') as f: - cache = json.load(f) - self.classes = cache['classes'] - self.class_to_idx = cache['class_to_idx'] - # Convert relative paths back to absolute - self.samples = [ - (os.path.join(self.root, rel_path), idx) - for rel_path, idx in cache['samples'] - ] - - def _build_and_save_cache(self, is_valid_file, allow_empty): - """Scan filesystem, build index, and save to cache.""" - self.classes, self.class_to_idx = self.find_classes(self.root) - self.samples = self.make_dataset( - self.root, - class_to_idx=self.class_to_idx, - extensions=self.extensions, - is_valid_file=is_valid_file, - allow_empty=allow_empty, - ) - - cache = { - 'classes': self.classes, - 'class_to_idx': self.class_to_idx, - 'samples': [ - (os.path.relpath(path, self.root), idx) for path, idx in self.samples - ], - } - with open(os.path.abspath(self.cache_file), 'w') as f: - json.dump(cache, f) - - def imagenet_v2_to_torch( batch: Dict[str, spec.Tensor], ) -> Dict[str, spec.Tensor]: @@ -177,8 +73,6 @@ def _build_dataset( use_mixup: bool = False, use_randaug: bool = False, ) -> Iterator[Dict[str, spec.Tensor]]: - del cache - del repeat_final_dataset if split == 'test': np_iter = imagenet_v2.get_imagenet_v2_iter( data_dir, @@ -191,83 +85,48 @@ def _build_dataset( ) return map(imagenet_v2_to_torch, itertools.cycle(np_iter)) - is_train = split == 'train' - normalize = transforms.Normalize( - mean=[i / 255.0 for i in self.train_mean], - std=[i / 255.0 for i in self.train_stddev], - ) - if is_train: - transform_config = [ - transforms.RandomResizedCrop( - self.center_crop_size, - scale=self.scale_ratio_range, - ratio=self.aspect_ratio_range, - ), - transforms.RandomHorizontalFlip(), - ] - if use_randaug: - transform_config.append(randaugment.RandAugment()) - transform_config.extend([transforms.ToTensor(), normalize]) - transform_config = transforms.Compose(transform_config) - else: - transform_config = transforms.Compose( - [ - transforms.Resize(self.resize_size), - transforms.CenterCrop(self.center_crop_size), - transforms.ToTensor(), - normalize, - ] - ) - - folder = 'train' if 'train' in split else 'val' - dataset = CachedImageFolder( - os.path.join(data_dir, folder), - transform=transform_config, - cache_file='.imagenet_cache_index.json', - ) - - if split == 'eval_train': - indices = list(range(self.num_train_examples)) - random.Random(int(data_rng[0])).shuffle(indices) - dataset = torch.utils.data.Subset( - dataset, indices[: self.num_eval_train_examples] - ) + # Use shared TFDS-based input pipeline (same TFRecords as JAX) + ds_builder = tfds.builder('imagenet2012:5.1.0', data_dir=data_dir) + train = split == 'train' - sampler = None + # Calculate per-device batch size for DDP if USE_PYTORCH_DDP: - per_device_batch_size = global_batch_size // N_GPUS - ds_iter_batch_size = per_device_batch_size + batch_size = global_batch_size // N_GPUS else: - ds_iter_batch_size = global_batch_size - if USE_PYTORCH_DDP: - if is_train: - sampler = torch.utils.data.distributed.DistributedSampler( - dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True - ) - else: - sampler = data_utils.DistributedEvalSampler( - dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False - ) - - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=ds_iter_batch_size, - shuffle=not USE_PYTORCH_DDP and is_train, - sampler=sampler, - num_workers=4 if is_train else self.eval_num_workers, - pin_memory=True, - drop_last=is_train, - persistent_workers=is_train, - ) - dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE) - dataloader = data_utils.cycle( - dataloader, - custom_sampler=USE_PYTORCH_DDP, + batch_size = global_batch_size + + + ds = input_pipeline.create_split( + split, + ds_builder, + jax.tree.map(lambda x: x.astype(np.uint32), data_rng), + batch_size, + train=train, + image_size=self.center_crop_size, + resize_size=self.resize_size, + mean_rgb=self.train_mean, + stddev_rgb=self.train_stddev, + cache=not train if cache is None else cache, + repeat_final_dataset=repeat_final_dataset if repeat_final_dataset is not None else train, + aspect_ratio_range=self.aspect_ratio_range, + area_range=self.scale_ratio_range, use_mixup=use_mixup, mixup_alpha=0.2, + use_randaug=use_randaug, + image_format='NCHW', + threadpool_size=12 if USE_PYTORCH_DDP else 48, ) - return dataloader + # Wrap to convert TF tensors to PyTorch tensors on device + def tf_to_pytorch_iter(tf_ds) -> Iterator[Dict[str, spec.Tensor]]: + for batch in tf_ds: + inputs = torch.from_numpy(batch['inputs'].numpy()).to(DEVICE) + targets = torch.from_numpy(batch['targets'].numpy()).to( + DEVICE, dtype=torch.long + ) + yield {'inputs': inputs, 'targets': targets} + + return tf_to_pytorch_iter(ds) def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_v2.py b/algoperf/workloads/imagenet_resnet/imagenet_v2.py index 75f844b86..31d9751a5 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_v2.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_v2.py @@ -9,7 +9,7 @@ import tensorflow_datasets as tfds from algoperf import data_utils, spec -from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline +from algoperf.workloads.imagenet_resnet import input_pipeline def get_imagenet_v2_iter( diff --git a/algoperf/workloads/imagenet_resnet/input_pipeline.py b/algoperf/workloads/imagenet_resnet/input_pipeline.py new file mode 100644 index 000000000..843ed2424 --- /dev/null +++ b/algoperf/workloads/imagenet_resnet/input_pipeline.py @@ -0,0 +1,455 @@ +"""ImageNet input pipeline. + +Forked from Flax example which can be found here: +https://github.com/google/flax/blob/main/examples/imagenet/input_pipeline.py. +""" + +import functools +from typing import Dict, Iterator, Tuple + +import jax +import tensorflow as tf +import tensorflow_datasets as tfds + +from algoperf import data_utils, spec +from algoperf.workloads.imagenet_resnet import randaugment + +TFDS_SPLIT_NAME = { + 'train': 'train', + 'eval_train': 'train', + 'validation': 'validation', +} + + +def _distorted_bounding_box_crop( + image_bytes: spec.Tensor, + rng: spec.RandomState, + bbox: spec.Tensor, + min_object_covered: float = 0.1, + aspect_ratio_range: Tuple[float, float] = (0.75, 1.33), + area_range: Tuple[float, float] = (0.05, 1.0), + max_attempts: int = 100, +) -> spec.Tensor: + """Generates cropped_image using one of the bboxes randomly distorted. + + See `tf.image.sample_distorted_bounding_box` for more documentation. + + Args: + image_bytes: `Tensor` of binary image data. + rng: a per-example, per-step unique RNG seed. + bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` + where each coordinate is [0, 1) and the coordinates are arranged + as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole + image. + min_object_covered: An optional `float`. Defaults to `0.1`. The cropped + area of the image must contain at least this fraction of any bounding + box supplied. + aspect_ratio_range: An optional list of `float`s. The cropped area of the + image must have an aspect ratio = width / height within this range. + area_range: An optional list of `float`s. The cropped area of the image + must contain a fraction of the supplied image within in this range. + max_attempts: An optional `int`. Number of attempts at generating a cropped + region of the image of the specified constraints. After `max_attempts` + failures, return the entire image. + + Returns: + cropped image `Tensor` + """ + shape = tf.io.extract_jpeg_shape(image_bytes) + bbox_begin, bbox_size, _ = tf.image.stateless_sample_distorted_bounding_box( + shape, + seed=rng, + bounding_boxes=bbox, + min_object_covered=min_object_covered, + aspect_ratio_range=aspect_ratio_range, + area_range=area_range, + max_attempts=max_attempts, + use_image_if_no_bounding_boxes=True, + ) + + # Crop the image to the specified bounding box. + offset_y, offset_x, _ = tf.unstack(bbox_begin) + target_height, target_width, _ = tf.unstack(bbox_size) + crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) + image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) + return image + + +def resize(image: spec.Tensor, image_size: int) -> spec.Tensor: + """Resizes the image given the image size. + + Args: + image: `Tensor` of image data. + image_size: A size of the image to be reshaped. + + Returns: + Resized image 'Tensor'. + """ + return tf.image.resize( + [image], [image_size, image_size], method=tf.image.ResizeMethod.BICUBIC + )[0] + + +def _at_least_x_are_equal(a: spec.Tensor, b: spec.Tensor, x: float) -> bool: + """At least `x` of `a` and `b` `Tensors` are equal.""" + match = tf.equal(a, b) + match = tf.cast(match, tf.int32) + return tf.greater_equal(tf.reduce_sum(match), x) + + +def _decode_and_random_crop( + image_bytes: spec.Tensor, + rng: spec.RandomState, + image_size: int, + aspect_ratio_range: Tuple[float, float], + area_range: Tuple[float, float], + resize_size: int, +) -> spec.Tensor: + """Make a random crop of image_size.""" + bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) + image = _distorted_bounding_box_crop( + image_bytes, + rng, + bbox, + min_object_covered=0.1, + aspect_ratio_range=aspect_ratio_range, + area_range=area_range, + max_attempts=10, + ) + original_shape = tf.io.extract_jpeg_shape(image_bytes) + bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3) + + image = tf.cond( + bad, + lambda: _decode_and_center_crop(image_bytes, image_size, resize_size), + lambda: resize(image, image_size), + ) + + return image + + +def _decode_and_center_crop( + image_bytes: spec.Tensor, image_size: int, resize_size: int +) -> spec.Tensor: + """Crops to center of image with padding then scales image_size.""" + shape = tf.io.extract_jpeg_shape(image_bytes) + image_height = shape[0] + image_width = shape[1] + + padded_center_crop_size = tf.cast( + ( + (image_size / resize_size) + * tf.cast(tf.minimum(image_height, image_width), tf.float32) + ), + tf.int32, + ) + + offset_height = ((image_height - padded_center_crop_size) + 1) // 2 + offset_width = ((image_width - padded_center_crop_size) + 1) // 2 + crop_window = tf.stack( + [ + offset_height, + offset_width, + padded_center_crop_size, + padded_center_crop_size, + ] + ) + image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) + image = resize(image, image_size) + + return image + + +def normalize_image( + image: spec.Tensor, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], +) -> spec.Tensor: + image -= tf.constant(mean_rgb, shape=[1, 1, 3], dtype=image.dtype) + image /= tf.constant(stddev_rgb, shape=[1, 1, 3], dtype=image.dtype) + return image + + +def preprocess_for_train( + image_bytes: spec.Tensor, + rng: spec.RandomState, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + aspect_ratio_range: Tuple[float, float], + area_range: Tuple[float, float], + image_size: int, + resize_size: int, + dtype: tf.DType = tf.float32, + use_randaug: bool = False, + randaug_num_layers: int = 2, + randaug_magnitude: int = 10, +) -> spec.Tensor: + """Preprocesses the given image for training. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + rng: a per-example, per-step unique RNG seed. + dtype: data type of the image. + image_size: image size. + + Returns: + A preprocessed image `Tensor`. + """ + rngs = tf.random.experimental.stateless_split(rng, 3) + + image = _decode_and_random_crop( + image_bytes, + rngs[0], + image_size, + aspect_ratio_range, + area_range, + resize_size, + ) + image = tf.reshape(image, [image_size, image_size, 3]) + image = tf.image.stateless_random_flip_left_right(image, seed=rngs[1]) + + if use_randaug: + image = tf.cast(tf.clip_by_value(image, 0, 255), tf.uint8) + image = randaugment.distort_image_with_randaugment( + image, randaug_num_layers, randaug_magnitude, rngs[2] + ) + image = tf.cast(image, tf.float32) + image = normalize_image(image, mean_rgb, stddev_rgb) + image = tf.image.convert_image_dtype(image, dtype=dtype) + return image + + +def preprocess_for_eval( + image_bytes: spec.Tensor, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + image_size: int, + resize_size: int, + dtype: tf.DType = tf.float32, +) -> spec.Tensor: + """Preprocesses the given image for evaluation. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + dtype: data type of the image. + image_size: image size. + + Returns: + A preprocessed image `Tensor`. + """ + image = _decode_and_center_crop(image_bytes, image_size, resize_size) + image = tf.reshape(image, [image_size, image_size, 3]) + image = normalize_image(image, mean_rgb, stddev_rgb) + image = tf.image.convert_image_dtype(image, dtype=dtype) + return image + + +# Modified from +# github.com/google/init2winit/blob/master/init2winit/dataset_lib/ (cont. below) +# image_preprocessing.py. +def mixup_tf( + key: spec.RandomState, + inputs: spec.Tensor, + targets: spec.Tensor, + alpha: float = 0.2, +) -> Tuple[spec.Tensor, spec.Tensor]: + """Perform mixup https://arxiv.org/abs/1710.09412. + + NOTE: Code taken from https://github.com/google/big_vision with variables + renamed to match `mixup` in this file and logic to synchronize globally. + + Args: + key: The random key to use. + inputs: inputs to mix. + targets: targets to mix. + alpha: the beta/dirichlet concentration parameter, typically 0.1 or 0.2. + + Returns: + Mixed inputs and targets. + """ + key_a = tf.random.experimental.stateless_fold_in(key, 0) + key_b = tf.random.experimental.stateless_fold_in(key_a, 0) + + gamma_a = tf.random.stateless_gamma((1,), key_a, alpha) + gamma_b = tf.random.stateless_gamma((1,), key_b, alpha) + weight = tf.squeeze(gamma_a / (gamma_a + gamma_b)) + # Transform to one-hot targets. + targets = tf.one_hot(targets, 1000) + + inputs = weight * inputs + (1.0 - weight) * tf.roll(inputs, 1, axis=0) + targets = weight * targets + (1.0 - weight) * tf.roll(targets, 1, axis=0) + return inputs, targets + + +def create_split( + split, + dataset_builder, + rng, + global_batch_size, + train, + image_size, + resize_size, + mean_rgb, + stddev_rgb, + cache=False, + repeat_final_dataset=False, + aspect_ratio_range=(0.75, 4.0 / 3.0), + area_range=(0.08, 1.0), + use_mixup=False, + mixup_alpha=0.1, + use_randaug=False, + randaug_num_layers=2, + randaug_magnitude=10, + image_format='NHWC', + threadpool_size=48, +) -> Iterator[Dict[str, spec.Tensor]]: + """Creates a split from the ImageNet dataset using TensorFlow Datasets. + + Args: + image_format: Output image format. 'NHWC' (default, TensorFlow/JAX style) + or 'NCHW' (PyTorch style). When 'NCHW', images are transposed after + batching. + """ + shuffle_rng, preprocess_rng, mixup_rng = jax.random.split(rng, 3) + + def decode_example(example_index, example): + dtype = tf.float32 + if train: + per_step_preprocess_rng = tf.random.experimental.stateless_fold_in( + tf.cast(preprocess_rng, tf.int64), example_index + ) + + image = preprocess_for_train( + example['image'], + per_step_preprocess_rng, + mean_rgb, + stddev_rgb, + aspect_ratio_range, + area_range, + image_size, + resize_size, + dtype, + use_randaug, + randaug_num_layers, + randaug_magnitude, + ) + else: + image = preprocess_for_eval( + example['image'], mean_rgb, stddev_rgb, image_size, resize_size, dtype + ) + return {'inputs': image, 'targets': example['label']} + + ds = dataset_builder.as_dataset( + split=TFDS_SPLIT_NAME[split], + decoders={ + 'image': tfds.decode.SkipDecoding(), + }, + ) + options = tf.data.Options() + options.threading.private_threadpool_size = threadpool_size + ds = ds.with_options(options) + + if cache: + ds = ds.cache() + + if train or split == 'eval_train': + ds = ds.repeat() + ds = ds.shuffle(16 * global_batch_size, seed=shuffle_rng[0]) + + # We call ds.enumerate() to get a globally unique per-example, per-step + # index that we can fold into the RNG seed. + ds = ds.enumerate() + ds = ds.map(decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) + ds = ds.batch(global_batch_size, drop_remainder=train) + + if use_mixup: + if train: + + def mixup_batch(batch_index, batch): + per_batch_mixup_rng = tf.random.experimental.stateless_fold_in( + mixup_rng, batch_index + ) + (inputs, targets) = mixup_tf( + per_batch_mixup_rng, + batch['inputs'], + batch['targets'], + alpha=mixup_alpha, + ) + batch['inputs'] = inputs + batch['targets'] = targets + return batch + + ds = ds.enumerate().map( + mixup_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE + ) + else: + raise ValueError('Mixup can only be used for the training split.') + + if repeat_final_dataset: + ds = ds.repeat() + + # Transpose to NCHW format if requested (for PyTorch) + if image_format == 'NCHW': + + def transpose_batch(batch): + # [N, H, W, C] -> [N, C, H, W] + batch['inputs'] = tf.transpose(batch['inputs'], [0, 3, 1, 2]) + return batch + + ds = ds.map(transpose_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE) + elif image_format != 'NHWC': + raise ValueError(f"image_format must be 'NHWC' or 'NCHW', got {image_format}") + + ds = ds.prefetch(10) + + return ds + + +def create_input_iter( + split: str, + dataset_builder: tfds.core.dataset_builder.DatasetBuilder, + rng: spec.RandomState, + global_batch_size: int, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + image_size: int, + resize_size: int, + aspect_ratio_range: Tuple[float, float], + area_range: Tuple[float, float], + train: bool, + cache: bool, + repeat_final_dataset: bool, + use_mixup: bool, + mixup_alpha: float, + use_randaug: bool, +) -> Iterator[Dict[str, spec.Tensor]]: + ds = create_split( + split, + dataset_builder, + rng, + global_batch_size, + train=train, + image_size=image_size, + resize_size=resize_size, + mean_rgb=mean_rgb, + stddev_rgb=stddev_rgb, + cache=cache, + repeat_final_dataset=repeat_final_dataset, + aspect_ratio_range=aspect_ratio_range, + area_range=area_range, + use_mixup=use_mixup, + mixup_alpha=mixup_alpha, + use_randaug=use_randaug, + ) + it = map( + functools.partial( + data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size + ), + ds, + ) + + # Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%. + # TODO (kasimbeg): put on device + # it = jax_utils.prefetch_to_device(it, 2) + + return iter(it) diff --git a/algoperf/workloads/imagenet_resnet/randaugment.py b/algoperf/workloads/imagenet_resnet/randaugment.py new file mode 100644 index 000000000..156beddec --- /dev/null +++ b/algoperf/workloads/imagenet_resnet/randaugment.py @@ -0,0 +1,548 @@ +"""TensorFlow implementation of RandAugmentation. + +Adapted from: +https://github.com/google/init2winit/blob/master/init2winit/dataset_lib/autoaugment.py. +""" + +import inspect +import math + +import tensorflow as tf + +from algoperf.workloads.imagenet_resnet.custom_tf_addons import ( + rotate_img, + transform, + translate, +) + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10.0 + + +def blend(image1, image2, factor): + """Blend image1 and image2 using 'factor'. + + Factor can be above 0.0. A value of 0.0 means only image1 is used. + A value of 1.0 means only image2 is used. A value between 0.0 and + 1.0 means we linearly interpolate the pixel values between the two + images. A value greater than 1.0 "extrapolates" the difference + between the two pixel values, and we clip the results to values + between 0 and 255. + + Args: + image1: An image Tensor of type uint8. + image2: An image Tensor of type uint8. + factor: A floating point value above 0.0. + + Returns: + A blended image Tensor of type uint8. + """ + if factor == 0.0: + return tf.convert_to_tensor(image1) + if factor == 1.0: + return tf.convert_to_tensor(image2) + + image1 = tf.cast(image1, tf.float32) + image2 = tf.cast(image2, tf.float32) + + difference = image2 - image1 + scaled = factor * difference + + # Do addition in float. + temp = tf.cast(image1, tf.float32) + scaled + + # Interpolate. + if 0.0 < factor < 1.0: + # Interpolation means we always stay within 0 and 255. + return tf.cast(temp, tf.uint8) + + # Extrapolate. + return tf.cast(tf.clip_by_value(temp, 0.0, 255.0), tf.uint8) + + +def cutout(image, pad_size, replace=0): + """Apply cutout (https://arxiv.org/abs/1708.04552) to image. + + This operation applies a (2*pad_size x 2*pad_size) mask of zeros to + a random location within `img`. The pixel values filled in will be of the + value `replace`. The located where the mask will be applied is randomly + chosen uniformly over the whole image. + + Args: + image: An image Tensor of type uint8. + pad_size: Specifies how big the zero mask that will be generated is that + is applied to the image. The mask will be of size + (2*pad_size x 2*pad_size). + replace: What pixel value to fill in the image in the area that has + the cutout mask applied to it. + + Returns: + An image Tensor that is of type uint8. + """ + image_height = tf.shape(image)[0] + image_width = tf.shape(image)[1] + + # Sample the center location in the image where the zero mask will be applied. + cutout_center_height = tf.random.uniform( + shape=[], minval=0, maxval=image_height, dtype=tf.int32 + ) + + cutout_center_width = tf.random.uniform( + shape=[], minval=0, maxval=image_width, dtype=tf.int32 + ) + + lower_pad = tf.maximum(0, cutout_center_height - pad_size) + upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size) + left_pad = tf.maximum(0, cutout_center_width - pad_size) + right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size) + + cutout_shape = [ + image_height - (lower_pad + upper_pad), + image_width - (left_pad + right_pad), + ] + padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] + mask = tf.pad( + tf.zeros(cutout_shape, dtype=image.dtype), padding_dims, constant_values=1 + ) + mask = tf.expand_dims(mask, -1) + mask = tf.tile(mask, [1, 1, 3]) + image = tf.where( + tf.equal(mask, 0), tf.ones_like(image, dtype=image.dtype) * replace, image + ) + return image + + +def solarize(image, threshold=128): + """Solarize the input image(s).""" + return tf.where(image < threshold, image, 255 - image) + + +def solarize_add(image, addition=0, threshold=128): + """Additive solarize the input image(s).""" + added_image = tf.cast(image, tf.int64) + addition + added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8) + return tf.where(image < threshold, added_image, image) + + +def color(image, factor): + """Equivalent of PIL Color.""" + degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image)) + return blend(degenerate, image, factor) + + +def contrast(image, factor): + """Equivalent of PIL Contrast.""" + degenerate = tf.image.rgb_to_grayscale(image) + # Cast before calling tf.histogram. + degenerate = tf.cast(degenerate, tf.int32) + + # Compute the grayscale histogram, then compute the mean pixel value, + # and create a constant image size of that value. Use that as the + # blending degenerate target of the original image. + hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256) + mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0 + degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean + degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) + degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8)) + return blend(degenerate, image, factor) + + +def brightness(image, factor): + """Equivalent of PIL Brightness.""" + degenerate = tf.zeros_like(image) + return blend(degenerate, image, factor) + + +def posterize(image, bits): + """Equivalent of PIL Posterize.""" + shift = 8 - bits + return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift) + + +def rotate(image, degrees, replace): + """Rotates the image by degrees either clockwise or counterclockwise. + + Args: + image: An image Tensor of type uint8. + degrees: Float, a scalar angle in degrees to rotate all images by. If + degrees is positive the image will be rotated clockwise otherwise it will + be rotated counterclockwise. + replace: A one or three value 1D tensor to fill empty pixels caused by + the rotate operation. + + Returns: + The rotated version of image. + """ + # Convert from degrees to radians. + degrees_to_radians = math.pi / 180.0 + radians = degrees * degrees_to_radians + + # In practice, we should randomize the rotation degrees by flipping + # it negatively half the time, but that's done on 'degrees' outside + # of the function. + image = rotate_img(wrap(image), radians) + return unwrap(image, replace) + + +def translate_x(image, pixels, replace): + """Equivalent of PIL Translate in X dimension.""" + image = translate(wrap(image), [-pixels, 0]) + return unwrap(image, replace) + + +def translate_y(image, pixels, replace): + """Equivalent of PIL Translate in Y dimension.""" + image = translate(wrap(image), [0, -pixels]) + return unwrap(image, replace) + + +def shear_x(image, level, replace): + """Equivalent of PIL Shearing in X dimension.""" + # Shear parallel to x axis is a projective transform + # with a matrix form of: + # [1 level + # 0 1]. + image = transform(wrap(image), [1.0, level, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]) + return unwrap(image, replace) + + +def shear_y(image, level, replace): + """Equivalent of PIL Shearing in Y dimension.""" + # Shear parallel to y axis is a projective transform + # with a matrix form of: + # [1 0 + # level 1]. + image = transform(wrap(image), [1.0, 0.0, 0.0, level, 1.0, 0.0, 0.0, 0.0]) + return unwrap(image, replace) + + +def autocontrast(image): + """Implements Autocontrast function from PIL using TF ops. + + Args: + image: A 3D uint8 tensor. + + Returns: + The image after it has had autocontrast applied to it and will be of type + uint8. + """ + + def scale_channel(image): + """Scale the 2D image using the autocontrast rule.""" + # A possibly cheaper version can be done using cumsum/unique_with_counts + # over the histogram values, rather than iterating over the entire image. + # to compute mins and maxes. + lo = tf.cast(tf.reduce_min(image), tf.float32) + hi = tf.cast(tf.reduce_max(image), tf.float32) + + # Scale the image, making the lowest value 0 and the highest value 255. + def scale_values(im): + scale = 255.0 / (hi - lo) + offset = -lo * scale + im = tf.cast(im, tf.float32) * scale + offset + im = tf.clip_by_value(im, 0.0, 255.0) + return tf.cast(im, tf.uint8) + + result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image) + return result + + # Assumes RGB for now. Scales each channel independently + # and then stacks the result. + s1 = scale_channel(image[:, :, 0]) + s2 = scale_channel(image[:, :, 1]) + s3 = scale_channel(image[:, :, 2]) + image = tf.stack([s1, s2, s3], 2) + return image + + +def sharpness(image, factor): + """Implements Sharpness function from PIL using TF ops.""" + orig_image = image + image = tf.cast(image, tf.float32) + # Make image 4D for conv operation. + image = tf.expand_dims(image, 0) + # SMOOTH PIL Kernel. + kernel = ( + tf.constant( + [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, shape=[3, 3, 1, 1] + ) + / 13.0 + ) + # Tile across channel dimension. + kernel = tf.tile(kernel, [1, 1, 3, 1]) + strides = [1, 1, 1, 1] + with tf.device('/cpu:0'): + # Some augmentation that uses depth-wise conv will cause crashing when + # training on GPU. + degenerate = tf.nn.depthwise_conv2d( + image, kernel, strides, padding='VALID', dilations=[1, 1] + ) + degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) + degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0]) + + # For the borders of the resulting image, fill in the values of the + # original image. + mask = tf.ones_like(degenerate) + padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]]) + padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]]) + result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image) + + # Blend the final result. + return blend(result, orig_image, factor) + + +def equalize(image): + """Implements Equalize function from PIL using TF ops.""" + + def scale_channel(im, c): + """Scale the data in the channel to implement equalize.""" + im = tf.cast(im[:, :, c], tf.int32) + # Compute the histogram of the image channel. + histo = tf.histogram_fixed_width(im, [0, 255], nbins=256) + + # For the purposes of computing the step, filter out the nonzeros. + nonzero = tf.where(tf.not_equal(histo, 0)) + nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1]) + step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255 + + def build_lut(histo, step): + # Compute the cumulative sum, shifting by step // 2 + # and then normalization by step. + lut = (tf.cumsum(histo) + (step // 2)) // step + # Shift lut, prepending with 0. + lut = tf.concat([[0], lut[:-1]], 0) + # Clip the counts to be in range. This is done + # in the C code for image.point. + return tf.clip_by_value(lut, 0, 255) + + # If step is zero, return the original image. Otherwise, build + # lut from the full histogram and step and then index from it. + result = tf.cond( + tf.equal(step, 0), + lambda: im, + lambda: tf.gather(build_lut(histo, step), im), + ) + + return tf.cast(result, tf.uint8) + + # Assumes RGB for now. Scales each channel independently + # and then stacks the result. + s1 = scale_channel(image, 0) + s2 = scale_channel(image, 1) + s3 = scale_channel(image, 2) + image = tf.stack([s1, s2, s3], 2) + return image + + +def invert(image): + """Inverts the image pixels.""" + image = tf.convert_to_tensor(image) + return 255 - image + + +def wrap(image): + """Returns 'image' with an extra channel set to all 1s.""" + shape = tf.shape(image) + extended_channel = tf.ones([shape[0], shape[1], 1], image.dtype) + extended = tf.concat([image, extended_channel], 2) + return extended + + +def unwrap(image, replace): + """Unwraps an image produced by wrap. + + Where there is a 0 in the last channel for every spatial position, + the rest of the three channels in that spatial dimension are grayed + (set to 128). Operations like translate and shear on a wrapped + Tensor will leave 0s in empty locations. Some transformations look + at the intensity of values to do preprocessing, and we want these + empty pixels to assume the 'average' value, rather than pure black. + + Args: + image: A 3D Image Tensor with 4 channels. + replace: A one or three value 1D tensor to fill empty pixels. + + Returns: + image: A 3D image Tensor with 3 channels. + """ + image_shape = tf.shape(image) + # Flatten the spatial dimensions. + flattened_image = tf.reshape(image, [-1, image_shape[2]]) + + # Find all pixels where the last channel is zero. + alpha_channel = tf.expand_dims(flattened_image[..., 3], axis=-1) + + replace = tf.concat([replace, tf.ones([1], image.dtype)], 0) + + # Where they are zero, fill them in with 'replace'. + flattened_image = tf.where( + tf.equal(alpha_channel, 0), + tf.ones_like(flattened_image, dtype=image.dtype) * replace, + flattened_image, + ) + + image = tf.reshape(flattened_image, image_shape) + image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3]) + return image + + +NAME_TO_FUNC = { + 'AutoContrast': autocontrast, + 'Equalize': equalize, + 'Invert': invert, + 'Rotate': rotate, + 'Posterize': posterize, + 'Solarize': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'Sharpness': sharpness, + 'ShearX': shear_x, + 'ShearY': shear_y, + 'TranslateX': translate_x, + 'TranslateY': translate_y, + 'Cutout': cutout, +} + + +def _randomly_negate_tensor(tensor): + """With 50% prob turn the tensor negative.""" + should_flip = tf.cast(tf.floor(tf.random.uniform([]) + 0.5), tf.bool) + final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor) + return final_tensor + + +def _rotate_level_to_arg(level): + level = (level / _MAX_LEVEL) * 30.0 + level = _randomly_negate_tensor(level) + return (level,) + + +def _enhance_level_to_arg(level): + return ((level / _MAX_LEVEL) * 1.8 + 0.1,) + + +def _shear_level_to_arg(level): + level = (level / _MAX_LEVEL) * 0.3 + # Flip level to negative with 50% chance. + level = _randomly_negate_tensor(level) + return (level,) + + +def _translate_level_to_arg(level, translate_const): + level = (level / _MAX_LEVEL) * float(translate_const) + # Flip level to negative with 50% chance. + level = _randomly_negate_tensor(level) + return (level,) + + +def level_to_arg(cutout_const, translate_const): + return { + 'AutoContrast': lambda level: (), + 'Equalize': lambda level: (), + 'Invert': lambda level: (), + 'Rotate': _rotate_level_to_arg, + 'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4),), + 'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), + 'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), + 'Color': _enhance_level_to_arg, + 'Contrast': _enhance_level_to_arg, + 'Brightness': _enhance_level_to_arg, + 'Sharpness': _enhance_level_to_arg, + 'ShearX': _shear_level_to_arg, + 'ShearY': _shear_level_to_arg, + 'Cutout': lambda level: (int((level / _MAX_LEVEL) * cutout_const),), + 'TranslateX': lambda level: _translate_level_to_arg(level, translate_const), + 'TranslateY': lambda level: _translate_level_to_arg(level, translate_const), + } + + +def _parse_policy_info( + name, prob, level, replace_value, cutout_const, translate_const +): + """Return the function that corresponds to `name` and update `level` param.""" + func = NAME_TO_FUNC[name] + args = level_to_arg(cutout_const, translate_const)[name](level) + + # Check to see if prob is passed into function. This is used for operations + # where we alter bboxes independently. + if 'prob' in inspect.getfullargspec(func)[0]: + args = tuple([prob] + list(args)) + + # Add in replace arg if it is required for the function that is being called. + if 'replace' in inspect.getfullargspec(func)[0]: + # Make sure replace is the final argument + assert 'replace' == inspect.getfullargspec(func)[0][-1] + args = tuple(list(args) + [replace_value]) + + return (func, prob, args) + + +def distort_image_with_randaugment(image, num_layers, magnitude, key): + """Applies the RandAugment policy to `image`. + + RandAugment is from the paper https://arxiv.org/abs/1909.13719, + + Args: + image: `Tensor` of shape [height, width, 3] representing an image. + num_layers: Integer, the number of augmentation transformations to apply + sequentially to an image. Represented as (N) in the paper. Usually best + values will be in the range [1, 3]. + magnitude: Integer, shared magnitude across all augmentation operations. + Represented as (M) in the paper. Best values are usually in the range + [5, 30]. + key: an rng key from tf.random.experimental.stateless_fold_in. + + Returns: + The augmented version of `image`. + """ + replace_value = [128] * 3 + available_ops = [ + 'AutoContrast', + 'Equalize', + 'Invert', + 'Rotate', + 'Posterize', + 'Solarize', + 'Color', + 'Contrast', + 'Brightness', + 'Sharpness', + 'ShearX', + 'ShearY', + 'TranslateX', + 'TranslateY', + 'Cutout', + 'SolarizeAdd', + ] + + for layer_num in range(num_layers): + key = tf.random.experimental.stateless_fold_in(key, layer_num) + op_to_select = tf.random.stateless_uniform( + [], seed=key, maxval=len(available_ops), dtype=tf.int32 + ) + random_magnitude = float(magnitude) + with tf.name_scope('randaug_layer_{}'.format(layer_num)): + for i, op_name in enumerate(available_ops): + key = tf.random.experimental.stateless_fold_in(key, i) + prob = tf.random.stateless_uniform( + [], seed=key, minval=0.2, maxval=0.8, dtype=tf.float32 + ) + func, _, args = _parse_policy_info( + op_name, + prob, + random_magnitude, + replace_value, + cutout_const=40, + translate_const=100, + ) + image = tf.cond( + tf.equal(i, op_to_select), + lambda selected_func=func, selected_args=args: selected_func( + image, *selected_args + ), + lambda: image, + ) + return image diff --git a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py index 0b32199ba..bb13ed49e 100644 --- a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -301,25 +301,15 @@ def update_params( optimizer_state['scheduler'].step() # Log training metrics - loss, grad_norm, batch_size. - if global_step <= 100 or global_step % 500 == 0: + + if global_step % 100 == 0 and workload.metrics_logger is not None: with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 ) - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, - global_step, - ) - logging.info( - '%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item(), + workload.metrics_logger.append_scalar_metrics( + {'loss': loss.item(), 'grad_norm': grad_norm.item()}, global_step ) return (optimizer_state, current_param_container, new_model_state) From 36086241676d17098cc7bdb1073eaa9870648866 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 18 Dec 2025 18:45:13 +0000 Subject: [PATCH 3/5] add debug scripts --- debug/benchmark_dataloader_jax.py | 91 +++++++++++++ debug/benchmark_dataloader_pytorch.py | 133 ++++++++++++++++++ debug/benchmark_dataloaders.fish | 72 ++++++++++ debug/benchmark_model_jax.py | 188 ++++++++++++++++++++++++++ debug/benchmark_model_pytorch.py | 175 ++++++++++++++++++++++++ debug/benchmark_models.fish | 79 +++++++++++ 6 files changed, 738 insertions(+) create mode 100644 debug/benchmark_dataloader_jax.py create mode 100644 debug/benchmark_dataloader_pytorch.py create mode 100755 debug/benchmark_dataloaders.fish create mode 100644 debug/benchmark_model_jax.py create mode 100644 debug/benchmark_model_pytorch.py create mode 100755 debug/benchmark_models.fish diff --git a/debug/benchmark_dataloader_jax.py b/debug/benchmark_dataloader_jax.py new file mode 100644 index 000000000..268ab1646 --- /dev/null +++ b/debug/benchmark_dataloader_jax.py @@ -0,0 +1,91 @@ +"""Benchmark script for JAX ImageNet dataloader.""" + +import time + +import jax +import numpy as np +import tensorflow_datasets as tfds + +from algoperf.workloads.imagenet_resnet import input_pipeline + +# ImageNet constants (same as workload) +TRAIN_MEAN = (0.485 * 255, 0.456 * 255, 0.406 * 255) +TRAIN_STDDEV = (0.229 * 255, 0.224 * 255, 0.225 * 255) +CENTER_CROP_SIZE = 224 +RESIZE_SIZE = 256 +ASPECT_RATIO_RANGE = (0.75, 4.0 / 3.0) +SCALE_RATIO_RANGE = (0.08, 1.0) + + +def main(): + data_dir = '/home/ak4605/algoperf-data/imagenet/jax' + global_batch_size = 1024 + num_batches = 100 + + rng = jax.random.PRNGKey(0) + ds_builder = tfds.builder('imagenet2012:5.1.0', data_dir=data_dir) + + print(f'Creating JAX ImageNet dataloader...') + print(f'Batch size: {global_batch_size}') + print(f'Num devices: {jax.local_device_count()}') + + ds = input_pipeline.create_split( + split='train', + dataset_builder=ds_builder, + rng=rng, + global_batch_size=global_batch_size, + train=True, + image_size=CENTER_CROP_SIZE, + resize_size=RESIZE_SIZE, + mean_rgb=TRAIN_MEAN, + stddev_rgb=TRAIN_STDDEV, + cache=False, + repeat_final_dataset=True, + aspect_ratio_range=ASPECT_RATIO_RANGE, + area_range=SCALE_RATIO_RANGE, + use_mixup=False, + use_randaug=False, + image_format='NHWC', + ) + + ds_iter = iter(ds) + + # Warmup + print('Warming up...') + for i in range(5): + start = time.perf_counter() + batch = next(ds_iter) + end = time.perf_counter() + print(f' Warmup batch {i+1}/5: {(end - start)*1000:.2f}ms') + + print(f"Batch 'inputs' shape: {batch['inputs'].shape}") + + # Benchmark + print(f'Benchmarking {num_batches} batches...') + times = [] + for i in range(num_batches): + start = time.perf_counter() + batch = next(ds_iter) + # Force sync by accessing data + _ = np.asarray(batch['inputs'][0, 0, 0, 0]) + end = time.perf_counter() + times.append(end - start) + if (i + 1) % 20 == 0: + print(f' Batch {i+1}/{num_batches}: {times[-1]*1000:.2f}ms') + + times = np.array(times) + print(f'\n=== JAX DataLoader Results ===') + print(f'Mean time per batch: {times.mean()*1000:.2f}ms') + print(f'Std time per batch: {times.std()*1000:.2f}ms') + print(f'Min time per batch: {times.min()*1000:.2f}ms') + print(f'Max time per batch: {times.max()*1000:.2f}ms') + print(f'Throughput: {global_batch_size / times.mean():.2f} images/sec') + + # Print machine-readable results for the fish script + print(f'\n=== RESULTS ===') + print(f'MEAN_MS={times.mean()*1000:.2f}') + print(f'THROUGHPUT={global_batch_size / times.mean():.2f}') + + +if __name__ == '__main__': + main() diff --git a/debug/benchmark_dataloader_pytorch.py b/debug/benchmark_dataloader_pytorch.py new file mode 100644 index 000000000..03b3e2ce0 --- /dev/null +++ b/debug/benchmark_dataloader_pytorch.py @@ -0,0 +1,133 @@ +"""Benchmark script for PyTorch ImageNet dataloader using shared TFDS pipeline.""" + +import time + +import jax +import numpy as np +import tensorflow as tf +tf.config.set_visible_devices([], 'GPU') # Disable TF GPU usage +import tensorflow_datasets as tfds +import torch +import torch.distributed as dist + +from algoperf import pytorch_utils +from algoperf.workloads.imagenet_resnet import input_pipeline + +# ImageNet constants (same as workload) +TRAIN_MEAN = (0.485 * 255, 0.456 * 255, 0.406 * 255) +TRAIN_STDDEV = (0.229 * 255, 0.224 * 255, 0.225 * 255) +CENTER_CROP_SIZE = 224 +RESIZE_SIZE = 256 +ASPECT_RATIO_RANGE = (0.75, 4.0 / 3.0) +SCALE_RATIO_RANGE = (0.08, 1.0) + + +def main(): + USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() + + # Initialize DDP process group + if USE_PYTORCH_DDP: + torch.cuda.set_device(RANK) + dist.init_process_group('nccl') + + data_dir = '/home/ak4605/algoperf-data/imagenet/jax' + global_batch_size = 1024 + num_batches = 100 + + if RANK == 0: + print(f'Creating PyTorch ImageNet dataloader (shared TFDS pipeline)...') + print(f'Batch size: {global_batch_size}') + print(f'Num GPUs: {N_GPUS}') + print(f'USE_PYTORCH_DDP: {USE_PYTORCH_DDP}') + + # Calculate per-device batch size for DDP + if USE_PYTORCH_DDP: + batch_size = global_batch_size // N_GPUS + else: + batch_size = global_batch_size + + if RANK == 0: + print(f'Per-device batch size: {batch_size}') + + rng = jax.random.PRNGKey(0) + ds_builder = tfds.builder('imagenet2012:5.1.0', data_dir=data_dir) + + ds = input_pipeline.create_split( + split='train', + dataset_builder=ds_builder, + rng=rng, + global_batch_size=batch_size, + train=True, + image_size=CENTER_CROP_SIZE, + resize_size=RESIZE_SIZE, + mean_rgb=TRAIN_MEAN, + stddev_rgb=TRAIN_STDDEV, + cache=False, + repeat_final_dataset=True, + aspect_ratio_range=ASPECT_RATIO_RANGE, + area_range=SCALE_RATIO_RANGE, + use_mixup=False, + use_randaug=False, + image_format='NCHW', + threadpool_size=48 if USE_PYTORCH_DDP else 48, + ) + + ds_iter = iter(ds) + + def get_batch(): + batch = next(ds_iter) + inputs = torch.from_numpy(batch['inputs'].numpy()).to(DEVICE) + targets = torch.from_numpy(batch['targets'].numpy()).to(DEVICE, dtype=torch.long) + return {'inputs': inputs, 'targets': targets} + + # Warmup + if RANK == 0: + print('Warming up...') + for i in range(5): + start = time.perf_counter() + batch = get_batch() + end = time.perf_counter() + if RANK == 0: + print(f' Warmup batch {i+1}/5: {(end - start)*1000:.2f}ms') + + if RANK == 0: + print(f"Batch 'inputs' shape: {batch['inputs'].shape}") + + # Synchronize before benchmark + if USE_PYTORCH_DDP: + dist.barrier() + + # Benchmark + if RANK == 0: + print(f'Benchmarking {num_batches} batches...') + times = [] + for i in range(num_batches): + if USE_PYTORCH_DDP: + dist.barrier() + start = time.perf_counter() + batch = get_batch() + end = time.perf_counter() + times.append(end - start) + if RANK == 0 and (i + 1) % 20 == 0: + print(f' Batch {i+1}/{num_batches}: {times[-1]*1000:.2f}ms') + + times = np.array(times) + if RANK == 0: + print(f'\n=== PyTorch DataLoader Results ===') + print(f'Mean time per batch: {times.mean()*1000:.2f}ms') + print(f'Std time per batch: {times.std()*1000:.2f}ms') + print(f'Min time per batch: {times.min()*1000:.2f}ms') + print(f'Max time per batch: {times.max()*1000:.2f}ms') + print(f'Throughput: {global_batch_size / times.mean():.2f} images/sec') + + # Print machine-readable results for the fish script + print(f'\n=== RESULTS ===') + print(f'MEAN_MS={times.mean()*1000:.2f}') + print(f'THROUGHPUT={global_batch_size / times.mean():.2f}') + + if USE_PYTORCH_DDP: + dist.destroy_process_group() + + +if __name__ == '__main__': + main() diff --git a/debug/benchmark_dataloaders.fish b/debug/benchmark_dataloaders.fish new file mode 100755 index 000000000..a5e3e53fb --- /dev/null +++ b/debug/benchmark_dataloaders.fish @@ -0,0 +1,72 @@ +#!/usr/bin/env fish + +# Benchmark script to compare JAX vs PyTorch ImageNet dataloaders +# Usage: ./benchmark_dataloaders.fish + +set script_dir (dirname (status filename)) +set pytorch_output "$script_dir/benchmark_dataloader_pytorch.txt" +set jax_output "$script_dir/benchmark_dataloader_jax.txt" + +echo "=============================================" +echo "ImageNet DataLoader Benchmark" +echo "=============================================" +echo "" + +# Run PyTorch benchmark with DDP (4 processes) +echo ">>> Running PyTorch DataLoader Benchmark (DDP with 4 GPUs)..." +echo ">>> Activating conda environment: ap11_torch_latest" +conda activate ap11_torch_latest + +echo ">>> Output will be saved to: $pytorch_output" +torchrun --nproc_per_node=4 --standalone benchmark_dataloader_pytorch.py 2>&1 | tee $pytorch_output +set pytorch_status $status + +if test $pytorch_status -ne 0 + echo "PyTorch benchmark failed with status $pytorch_status" +end + +echo "" + +# Run JAX benchmark +echo ">>> Running JAX DataLoader Benchmark..." +echo ">>> Activating conda environment: ap11_jax" +conda activate ap11_jax + +echo ">>> Output will be saved to: $jax_output" +python benchmark_dataloader_jax.py 2>&1 | tee $jax_output +set jax_status $status + +if test $jax_status -ne 0 + echo "JAX benchmark failed with status $jax_status" +end + +echo "" + +# Extract results from output files +function extract_result + set file $argv[1] + set key $argv[2] + grep "^$key=" $file | sed "s/$key=//" +end + +# Parse PyTorch results +set pt_mean_ms (extract_result $pytorch_output "MEAN_MS") +set pt_throughput (extract_result $pytorch_output "THROUGHPUT") + +# Parse JAX results +set jax_mean_ms (extract_result $jax_output "MEAN_MS") +set jax_throughput (extract_result $jax_output "THROUGHPUT") + +echo "=============================================" +echo " RESULTS TABLE" +echo "=============================================" +echo "" +printf "%-25s %15s %15s\n" "" "PyTorch" "JAX" +echo "-------------------------------------------------------------" +printf "%-25s %12s ms %12s ms\n" "Mean Time per Batch" "$pt_mean_ms" "$jax_mean_ms" +printf "%-25s %12s/s %12s/s\n" "Throughput" "$pt_throughput" "$jax_throughput" +echo "-------------------------------------------------------------" +echo "" +echo "Note: Both use shared TFDS/TFRecords input pipeline" +echo " Batch size: 1024 (global)" +echo "" diff --git a/debug/benchmark_model_jax.py b/debug/benchmark_model_jax.py new file mode 100644 index 000000000..afebe1f92 --- /dev/null +++ b/debug/benchmark_model_jax.py @@ -0,0 +1,188 @@ +"""Benchmark script for JAX ImageNet ResNet50 model (forward + backward).""" + +import functools +import time + +import jax +import jax.numpy as jnp +import numpy as np +from flax import linen as nn +from flax.core import pop + +from algoperf import jax_sharding_utils +from algoperf.workloads.imagenet_resnet.imagenet_jax import models + +# Training config +BATCH_SIZE = 1024 +IMAGE_SIZE = 224 +NUM_CLASSES = 1000 +NUM_WARMUP = 10 +NUM_BENCHMARK = 100 + + +def main(): + print(f'=== JAX ResNet50 Model Benchmark ===') + print(f'Batch size: {BATCH_SIZE}') + print(f'Image size: {IMAGE_SIZE}') + print(f'Num devices: {jax.local_device_count()}') + print(f'Devices: {jax.devices()}') + + # Initialize model + print('\nInitializing model...') + rng = jax.random.PRNGKey(0) + model = models.ResNet50(num_classes=NUM_CLASSES, act=nn.relu, dtype=jnp.float32) + + input_shape = (BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3) + variables = model.init({'params': rng}, jnp.ones(input_shape, jnp.float32)) + model_state, params = pop(variables, 'params') + + # Replicate params and model_state across devices (like the workload does) + params = jax.tree.map( + lambda x: jax.device_put(x, jax_sharding_utils.get_replicate_sharding()), + params, + ) + model_state = jax.tree.map( + lambda x: jax.device_put(x, jax_sharding_utils.get_replicate_sharding()), + model_state, + ) + + print(f'Model initialized. Param count: {sum(p.size for p in jax.tree.leaves(params)):,}') + + # Define forward pass (jit compiled with sharding) + @functools.partial( + jax.jit, + in_shardings=( + jax_sharding_utils.get_replicate_sharding(), # params + jax_sharding_utils.get_replicate_sharding(), # model_state + jax_sharding_utils.get_batch_dim_sharding(), # inputs + ), + out_shardings=( + jax_sharding_utils.get_batch_dim_sharding(), # logits + jax_sharding_utils.get_replicate_sharding(), # new_model_state + ), + ) + def forward_fn(params, model_state, inputs): + variables = {'params': params, **model_state} + logits, new_model_state = model.apply( + variables, + inputs, + update_batch_norm=True, + mutable=['batch_stats'], + ) + return logits, new_model_state + + # Define loss function (called inside train_step, not jitted separately) + def loss_fn(params, model_state, inputs, targets): + variables = {'params': params, **model_state} + logits, new_model_state = model.apply( + variables, + inputs, + update_batch_norm=True, + mutable=['batch_stats'], + ) + one_hot_targets = jax.nn.one_hot(targets, NUM_CLASSES) + per_example_loss = -jnp.sum(one_hot_targets * jax.nn.log_softmax(logits, axis=-1), axis=-1) + loss = jnp.mean(per_example_loss) + return loss, new_model_state + + # Define forward + backward pass (jit compiled with sharding) + @functools.partial( + jax.jit, + in_shardings=( + jax_sharding_utils.get_replicate_sharding(), # params + jax_sharding_utils.get_replicate_sharding(), # model_state + jax_sharding_utils.get_batch_dim_sharding(), # inputs + jax_sharding_utils.get_batch_dim_sharding(), # targets + ), + out_shardings=( + jax_sharding_utils.get_replicate_sharding(), # loss + jax_sharding_utils.get_replicate_sharding(), # grads + jax_sharding_utils.get_replicate_sharding(), # new_model_state + ), + ) + def train_step(params, model_state, inputs, targets): + (loss, new_model_state), grads = jax.value_and_grad(loss_fn, has_aux=True)( + params, model_state, inputs, targets + ) + return loss, grads, new_model_state + + # Generate random data and shard along batch dimension + print('Generating random data...') + data_rng = jax.random.PRNGKey(42) + inputs = jax.random.normal(data_rng, (BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3), dtype=jnp.float32) + targets = jax.random.randint(jax.random.PRNGKey(43), (BATCH_SIZE,), 0, NUM_CLASSES) + + # Shard inputs along batch dimension + inputs = jax.device_put(inputs, jax_sharding_utils.get_batch_dim_sharding()) + targets = jax.device_put(targets, jax_sharding_utils.get_batch_dim_sharding()) + + print(f'Input sharding: {inputs.sharding}') + + # Warmup forward pass + print(f'\nWarming up forward pass ({NUM_WARMUP} iterations)...') + for i in range(NUM_WARMUP): + start = time.perf_counter() + logits, _ = forward_fn(params, model_state, inputs) + logits.block_until_ready() + end = time.perf_counter() + print(f' Warmup {i+1}/{NUM_WARMUP}: {(end - start)*1000:.2f}ms') + + # Benchmark forward pass + print(f'\nBenchmarking forward pass ({NUM_BENCHMARK} iterations)...') + forward_times = [] + for i in range(NUM_BENCHMARK): + start = time.perf_counter() + logits, _ = forward_fn(params, model_state, inputs) + logits.block_until_ready() + end = time.perf_counter() + forward_times.append(end - start) + if (i + 1) % 20 == 0: + print(f' Batch {i+1}/{NUM_BENCHMARK}: {forward_times[-1]*1000:.2f}ms') + + forward_times = np.array(forward_times) + print(f'\n--- Forward Pass Results ---') + print(f'Mean: {forward_times.mean()*1000:.2f}ms') + print(f'Std: {forward_times.std()*1000:.2f}ms') + print(f'Min: {forward_times.min()*1000:.2f}ms') + print(f'Max: {forward_times.max()*1000:.2f}ms') + print(f'Throughput: {BATCH_SIZE / forward_times.mean():.2f} images/sec') + + # Warmup forward + backward pass + print(f'\nWarming up forward+backward pass ({NUM_WARMUP} iterations)...') + for i in range(NUM_WARMUP): + start = time.perf_counter() + loss, grads, new_model_state = train_step(params, model_state, inputs, targets) + loss.block_until_ready() + end = time.perf_counter() + print(f' Warmup {i+1}/{NUM_WARMUP}: {(end - start)*1000:.2f}ms') + + # Benchmark forward + backward pass + print(f'\nBenchmarking forward+backward pass ({NUM_BENCHMARK} iterations)...') + train_times = [] + for i in range(NUM_BENCHMARK): + start = time.perf_counter() + loss, grads, new_model_state = train_step(params, model_state, inputs, targets) + loss.block_until_ready() + end = time.perf_counter() + train_times.append(end - start) + if (i + 1) % 20 == 0: + print(f' Batch {i+1}/{NUM_BENCHMARK}: {train_times[-1]*1000:.2f}ms') + + train_times = np.array(train_times) + print(f'\n--- Forward+Backward Pass Results ---') + print(f'Mean: {train_times.mean()*1000:.2f}ms') + print(f'Std: {train_times.std()*1000:.2f}ms') + print(f'Min: {train_times.min()*1000:.2f}ms') + print(f'Max: {train_times.max()*1000:.2f}ms') + print(f'Throughput: {BATCH_SIZE / train_times.mean():.2f} images/sec') + + # Print machine-readable results for the fish script + print(f'\n=== RESULTS ===') + print(f'FORWARD_MEAN_MS={forward_times.mean()*1000:.2f}') + print(f'FORWARD_THROUGHPUT={BATCH_SIZE / forward_times.mean():.2f}') + print(f'TRAIN_MEAN_MS={train_times.mean()*1000:.2f}') + print(f'TRAIN_THROUGHPUT={BATCH_SIZE / train_times.mean():.2f}') + + +if __name__ == '__main__': + main() diff --git a/debug/benchmark_model_pytorch.py b/debug/benchmark_model_pytorch.py new file mode 100644 index 000000000..02ee4d5a9 --- /dev/null +++ b/debug/benchmark_model_pytorch.py @@ -0,0 +1,175 @@ +"""Benchmark script for PyTorch ImageNet ResNet50 model (forward + backward).""" + +import time + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP + +from algoperf import pytorch_utils +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import resnet50 + +# Training config +BATCH_SIZE = 1024 +IMAGE_SIZE = 224 +NUM_CLASSES = 1000 +NUM_WARMUP = 10 +NUM_BENCHMARK = 100 + + +def main(): + USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() + + # Initialize DDP process group + if USE_PYTORCH_DDP: + torch.cuda.set_device(RANK) + dist.init_process_group('nccl') + + if RANK == 0: + print(f'=== PyTorch ResNet50 Model Benchmark ===') + print(f'Global batch size: {BATCH_SIZE}') + print(f'Image size: {IMAGE_SIZE}') + print(f'Num GPUs: {N_GPUS}') + print(f'USE_PYTORCH_DDP: {USE_PYTORCH_DDP}') + print(f'Device: {DEVICE}') + + # Calculate per-device batch size + if USE_PYTORCH_DDP: + per_device_batch_size = BATCH_SIZE // N_GPUS + else: + per_device_batch_size = BATCH_SIZE + + if RANK == 0: + print(f'Per-device batch size: {per_device_batch_size}') + + # Initialize model + if RANK == 0: + print('\nInitializing model...') + + torch.manual_seed(0) + model = resnet50(act_fnc=torch.nn.ReLU(inplace=True)) + model.to(DEVICE) + + if USE_PYTORCH_DDP: + model = DDP(model, device_ids=[RANK], output_device=RANK) + + param_count = sum(p.numel() for p in model.parameters()) + if RANK == 0: + print(f'Model initialized. Param count: {param_count:,}') + + # Compile model with torch.compile for optimized performance + if RANK == 0: + print('Compiling model with torch.compile...') + model = torch.compile(model) + + # Generate random data (NCHW format for PyTorch) + if RANK == 0: + print('Generating random data...') + + torch.manual_seed(42 + RANK) # Different data per rank + inputs = torch.randn(per_device_batch_size, 3, IMAGE_SIZE, IMAGE_SIZE, device=DEVICE) + targets = torch.randint(0, NUM_CLASSES, (per_device_batch_size,), device=DEVICE) + + # Warmup forward pass (includes torch.compile compilation) + if RANK == 0: + print(f'\nWarming up forward pass ({NUM_WARMUP} iterations)...') + + model.eval() + with torch.no_grad(): + for i in range(NUM_WARMUP): + if USE_PYTORCH_DDP: + dist.barrier() + start = time.perf_counter() + logits = model(inputs) + torch.cuda.synchronize() + end = time.perf_counter() + if RANK == 0: + print(f' Warmup {i+1}/{NUM_WARMUP}: {(end - start)*1000:.2f}ms') + + # Benchmark forward pass + if RANK == 0: + print(f'\nBenchmarking forward pass ({NUM_BENCHMARK} iterations)...') + + forward_times = [] + with torch.no_grad(): + for i in range(NUM_BENCHMARK): + if USE_PYTORCH_DDP: + dist.barrier() + start = time.perf_counter() + logits = model(inputs) + torch.cuda.synchronize() + end = time.perf_counter() + forward_times.append(end - start) + if RANK == 0 and (i + 1) % 20 == 0: + print(f' Batch {i+1}/{NUM_BENCHMARK}: {forward_times[-1]*1000:.2f}ms') + + forward_times = np.array(forward_times) + if RANK == 0: + print(f'\n--- Forward Pass Results ---') + print(f'Mean: {forward_times.mean()*1000:.2f}ms') + print(f'Std: {forward_times.std()*1000:.2f}ms') + print(f'Min: {forward_times.min()*1000:.2f}ms') + print(f'Max: {forward_times.max()*1000:.2f}ms') + print(f'Throughput: {BATCH_SIZE / forward_times.mean():.2f} images/sec') + + # Warmup forward + backward pass (includes torch.compile compilation for backward) + if RANK == 0: + print(f'\nWarming up forward+backward pass ({NUM_WARMUP} iterations)...') + + model.train() + for i in range(NUM_WARMUP): + if USE_PYTORCH_DDP: + dist.barrier() + start = time.perf_counter() + logits = model(inputs) + loss = F.cross_entropy(logits, targets) + loss.backward() + torch.cuda.synchronize() + end = time.perf_counter() + model.zero_grad() + if RANK == 0: + print(f' Warmup {i+1}/{NUM_WARMUP}: {(end - start)*1000:.2f}ms') + + # Benchmark forward + backward pass + if RANK == 0: + print(f'\nBenchmarking forward+backward pass ({NUM_BENCHMARK} iterations)...') + + train_times = [] + for i in range(NUM_BENCHMARK): + if USE_PYTORCH_DDP: + dist.barrier() + start = time.perf_counter() + logits = model(inputs) + loss = F.cross_entropy(logits, targets) + loss.backward() + torch.cuda.synchronize() + end = time.perf_counter() + train_times.append(end - start) + model.zero_grad() + if RANK == 0 and (i + 1) % 20 == 0: + print(f' Batch {i+1}/{NUM_BENCHMARK}: {train_times[-1]*1000:.2f}ms') + + train_times = np.array(train_times) + if RANK == 0: + print(f'\n--- Forward+Backward Pass Results ---') + print(f'Mean: {train_times.mean()*1000:.2f}ms') + print(f'Std: {train_times.std()*1000:.2f}ms') + print(f'Min: {train_times.min()*1000:.2f}ms') + print(f'Max: {train_times.max()*1000:.2f}ms') + print(f'Throughput: {BATCH_SIZE / train_times.mean():.2f} images/sec') + + # Print machine-readable results for the fish script + print(f'\n=== RESULTS ===') + print(f'FORWARD_MEAN_MS={forward_times.mean()*1000:.2f}') + print(f'FORWARD_THROUGHPUT={BATCH_SIZE / forward_times.mean():.2f}') + print(f'TRAIN_MEAN_MS={train_times.mean()*1000:.2f}') + print(f'TRAIN_THROUGHPUT={BATCH_SIZE / train_times.mean():.2f}') + + if USE_PYTORCH_DDP: + dist.destroy_process_group() + + +if __name__ == '__main__': + main() diff --git a/debug/benchmark_models.fish b/debug/benchmark_models.fish new file mode 100755 index 000000000..a298b9963 --- /dev/null +++ b/debug/benchmark_models.fish @@ -0,0 +1,79 @@ +#!/usr/bin/env fish + +# Benchmark script to compare JAX vs PyTorch ResNet50 model performance +# Usage: ./benchmark_models.fish + +set script_dir (dirname (status filename)) +set pytorch_output "$script_dir/benchmark_model_pytorch.txt" +set jax_output "$script_dir/benchmark_model_jax.txt" + +echo "=============================================" +echo "ResNet50 Model Benchmark (Forward + Backward)" +echo "=============================================" +echo "" + +# Run PyTorch benchmark with DDP (4 processes) +echo ">>> Running PyTorch Model Benchmark (DDP with 4 GPUs)..." +echo ">>> Activating conda environment: ap11_torch_latest" +conda activate ap11_torch_latest + +echo ">>> Output will be saved to: $pytorch_output" +torchrun --nproc_per_node=4 --standalone benchmark_model_pytorch.py 2>&1 | tee $pytorch_output +set pytorch_status $status + +if test $pytorch_status -ne 0 + echo "PyTorch benchmark failed with status $pytorch_status" +end + +echo "" + +# Run JAX benchmark +echo ">>> Running JAX Model Benchmark..." +echo ">>> Activating conda environment: ap11_jax" +conda activate ap11_jax + +echo ">>> Output will be saved to: $jax_output" +python benchmark_model_jax.py 2>&1 | tee $jax_output +set jax_status $status + +if test $jax_status -ne 0 + echo "JAX benchmark failed with status $jax_status" +end + +echo "" + +# Extract results from output files +function extract_result + set file $argv[1] + set key $argv[2] + grep "^$key=" $file | sed "s/$key=//" +end + +# Parse PyTorch results +set pt_forward_ms (extract_result $pytorch_output "FORWARD_MEAN_MS") +set pt_forward_tp (extract_result $pytorch_output "FORWARD_THROUGHPUT") +set pt_train_ms (extract_result $pytorch_output "TRAIN_MEAN_MS") +set pt_train_tp (extract_result $pytorch_output "TRAIN_THROUGHPUT") + +# Parse JAX results +set jax_forward_ms (extract_result $jax_output "FORWARD_MEAN_MS") +set jax_forward_tp (extract_result $jax_output "FORWARD_THROUGHPUT") +set jax_train_ms (extract_result $jax_output "TRAIN_MEAN_MS") +set jax_train_tp (extract_result $jax_output "TRAIN_THROUGHPUT") + +echo "=============================================" +echo " RESULTS TABLE" +echo "=============================================" +echo "" +printf "%-25s %15s %15s\n" "" "PyTorch" "JAX" +echo "-------------------------------------------------------------" +printf "%-25s %12s ms %12s ms\n" "Forward Mean" "$pt_forward_ms" "$jax_forward_ms" +printf "%-25s %12s/s %12s/s\n" "Forward Throughput" "$pt_forward_tp" "$jax_forward_tp" +echo "-------------------------------------------------------------" +printf "%-25s %12s ms %12s ms\n" "Train (Fwd+Bwd) Mean" "$pt_train_ms" "$jax_train_ms" +printf "%-25s %12s/s %12s/s\n" "Train Throughput" "$pt_train_tp" "$jax_train_tp" +echo "-------------------------------------------------------------" +echo "" +echo "Note: PyTorch uses torch.compile, JAX uses jax.jit" +echo " Both use batch size 1024 (global)" +echo "" From cbeb5948f0f01cb0e79d245480ae42f55acf1ad0 Mon Sep 17 00:00:00 2001 From: rka97 Date: Sat, 27 Dec 2025 01:50:55 +0000 Subject: [PATCH 4/5] Move datasets/ to algoperf/datasets (otherwise it gives off an error because we now import the datasets library for the lm workload) --- {datasets => algoperf/datasets}/README.md | 14 ++-- .../datasets}/dataset_setup.py | 6 +- .../datasets}/librispeech_preprocess.py | 2 +- .../datasets}/librispeech_tokenizer.py | 0 .../imagenet_pytorch/workload.py | 5 +- .../imagenet_resnet/input_pipeline.py | 8 ++- .../pytorch_nadamw_full_budget.py | 4 +- debug/benchmark_dataloader_jax.py | 22 +++---- debug/benchmark_dataloader_pytorch.py | 37 ++++++----- debug/benchmark_model_jax.py | 66 ++++++++++++------- debug/benchmark_model_pytorch.py | 50 ++++++++------ docs/GETTING_STARTED.md | 2 +- submission_runner.py | 20 ++++++ 13 files changed, 143 insertions(+), 93 deletions(-) rename {datasets => algoperf/datasets}/README.md (98%) rename {datasets => algoperf/datasets}/dataset_setup.py (99%) rename {datasets => algoperf/datasets}/librispeech_preprocess.py (98%) rename {datasets => algoperf/datasets}/librispeech_tokenizer.py (100%) diff --git a/datasets/README.md b/algoperf/datasets/README.md similarity index 98% rename from datasets/README.md rename to algoperf/datasets/README.md index 1aeb83239..e58194edd 100644 --- a/datasets/README.md +++ b/algoperf/datasets/README.md @@ -24,7 +24,7 @@ This document provides instructions on downloading and preparing all datasets ut *TL;DR to download and prepare a dataset, run `dataset_setup.py`:* ```bash -python3 datasets/dataset_setup.py \ +python3 algoperf/datasets/dataset_setup.py \ --data_dir=~/data \ -- -- @@ -88,7 +88,7 @@ By default, a user will be prompted before any files are deleted. If you do not From `algorithmic-efficiency` run: ```bash -python3 datasets/dataset_setup.py \ +python3 algoperf/datasets/dataset_setup.py \ --data_dir $DATA_DIR \ --ogbg ``` @@ -124,7 +124,7 @@ In total, it should contain 13 files (via `find -type f | wc -l`) for a total of From `algorithmic-efficiency` run: ```bash -python3 datasets/dataset_setup.py \ +python3 algoperf/datasets/dataset_setup.py \ --data_dir $DATA_DIR \ --wmt ``` @@ -194,7 +194,7 @@ you should get an email containing the URLS for "knee_singlecoil_train", "knee_singlecoil_val" and "knee_singlecoil_test". ```bash -python3 datasets/dataset_setup.py \ +python3 algoperf/datasets/dataset_setup.py \ --data_dir $DATA_DIR \ --fastmri \ --fastmri_knee_singlecoil_train_url '' \ @@ -235,7 +235,7 @@ The ImageNet data pipeline differs between the PyTorch and JAX workloads. Therefore, you will have to specify the framework (either `pytorch` or `jax`) through the framework flag. ```bash -python3 datasets/dataset_setup.py \ +python3 algoperf/datasets/dataset_setup.py \ --data_dir $DATA_DIR \ --imagenet \ --temp_dir $DATA_DIR/tmp \ @@ -349,7 +349,7 @@ In total, it should contain 20 files (via `find -type f | wc -l`) for a total of ### Criteo1TB ```bash -python3 datasets/dataset_setup.py \ +python3 algoperf/datasets/dataset_setup.py \ --data_dir $DATA_DIR \ --temp_dir $DATA_DIR/tmp \ --criteo1tb @@ -378,7 +378,7 @@ In total, it should contain 885 files (via `find -type f | wc -l`) for a total o To download, train a tokenizer and preprocess the librispeech dataset: ```bash -python3 datasets/dataset_setup.py \ +python3 algoperf/datasets/dataset_setup.py \ --data_dir $DATA_DIR \ --temp_dir $DATA_DIR/tmp \ --librispeech diff --git a/datasets/dataset_setup.py b/algoperf/datasets/dataset_setup.py similarity index 99% rename from datasets/dataset_setup.py rename to algoperf/datasets/dataset_setup.py index e110930cd..1da550ccb 100644 --- a/datasets/dataset_setup.py +++ b/algoperf/datasets/dataset_setup.py @@ -56,7 +56,7 @@ Example command: -python3 datasets/dataset_setup.py \ +python3 algoperf/datasets/dataset_setup.py \ --data_dir=~/data \ --temp_dir=/tmp/mlcommons_data --imagenet \ @@ -73,8 +73,8 @@ from algoperf.workloads.wmt import tokenizer from algoperf.workloads.wmt.input_pipeline import normalize_feature_names -from datasets import librispeech_preprocess -from datasets import librispeech_tokenizer +from algoperf.datasets import librispeech_preprocess +from algoperf.datasets import librispeech_tokenizer import functools import os diff --git a/datasets/librispeech_preprocess.py b/algoperf/datasets/librispeech_preprocess.py similarity index 98% rename from datasets/librispeech_preprocess.py rename to algoperf/datasets/librispeech_preprocess.py index 1c216db46..7443156a1 100644 --- a/datasets/librispeech_preprocess.py +++ b/algoperf/datasets/librispeech_preprocess.py @@ -14,7 +14,7 @@ from absl import logging from pydub import AudioSegment -from datasets import librispeech_tokenizer +from algoperf.datasets import librispeech_tokenizer gfile = tf.io.gfile copy = tf.io.gfile.copy diff --git a/datasets/librispeech_tokenizer.py b/algoperf/datasets/librispeech_tokenizer.py similarity index 100% rename from datasets/librispeech_tokenizer.py rename to algoperf/datasets/librispeech_tokenizer.py diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 541d82165..29ebd4cf8 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -94,7 +94,6 @@ def _build_dataset( batch_size = global_batch_size // N_GPUS else: batch_size = global_batch_size - ds = input_pipeline.create_split( split, @@ -107,7 +106,9 @@ def _build_dataset( mean_rgb=self.train_mean, stddev_rgb=self.train_stddev, cache=not train if cache is None else cache, - repeat_final_dataset=repeat_final_dataset if repeat_final_dataset is not None else train, + repeat_final_dataset=repeat_final_dataset + if repeat_final_dataset is not None + else train, aspect_ratio_range=self.aspect_ratio_range, area_range=self.scale_ratio_range, use_mixup=use_mixup, diff --git a/algoperf/workloads/imagenet_resnet/input_pipeline.py b/algoperf/workloads/imagenet_resnet/input_pipeline.py index 843ed2424..23bbf7027 100644 --- a/algoperf/workloads/imagenet_resnet/input_pipeline.py +++ b/algoperf/workloads/imagenet_resnet/input_pipeline.py @@ -396,9 +396,13 @@ def transpose_batch(batch): batch['inputs'] = tf.transpose(batch['inputs'], [0, 3, 1, 2]) return batch - ds = ds.map(transpose_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE) + ds = ds.map( + transpose_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE + ) elif image_format != 'NHWC': - raise ValueError(f"image_format must be 'NHWC' or 'NCHW', got {image_format}") + raise ValueError( + f"image_format must be 'NHWC' or 'NCHW', got {image_format}" + ) ds = ds.prefetch(10) diff --git a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py index bb13ed49e..58606037f 100644 --- a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -5,7 +5,6 @@ import torch import torch.distributed.nn as dist_nn -from absl import logging from torch import Tensor from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR @@ -300,8 +299,7 @@ def update_params( optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() - # Log training metrics - loss, grad_norm, batch_size. - + # Log training metrics - loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] diff --git a/debug/benchmark_dataloader_jax.py b/debug/benchmark_dataloader_jax.py index 268ab1646..7abed2907 100644 --- a/debug/benchmark_dataloader_jax.py +++ b/debug/benchmark_dataloader_jax.py @@ -18,14 +18,14 @@ def main(): - data_dir = '/home/ak4605/algoperf-data/imagenet/jax' + data_dir = '/home/ak4605/data/imagenet/jax' global_batch_size = 1024 num_batches = 100 rng = jax.random.PRNGKey(0) ds_builder = tfds.builder('imagenet2012:5.1.0', data_dir=data_dir) - print(f'Creating JAX ImageNet dataloader...') + print('Creating JAX ImageNet dataloader...') print(f'Batch size: {global_batch_size}') print(f'Num devices: {jax.local_device_count()}') @@ -56,7 +56,7 @@ def main(): start = time.perf_counter() batch = next(ds_iter) end = time.perf_counter() - print(f' Warmup batch {i+1}/5: {(end - start)*1000:.2f}ms') + print(f' Warmup batch {i + 1}/5: {(end - start) * 1000:.2f}ms') print(f"Batch 'inputs' shape: {batch['inputs'].shape}") @@ -71,19 +71,19 @@ def main(): end = time.perf_counter() times.append(end - start) if (i + 1) % 20 == 0: - print(f' Batch {i+1}/{num_batches}: {times[-1]*1000:.2f}ms') + print(f' Batch {i + 1}/{num_batches}: {times[-1] * 1000:.2f}ms') times = np.array(times) - print(f'\n=== JAX DataLoader Results ===') - print(f'Mean time per batch: {times.mean()*1000:.2f}ms') - print(f'Std time per batch: {times.std()*1000:.2f}ms') - print(f'Min time per batch: {times.min()*1000:.2f}ms') - print(f'Max time per batch: {times.max()*1000:.2f}ms') + print('\n=== JAX DataLoader Results ===') + print(f'Mean time per batch: {times.mean() * 1000:.2f}ms') + print(f'Std time per batch: {times.std() * 1000:.2f}ms') + print(f'Min time per batch: {times.min() * 1000:.2f}ms') + print(f'Max time per batch: {times.max() * 1000:.2f}ms') print(f'Throughput: {global_batch_size / times.mean():.2f} images/sec') # Print machine-readable results for the fish script - print(f'\n=== RESULTS ===') - print(f'MEAN_MS={times.mean()*1000:.2f}') + print('\n=== RESULTS ===') + print(f'MEAN_MS={times.mean() * 1000:.2f}') print(f'THROUGHPUT={global_batch_size / times.mean():.2f}') diff --git a/debug/benchmark_dataloader_pytorch.py b/debug/benchmark_dataloader_pytorch.py index 03b3e2ce0..7e6418c5c 100644 --- a/debug/benchmark_dataloader_pytorch.py +++ b/debug/benchmark_dataloader_pytorch.py @@ -5,13 +5,14 @@ import jax import numpy as np import tensorflow as tf + tf.config.set_visible_devices([], 'GPU') # Disable TF GPU usage -import tensorflow_datasets as tfds -import torch -import torch.distributed as dist +import tensorflow_datasets as tfds # noqa: E402 +import torch # noqa: E402 +import torch.distributed as dist # noqa: E402 -from algoperf import pytorch_utils -from algoperf.workloads.imagenet_resnet import input_pipeline +from algoperf import pytorch_utils # noqa: E402 +from algoperf.workloads.imagenet_resnet import input_pipeline # noqa: E402 # ImageNet constants (same as workload) TRAIN_MEAN = (0.485 * 255, 0.456 * 255, 0.406 * 255) @@ -30,12 +31,12 @@ def main(): torch.cuda.set_device(RANK) dist.init_process_group('nccl') - data_dir = '/home/ak4605/algoperf-data/imagenet/jax' + data_dir = '/home/ak4605/data/imagenet/jax' global_batch_size = 1024 num_batches = 100 if RANK == 0: - print(f'Creating PyTorch ImageNet dataloader (shared TFDS pipeline)...') + print('Creating PyTorch ImageNet dataloader (shared TFDS pipeline)...') print(f'Batch size: {global_batch_size}') print(f'Num GPUs: {N_GPUS}') print(f'USE_PYTORCH_DDP: {USE_PYTORCH_DDP}') @@ -77,7 +78,9 @@ def main(): def get_batch(): batch = next(ds_iter) inputs = torch.from_numpy(batch['inputs'].numpy()).to(DEVICE) - targets = torch.from_numpy(batch['targets'].numpy()).to(DEVICE, dtype=torch.long) + targets = torch.from_numpy(batch['targets'].numpy()).to( + DEVICE, dtype=torch.long + ) return {'inputs': inputs, 'targets': targets} # Warmup @@ -88,7 +91,7 @@ def get_batch(): batch = get_batch() end = time.perf_counter() if RANK == 0: - print(f' Warmup batch {i+1}/5: {(end - start)*1000:.2f}ms') + print(f' Warmup batch {i + 1}/5: {(end - start) * 1000:.2f}ms') if RANK == 0: print(f"Batch 'inputs' shape: {batch['inputs'].shape}") @@ -109,20 +112,20 @@ def get_batch(): end = time.perf_counter() times.append(end - start) if RANK == 0 and (i + 1) % 20 == 0: - print(f' Batch {i+1}/{num_batches}: {times[-1]*1000:.2f}ms') + print(f' Batch {i + 1}/{num_batches}: {times[-1] * 1000:.2f}ms') times = np.array(times) if RANK == 0: - print(f'\n=== PyTorch DataLoader Results ===') - print(f'Mean time per batch: {times.mean()*1000:.2f}ms') - print(f'Std time per batch: {times.std()*1000:.2f}ms') - print(f'Min time per batch: {times.min()*1000:.2f}ms') - print(f'Max time per batch: {times.max()*1000:.2f}ms') + print('\n=== PyTorch DataLoader Results ===') + print(f'Mean time per batch: {times.mean() * 1000:.2f}ms') + print(f'Std time per batch: {times.std() * 1000:.2f}ms') + print(f'Min time per batch: {times.min() * 1000:.2f}ms') + print(f'Max time per batch: {times.max() * 1000:.2f}ms') print(f'Throughput: {global_batch_size / times.mean():.2f} images/sec') # Print machine-readable results for the fish script - print(f'\n=== RESULTS ===') - print(f'MEAN_MS={times.mean()*1000:.2f}') + print('\n=== RESULTS ===') + print(f'MEAN_MS={times.mean() * 1000:.2f}') print(f'THROUGHPUT={global_batch_size / times.mean():.2f}') if USE_PYTORCH_DDP: diff --git a/debug/benchmark_model_jax.py b/debug/benchmark_model_jax.py index afebe1f92..94b9bc943 100644 --- a/debug/benchmark_model_jax.py +++ b/debug/benchmark_model_jax.py @@ -21,7 +21,7 @@ def main(): - print(f'=== JAX ResNet50 Model Benchmark ===') + print('=== JAX ResNet50 Model Benchmark ===') print(f'Batch size: {BATCH_SIZE}') print(f'Image size: {IMAGE_SIZE}') print(f'Num devices: {jax.local_device_count()}') @@ -30,7 +30,9 @@ def main(): # Initialize model print('\nInitializing model...') rng = jax.random.PRNGKey(0) - model = models.ResNet50(num_classes=NUM_CLASSES, act=nn.relu, dtype=jnp.float32) + model = models.ResNet50( + num_classes=NUM_CLASSES, act=nn.relu, dtype=jnp.float32 + ) input_shape = (BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3) variables = model.init({'params': rng}, jnp.ones(input_shape, jnp.float32)) @@ -46,7 +48,9 @@ def main(): model_state, ) - print(f'Model initialized. Param count: {sum(p.size for p in jax.tree.leaves(params)):,}') + print( + f'Model initialized. Param count: {sum(p.size for p in jax.tree.leaves(params)):,}' + ) # Define forward pass (jit compiled with sharding) @functools.partial( @@ -81,7 +85,9 @@ def loss_fn(params, model_state, inputs, targets): mutable=['batch_stats'], ) one_hot_targets = jax.nn.one_hot(targets, NUM_CLASSES) - per_example_loss = -jnp.sum(one_hot_targets * jax.nn.log_softmax(logits, axis=-1), axis=-1) + per_example_loss = -jnp.sum( + one_hot_targets * jax.nn.log_softmax(logits, axis=-1), axis=-1 + ) loss = jnp.mean(per_example_loss) return loss, new_model_state @@ -109,8 +115,12 @@ def train_step(params, model_state, inputs, targets): # Generate random data and shard along batch dimension print('Generating random data...') data_rng = jax.random.PRNGKey(42) - inputs = jax.random.normal(data_rng, (BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3), dtype=jnp.float32) - targets = jax.random.randint(jax.random.PRNGKey(43), (BATCH_SIZE,), 0, NUM_CLASSES) + inputs = jax.random.normal( + data_rng, (BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3), dtype=jnp.float32 + ) + targets = jax.random.randint( + jax.random.PRNGKey(43), (BATCH_SIZE,), 0, NUM_CLASSES + ) # Shard inputs along batch dimension inputs = jax.device_put(inputs, jax_sharding_utils.get_batch_dim_sharding()) @@ -125,7 +135,7 @@ def train_step(params, model_state, inputs, targets): logits, _ = forward_fn(params, model_state, inputs) logits.block_until_ready() end = time.perf_counter() - print(f' Warmup {i+1}/{NUM_WARMUP}: {(end - start)*1000:.2f}ms') + print(f' Warmup {i + 1}/{NUM_WARMUP}: {(end - start) * 1000:.2f}ms') # Benchmark forward pass print(f'\nBenchmarking forward pass ({NUM_BENCHMARK} iterations)...') @@ -137,50 +147,56 @@ def train_step(params, model_state, inputs, targets): end = time.perf_counter() forward_times.append(end - start) if (i + 1) % 20 == 0: - print(f' Batch {i+1}/{NUM_BENCHMARK}: {forward_times[-1]*1000:.2f}ms') + print( + f' Batch {i + 1}/{NUM_BENCHMARK}: {forward_times[-1] * 1000:.2f}ms' + ) forward_times = np.array(forward_times) - print(f'\n--- Forward Pass Results ---') - print(f'Mean: {forward_times.mean()*1000:.2f}ms') - print(f'Std: {forward_times.std()*1000:.2f}ms') - print(f'Min: {forward_times.min()*1000:.2f}ms') - print(f'Max: {forward_times.max()*1000:.2f}ms') + print('\n--- Forward Pass Results ---') + print(f'Mean: {forward_times.mean() * 1000:.2f}ms') + print(f'Std: {forward_times.std() * 1000:.2f}ms') + print(f'Min: {forward_times.min() * 1000:.2f}ms') + print(f'Max: {forward_times.max() * 1000:.2f}ms') print(f'Throughput: {BATCH_SIZE / forward_times.mean():.2f} images/sec') # Warmup forward + backward pass print(f'\nWarming up forward+backward pass ({NUM_WARMUP} iterations)...') for i in range(NUM_WARMUP): start = time.perf_counter() - loss, grads, new_model_state = train_step(params, model_state, inputs, targets) + loss, grads, new_model_state = train_step( + params, model_state, inputs, targets + ) loss.block_until_ready() end = time.perf_counter() - print(f' Warmup {i+1}/{NUM_WARMUP}: {(end - start)*1000:.2f}ms') + print(f' Warmup {i + 1}/{NUM_WARMUP}: {(end - start) * 1000:.2f}ms') # Benchmark forward + backward pass print(f'\nBenchmarking forward+backward pass ({NUM_BENCHMARK} iterations)...') train_times = [] for i in range(NUM_BENCHMARK): start = time.perf_counter() - loss, grads, new_model_state = train_step(params, model_state, inputs, targets) + loss, grads, new_model_state = train_step( + params, model_state, inputs, targets + ) loss.block_until_ready() end = time.perf_counter() train_times.append(end - start) if (i + 1) % 20 == 0: - print(f' Batch {i+1}/{NUM_BENCHMARK}: {train_times[-1]*1000:.2f}ms') + print(f' Batch {i + 1}/{NUM_BENCHMARK}: {train_times[-1] * 1000:.2f}ms') train_times = np.array(train_times) - print(f'\n--- Forward+Backward Pass Results ---') - print(f'Mean: {train_times.mean()*1000:.2f}ms') - print(f'Std: {train_times.std()*1000:.2f}ms') - print(f'Min: {train_times.min()*1000:.2f}ms') - print(f'Max: {train_times.max()*1000:.2f}ms') + print('\n--- Forward+Backward Pass Results ---') + print(f'Mean: {train_times.mean() * 1000:.2f}ms') + print(f'Std: {train_times.std() * 1000:.2f}ms') + print(f'Min: {train_times.min() * 1000:.2f}ms') + print(f'Max: {train_times.max() * 1000:.2f}ms') print(f'Throughput: {BATCH_SIZE / train_times.mean():.2f} images/sec') # Print machine-readable results for the fish script - print(f'\n=== RESULTS ===') - print(f'FORWARD_MEAN_MS={forward_times.mean()*1000:.2f}') + print('\n=== RESULTS ===') + print(f'FORWARD_MEAN_MS={forward_times.mean() * 1000:.2f}') print(f'FORWARD_THROUGHPUT={BATCH_SIZE / forward_times.mean():.2f}') - print(f'TRAIN_MEAN_MS={train_times.mean()*1000:.2f}') + print(f'TRAIN_MEAN_MS={train_times.mean() * 1000:.2f}') print(f'TRAIN_THROUGHPUT={BATCH_SIZE / train_times.mean():.2f}') diff --git a/debug/benchmark_model_pytorch.py b/debug/benchmark_model_pytorch.py index 02ee4d5a9..15aedfe46 100644 --- a/debug/benchmark_model_pytorch.py +++ b/debug/benchmark_model_pytorch.py @@ -28,7 +28,7 @@ def main(): dist.init_process_group('nccl') if RANK == 0: - print(f'=== PyTorch ResNet50 Model Benchmark ===') + print('=== PyTorch ResNet50 Model Benchmark ===') print(f'Global batch size: {BATCH_SIZE}') print(f'Image size: {IMAGE_SIZE}') print(f'Num GPUs: {N_GPUS}') @@ -69,8 +69,12 @@ def main(): print('Generating random data...') torch.manual_seed(42 + RANK) # Different data per rank - inputs = torch.randn(per_device_batch_size, 3, IMAGE_SIZE, IMAGE_SIZE, device=DEVICE) - targets = torch.randint(0, NUM_CLASSES, (per_device_batch_size,), device=DEVICE) + inputs = torch.randn( + per_device_batch_size, 3, IMAGE_SIZE, IMAGE_SIZE, device=DEVICE + ) + targets = torch.randint( + 0, NUM_CLASSES, (per_device_batch_size,), device=DEVICE + ) # Warmup forward pass (includes torch.compile compilation) if RANK == 0: @@ -86,7 +90,7 @@ def main(): torch.cuda.synchronize() end = time.perf_counter() if RANK == 0: - print(f' Warmup {i+1}/{NUM_WARMUP}: {(end - start)*1000:.2f}ms') + print(f' Warmup {i + 1}/{NUM_WARMUP}: {(end - start) * 1000:.2f}ms') # Benchmark forward pass if RANK == 0: @@ -103,15 +107,17 @@ def main(): end = time.perf_counter() forward_times.append(end - start) if RANK == 0 and (i + 1) % 20 == 0: - print(f' Batch {i+1}/{NUM_BENCHMARK}: {forward_times[-1]*1000:.2f}ms') + print( + f' Batch {i + 1}/{NUM_BENCHMARK}: {forward_times[-1] * 1000:.2f}ms' + ) forward_times = np.array(forward_times) if RANK == 0: - print(f'\n--- Forward Pass Results ---') - print(f'Mean: {forward_times.mean()*1000:.2f}ms') - print(f'Std: {forward_times.std()*1000:.2f}ms') - print(f'Min: {forward_times.min()*1000:.2f}ms') - print(f'Max: {forward_times.max()*1000:.2f}ms') + print('\n--- Forward Pass Results ---') + print(f'Mean: {forward_times.mean() * 1000:.2f}ms') + print(f'Std: {forward_times.std() * 1000:.2f}ms') + print(f'Min: {forward_times.min() * 1000:.2f}ms') + print(f'Max: {forward_times.max() * 1000:.2f}ms') print(f'Throughput: {BATCH_SIZE / forward_times.mean():.2f} images/sec') # Warmup forward + backward pass (includes torch.compile compilation for backward) @@ -130,11 +136,13 @@ def main(): end = time.perf_counter() model.zero_grad() if RANK == 0: - print(f' Warmup {i+1}/{NUM_WARMUP}: {(end - start)*1000:.2f}ms') + print(f' Warmup {i + 1}/{NUM_WARMUP}: {(end - start) * 1000:.2f}ms') # Benchmark forward + backward pass if RANK == 0: - print(f'\nBenchmarking forward+backward pass ({NUM_BENCHMARK} iterations)...') + print( + f'\nBenchmarking forward+backward pass ({NUM_BENCHMARK} iterations)...' + ) train_times = [] for i in range(NUM_BENCHMARK): @@ -149,22 +157,22 @@ def main(): train_times.append(end - start) model.zero_grad() if RANK == 0 and (i + 1) % 20 == 0: - print(f' Batch {i+1}/{NUM_BENCHMARK}: {train_times[-1]*1000:.2f}ms') + print(f' Batch {i + 1}/{NUM_BENCHMARK}: {train_times[-1] * 1000:.2f}ms') train_times = np.array(train_times) if RANK == 0: - print(f'\n--- Forward+Backward Pass Results ---') - print(f'Mean: {train_times.mean()*1000:.2f}ms') - print(f'Std: {train_times.std()*1000:.2f}ms') - print(f'Min: {train_times.min()*1000:.2f}ms') - print(f'Max: {train_times.max()*1000:.2f}ms') + print('\n--- Forward+Backward Pass Results ---') + print(f'Mean: {train_times.mean() * 1000:.2f}ms') + print(f'Std: {train_times.std() * 1000:.2f}ms') + print(f'Min: {train_times.min() * 1000:.2f}ms') + print(f'Max: {train_times.max() * 1000:.2f}ms') print(f'Throughput: {BATCH_SIZE / train_times.mean():.2f} images/sec') # Print machine-readable results for the fish script - print(f'\n=== RESULTS ===') - print(f'FORWARD_MEAN_MS={forward_times.mean()*1000:.2f}') + print('\n=== RESULTS ===') + print(f'FORWARD_MEAN_MS={forward_times.mean() * 1000:.2f}') print(f'FORWARD_THROUGHPUT={BATCH_SIZE / forward_times.mean():.2f}') - print(f'TRAIN_MEAN_MS={train_times.mean()*1000:.2f}') + print(f'TRAIN_MEAN_MS={train_times.mean() * 1000:.2f}') print(f'TRAIN_THROUGHPUT={BATCH_SIZE / train_times.mean():.2f}') if USE_PYTORCH_DDP: diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md index 0cc286099..eeca5a765 100644 --- a/docs/GETTING_STARTED.md +++ b/docs/GETTING_STARTED.md @@ -191,7 +191,7 @@ Users that wish to customize their images are invited to check and modify the `S ## Download the Data -The workloads in this benchmark use 6 different datasets across 8 workloads. You may choose to download some or all of the datasets as you are developing your submission, but your submission will be scored across all 8 workloads. For instructions on obtaining and setting up the datasets see [datasets/README](/datasets/README.md#dataset-setup). +The workloads in this benchmark use 6 different datasets across 8 workloads. You may choose to download some or all of the datasets as you are developing your submission, but your submission will be scored across all 8 workloads. For instructions on obtaining and setting up the datasets see [datasets/README](/algoperf/datasets/README.md#dataset-setup). ## Develop your Submission diff --git a/submission_runner.py b/submission_runner.py index 552c99b79..0018258e3 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -312,6 +312,8 @@ def train_once( 'accumulated_logging_time': 0, 'last_step_end_time': None, } + # Step time tracking (separate from train_state to avoid checkpoint issues) + step_time_ema = None # EMA of step time in milliseconds global_step = 0 eval_results = [] preemption_count = 0 @@ -410,6 +412,24 @@ def train_once( train_step_end_time = get_time() + # Calculate step time and update EMA (includes data loading) + if train_state['last_step_end_time'] is not None: + current_step_time_ms = ( + train_step_end_time - train_state['last_step_end_time'] + ) * 1000.0 + if step_time_ema is None: + step_time_ema = current_step_time_ms + else: + step_time_ema = 0.9 * step_time_ema + 0.1 * current_step_time_ms + + # Log step time every 100 steps + # Note: global_step was incremented, so use (global_step - 1) to match + if (global_step - 1) % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + {'step_time_ms': step_time_ema if step_time_ema is not None else 0.0}, + global_step - 1, + ) + train_state['accumulated_submission_time'] += ( train_step_end_time - train_state['last_step_end_time'] ) From 1693b4f678db247ddb97ab9f172b6919bdb94071 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 8 Jan 2026 17:32:53 +0000 Subject: [PATCH 5/5] in docker startup script set imangenet data dir to jax sub dir for shared input pipeline --- docker/scripts/startup.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 1cd676d2a..3689870b3 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -226,8 +226,8 @@ fi # Set data directory and bucket (bucket is only relevant in internal mode) if [[ "${DATASET}" == "imagenet" ]]; then - DATA_DIR="${ROOT_DATA_DIR}/${DATASET}/${FRAMEWORK}" - DATA_BUCKET="${ROOT_DATA_BUCKET}/${DATASET}/${FRAMEWORK}" + DATA_DIR="${ROOT_DATA_DIR}/${DATASET}/jax" + DATA_BUCKET="${ROOT_DATA_BUCKET}/${DATASET}/jax" elif [[ ! -z "${DATASET}" ]]; then DATA_DIR="${ROOT_DATA_DIR}/${DATASET}" DATA_BUCKET="${ROOT_DATA_BUCKET}/${DATASET}"