Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions ffcv/loader/epoch_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from queue import Queue, Full
from contextlib import nullcontext
from typing import Sequence, TYPE_CHECKING
from os import environ

import torch as ch

Expand All @@ -12,13 +13,15 @@

if TYPE_CHECKING:
from .loader import Loader

IS_CUDA = ch.cuda.is_available()

QUASIRANDOM_ERROR_MSG = '''Not enough memory; try setting quasi-random ordering
(`OrderOption.QUASI_RANDOM`) in the dataloader constructor's `order` argument.
'''

ADDITIONAL_BATCHES_AHEAD = int(environ.get('FFCV_ADDITIONAL_BATCHES_AHEAD', "2"))

def select_buffer(buffer, batch_slot, count):
"""Util function to select the relevent subpart of a buffer for a given
batch_slot and batch size"""
Expand Down Expand Up @@ -59,12 +62,11 @@ def __init__(self, loader: 'Loader', order: Sequence[int]):

self.storage_state = self.memory_context.state

self.cuda_streams = [(ch.cuda.Stream() if IS_CUDA else None)
for _ in range(self.loader.batches_ahead + 2)]
self.cuda_streams = self.loader.cuda_streams

self.memory_allocations = self.loader.graph.allocate_memory(
self.loader.batch_size,
self.loader.batches_ahead + 2
self.loader.batches_ahead + ADDITIONAL_BATCHES_AHEAD
)

self.start()
Expand All @@ -80,7 +82,7 @@ def run(self):
ixes = next(self.iter_ixes)
slot = self.current_batch_slot
self.current_batch_slot = (
slot + 1) % (self.loader.batches_ahead + 2)
slot + 1) % (self.loader.batches_ahead + ADDITIONAL_BATCHES_AHEAD)
result = self.run_pipeline(b_ix, ixes, slot, events[slot])
# print("RES", b_ix, "ready")
to_output = (slot, result)
Expand All @@ -101,7 +103,8 @@ def run(self):
# Therefore batch_slot - batch_ahead must have all it's work submitted
# We will record an event of all the work submitted on the main stream
# and make sure no one overwrite the data until they are done
just_finished_slot = (slot - self.loader.batches_ahead - 1) % (self.loader.batches_ahead + 2)
just_finished_slot = ((slot - self.loader.batches_ahead - 1) %
(self.loader.batches_ahead + ADDITIONAL_BATCHES_AHEAD))
# print("JFS", just_finished_slot)
event = ch.cuda.Event()
event.record(self.current_stream)
Expand Down Expand Up @@ -173,4 +176,4 @@ def close(self):

def __del__(self):
self.close()

13 changes: 9 additions & 4 deletions ffcv/loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch as ch
import numpy as np

from .epoch_iterator import EpochIterator
from .epoch_iterator import EpochIterator, IS_CUDA, ADDITIONAL_BATCHES_AHEAD
from ..reader import Reader
from ..traversal_order.base import TraversalOrder
from ..traversal_order import Random, Sequential, QuasiRandom
Expand Down Expand Up @@ -137,6 +137,7 @@ def __init__(self,
self.distributed: bool = distributed
self.code = None
self.recompile = recompile
self.cuda_streams = None

if self.num_workers < 1:
self.num_workers = cpu_count()
Expand Down Expand Up @@ -167,7 +168,7 @@ def __init__(self,
self.pipelines = {}
self.pipeline_specs = {}
self.field_name_to_f_ix = {}

custom_pipeline_specs = {}

# Creating PipelineSpec objects from the pipeline dict passed
Expand Down Expand Up @@ -206,7 +207,7 @@ def __init__(self,
self.graph = Graph(self.pipeline_specs, self.reader.handlers,
self.field_name_to_f_ix, self.reader.metadata,
memory_read)

self.generate_code()
self.first_traversal_order = self.next_traversal_order()

Expand All @@ -223,6 +224,10 @@ def __iter__(self):
if self.code is None or self.recompile:
self.generate_code()

if self.cuda_streams is None:
self.cuda_streams = [(ch.cuda.Stream() if IS_CUDA else None)
for _ in range(self.batches_ahead + ADDITIONAL_BATCHES_AHEAD)]

return EpochIterator(self, selected_order)

def filter(self, field_name:str, condition: Callable[[Any], bool]) -> 'Loader':
Expand Down Expand Up @@ -274,5 +279,5 @@ def __len__(self):
def generate_code(self):
queries, code = self.graph.collect_requirements()
self.code = self.graph.codegen_all(code)