From 681e7496d543586b547665145beb16bbf09e84c7 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 19 Jul 2025 12:27:16 +0900 Subject: [PATCH] feat: implement maxdepth and withdirs options for S3FileSystem.find() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add maxdepth parameter to limit directory traversal depth - Uses recursive approach with delimiter="/" for efficiency - Compatible with s3fs behavior - Add withdirs parameter to include directories in results - Default is False (returns only files) - When True, includes directory entries - Improve cache management using (path, delimiter) tuple as key - Add comprehensive tests for both parameters Implementation details: - When maxdepth is specified, uses level-by-level traversal - When withdirs=True and delimiter="", derives directories from file paths - Ensures compatibility with s3fs conventions 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- CLAUDE.md | 74 ++++++++++ pyathena/filesystem/s3.py | 208 ++++++++++++++++++++------- tests/pyathena/filesystem/test_s3.py | 58 +++++++- 3 files changed, 288 insertions(+), 52 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index ded2bd73..cf968ba9 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -31,6 +31,8 @@ The project supports different cursor implementations for various use cases: ## Development Guidelines ### Code Style and Quality + +#### Commands ```bash # Format code (auto-fix imports and format) make fmt @@ -51,12 +53,67 @@ make tox make docs ``` +#### Docstring Style +Use Google style docstrings for all public methods and complex internal methods: + +```python +def method_name(self, param1: str, param2: Optional[int] = None) -> List[str]: + """Brief description of what the method does. + + Longer description if needed, explaining the method's behavior, + edge cases, or important details. + + Args: + param1: Description of the first parameter. + param2: Description of the optional parameter. + + Returns: + Description of the return value. + + Raises: + ValueError: When invalid parameters are provided. + """ +``` + ### Testing Requirements + +#### General Guidelines 1. **Unit Tests**: All new features must include unit tests 2. **Integration Tests**: Test actual AWS Athena interactions when modifying query execution logic 3. **SQLAlchemy Compliance**: Ensure SQLAlchemy dialect tests pass when modifying dialect code 4. **Mock AWS Services**: Use `moto` or similar for testing AWS interactions without real resources +#### Writing Tests +- Place tests in `tests/pyathena/` mirroring the source structure +- Use pytest fixtures for common setup (see `conftest.py`) +- Test both success and error cases +- For filesystem operations, test edge cases like empty results, missing files, etc. + +Example test structure: +```python +def test_find_maxdepth(self, fs): + """Test find with maxdepth parameter.""" + # Setup test data + dir_ = f"s3://{ENV.s3_staging_bucket}/test_path" + fs.touch(f"{dir_}/file0.txt") + fs.touch(f"{dir_}/level1/file1.txt") + + # Test maxdepth=0 + result = fs.find(dir_, maxdepth=0) + assert len(result) == 1 + assert fs._strip_protocol(f"{dir_}/file0.txt") in result + + # Test edge cases and error conditions + with pytest.raises(ValueError): + fs.find("s3://", maxdepth=0) +``` + +#### Test Organization +- Group related tests in classes (e.g., `TestS3FileSystem`) +- Use descriptive test names that explain what is being tested +- Keep tests focused and independent +- Clean up test data after each test when using real AWS resources + ### Common Development Tasks #### Adding a New Feature @@ -94,6 +151,8 @@ pyathena/ │ └── requirements.py # SQLAlchemy requirements │ └── filesystem/ # S3 filesystem abstractions + ├── s3.py # S3FileSystem implementation (fsspec compatible) + └── s3_object.py # S3 object representations ``` ### Important Implementation Details @@ -115,6 +174,21 @@ pyathena/ - Follow DB API 2.0 exception hierarchy - Provide meaningful error messages that include Athena query IDs when available +#### S3 FileSystem Operations +- `S3FileSystem` implements fsspec's `AbstractFileSystem` interface +- Key methods include `ls()`, `find()`, `get()`, `put()`, `rm()`, etc. +- `find()` method supports: + - `maxdepth`: Limits directory traversal depth (uses recursive approach for efficiency) + - `withdirs`: Controls whether directories are included in results (default: False) +- Cache management uses `(path, delimiter)` as key to handle different listing modes +- Always extract reusable logic into helper methods (e.g., `_extract_parent_directories()`) + +When implementing filesystem methods: +1. **Consider s3fs compatibility** - Many users migrate from s3fs, so matching its behavior is important +2. **Optimize for S3's API** - Use delimiter="/" for recursive operations to minimize API calls +3. **Handle edge cases** - Empty paths, trailing slashes, bucket-only paths +4. **Test with real S3** - Mock tests may not catch S3-specific behaviors + ### Performance Considerations 1. **Result Caching**: Utilize Athena's result reuse feature (engine v3) when possible 2. **Batch Operations**: Support `executemany()` for bulk operations diff --git a/pyathena/filesystem/s3.py b/pyathena/filesystem/s3.py index 8a062570..de7dff52 100644 --- a/pyathena/filesystem/s3.py +++ b/pyathena/filesystem/s3.py @@ -250,55 +250,57 @@ def _ls_dirs( bucket, key, version_id = self.parse_path(path) if key: prefix = f"{key}/{prefix if prefix else ''}" - if path not in self.dircache or refresh: - files: List[S3Object] = [] - while True: - request: Dict[Any, Any] = { - "Bucket": bucket, - "Prefix": prefix, - "Delimiter": delimiter, - } - if next_token: - request.update({"ContinuationToken": next_token}) - if max_keys: - request.update({"MaxKeys": max_keys}) - response = self._call( - self._client.list_objects_v2, - **request, - ) - files.extend( - S3Object( - init={ - "ContentLength": 0, - "ContentType": None, - "StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, - "ETag": None, - "LastModified": None, - }, - type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, - bucket=bucket, - key=c["Prefix"][:-1].rstrip("/"), - version_id=version_id, - ) - for c in response.get("CommonPrefixes", []) + + # Create a cache key that includes the delimiter + cache_key = (path, delimiter) + if cache_key in self.dircache and not refresh: + return cast(List[S3Object], self.dircache[cache_key]) + + files: List[S3Object] = [] + while True: + request: Dict[Any, Any] = { + "Bucket": bucket, + "Prefix": prefix, + "Delimiter": delimiter, + } + if next_token: + request.update({"ContinuationToken": next_token}) + if max_keys: + request.update({"MaxKeys": max_keys}) + response = self._call( + self._client.list_objects_v2, + **request, + ) + files.extend( + S3Object( + init={ + "ContentLength": 0, + "ContentType": None, + "StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, + "ETag": None, + "LastModified": None, + }, + type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, + bucket=bucket, + key=c["Prefix"][:-1].rstrip("/"), + version_id=version_id, ) - files.extend( - S3Object( - init=c, - type=S3ObjectType.S3_OBJECT_TYPE_FILE, - bucket=bucket, - key=c["Key"], - ) - for c in response.get("Contents", []) + for c in response.get("CommonPrefixes", []) + ) + files.extend( + S3Object( + init=c, + type=S3ObjectType.S3_OBJECT_TYPE_FILE, + bucket=bucket, + key=c["Key"], ) - next_token = response.get("NextContinuationToken") - if not next_token: - break - if files: - self.dircache[path] = files - else: - cache = self.dircache[path] - files = cache if isinstance(cache, list) else [cache] + for c in response.get("Contents", []) + ) + next_token = response.get("NextContinuationToken") + if not next_token: + break + if files: + self.dircache[cache_key] = files return files def ls( @@ -396,27 +398,131 @@ def info(self, path: str, **kwargs) -> S3Object: ) raise FileNotFoundError(path) - def find( + def _extract_parent_directories( + self, files: List[S3Object], bucket: str, base_key: Optional[str] + ) -> List[S3Object]: + """Extract parent directory objects from file paths. + + When listing files without delimiter, S3 doesn't return directory entries. + This method creates directory objects by analyzing file paths. + + Args: + files: List of S3Object instances representing files. + bucket: S3 bucket name. + base_key: Base key path to calculate relative paths from. + + Returns: + List of S3Object instances representing directories. + """ + dirs = set() + base_key = base_key.rstrip("/") if base_key else "" + + for f in files: + if f.key and f.type == S3ObjectType.S3_OBJECT_TYPE_FILE: + # Extract directory paths from file paths + f_key = f.key + if base_key and f_key.startswith(base_key + "/"): + relative_path = f_key[len(base_key) + 1 :] + elif not base_key: + relative_path = f_key + else: + continue + + # Get all parent directories + parts = relative_path.split("/") + for i in range(1, len(parts)): + if base_key: + dir_path = base_key + "/" + "/".join(parts[:i]) + else: + dir_path = "/".join(parts[:i]) + dirs.add(dir_path) + + # Create S3Object instances for directories + directory_objects = [] + for dir_path in dirs: + dir_obj = S3Object( + init={ + "ContentLength": 0, + "ContentType": None, + "StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, + "ETag": None, + "LastModified": None, + }, + type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, + bucket=bucket, + key=dir_path, + version_id=None, + ) + directory_objects.append(dir_obj) + + return directory_objects + + def _find( self, path: str, maxdepth: Optional[int] = None, withdirs: Optional[bool] = None, - detail: bool = False, **kwargs, - ) -> Union[Dict[str, S3Object], List[str]]: - # TODO: Support maxdepth and withdirs + ) -> List[S3Object]: path = self._strip_protocol(path) if path in ["", "/"]: raise ValueError("Cannot traverse all files in S3.") bucket, key, _ = self.parse_path(path) prefix = kwargs.pop("prefix", "") + # When maxdepth is specified, use a recursive approach with delimiter + if maxdepth is not None: + result: List[S3Object] = [] + + # List files and directories at current level + current_items = self._ls_dirs(path, prefix=prefix, delimiter="/") + + for item in current_items: + if item.type == S3ObjectType.S3_OBJECT_TYPE_FILE: + # Add files + result.append(item) + elif item.type == S3ObjectType.S3_OBJECT_TYPE_DIRECTORY: + # Add directory if withdirs is True + if withdirs: + result.append(item) + + # Recursively explore subdirectory if depth allows + if maxdepth > 0: + sub_path = f"s3://{bucket}/{item.key}" + sub_results = self._find( + sub_path, maxdepth=maxdepth - 1, withdirs=withdirs, **kwargs + ) + result.extend(sub_results) + + return result + + # For unlimited depth, use the original approach (get all files at once) files = self._ls_dirs(path, prefix=prefix, delimiter="") if not files and key: try: files = [self.info(path)] except FileNotFoundError: files = [] + + # If withdirs is True, we need to derive directories from file paths + if withdirs: + files.extend(self._extract_parent_directories(files, bucket, key)) + + # Filter directories if withdirs is False (default) + if withdirs is False or withdirs is None: + files = [f for f in files if f.type != S3ObjectType.S3_OBJECT_TYPE_DIRECTORY] + + return files + + def find( + self, + path: str, + maxdepth: Optional[int] = None, + withdirs: Optional[bool] = None, + detail: bool = False, + **kwargs, + ) -> Union[Dict[str, S3Object], List[str]]: + files = self._find(path=path, maxdepth=maxdepth, withdirs=withdirs, **kwargs) if detail: return {f.name: f for f in files} return [f.name for f in files] diff --git a/tests/pyathena/filesystem/test_s3.py b/tests/pyathena/filesystem/test_s3.py index dd514aa7..f737eb17 100644 --- a/tests/pyathena/filesystem/test_s3.py +++ b/tests/pyathena/filesystem/test_s3.py @@ -360,7 +360,6 @@ def test_info_file(self, fs): assert info.version_id == version_id def test_find(self, fs): - # TODO maxdepsth and withdirs options dir_ = f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/filesystem/test_find" for i in range(5): fs.pipe(f"{dir_}/prefix/test_{i}", bytes(i)) @@ -384,6 +383,63 @@ def test_find(self, fs): ].name == fs._strip_protocol(f"{dir_}/prefix/test_1") assert test_1_detail[fs._strip_protocol(f"{dir_}/prefix/test_1")].size == 1 + def test_find_maxdepth(self, fs): + dir_ = f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/filesystem/test_find_maxdepth" + # Create files at different depths + fs.touch(f"{dir_}/file0.txt") + fs.touch(f"{dir_}/level1/file1.txt") + fs.touch(f"{dir_}/level1/level2/file2.txt") + fs.touch(f"{dir_}/level1/level2/level3/file3.txt") + + # Test maxdepth=0 (only files in the root) + result = fs.find(dir_, maxdepth=0) + assert len(result) == 1 + assert fs._strip_protocol(f"{dir_}/file0.txt") in result + + # Test maxdepth=1 (files in root and level1) + result = fs.find(dir_, maxdepth=1) + assert len(result) == 2 + assert fs._strip_protocol(f"{dir_}/file0.txt") in result + assert fs._strip_protocol(f"{dir_}/level1/file1.txt") in result + + # Test maxdepth=2 (files in root, level1, and level2) + result = fs.find(dir_, maxdepth=2) + assert len(result) == 3 + assert fs._strip_protocol(f"{dir_}/level1/level2/file2.txt") in result + + # Test no maxdepth (all files) + result = fs.find(dir_) + assert len(result) == 4 + + def test_find_withdirs(self, fs): + dir_ = f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/filesystem/test_find_withdirs" + # Create directory structure with files + fs.touch(f"{dir_}/file1.txt") + fs.touch(f"{dir_}/subdir1/file2.txt") + fs.touch(f"{dir_}/subdir1/subdir2/file3.txt") + fs.touch(f"{dir_}/subdir3/file4.txt") + + # Test default behavior (withdirs=False) + result = fs.find(dir_) + assert len(result) == 4 # Only files + for r in result: + assert r.endswith(".txt") + + # Test withdirs=True + result = fs.find(dir_, withdirs=True) + assert len(result) > 4 # Files and directories + + # Verify directories are included + dirs = [r for r in result if not r.endswith(".txt")] + assert len(dirs) > 0 + assert any("subdir1" in d for d in dirs) + assert any("subdir2" in d for d in dirs) + assert any("subdir3" in d for d in dirs) + + # Test withdirs=False explicitly + result = fs.find(dir_, withdirs=False) + assert len(result) == 4 # Only files + def test_du(self): # TODO pass