Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions pyathena/aio/s3fs/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pyathena.aio.common import WithAsyncFetch
from pyathena.common import CursorIterator
from pyathena.error import OperationalError, ProgrammingError
from pyathena.filesystem.s3_async import AioS3FileSystem
from pyathena.model import AthenaQueryExecution
from pyathena.s3fs.converter import DefaultS3FSTypeConverter
from pyathena.s3fs.result_set import AthenaS3FSResultSet, CSVReaderType
Expand All @@ -16,11 +17,12 @@


class AioS3FSCursor(WithAsyncFetch):
"""Native asyncio cursor that reads CSV results via S3FileSystem.
"""Native asyncio cursor that reads CSV results via AioS3FileSystem.

Uses ``asyncio.to_thread()`` for result set creation and fetch operations
because ``AthenaS3FSResultSet`` lazily streams rows from S3 via a CSV
reader, making fetch calls blocking I/O.
Uses ``AioS3FileSystem`` for S3 operations, which replaces
``ThreadPoolExecutor`` parallelism with ``asyncio.gather`` +
``asyncio.to_thread``. Fetch operations are wrapped in
``asyncio.to_thread()`` because CSV reading is blocking I/O.

Example:
>>> async with await pyathena.aio_connect(...) as conn:
Expand Down Expand Up @@ -127,6 +129,7 @@ async def execute( # type: ignore[override]
arraysize=self.arraysize,
retry_config=self._retry_config,
csv_reader=self._csv_reader,
filesystem_class=AioS3FileSystem,
**kwargs,
)
else:
Expand Down
53 changes: 31 additions & 22 deletions pyathena/filesystem/s3.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import itertools
import logging
import mimetypes
import os.path
import re
from concurrent.futures import Future, as_completed
from concurrent.futures.thread import ThreadPoolExecutor
from copy import deepcopy
from datetime import datetime
from multiprocessing import cpu_count
Expand All @@ -23,6 +21,7 @@
from fsspec.utils import tokenize

import pyathena
from pyathena.filesystem.s3_executor import S3Executor, S3ThreadPoolExecutor
from pyathena.filesystem.s3_object import (
S3CompleteMultipartUpload,
S3MultipartUpload,
Expand Down Expand Up @@ -686,6 +685,20 @@ def _delete_object(
**request,
)

def _create_executor(self, max_workers: int) -> S3Executor:
"""Create an executor strategy for parallel operations.

Subclasses can override to provide alternative execution strategies
(e.g., asyncio-based execution).

Args:
max_workers: Maximum number of parallel workers.

Returns:
An S3Executor instance.
"""
return S3ThreadPoolExecutor(max_workers=max_workers)

def _delete_objects(
self, bucket: str, paths: List[str], max_workers: Optional[int] = None, **kwargs
) -> None:
Expand All @@ -703,7 +716,7 @@ def _delete_objects(
object_.update({"VersionId": version_id})
delete_objects.append(object_)

with ThreadPoolExecutor(max_workers=max_workers) as executor:
with self._create_executor(max_workers=max_workers) as executor:
fs = []
for delete in [
delete_objects[i : i + self.DELETE_OBJECTS_MAX_KEYS]
Expand Down Expand Up @@ -861,7 +874,7 @@ def _copy_object_with_multipart_upload(
**kwargs,
)
parts = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
with self._create_executor(max_workers=max_workers) as executor:
fs = [
executor.submit(
self._upload_part_copy,
Expand Down Expand Up @@ -1106,6 +1119,7 @@ def _open(
mode,
version_id=None,
max_workers=max_workers,
executor=self._create_executor(max_workers=max_workers),
block_size=block_size,
cache_type=cache_type,
autocommit=autocommit,
Expand Down Expand Up @@ -1256,6 +1270,7 @@ def __init__(
mode: str = "rb",
version_id: Optional[str] = None,
max_workers: int = (cpu_count() or 1) * 5,
executor: Optional[S3Executor] = None,
block_size: int = S3FileSystem.DEFAULT_BLOCK_SIZE,
cache_type: str = "bytes",
autocommit: bool = True,
Expand All @@ -1265,7 +1280,7 @@ def __init__(
**kwargs,
) -> None:
self.max_workers = max_workers
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._executor: S3Executor = executor or S3ThreadPoolExecutor(max_workers=max_workers)
self.s3_additional_kwargs = s3_additional_kwargs if s3_additional_kwargs else {}

super().__init__(
Expand Down Expand Up @@ -1481,24 +1496,18 @@ def _fetch_range(self, start: int, end: int) -> bytes:
start, end, max_workers=self.max_workers, worker_block_size=self.blocksize
)
if len(ranges) > 1:
object_ = self._merge_objects(
list(
self._executor.map(
lambda bucket, key, ranges, version_id, kwargs: self.fs._get_object(
bucket=bucket,
key=key,
ranges=ranges,
version_id=version_id,
**kwargs,
),
itertools.repeat(self.bucket),
itertools.repeat(self.key),
ranges,
itertools.repeat(self.version_id),
itertools.repeat(self.s3_additional_kwargs),
)
futures = [
self._executor.submit(
self.fs._get_object,
bucket=self.bucket,
key=self.key,
ranges=r,
version_id=self.version_id,
**self.s3_additional_kwargs,
)
)
for r in ranges
]
object_ = self._merge_objects([f.result() for f in as_completed(futures)])
else:
object_ = self.fs._get_object(
self.bucket,
Expand Down
Loading