From a3600bd615bca4dc071cd2103ce890c8690b8fc4 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Thu, 6 Apr 2023 17:08:59 -0500 Subject: [PATCH] Reuse existing Cuda streams --- ffcv/loader/epoch_iterator.py | 17 ++++++++++------- ffcv/loader/loader.py | 13 +++++++++---- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/ffcv/loader/epoch_iterator.py b/ffcv/loader/epoch_iterator.py index 1be7cdfa..b50e2b56 100644 --- a/ffcv/loader/epoch_iterator.py +++ b/ffcv/loader/epoch_iterator.py @@ -3,6 +3,7 @@ from queue import Queue, Full from contextlib import nullcontext from typing import Sequence, TYPE_CHECKING +from os import environ import torch as ch @@ -12,13 +13,15 @@ if TYPE_CHECKING: from .loader import Loader - + IS_CUDA = ch.cuda.is_available() QUASIRANDOM_ERROR_MSG = '''Not enough memory; try setting quasi-random ordering (`OrderOption.QUASI_RANDOM`) in the dataloader constructor's `order` argument. ''' +ADDITIONAL_BATCHES_AHEAD = int(environ.get('FFCV_ADDITIONAL_BATCHES_AHEAD', "2")) + def select_buffer(buffer, batch_slot, count): """Util function to select the relevent subpart of a buffer for a given batch_slot and batch size""" @@ -59,12 +62,11 @@ def __init__(self, loader: 'Loader', order: Sequence[int]): self.storage_state = self.memory_context.state - self.cuda_streams = [(ch.cuda.Stream() if IS_CUDA else None) - for _ in range(self.loader.batches_ahead + 2)] + self.cuda_streams = self.loader.cuda_streams self.memory_allocations = self.loader.graph.allocate_memory( self.loader.batch_size, - self.loader.batches_ahead + 2 + self.loader.batches_ahead + ADDITIONAL_BATCHES_AHEAD ) self.start() @@ -80,7 +82,7 @@ def run(self): ixes = next(self.iter_ixes) slot = self.current_batch_slot self.current_batch_slot = ( - slot + 1) % (self.loader.batches_ahead + 2) + slot + 1) % (self.loader.batches_ahead + ADDITIONAL_BATCHES_AHEAD) result = self.run_pipeline(b_ix, ixes, slot, events[slot]) # print("RES", b_ix, "ready") to_output = (slot, result) @@ -101,7 +103,8 @@ def run(self): # Therefore batch_slot - batch_ahead must have all it's work submitted # We will record an event of all the work submitted on the main stream # and make sure no one overwrite the data until they are done - just_finished_slot = (slot - self.loader.batches_ahead - 1) % (self.loader.batches_ahead + 2) + just_finished_slot = ((slot - self.loader.batches_ahead - 1) % + (self.loader.batches_ahead + ADDITIONAL_BATCHES_AHEAD)) # print("JFS", just_finished_slot) event = ch.cuda.Event() event.record(self.current_stream) @@ -173,4 +176,4 @@ def close(self): def __del__(self): self.close() - + diff --git a/ffcv/loader/loader.py b/ffcv/loader/loader.py index 2a9af03e..c1c7dba9 100644 --- a/ffcv/loader/loader.py +++ b/ffcv/loader/loader.py @@ -16,7 +16,7 @@ import torch as ch import numpy as np -from .epoch_iterator import EpochIterator +from .epoch_iterator import EpochIterator, IS_CUDA, ADDITIONAL_BATCHES_AHEAD from ..reader import Reader from ..traversal_order.base import TraversalOrder from ..traversal_order import Random, Sequential, QuasiRandom @@ -137,6 +137,7 @@ def __init__(self, self.distributed: bool = distributed self.code = None self.recompile = recompile + self.cuda_streams = None if self.num_workers < 1: self.num_workers = cpu_count() @@ -167,7 +168,7 @@ def __init__(self, self.pipelines = {} self.pipeline_specs = {} self.field_name_to_f_ix = {} - + custom_pipeline_specs = {} # Creating PipelineSpec objects from the pipeline dict passed @@ -206,7 +207,7 @@ def __init__(self, self.graph = Graph(self.pipeline_specs, self.reader.handlers, self.field_name_to_f_ix, self.reader.metadata, memory_read) - + self.generate_code() self.first_traversal_order = self.next_traversal_order() @@ -223,6 +224,10 @@ def __iter__(self): if self.code is None or self.recompile: self.generate_code() + if self.cuda_streams is None: + self.cuda_streams = [(ch.cuda.Stream() if IS_CUDA else None) + for _ in range(self.batches_ahead + ADDITIONAL_BATCHES_AHEAD)] + return EpochIterator(self, selected_order) def filter(self, field_name:str, condition: Callable[[Any], bool]) -> 'Loader': @@ -274,5 +279,5 @@ def __len__(self): def generate_code(self): queries, code = self.graph.collect_requirements() self.code = self.graph.codegen_all(code) - +