Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ The project supports different cursor implementations for various use cases:
## Development Guidelines

### Code Style and Quality

#### Commands
```bash
# Format code (auto-fix imports and format)
make fmt
Expand All @@ -51,12 +53,67 @@ make tox
make docs
```

#### Docstring Style
Use Google style docstrings for all public methods and complex internal methods:

```python
def method_name(self, param1: str, param2: Optional[int] = None) -> List[str]:
"""Brief description of what the method does.

Longer description if needed, explaining the method's behavior,
edge cases, or important details.

Args:
param1: Description of the first parameter.
param2: Description of the optional parameter.

Returns:
Description of the return value.

Raises:
ValueError: When invalid parameters are provided.
"""
```

### Testing Requirements

#### General Guidelines
1. **Unit Tests**: All new features must include unit tests
2. **Integration Tests**: Test actual AWS Athena interactions when modifying query execution logic
3. **SQLAlchemy Compliance**: Ensure SQLAlchemy dialect tests pass when modifying dialect code
4. **Mock AWS Services**: Use `moto` or similar for testing AWS interactions without real resources

#### Writing Tests
- Place tests in `tests/pyathena/` mirroring the source structure
- Use pytest fixtures for common setup (see `conftest.py`)
- Test both success and error cases
- For filesystem operations, test edge cases like empty results, missing files, etc.

Example test structure:
```python
def test_find_maxdepth(self, fs):
"""Test find with maxdepth parameter."""
# Setup test data
dir_ = f"s3://{ENV.s3_staging_bucket}/test_path"
fs.touch(f"{dir_}/file0.txt")
fs.touch(f"{dir_}/level1/file1.txt")

# Test maxdepth=0
result = fs.find(dir_, maxdepth=0)
assert len(result) == 1
assert fs._strip_protocol(f"{dir_}/file0.txt") in result

# Test edge cases and error conditions
with pytest.raises(ValueError):
fs.find("s3://", maxdepth=0)
```

#### Test Organization
- Group related tests in classes (e.g., `TestS3FileSystem`)
- Use descriptive test names that explain what is being tested
- Keep tests focused and independent
- Clean up test data after each test when using real AWS resources

### Common Development Tasks

#### Adding a New Feature
Expand Down Expand Up @@ -94,6 +151,8 @@ pyathena/
│ └── requirements.py # SQLAlchemy requirements
└── filesystem/ # S3 filesystem abstractions
├── s3.py # S3FileSystem implementation (fsspec compatible)
└── s3_object.py # S3 object representations
```

### Important Implementation Details
Expand All @@ -115,6 +174,21 @@ pyathena/
- Follow DB API 2.0 exception hierarchy
- Provide meaningful error messages that include Athena query IDs when available

#### S3 FileSystem Operations
- `S3FileSystem` implements fsspec's `AbstractFileSystem` interface
- Key methods include `ls()`, `find()`, `get()`, `put()`, `rm()`, etc.
- `find()` method supports:
- `maxdepth`: Limits directory traversal depth (uses recursive approach for efficiency)
- `withdirs`: Controls whether directories are included in results (default: False)
- Cache management uses `(path, delimiter)` as key to handle different listing modes
- Always extract reusable logic into helper methods (e.g., `_extract_parent_directories()`)

When implementing filesystem methods:
1. **Consider s3fs compatibility** - Many users migrate from s3fs, so matching its behavior is important
2. **Optimize for S3's API** - Use delimiter="/" for recursive operations to minimize API calls
3. **Handle edge cases** - Empty paths, trailing slashes, bucket-only paths
4. **Test with real S3** - Mock tests may not catch S3-specific behaviors

### Performance Considerations
1. **Result Caching**: Utilize Athena's result reuse feature (engine v3) when possible
2. **Batch Operations**: Support `executemany()` for bulk operations
Expand Down
208 changes: 157 additions & 51 deletions pyathena/filesystem/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,55 +250,57 @@ def _ls_dirs(
bucket, key, version_id = self.parse_path(path)
if key:
prefix = f"{key}/{prefix if prefix else ''}"
if path not in self.dircache or refresh:
files: List[S3Object] = []
while True:
request: Dict[Any, Any] = {
"Bucket": bucket,
"Prefix": prefix,
"Delimiter": delimiter,
}
if next_token:
request.update({"ContinuationToken": next_token})
if max_keys:
request.update({"MaxKeys": max_keys})
response = self._call(
self._client.list_objects_v2,
**request,
)
files.extend(
S3Object(
init={
"ContentLength": 0,
"ContentType": None,
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY,
"ETag": None,
"LastModified": None,
},
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
bucket=bucket,
key=c["Prefix"][:-1].rstrip("/"),
version_id=version_id,
)
for c in response.get("CommonPrefixes", [])

# Create a cache key that includes the delimiter
cache_key = (path, delimiter)
if cache_key in self.dircache and not refresh:
return cast(List[S3Object], self.dircache[cache_key])

files: List[S3Object] = []
while True:
request: Dict[Any, Any] = {
"Bucket": bucket,
"Prefix": prefix,
"Delimiter": delimiter,
}
if next_token:
request.update({"ContinuationToken": next_token})
if max_keys:
request.update({"MaxKeys": max_keys})
response = self._call(
self._client.list_objects_v2,
**request,
)
files.extend(
S3Object(
init={
"ContentLength": 0,
"ContentType": None,
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY,
"ETag": None,
"LastModified": None,
},
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
bucket=bucket,
key=c["Prefix"][:-1].rstrip("/"),
version_id=version_id,
)
files.extend(
S3Object(
init=c,
type=S3ObjectType.S3_OBJECT_TYPE_FILE,
bucket=bucket,
key=c["Key"],
)
for c in response.get("Contents", [])
for c in response.get("CommonPrefixes", [])
)
files.extend(
S3Object(
init=c,
type=S3ObjectType.S3_OBJECT_TYPE_FILE,
bucket=bucket,
key=c["Key"],
)
next_token = response.get("NextContinuationToken")
if not next_token:
break
if files:
self.dircache[path] = files
else:
cache = self.dircache[path]
files = cache if isinstance(cache, list) else [cache]
for c in response.get("Contents", [])
)
next_token = response.get("NextContinuationToken")
if not next_token:
break
if files:
self.dircache[cache_key] = files
return files

def ls(
Expand Down Expand Up @@ -396,27 +398,131 @@ def info(self, path: str, **kwargs) -> S3Object:
)
raise FileNotFoundError(path)

def find(
def _extract_parent_directories(
self, files: List[S3Object], bucket: str, base_key: Optional[str]
) -> List[S3Object]:
"""Extract parent directory objects from file paths.

When listing files without delimiter, S3 doesn't return directory entries.
This method creates directory objects by analyzing file paths.

Args:
files: List of S3Object instances representing files.
bucket: S3 bucket name.
base_key: Base key path to calculate relative paths from.

Returns:
List of S3Object instances representing directories.
"""
dirs = set()
base_key = base_key.rstrip("/") if base_key else ""

for f in files:
if f.key and f.type == S3ObjectType.S3_OBJECT_TYPE_FILE:
# Extract directory paths from file paths
f_key = f.key
if base_key and f_key.startswith(base_key + "/"):
relative_path = f_key[len(base_key) + 1 :]
elif not base_key:
relative_path = f_key
else:
continue

# Get all parent directories
parts = relative_path.split("/")
for i in range(1, len(parts)):
if base_key:
dir_path = base_key + "/" + "/".join(parts[:i])
else:
dir_path = "/".join(parts[:i])
dirs.add(dir_path)

# Create S3Object instances for directories
directory_objects = []
for dir_path in dirs:
dir_obj = S3Object(
init={
"ContentLength": 0,
"ContentType": None,
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY,
"ETag": None,
"LastModified": None,
},
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
bucket=bucket,
key=dir_path,
version_id=None,
)
directory_objects.append(dir_obj)

return directory_objects

def _find(
self,
path: str,
maxdepth: Optional[int] = None,
withdirs: Optional[bool] = None,
detail: bool = False,
**kwargs,
) -> Union[Dict[str, S3Object], List[str]]:
# TODO: Support maxdepth and withdirs
) -> List[S3Object]:
path = self._strip_protocol(path)
if path in ["", "/"]:
raise ValueError("Cannot traverse all files in S3.")
bucket, key, _ = self.parse_path(path)
prefix = kwargs.pop("prefix", "")

# When maxdepth is specified, use a recursive approach with delimiter
if maxdepth is not None:
result: List[S3Object] = []

# List files and directories at current level
current_items = self._ls_dirs(path, prefix=prefix, delimiter="/")

for item in current_items:
if item.type == S3ObjectType.S3_OBJECT_TYPE_FILE:
# Add files
result.append(item)
elif item.type == S3ObjectType.S3_OBJECT_TYPE_DIRECTORY:
# Add directory if withdirs is True
if withdirs:
result.append(item)

# Recursively explore subdirectory if depth allows
if maxdepth > 0:
sub_path = f"s3://{bucket}/{item.key}"
sub_results = self._find(
sub_path, maxdepth=maxdepth - 1, withdirs=withdirs, **kwargs
)
result.extend(sub_results)

return result

# For unlimited depth, use the original approach (get all files at once)
files = self._ls_dirs(path, prefix=prefix, delimiter="")
if not files and key:
try:
files = [self.info(path)]
except FileNotFoundError:
files = []

# If withdirs is True, we need to derive directories from file paths
if withdirs:
files.extend(self._extract_parent_directories(files, bucket, key))

# Filter directories if withdirs is False (default)
if withdirs is False or withdirs is None:
files = [f for f in files if f.type != S3ObjectType.S3_OBJECT_TYPE_DIRECTORY]

return files

def find(
self,
path: str,
maxdepth: Optional[int] = None,
withdirs: Optional[bool] = None,
detail: bool = False,
**kwargs,
) -> Union[Dict[str, S3Object], List[str]]:
files = self._find(path=path, maxdepth=maxdepth, withdirs=withdirs, **kwargs)
if detail:
return {f.name: f for f in files}
return [f.name for f in files]
Expand Down
Loading