From 4a9261d4e056b026bce37218cacd26f8bd4259ab Mon Sep 17 00:00:00 2001 From: Boston Walker Date: Wed, 27 Nov 2024 11:33:43 -0500 Subject: [PATCH 1/3] feat: Use csv.DictReader to parse header fields (msto#19) --- dataclass_io/_lib/assertions.py | 5 +++-- dataclass_io/_lib/file.py | 11 ++++++++++- dataclass_io/reader.py | 7 ++++++- dataclass_io/writer.py | 4 ++++ tests/test_reader.py | 27 +++++++++++++++++++++++++++ 5 files changed, 50 insertions(+), 4 deletions(-) diff --git a/dataclass_io/_lib/assertions.py b/dataclass_io/_lib/assertions.py index 36dac4c..fcb8373 100644 --- a/dataclass_io/_lib/assertions.py +++ b/dataclass_io/_lib/assertions.py @@ -100,6 +100,7 @@ def assert_file_header_matches_dataclass( dataclass_type: type[DataclassInstance], delimiter: str, comment_prefix: str, + quoting: int, ) -> None: """ Check that the specified file has a header and its fields match those of the provided dataclass. @@ -107,11 +108,11 @@ def assert_file_header_matches_dataclass( header: FileHeader | None if isinstance(file, Path): with file.open("r") as fin: - header = get_header(fin, delimiter=delimiter, comment_prefix=comment_prefix) + header = get_header(fin, delimiter=delimiter, comment_prefix=comment_prefix, quoting=quoting) else: pos = file.tell() try: - header = get_header(file, delimiter=delimiter, comment_prefix=comment_prefix) + header = get_header(file, delimiter=delimiter, comment_prefix=comment_prefix, quoting=quoting) finally: file.seek(pos) diff --git a/dataclass_io/_lib/file.py b/dataclass_io/_lib/file.py index 1d9ac8c..0a182d2 100644 --- a/dataclass_io/_lib/file.py +++ b/dataclass_io/_lib/file.py @@ -1,3 +1,4 @@ +from csv import DictReader from dataclasses import dataclass from enum import Enum from enum import unique @@ -68,6 +69,7 @@ def get_header( reader: ReadableFileHandle, delimiter: str, comment_prefix: str, + quoting: int, ) -> Optional[FileHeader]: """ Read the header from an open file. @@ -85,6 +87,7 @@ def get_header( Args: reader: An open, readable file handle. comment_char: The character which indicates the start of a comment line. + quoting: Quoting style (enum value from Python csv package). Returns: A `FileHeader` containing the field names and any preceding lines. @@ -103,6 +106,12 @@ def get_header( else: return None - fieldnames = line.strip().split(delimiter) + ''' + msto#19 Read header fields + + Use csv.DictReader because RFC4180 is tricky to implement correctly + ''' + header_reader = DictReader([line], delimiter=delimiter, quoting=quoting) + fieldnames = header_reader.fieldnames return FileHeader(preface=preface, fieldnames=fieldnames) diff --git a/dataclass_io/reader.py b/dataclass_io/reader.py index 44ed599..96f3fcc 100644 --- a/dataclass_io/reader.py +++ b/dataclass_io/reader.py @@ -1,3 +1,4 @@ +import csv from contextlib import contextmanager from csv import DictReader from pathlib import Path @@ -27,6 +28,7 @@ def __init__( dataclass_type: type[DataclassInstance], delimiter: str = "\t", comment_prefix: str = "#", + quoting: int = csv.QUOTE_MINIMAL, **kwds: Any, ) -> None: """ @@ -35,6 +37,7 @@ def __init__( dataclass_type: Dataclass type. delimiter: The input file delimiter. comment_prefix: The prefix for any comment/preface rows preceding the header row. + quoting: Quoting style (enum value from Python csv package). dataclass_type: Dataclass type. Raises: @@ -46,17 +49,19 @@ def __init__( dataclass_type=dataclass_type, delimiter=delimiter, comment_prefix=comment_prefix, + quoting=quoting, ) self._dataclass_type = dataclass_type self._fin = fin self._header = get_header( - reader=self._fin, delimiter=delimiter, comment_prefix=comment_prefix + reader=self._fin, delimiter=delimiter, comment_prefix=comment_prefix, quoting=quoting ) self._reader = DictReader( f=self._fin, fieldnames=fieldnames(dataclass_type), delimiter=delimiter, + quoting=quoting, ) def __iter__(self) -> "DataclassReader": diff --git a/dataclass_io/writer.py b/dataclass_io/writer.py index 1d71b9a..8522cc2 100644 --- a/dataclass_io/writer.py +++ b/dataclass_io/writer.py @@ -1,3 +1,4 @@ +import csv from contextlib import contextmanager from csv import DictWriter from dataclasses import asdict @@ -126,6 +127,7 @@ def open( overwrite: bool = True, delimiter: str = "\t", comment_prefix: str = "#", + quoting: int = csv.QUOTE_MINIMAL, **kwds: Any, ) -> Iterator["DataclassWriter"]: """ @@ -146,6 +148,7 @@ def open( comment_prefix: The prefix for any comment/preface rows preceding the header row. (This argument is ignored when `mode="write"`. It is used when `mode="append"` to validate that the existing file's header matches the specified dataclass.) + quoting: Quoting style (enum value from Python csv package). **kwds: Additional keyword arguments to be passed to the `DataclassWriter` constructor. Yields: @@ -178,6 +181,7 @@ def open( dataclass_type=dataclass_type, delimiter=delimiter, comment_prefix=comment_prefix, + quoting=quoting, ) fout = filepath.open(write_mode.abbreviation) diff --git a/tests/test_reader.py b/tests/test_reader.py index 6e1309e..c7eb482 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -35,3 +35,30 @@ class FakeDataclass: assert isinstance(rows[0], FakeDataclass) assert rows[0].foo == "abc" assert rows[0].bar == 1 + + +def test_read_csv_with_header_quotes(tmp_path: Path) -> None: + """ + Test that having quotes around column names in header row doesn't break anything + https://github.com/msto/dataclass_io/issues/19 + """ + fpath = tmp_path / "test.txt" + + @dataclass + class FakeDataclass: + id: str + title: str + + test_csv = [ + '"id"\t"title"\n', + '"fake"\t"A fake object"\n', + '"also_fake"\t"Another fake object"\n', + ] + + with fpath.open("w") as f: + f.writelines(test_csv) + + # Parse CSV using DataclassReader + with DataclassReader.open(fpath, FakeDataclass) as reader: + for fake_object in reader: + print(fake_object) From cbc2c2a05f9d1f3c4a472edcd3f097baf434f9b6 Mon Sep 17 00:00:00 2001 From: Matt Stone Date: Sun, 8 Dec 2024 20:56:32 -0500 Subject: [PATCH 2/3] feat: pass kwargs through to DictReader and DictWriteR --- dataclass_io/_lib/assertions.py | 8 ++++---- dataclass_io/_lib/file.py | 16 ++++++---------- dataclass_io/reader.py | 15 ++++++++------- dataclass_io/writer.py | 18 ++++++++---------- 4 files changed, 26 insertions(+), 31 deletions(-) diff --git a/dataclass_io/_lib/assertions.py b/dataclass_io/_lib/assertions.py index fcb8373..3891110 100644 --- a/dataclass_io/_lib/assertions.py +++ b/dataclass_io/_lib/assertions.py @@ -4,6 +4,7 @@ from os import access from os import stat from pathlib import Path +from typing import Any from dataclass_io._lib.dataclass_extensions import DataclassInstance from dataclass_io._lib.dataclass_extensions import fieldnames @@ -98,9 +99,8 @@ def assert_file_is_appendable( def assert_file_header_matches_dataclass( file: Path | ReadableFileHandle, dataclass_type: type[DataclassInstance], - delimiter: str, comment_prefix: str, - quoting: int, + **kwargs: Any, ) -> None: """ Check that the specified file has a header and its fields match those of the provided dataclass. @@ -108,11 +108,11 @@ def assert_file_header_matches_dataclass( header: FileHeader | None if isinstance(file, Path): with file.open("r") as fin: - header = get_header(fin, delimiter=delimiter, comment_prefix=comment_prefix, quoting=quoting) + header = get_header(reader=fin, comment_prefix=comment_prefix, **kwargs) else: pos = file.tell() try: - header = get_header(file, delimiter=delimiter, comment_prefix=comment_prefix, quoting=quoting) + header = get_header(reader=file, comment_prefix=comment_prefix, **kwargs) finally: file.seek(pos) diff --git a/dataclass_io/_lib/file.py b/dataclass_io/_lib/file.py index 0a182d2..0d7251a 100644 --- a/dataclass_io/_lib/file.py +++ b/dataclass_io/_lib/file.py @@ -67,9 +67,8 @@ class FileHeader: def get_header( reader: ReadableFileHandle, - delimiter: str, comment_prefix: str, - quoting: int, + **kwargs: Any, ) -> Optional[FileHeader]: """ Read the header from an open file. @@ -87,7 +86,7 @@ def get_header( Args: reader: An open, readable file handle. comment_char: The character which indicates the start of a comment line. - quoting: Quoting style (enum value from Python csv package). + **kwargs: Additional keyword arguments to pass to `csv.DictReader`. Returns: A `FileHeader` containing the field names and any preceding lines. @@ -106,12 +105,9 @@ def get_header( else: return None - ''' - msto#19 Read header fields - - Use csv.DictReader because RFC4180 is tricky to implement correctly - ''' - header_reader = DictReader([line], delimiter=delimiter, quoting=quoting) - fieldnames = header_reader.fieldnames + # msto#19 Read header fields + # Use csv.DictReader because RFC4180 is tricky to implement correctly + header_reader = DictReader([line], **kwargs) + fieldnames = list(header_reader.fieldnames) return FileHeader(preface=preface, fieldnames=fieldnames) diff --git a/dataclass_io/reader.py b/dataclass_io/reader.py index 96f3fcc..d953694 100644 --- a/dataclass_io/reader.py +++ b/dataclass_io/reader.py @@ -1,4 +1,3 @@ -import csv from contextlib import contextmanager from csv import DictReader from pathlib import Path @@ -26,10 +25,9 @@ def __init__( self, fin: ReadableFileHandle, dataclass_type: type[DataclassInstance], - delimiter: str = "\t", comment_prefix: str = "#", - quoting: int = csv.QUOTE_MINIMAL, - **kwds: Any, + delimiter: str = "\t", + **kwargs: Any, ) -> None: """ Args: @@ -49,19 +47,22 @@ def __init__( dataclass_type=dataclass_type, delimiter=delimiter, comment_prefix=comment_prefix, - quoting=quoting, + **kwargs, ) self._dataclass_type = dataclass_type self._fin = fin self._header = get_header( - reader=self._fin, delimiter=delimiter, comment_prefix=comment_prefix, quoting=quoting + reader=self._fin, + delimiter=delimiter, + comment_prefix=comment_prefix, + **kwargs, ) self._reader = DictReader( f=self._fin, fieldnames=fieldnames(dataclass_type), delimiter=delimiter, - quoting=quoting, + **kwargs, ) def __iter__(self) -> "DataclassReader": diff --git a/dataclass_io/writer.py b/dataclass_io/writer.py index 8522cc2..9f76f5f 100644 --- a/dataclass_io/writer.py +++ b/dataclass_io/writer.py @@ -1,4 +1,3 @@ -import csv from contextlib import contextmanager from csv import DictWriter from dataclasses import asdict @@ -32,7 +31,7 @@ def __init__( include_fields: list[str] | None = None, exclude_fields: list[str] | None = None, write_header: bool = True, - **kwds: Any, + **kwargs: Any, ) -> None: """ Args: @@ -66,6 +65,7 @@ def __init__( f=self._fout, fieldnames=self._fieldnames, delimiter=delimiter, + **kwargs, ) # TODO: permit writing comment/preface rows before header @@ -125,10 +125,9 @@ def open( dataclass_type: type[DataclassInstance], mode: str = "write", overwrite: bool = True, - delimiter: str = "\t", comment_prefix: str = "#", - quoting: int = csv.QUOTE_MINIMAL, - **kwds: Any, + delimiter: str = "\t", + **kwargs: Any, ) -> Iterator["DataclassWriter"]: """ Open a new `DataclassWriter` from a file path. @@ -144,12 +143,11 @@ def open( `exclude_fields`. overwrite: If `True`, and `mode="write"`, the file specified at `path` will be overwritten if it exists. - delimiter: The output file delimiter. comment_prefix: The prefix for any comment/preface rows preceding the header row. (This argument is ignored when `mode="write"`. It is used when `mode="append"` to validate that the existing file's header matches the specified dataclass.) - quoting: Quoting style (enum value from Python csv package). - **kwds: Additional keyword arguments to be passed to the `DataclassWriter` constructor. + delimiter: The output file delimiter. + **kwds: Additional keyword arguments to be passed to `csv.DictWriter`. Yields: A `DataclassWriter` instance. @@ -181,7 +179,7 @@ def open( dataclass_type=dataclass_type, delimiter=delimiter, comment_prefix=comment_prefix, - quoting=quoting, + **kwargs, ) fout = filepath.open(write_mode.abbreviation) @@ -190,7 +188,7 @@ def open( fout=fout, dataclass_type=dataclass_type, write_header=(write_mode is WriteMode.WRITE), # Skip header when appending - **kwds, + **kwargs, ) finally: fout.close() From e154fbd3b599b0580c8062dd5449d8569a5fad1e Mon Sep 17 00:00:00 2001 From: Matt Stone Date: Sun, 8 Dec 2024 21:03:03 -0500 Subject: [PATCH 3/3] test: Add test data --- tests/conftest.py | 10 ++++++++++ tests/data/reader_should_parse_quotes.tsv | 3 +++ tests/test_reader.py | 19 ++++++------------- 3 files changed, 19 insertions(+), 13 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/data/reader_should_parse_quotes.tsv diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1ad8db4 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,10 @@ +from pathlib import Path + +import pytest + + +@pytest.fixture(scope="session") +def datadir() -> Path: + """Path to the test data directory.""" + + return Path(__file__).parent / "data" diff --git a/tests/data/reader_should_parse_quotes.tsv b/tests/data/reader_should_parse_quotes.tsv new file mode 100644 index 0000000..acad7e4 --- /dev/null +++ b/tests/data/reader_should_parse_quotes.tsv @@ -0,0 +1,3 @@ +"id" "title" +"fake" "A fake object" +"also_fake" "Another fake object" diff --git a/tests/test_reader.py b/tests/test_reader.py index c7eb482..9d4228f 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -37,28 +37,21 @@ class FakeDataclass: assert rows[0].bar == 1 -def test_read_csv_with_header_quotes(tmp_path: Path) -> None: +def test_reader_should_parse_quotes(datadir: Path) -> None: """ Test that having quotes around column names in header row doesn't break anything https://github.com/msto/dataclass_io/issues/19 """ - fpath = tmp_path / "test.txt" + fpath = datadir / "reader_should_parse_quotes.tsv" @dataclass class FakeDataclass: id: str title: str - test_csv = [ - '"id"\t"title"\n', - '"fake"\t"A fake object"\n', - '"also_fake"\t"Another fake object"\n', - ] - - with fpath.open("w") as f: - f.writelines(test_csv) - # Parse CSV using DataclassReader with DataclassReader.open(fpath, FakeDataclass) as reader: - for fake_object in reader: - print(fake_object) + records = [record for record in reader] + + assert records[0] == FakeDataclass(id="fake", title="A fake object") + assert records[1] == FakeDataclass(id="also_fake", title="Another fake object")