diff --git a/README.md b/README.md index a0abc6a..cd4758c 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,8 @@ The Python Shapefile Library (PyShp) reads and writes ESRI Shapefiles in pure Py - **Author**: [Joel Lawhead](https://github.com/GeospatialPython) - **Maintainers**: [Karim Bahgat](https://github.com/karimbahgat) -- **Version**: 2.3.1 -- **Date**: 28 July, 2022 +- **Version**: 3.0.0-alpha +- **Date**: 31 July, 2025 - **License**: [MIT](https://github.com/GeospatialPython/pyshp/blob/master/LICENSE.TXT) ## Contents @@ -93,6 +93,30 @@ part of your geospatial project. # Version Changes +## 3.0.0-alpha + +### Breaking Changes: +- Python 2 and Python 3.8 support dropped. +- Field info tuple is now a namedtuple (Field) instead of a list. +- Field type codes are now FieldType enum members. +- bbox, mbox and zbox attributes are all new Namedtuples. +- Writer does not mutate shapes. +- New custom subclasses for each shape type: Null, Multipatch, Point, Polyline, + Multipoint, and Polygon, plus the latter 4's M and Z variants (Reader and + Writer are still compatible with their base class, Shape, as before). +- Shape sub classes are creatable from, and serializable to bytes streams, + as per the shapefile spec. + +### Code quality +- Statically typed, and checked with Mypy +- Checked with Ruff. +- f-strings +- Remove Python 2 specific functions. +- Run doctests against wheels. +- Testing of wheels before publishing them +- pyproject.toml src layout +- Slow test marked. + ## 2.4.0 ### Breaking Change. Support for Python 2 and Pythons <= 3.8 to be dropped. @@ -406,7 +430,7 @@ and the bounding box area the shapefile covers: >>> len(sf) 663 >>> sf.bbox - (-122.515048, 37.652916, -122.327622, 37.863433) + BBox(xmin=-122.515048, ymin=37.652916, xmax=-122.327622, ymax=37.863433) Finally, if you would prefer to work with the entire shapefile in a different format, you can convert all of it to a GeoJSON dictionary, although you may lose @@ -553,45 +577,34 @@ in the shp geometry file and the dbf attribute file. The field names of a shapefile are available as soon as you read a shapefile. You can call the "fields" attribute of the shapefile as a Python list. Each -field is a Python list with the following information: +field is a Python namedtuple (Field) with the following information: - * Field name: the name describing the data at this column index. - * Field type: the type of data at this column index. Types can be: + * name: the name describing the data at this column index (a string). + * field_type: a FieldType enum member determining the type of data at this column index. Names can be: * "C": Characters, text. * "N": Numbers, with or without decimals. * "F": Floats (same as "N"). * "L": Logical, for boolean True/False values. * "D": Dates. * "M": Memo, has no meaning within a GIS and is part of the xbase spec instead. - * Field length: the length of the data found at this column index. Older GIS + * size: Field length: the length of the data found at this column index. Older GIS software may truncate this length to 8 or 11 characters for "Character" fields. - * Decimal length: the number of decimal places found in "Number" fields. + * deci: Decimal length. The number of decimal places found in "Number" fields. + +A new field can be created directly from the type enum member etc., or as follows: + + >>> shapefile.Field.from_unchecked("Population", "N", 10,0) + Field(name="Population", field_type=FieldType.N, size=10, decimal=0) + +Using this method the conversion from string to enum is done automatically. To see the fields for the Reader object above (sf) call the "fields" attribute: - >>> fields = sf.fields - - >>> assert fields == [("DeletionFlag", "C", 1, 0), ["AREA", "N", 18, 5], - ... ["BKG_KEY", "C", 12, 0], ["POP1990", "N", 9, 0], ["POP90_SQMI", "N", 10, 1], - ... ["HOUSEHOLDS", "N", 9, 0], - ... ["MALES", "N", 9, 0], ["FEMALES", "N", 9, 0], ["WHITE", "N", 9, 0], - ... ["BLACK", "N", 8, 0], ["AMERI_ES", "N", 7, 0], ["ASIAN_PI", "N", 8, 0], - ... ["OTHER", "N", 8, 0], ["HISPANIC", "N", 8, 0], ["AGE_UNDER5", "N", 8, 0], - ... ["AGE_5_17", "N", 8, 0], ["AGE_18_29", "N", 8, 0], ["AGE_30_49", "N", 8, 0], - ... ["AGE_50_64", "N", 8, 0], ["AGE_65_UP", "N", 8, 0], - ... ["NEVERMARRY", "N", 8, 0], ["MARRIED", "N", 9, 0], ["SEPARATED", "N", 7, 0], - ... ["WIDOWED", "N", 8, 0], ["DIVORCED", "N", 8, 0], ["HSEHLD_1_M", "N", 8, 0], - ... ["HSEHLD_1_F", "N", 8, 0], ["MARHH_CHD", "N", 8, 0], - ... ["MARHH_NO_C", "N", 8, 0], ["MHH_CHILD", "N", 7, 0], - ... ["FHH_CHILD", "N", 7, 0], ["HSE_UNITS", "N", 9, 0], ["VACANT", "N", 7, 0], - ... ["OWNER_OCC", "N", 8, 0], ["RENTER_OCC", "N", 8, 0], - ... ["MEDIAN_VAL", "N", 7, 0], ["MEDIANRENT", "N", 4, 0], - ... ["UNITS_1DET", "N", 8, 0], ["UNITS_1ATT", "N", 7, 0], ["UNITS2", "N", 7, 0], - ... ["UNITS3_9", "N", 8, 0], ["UNITS10_49", "N", 8, 0], - ... ["UNITS50_UP", "N", 8, 0], ["MOBILEHOME", "N", 7, 0]] + >>> sf.fields + [Field(name="DeletionFlag", field_type=FieldType.C, size=1, decimal=0), Field(name="AREA", field_type=FieldType.N, size=18, decimal=5), Field(name="BKG_KEY", field_type=FieldType.C, size=12, decimal=0), Field(name="POP1990", field_type=FieldType.N, size=9, decimal=0), Field(name="POP90_SQMI", field_type=FieldType.N, size=10, decimal=1), Field(name="HOUSEHOLDS", field_type=FieldType.N, size=9, decimal=0), Field(name="MALES", field_type=FieldType.N, size=9, decimal=0), Field(name="FEMALES", field_type=FieldType.N, size=9, decimal=0), Field(name="WHITE", field_type=FieldType.N, size=9, decimal=0), Field(name="BLACK", field_type=FieldType.N, size=8, decimal=0), Field(name="AMERI_ES", field_type=FieldType.N, size=7, decimal=0), Field(name="ASIAN_PI", field_type=FieldType.N, size=8, decimal=0), Field(name="OTHER", field_type=FieldType.N, size=8, decimal=0), Field(name="HISPANIC", field_type=FieldType.N, size=8, decimal=0), Field(name="AGE_UNDER5", field_type=FieldType.N, size=8, decimal=0), Field(name="AGE_5_17", field_type=FieldType.N, size=8, decimal=0), Field(name="AGE_18_29", field_type=FieldType.N, size=8, decimal=0), Field(name="AGE_30_49", field_type=FieldType.N, size=8, decimal=0), Field(name="AGE_50_64", field_type=FieldType.N, size=8, decimal=0), Field(name="AGE_65_UP", field_type=FieldType.N, size=8, decimal=0), Field(name="NEVERMARRY", field_type=FieldType.N, size=8, decimal=0), Field(name="MARRIED", field_type=FieldType.N, size=9, decimal=0), Field(name="SEPARATED", field_type=FieldType.N, size=7, decimal=0), Field(name="WIDOWED", field_type=FieldType.N, size=8, decimal=0), Field(name="DIVORCED", field_type=FieldType.N, size=8, decimal=0), Field(name="HSEHLD_1_M", field_type=FieldType.N, size=8, decimal=0), Field(name="HSEHLD_1_F", field_type=FieldType.N, size=8, decimal=0), Field(name="MARHH_CHD", field_type=FieldType.N, size=8, decimal=0), Field(name="MARHH_NO_C", field_type=FieldType.N, size=8, decimal=0), Field(name="MHH_CHILD", field_type=FieldType.N, size=7, decimal=0), Field(name="FHH_CHILD", field_type=FieldType.N, size=7, decimal=0), Field(name="HSE_UNITS", field_type=FieldType.N, size=9, decimal=0), Field(name="VACANT", field_type=FieldType.N, size=7, decimal=0), Field(name="OWNER_OCC", field_type=FieldType.N, size=8, decimal=0), Field(name="RENTER_OCC", field_type=FieldType.N, size=8, decimal=0), Field(name="MEDIAN_VAL", field_type=FieldType.N, size=7, decimal=0), Field(name="MEDIANRENT", field_type=FieldType.N, size=4, decimal=0), Field(name="UNITS_1DET", field_type=FieldType.N, size=8, decimal=0), Field(name="UNITS_1ATT", field_type=FieldType.N, size=7, decimal=0), Field(name="UNITS2", field_type=FieldType.N, size=7, decimal=0), Field(name="UNITS3_9", field_type=FieldType.N, size=8, decimal=0), Field(name="UNITS10_49", field_type=FieldType.N, size=8, decimal=0), Field(name="UNITS50_UP", field_type=FieldType.N, size=8, decimal=0), Field(name="MOBILEHOME", field_type=FieldType.N, size=7, decimal=0)] The first field of a dbf file is always a 1-byte field called "DeletionFlag", which indicates records that have been deleted but not removed. However, @@ -919,8 +932,8 @@ You can also add attributes using keyword arguments where the keys are field nam >>> w = shapefile.Writer('shapefiles/test/dtype') - >>> w.field('FIRST_FLD','C','40') - >>> w.field('SECOND_FLD','C','40') + >>> w.field('FIRST_FLD','C', 40) + >>> w.field('SECOND_FLD','C', 40) >>> w.null() >>> w.null() >>> w.record('First', 'Line') @@ -1375,7 +1388,7 @@ Shapefiles containing M-values can be examined in several ways: >>> r = shapefile.Reader('shapefiles/test/linem') >>> r.mbox # the lower and upper bound of M-values in the shapefile - [0.0, 3.0] + MBox(mmin=0.0, mmax=3.0) >>> r.shape(0).m # flat list of M-values [0.0, None, 3.0, None, 0.0, None, None] @@ -1408,7 +1421,7 @@ To examine a Z-type shapefile you can do: >>> r = shapefile.Reader('shapefiles/test/linez') >>> r.zbox # the lower and upper bound of Z-values in the shapefile - [0.0, 22.0] + ZBox(zmin=0.0, zmax=22.0) >>> r.shape(0).z # flat list of Z-values [18.0, 20.0, 22.0, 0.0, 0.0, 0.0, 0.0, 15.0, 13.0, 14.0] diff --git a/changelog.txt b/changelog.txt index 48a534a..45bfd76 100644 --- a/changelog.txt +++ b/changelog.txt @@ -1,9 +1,20 @@ VERSION 3.0.0-alpha -Python 2 and Python 3.8 support dropped + Breaking Changes: + * Python 2 and Python 3.8 support dropped. + * Field info tuple is now a namedtuple (Field) instead of a list. + * Field type codes are now FieldType enum members. + * bbox, mbox and zbox attributes are all new Namedtuples. + * Writer does not mutate shapes. + * New custom subclasses for each shape type: Null, Multipatch, Point, Polyline, + Multipoint, and Polygon, plus the latter 4's M and Z variants (Reader and + Writer are still compatible with their base class, Shape, as before). + * Shape sub classes are creatable from, and serializable to bytes streams, + as per the shapefile spec. -2025-07-22 Code quality + * Statically typed and checked with Mypy + * Checked with Ruff. * Type hints * f-strings * Remove Python 2 specific functions. diff --git a/pyproject.toml b/pyproject.toml index d3e0e89..ca8f667 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,9 @@ classifiers = [ "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", ] +dependencies = [ + "typing_extensions", +] [project.optional-dependencies] test = ["pytest"] diff --git a/src/shapefile.py b/src/shapefile.py index d0ec177..f002bdd 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -12,6 +12,7 @@ import array import doctest +import enum import io import logging import os @@ -29,7 +30,7 @@ Iterable, Iterator, Literal, - NoReturn, + NamedTuple, Optional, Protocol, Reversible, @@ -43,6 +44,8 @@ from urllib.parse import urlparse, urlunparse from urllib.request import Request, urlopen +from typing_extensions import Never, NotRequired, Self, TypeIs + # Create named logger logger = logging.getLogger(__name__) @@ -119,26 +122,112 @@ PointT = Union[Point2D, PointMT, PointZT] PointsT = list[PointT] -BBox = tuple[float, float, float, float] + +class BBox(NamedTuple): + xmin: float + ymin: float + xmax: float + ymax: float + + +class MBox(NamedTuple): + mmin: Optional[float] + mmax: Optional[float] + + +class ZBox(NamedTuple): + zmin: float + zmax: float + + +class WriteableBinStream(Protocol): + def write(self, b: bytes): ... # pylint: disable=redefined-outer-name + + +class ReadableBinStream(Protocol): + def read(self, size: int = -1): ... -class BinaryWritable(Protocol): - def write(self, data: bytes): ... +class WriteSeekableBinStream(Protocol): + def write(self, b: bytes): ... # pylint: disable=redefined-outer-name + def seek(self, offset: int, whence: int = 0): ... # pylint: disable=unused-argument + def tell(self): ... -class BinaryWritableSeekable(BinaryWritable): - def seek(self, i: int): ... # pylint: disable=unused-argument +class ReadSeekableBinStream(Protocol): + def seek(self, offset: int, whence: int = 0): ... # pylint: disable=unused-argument def tell(self): ... + def read(self, size: int = -1): ... + + +class ReadWriteSeekableBinStream(Protocol): + def write(self, b: bytes): ... # pylint: disable=redefined-outer-name + def seek(self, offset: int, whence: int = 0): ... # pylint: disable=unused-argument + def tell(self): ... + def read(self, size: int = -1): ... # File name, file object or anything with a read() method that returns bytes. BinaryFileT = Union[str, IO[bytes]] -BinaryFileStreamT = Union[IO[bytes], io.BytesIO, BinaryWritableSeekable] +BinaryFileStreamT = Union[IO[bytes], io.BytesIO, WriteSeekableBinStream] + + +# https://en.wikipedia.org/wiki/.dbf#Database_records +class FieldType(enum.Enum): + # Use an ascii-encoded byte of the name, to save a decoding step. + C = "Character" # (str) + D = "Date" + F = "Floating point" + L = "Logical" # (bool) + M = "Memo" # Legacy. (10 digit str, starting block in an .dbt file) + N = "Numeric" # (int) + + +# Use functional syntax to have an attribute named type, a Python keyword +class Field(NamedTuple): + name: str + field_type: FieldType + size: int + decimal: int + + @classmethod + def from_unchecked( + cls, + name: str, + field_type: Union[str, FieldType] = FieldType.C, + size: int = 50, + decimal: int = 0, + ) -> Self: + if isinstance(field_type, str): + if field_type.upper() in FieldType.__members__: + field_type = FieldType[field_type.upper()] + else: + raise ShapefileException( + "type must be C,D,F,L,M,N, or a FieldType enum member. " + f"Got: {field_type=}. " + ) + + if field_type is FieldType.D: + size = 8 + decimal = 0 + elif field_type is FieldType.L: + size = 1 + decimal = 0 + + # A doctest in README.md previously passed in a string ('40') for size, + # so explictly convert name to str, and size and decimal to ints. + return cls( + name=str(name), field_type=field_type, size=int(size), decimal=int(decimal) + ) + + def __repr__(self) -> str: + return f'Field(name="{self.name}", field_type=FieldType.{self.field_type.name}, size={self.size}, decimal={self.decimal})' + -FieldTuple = tuple[str, str, int, int] -RecordValue = Union[ - bool, int, float, str, date -] # A Possible value in a Shapefile record, e.g. L, N, F, C, D types +RecordValueNotDate = Union[bool, int, float, str, date] + +# A Possible value in a Shapefile dbf record, i.e. L, N, M, F, C, or D types +RecordValue = Union[RecordValueNotDate, date] class HasGeoInterface(Protocol): @@ -227,59 +316,23 @@ class GeoJSONFeatureCollection(TypedDict): features: list[GeoJSONFeature] -class GeoJSONFeatureCollectionWithBBox(GeoJSONFeatureCollection, total=False): +class GeoJSONFeatureCollectionWithBBox(GeoJSONFeatureCollection): # bbox is optional # typing.NotRequired requires Python 3.11 # and we must support 3.9 (at least until October) # https://docs.python.org/3/library/typing.html#typing.Required # Is there a backport? - bbox: list[float] + bbox: NotRequired[list[float]] # Helpers -MISSING = [None, ""] +MISSING = (None, "") # Don't make a set, as user input may not be Hashable NODATA = -10e38 # as per the ESRI shapefile spec, only used for m-values. unpack_2_int32_be = Struct(">2i").unpack -def b( - v: Union[str, bytes], encoding: str = "utf-8", encodingErrors: str = "strict" -) -> bytes: - if isinstance(v, str): - # For python 3 encode str to bytes. - return v.encode(encoding, encodingErrors) - if isinstance(v, bytes): - # Already bytes. - return v - if v is None: - # Since we're dealing with text, interpret None as "" - return b"" - # Force string representation. - return str(v).encode(encoding, encodingErrors) - - -def u( - v: Union[str, bytes], encoding: str = "utf-8", encodingErrors: str = "strict" -) -> str: - if isinstance(v, bytes): - # For python 3 decode bytes to str. - return v.decode(encoding, encodingErrors) - if isinstance(v, str): - # Already str. - return v - if v is None: - # Since we're dealing with text, interpret None as "" - return "" - # Force string representation. - return bytes(v).decode(encoding, encodingErrors) - - -def is_string(v: Any) -> bool: - return isinstance(v, str) - - @overload def fsdecode_if_pathlike(path: os.PathLike) -> str: ... @overload @@ -337,7 +390,7 @@ def rewind(coords: Reversible[PointT]) -> PointsT: def ring_bbox(coords: PointsT) -> BBox: """Calculates and returns the bounding box of a ring.""" xs, ys = map(list, list(zip(*coords))[:2]) # ignore any z or m values - bbox = min(xs), min(ys), max(xs), max(ys) + bbox = BBox(xmin=min(xs), ymin=min(ys), xmax=max(xs), ymax=max(ys)) return bbox @@ -603,6 +656,24 @@ class _NoShapeTypeSentinel: class Shape: shapeType: int = NULL + _shapeTypes = frozenset( + [ + NULL, + POINT, + POINTM, + POINTZ, + POLYLINE, + POLYLINEM, + POLYLINEZ, + POLYGON, + POLYGONM, + POLYGONZ, + MULTIPOINT, + MULTIPOINTM, + MULTIPOINTZ, + MULTIPATCH, + ] + ) def __init__( self, @@ -838,19 +909,43 @@ def __repr__(self): return f"Shape #{self.__oid}: {self.shapeTypeName}" +S = TypeVar("S", bound=Shape) + + +def compatible_with(s: Shape, cls: type[S]) -> TypeIs[S]: + return s.shapeType in cls._shapeTypes + + +# pylint: disable=unused-argument +# Need unused arguments to keep the same call signature for +# different implementations of from_byte_stream and write_to_byte_stream class NullShape(Shape): # Shape.shapeType = NULL already, # to preserve handling of default args in Shape.__init__ # Repeated for clarity. shapeType = NULL + _shapeTypes = frozenset([NULL]) @classmethod - def from_byte_stream(cls, b_io, next_shape, oid=None, bbox=None): # pylint: disable=unused-argument + def from_byte_stream( + cls, + b_io: ReadSeekableBinStream, + next_shape: int, + oid: Optional[int] = None, + bbox: Optional[BBox] = None, + ) -> Self: # Shape.__init__ sets self.points = points or [] return cls(oid=oid) @staticmethod - def write_to_byte_stream(b_io, s, i, bbox, mbox, zbox): # pylint: disable=unused-argument + def write_to_byte_stream( + b_io: WriteableBinStream, + s: Shape, + i: int, + bbox: Optional[BBox], + mbox: Optional[MBox], + zbox: Optional[ZBox], + ) -> int: return 0 @@ -875,14 +970,18 @@ class _CanHaveBBox(Shape): ] ) - # Not a BBox because the legacy implementation was a list, not a 4-tuple. - bbox: Optional[Sequence[float]] = None + bbox: Optional[BBox] = None - def _set_bbox_from_byte_stream(self, b_io): - self.bbox = _Array[float]("d", unpack("<4d", b_io.read(32))) + def _get_set_bbox_from_byte_stream(self, b_io: ReadableBinStream) -> BBox: + self.bbox: BBox = BBox(*_Array[float]("d", unpack("<4d", b_io.read(32)))) + return self.bbox @staticmethod - def _write_bbox_to_byte_stream(b_io, i, bbox): + def _write_bbox_to_byte_stream( + b_io: WriteableBinStream, i: int, bbox: Optional[BBox] + ) -> int: + if not bbox or len(bbox) != 4: + raise ShapefileException(f"Four numbers required for bbox. Got: {bbox}") try: return b_io.write(pack("<4d", *bbox)) except error: @@ -891,20 +990,22 @@ def _write_bbox_to_byte_stream(b_io, i, bbox): ) @staticmethod - def _get_npoints_from_byte_stream(b_io): + def _get_npoints_from_byte_stream(b_io: ReadableBinStream) -> int: return unpack(" int: return b_io.write(pack(" int: + x_ys: list[float] = [] for point in s.points: x_ys.extend(point[:2]) try: @@ -914,33 +1015,38 @@ def _write_points_to_byte_stream(b_io, s, i): f"Failed to write points for record {i}. Expected floats." ) - # pylint: disable=unused-argument @staticmethod - def _get_nparts_from_byte_stream(b_io): - return None + def _get_nparts_from_byte_stream(b_io: ReadableBinStream) -> int: + return 0 - def _set_parts_from_byte_stream(self, b_io, nParts): + def _set_parts_from_byte_stream(self, b_io: ReadableBinStream, nParts: int): pass - def _set_part_types_from_byte_stream(self, b_io, nParts): + def _set_part_types_from_byte_stream(self, b_io: ReadableBinStream, nParts: int): pass - def _set_zs_from_byte_stream(self, b_io, nPoints): + def _set_zs_from_byte_stream(self, b_io: ReadableBinStream, nPoints: int): pass - def _set_ms_from_byte_stream(self, b_io, nPoints, next_shape): + def _set_ms_from_byte_stream( + self, b_io: ReadSeekableBinStream, nPoints: int, next_shape: int + ): pass - # pylint: enable=unused-argument - @classmethod - def from_byte_stream(cls, b_io, next_shape, oid=None, bbox=None): + def from_byte_stream( + cls, + b_io: ReadSeekableBinStream, + next_shape: int, + oid: Optional[int] = None, + bbox: Optional[BBox] = None, + ) -> Optional[Self]: shape = cls(oid=oid) - shape._set_bbox_from_byte_stream(b_io) # pylint: disable=assignment-from-none + shape_bbox = shape._get_set_bbox_from_byte_stream(b_io) # if bbox specified and no overlap, skip this shape - if bbox is not None and not bbox_overlap(bbox, tuple(shape.bbox)): # pylint: disable=no-member + if bbox is not None and not bbox_overlap(bbox, shape_bbox): # because we stop parsing this shape, caller must skip to beginning of # next shape after we return (as done in f.seek(next_shape)) return None @@ -963,33 +1069,43 @@ def from_byte_stream(cls, b_io, next_shape, oid=None, bbox=None): return shape @staticmethod - def write_to_byte_stream(b_io, s, i, bbox, mbox, zbox): + def write_to_byte_stream( + b_io: WriteableBinStream, + s: Shape, + i: int, + bbox: Optional[BBox], + mbox: Optional[MBox], + zbox: Optional[ZBox], + ) -> int: # We use static methods here and below, # to support s only being an instance of a the # Shape base class (with shapeType set) # i.e. not necessarily one of our newer shape specific # sub classes. - n = _CanHaveBBox._write_bbox_to_byte_stream(b_io, i, bbox) + n = 0 + + if compatible_with(s, _CanHaveBBox): + n += _CanHaveBBox._write_bbox_to_byte_stream(b_io, i, bbox) - if s.shapeType in _CanHaveParts._shapeTypes: + if compatible_with(s, _CanHaveParts): n += _CanHaveParts._write_nparts_to_byte_stream(b_io, s) # Shape types with multiple points per record - if s.shapeType in _CanHaveBBox._shapeTypes: + if compatible_with(s, _CanHaveBBox): n += _CanHaveBBox._write_npoints_to_byte_stream(b_io, s) # Write part indexes. Includes MultiPatch - if s.shapeType in _CanHaveParts._shapeTypes: + if compatible_with(s, _CanHaveParts): n += _CanHaveParts._write_part_indices_to_byte_stream(b_io, s) - if s.shapeType == MULTIPATCH: + if compatible_with(s, MultiPatch): n += MultiPatch._write_part_types_to_byte_stream(b_io, s) # Write points for multiple-point records - if s.shapeType in _CanHaveBBox._shapeTypes: + if compatible_with(s, _CanHaveBBox): n += _CanHaveBBox._write_points_to_byte_stream(b_io, s, i) - if s.shapeType in _HasZ._shapeTypes: + if compatible_with(s, _HasZ): n += _HasZ._write_zs_to_byte_stream(b_io, s, i, zbox) - if s.shapeType in _HasM._shapeTypes: + if compatible_with(s, _HasM): n += _HasM._write_ms_to_byte_stream(b_io, s, i, mbox) return n @@ -1012,18 +1128,20 @@ class _CanHaveParts(_CanHaveBBox): ) @staticmethod - def _get_nparts_from_byte_stream(b_io): + def _get_nparts_from_byte_stream(b_io: ReadableBinStream) -> int: return unpack(" int: return b_io.write(pack(" int: return b_io.write(pack(f"<{len(s.parts)}i", *s.parts)) @@ -1034,21 +1152,25 @@ class Point(Shape): shapeType = POINT _shapeTypes = frozenset([POINT, POINTM, POINTZ]) - def _set_single_point_z_from_byte_stream(self, b_io): + def _set_single_point_z_from_byte_stream(self, b_io: ReadableBinStream): pass - def _set_single_point_m_from_byte_stream(self, b_io, next_shape): + def _set_single_point_m_from_byte_stream( + self, b_io: ReadSeekableBinStream, next_shape: int + ): pass @staticmethod - def _x_y_from_byte_stream(b_io): + def _x_y_from_byte_stream(b_io: ReadableBinStream): # Unpack _Array too x, y = _Array[float]("d", unpack("<2d", b_io.read(16))) # Convert to tuple return x, y @staticmethod - def _write_x_y_to_byte_stream(b_io, x, y, i): + def _write_x_y_to_byte_stream( + b_io: WriteableBinStream, x: float, y: float, i: int + ) -> int: try: return b_io.write(pack("<2d", x, y)) except error: @@ -1057,7 +1179,13 @@ def _write_x_y_to_byte_stream(b_io, x, y, i): ) @classmethod - def from_byte_stream(cls, b_io, next_shape, oid=None, bbox=None): + def from_byte_stream( + cls, + b_io: ReadSeekableBinStream, + next_shape: int, + oid: Optional[int] = None, + bbox: Optional[BBox] = None, + ) -> Optional[Self]: shape = cls(oid=oid) x, y = cls._x_y_from_byte_stream(b_io) @@ -1065,7 +1193,7 @@ def from_byte_stream(cls, b_io, next_shape, oid=None, bbox=None): if bbox is not None: # create bounding box for Point by duplicating coordinates # skip shape if no overlap with bounding box - if not bbox_overlap(bbox, (x, y, x, y)): + if not bbox_overlap(bbox, BBox(x, y, x, y)): return None shape.points = [(x, y)] @@ -1077,7 +1205,14 @@ def from_byte_stream(cls, b_io, next_shape, oid=None, bbox=None): return shape @staticmethod - def write_to_byte_stream(b_io, s, i, bbox, mbox, zbox): # pylint: disable=unused-argument + def write_to_byte_stream( + b_io: WriteableBinStream, + s: Shape, + i: int, + bbox: Optional[BBox], + mbox: Optional[MBox], + zbox: Optional[ZBox], + ) -> int: # Serialize a single point x, y = s.points[0][0], s.points[0][1] n = Point._write_x_y_to_byte_stream(b_io, x, y, i) @@ -1093,16 +1228,22 @@ def write_to_byte_stream(b_io, s, i, bbox, mbox, zbox): # pylint: disable=unuse return n +# pylint: enable=unused-argument + + class Polyline(_CanHaveParts): shapeType = POLYLINE + _shapeTypes = frozenset([POLYLINE, POLYLINEM, POLYLINEZ]) class Polygon(_CanHaveParts): shapeType = POLYGON + _shapeTypes = frozenset([POLYGON, POLYGONM, POLYGONZ]) class MultiPoint(_CanHaveBBox): shapeType = MULTIPOINT + _shapeTypes = frozenset([MULTIPOINT, MULTIPOINTM, MULTIPOINTZ]) class _HasM(_CanHaveBBox): @@ -1120,7 +1261,9 @@ class _HasM(_CanHaveBBox): ) m: Sequence[Optional[float]] - def _set_ms_from_byte_stream(self, b_io, nPoints, next_shape): + def _set_ms_from_byte_stream( + self, b_io: ReadSeekableBinStream, nPoints: int, next_shape: int + ): if next_shape - b_io.tell() >= 16: __mmin, __mmax = unpack("<2d", b_io.read(16)) # Measure values less than -10e38 are nodata values according to the spec @@ -1135,7 +1278,11 @@ def _set_ms_from_byte_stream(self, b_io, nPoints, next_shape): self.m = [None for _ in range(nPoints)] @staticmethod - def _write_ms_to_byte_stream(b_io, s, i, mbox): + def _write_ms_to_byte_stream( + b_io: WriteableBinStream, s: Shape, i: int, mbox: Optional[MBox] + ) -> int: + if not mbox or len(mbox) != 2: + raise ShapefileException(f"Two numbers required for mbox. Got: {mbox}") # Write m extremes and values # When reading a file, pyshp converts NODATA m values to None, so here we make sure to convert them back to NODATA # Note: missing m values are autoset to NODATA. @@ -1183,12 +1330,17 @@ class _HasZ(_CanHaveBBox): ) z: Sequence[float] - def _set_zs_from_byte_stream(self, b_io, nPoints): + def _set_zs_from_byte_stream(self, b_io: ReadableBinStream, nPoints: int): __zmin, __zmax = unpack("<2d", b_io.read(16)) # pylint: disable=unused-private-member self.z = _Array[float]("d", unpack(f"<{nPoints}d", b_io.read(nPoints * 8))) @staticmethod - def _write_zs_to_byte_stream(b_io, s, i, zbox): + def _write_zs_to_byte_stream( + b_io: WriteableBinStream, s: Shape, i: int, zbox: Optional[ZBox] + ) -> int: + if not zbox or len(zbox) != 2: + raise ShapefileException(f"Two numbers required for zbox. Got: {zbox}") + # Write z extremes and values # Note: missing z values are autoset to 0, but not sure if this is ideal. try: @@ -1216,22 +1368,27 @@ def _write_zs_to_byte_stream(b_io, s, i, zbox): class MultiPatch(_HasM, _HasZ, _CanHaveParts): shapeType = MULTIPATCH + _shapeTypes = frozenset([MULTIPATCH]) - def _set_part_types_from_byte_stream(self, b_io, nParts): + def _set_part_types_from_byte_stream(self, b_io: ReadableBinStream, nParts: int): self.partTypes = _Array[int]("i", unpack(f"<{nParts}i", b_io.read(nParts * 4))) @staticmethod - def _write_part_types_to_byte_stream(b_io, s): + def _write_part_types_to_byte_stream(b_io: WriteableBinStream, s: Shape) -> int: return b_io.write(pack(f"<{len(s.partTypes)}i", *s.partTypes)) class PointM(Point): shapeType = POINTM + _shapeTypes = frozenset([POINTM, POINTZ]) + # same default as in Writer.__shpRecord (if s.shapeType in (11, 21):) # PyShp encodes None m values as NODATA m = (None,) - def _set_single_point_m_from_byte_stream(self, b_io, next_shape): + def _set_single_point_m_from_byte_stream( + self, b_io: ReadSeekableBinStream, next_shape: int + ): if next_shape - b_io.tell() >= 8: (m,) = unpack(" int: # Write a single M value # Note: missing m values are autoset to NODATA. @@ -1285,36 +1444,41 @@ def _write_single_point_m_to_byte_stream(b_io, s, i): class PolylineM(Polyline, _HasM): shapeType = POLYLINEM + _shapeTypes = frozenset([POLYLINEM, POLYLINEZ]) class PolygonM(Polygon, _HasM): shapeType = POLYGONM + _shapeTypes = frozenset([POLYGONM, POLYGONZ]) class MultiPointM(MultiPoint, _HasM): shapeType = MULTIPOINTM + _shapeTypes = frozenset([MULTIPOINTM, MULTIPOINTZ]) + class PointZ(PointM): shapeType = POINTZ + _shapeTypes = frozenset([POINTZ]) + # same default as in Writer.__shpRecord (if s.shapeType == 11:) z: Sequence[float] = (0.0,) - def _set_single_point_z_from_byte_stream(self, b_io): + def _set_single_point_z_from_byte_stream(self, b_io: ReadableBinStream): self.z = tuple(unpack(" int: # Note: missing z values are autoset to 0, but not sure if this is ideal. - + z: float = 0.0 # then write value if hasattr(s, "z"): # if z values are stored in attribute try: - if not s.z: - # s.z = (0,) - z = 0 - else: + if s.z: z = s.z[0] except error: raise ShapefileException( @@ -1323,10 +1487,7 @@ def _write_single_point_z_to_byte_stream(b_io, s, i): else: # if z values are stored as 3rd dimension try: - if len(s.points[0]) < 3: - # s.points[0].append(0) - z = 0 - else: + if len(s.points[0]) >= 3 and s.points[0][2] is not None: z = s.points[0][2] except error: raise ShapefileException( @@ -1338,14 +1499,18 @@ def _write_single_point_z_to_byte_stream(b_io, s, i): class PolylineZ(PolylineM, _HasZ): shapeType = POLYLINEZ + _shapeTypes = frozenset([POLYLINEZ]) class PolygonZ(PolygonM, _HasZ): shapeType = POLYGONZ + _shapeTypes = frozenset([POLYGONZ]) + class MultiPointZ(MultiPointM, _HasZ): shapeType = MULTIPOINTZ + _shapeTypes = frozenset([MULTIPOINTZ]) SHAPE_CLASS_FROM_SHAPETYPE: dict[int, type[Union[NullShape, Point, _CanHaveBBox]]] = { @@ -1620,6 +1785,7 @@ def _assert_ext_is_supported(self, ext: str): assert ext in self.CONSTITUENT_FILE_EXTS def __init__( + # pylint: disable=unused-argument self, shapefile_path: Union[str, os.PathLike] = "", /, @@ -1629,7 +1795,9 @@ def __init__( shp: Union[_NoShpSentinel, Optional[BinaryFileT]] = _NoShpSentinel(), shx: Optional[BinaryFileT] = None, dbf: Optional[BinaryFileT] = None, - **kwargs, # pylint: disable=unused-argument + # Keep kwargs even though unused, to preserve PyShp 2.4 API + **kwargs, + # pylint: enable=unused-argument ): self.shp = None self.shx = None @@ -1640,7 +1808,7 @@ def __init__( self.shpLength: Optional[int] = None self.numRecords: Optional[int] = None self.numShapes: Optional[int] = None - self.fields: list[FieldTuple] = [] + self.fields: list[Field] = [] self.__dbfHdrLength = 0 self.__fieldLookup: dict[str, int] = {} self.encoding = encoding @@ -1648,7 +1816,7 @@ def __init__( # See if a shapefile name was passed as the first argument if shapefile_path: path = fsdecode_if_pathlike(shapefile_path) - if is_string(path): + if isinstance(path, str): if ".zip" in path: # Shapefile is inside a zipfile if path.count(".zip") > 1: @@ -2025,7 +2193,7 @@ def __restrictIndex(self, i: int) -> int: i = range(self.numRecords)[i] return i - def __shpHeader(self): + def __shpHeader(self) -> None: """Reads the header information from a .shp file.""" if not self.shp: raise ShapefileException( @@ -2040,17 +2208,18 @@ def __shpHeader(self): shp.seek(32) self.shapeType = unpack(" NODATA: - self.mbox.append(m) - else: - self.mbox.append(None) + # Measure values less than -10e38 are nodata values according to the spec + m_bounds = [ + float(m_bound) if m_bound >= NODATA else None + for m_bound in unpack("<2d", shp.read(16)) + ] + self.mbox = MBox(mmin=m_bounds[0], mmax=m_bounds[1]) def __shape( self, oid: Optional[int] = None, bbox: Optional[BBox] = None @@ -2070,7 +2239,7 @@ def __shape( # Read entire record into memory to avoid having to call # seek on the file afterwards - b_io = io.BytesIO(f.read(recLength_bytes)) + b_io: ReadSeekableBinStream = io.BytesIO(f.read(recLength_bytes)) b_io.seek(0) shapeType = unpack(" Optional[int]: """Returns the offset in a .shp file for a shape based on information @@ -2213,7 +2388,7 @@ def iterShapes(self, bbox: Optional[BBox] = None) -> Iterator[Optional[Shape]]: self.numShapes = i self._offsets = offsets - def __dbfHeader(self): + def __dbfHeader(self) -> None: """Reads a dbf header. Xbase-related code borrows heavily from ActiveState Python Cookbook Recipe 362715 by Raymond Hettinger""" if not self.dbf: @@ -2230,18 +2405,22 @@ def __dbfHeader(self): # read fields numFields = (self.__dbfHdrLength - 33) // 32 for __field in range(numFields): - fieldDesc = list(unpack("<11sc4xBB14x", dbf.read(32))) - name = 0 - idx = 0 - if b"\x00" in fieldDesc[name]: - idx = fieldDesc[name].index(b"\x00") + encoded_field_tuple: tuple[bytes, bytes, int, int] = unpack( + "<11sc4xBB14x", dbf.read(32) + ) + encoded_name, encoded_type_char, size, decimal = encoded_field_tuple + + if b"\x00" in encoded_name: + idx = encoded_name.index(b"\x00") else: - idx = len(fieldDesc[name]) - 1 - fieldDesc[name] = fieldDesc[name][:idx] - fieldDesc[name] = u(fieldDesc[name], self.encoding, self.encodingErrors) - fieldDesc[name] = fieldDesc[name].lstrip() - fieldDesc[1] = u(fieldDesc[1], "ascii") - self.fields.append(fieldDesc) + idx = len(encoded_name) - 1 + encoded_name = encoded_name[:idx] + name = encoded_name.decode(self.encoding, self.encodingErrors) + name = name.lstrip() + + field_type = FieldType[encoded_type_char.decode("ascii").upper()] + + self.fields.append(Field(name, field_type, size, decimal)) terminator = dbf.read(1) if terminator != b"\r": raise ShapefileException( @@ -2249,7 +2428,7 @@ def __dbfHeader(self): ) # insert deletion field at start - self.fields.insert(0, ("DeletionFlag", "C", 1, 0)) + self.fields.insert(0, Field("DeletionFlag", FieldType.C, 1, 0)) # store all field positions for easy lookups # note: fieldLookup gives the index position of a field inside Reader.fields @@ -2269,14 +2448,14 @@ def __recordFmt(self, fields: Optional[Container[str]] = None) -> tuple[str, int """ if self.numRecords is None: self.__dbfHeader() - structcodes = [f"{fieldinfo[2]}s" for fieldinfo in self.fields] + structcodes = [f"{fieldinfo.size}s" for fieldinfo in self.fields] if fields is not None: # only unpack specified fields, ignore others using padbytes (x) structcodes = [ code - if fieldinfo[0] in fields - or fieldinfo[0] == "DeletionFlag" # always unpack delflag - else f"{fieldinfo[2]}x" + if fieldinfo.name in fields + or fieldinfo.name == "DeletionFlag" # always unpack delflag + else f"{fieldinfo.size}x" for fieldinfo, code in zip(self.fields, structcodes) ] fmt = "".join(structcodes) @@ -2290,7 +2469,7 @@ def __recordFmt(self, fields: Optional[Container[str]] = None) -> tuple[str, int def __recordFields( self, fields: Optional[Iterable[str]] = None - ) -> tuple[list[FieldTuple], dict[str, int], Struct]: + ) -> tuple[list[Field], dict[str, int], Struct]: """Returns the necessary info required to unpack a record's fields, restricted to a subset of fieldnames 'fields' if specified. Returns a list of field info tuples, a name-index lookup dict, @@ -2325,17 +2504,20 @@ def __recordFields( def __record( self, - fieldTuples: list[FieldTuple], + fieldTuples: list[Field], recLookup: dict[str, int], recStruct: Struct, oid: Optional[int] = None, ) -> Optional[_Record]: """Reads and returns a dbf record row as a list of values. Requires specifying - a list of field info tuples 'fieldTuples', a record name-index dict 'recLookup', + a list of field info Field namedtuples 'fieldTuples', a record name-index dict 'recLookup', and a Struct instance 'recStruct' for unpacking these fields. """ f = self.__getFileObj(self.dbf) + # The only format chars in from self.__recordFmt, in recStruct from __recordFields, + # are s and x (ascii encoded str and pad byte) so everything in recordContents is bytes + # https://docs.python.org/3/library/struct.html#format-characters recordContents = recStruct.unpack(f.read(recStruct.size)) # deletion flag field is always unpacked as first value (see __recordFmt) @@ -2355,14 +2537,14 @@ def __record( # parse each value record = [] - for (__name, typ, __size, deci), value in zip(fieldTuples, recordContents): - if typ in {"N", "F"}: + for (__name, typ, __size, decimal), value in zip(fieldTuples, recordContents): + if typ in {FieldType.N, FieldType.F}: # numeric or float: number stored as a string, right justified, and padded with blanks to the width of the field. value = value.split(b"\0")[0] value = value.replace(b"*", b"") # QGIS NULL is all '*' chars if value == b"": value = None - elif deci: + elif decimal: try: value = float(value) except ValueError: @@ -2382,7 +2564,7 @@ def __record( except ValueError: # not parseable as int, set to None value = None - elif typ == "D": + elif typ is FieldType.D: # date: 8 bytes - date stored as a string in the format YYYYMMDD. if ( not value.replace(b"\x00", b"") @@ -2398,9 +2580,9 @@ def __record( y, m, d = int(value[:4]), int(value[4:6]), int(value[6:8]) value = date(y, m, d) except (TypeError, ValueError): - # if invalid date, just return as unicode string so user can decide - value = u(value.strip()) - elif typ == "L": + # if invalid date, just return as unicode string so user can decimalde + value = str(value.strip()) + elif typ is FieldType.L: # logical: 1 byte - initialized to 0x20 (space) otherwise T or F. if value == b" ": value = None # space means missing or not yet set @@ -2412,8 +2594,7 @@ def __record( else: value = None # unknown value is set to missing else: - # anything else is forced to string/unicode - value = u(value, self.encoding, self.encodingErrors) + value = value.decode(self.encoding, self.encodingErrors) value = value.strip().rstrip( "\x00" ) # remove null-padding at end of strings @@ -2570,6 +2751,9 @@ def iterShapeRecords( class Writer: """Provides write support for ESRI Shapefiles.""" + W = TypeVar("W", bound=WriteSeekableBinStream) + + # pylint: disable=unused-argument def __init__( self, target: Union[str, os.PathLike, None] = None, @@ -2578,22 +2762,24 @@ def __init__( *, encoding: str = "utf-8", encodingErrors: str = "strict", - shp: Optional[BinaryWritableSeekable] = None, - shx: Optional[BinaryWritableSeekable] = None, - dbf: Optional[BinaryWritableSeekable] = None, - **kwargs, # pylint: disable=unused-argument + shp: Optional[WriteSeekableBinStream] = None, + shx: Optional[WriteSeekableBinStream] = None, + dbf: Optional[WriteSeekableBinStream] = None, + # Keep kwargs even though unused, to preserve PyShp 2.4 API + **kwargs, + # pylint: enable=unused-argument ): self.target = target self.autoBalance = autoBalance - self.fields: list[FieldTuple] = [] + self.fields: list[Field] = [] self.shapeType = shapeType - self.shp: Optional[BinaryFileStreamT] = None - self.shx: Optional[BinaryFileStreamT] = None - self.dbf: Optional[BinaryFileStreamT] = None + self.shp: Optional[WriteSeekableBinStream] = None + self.shx: Optional[WriteSeekableBinStream] = None + self.dbf: Optional[WriteSeekableBinStream] = None self._files_to_close: list[BinaryFileStreamT] = [] if target: target = fsdecode_if_pathlike(target) - if not is_string(target): + if not isinstance(target, str): raise TypeError( f"The target filepath {target!r} must be of type str/unicode or path-like, not {type(target)}." ) @@ -2619,9 +2805,9 @@ def __init__( # Geometry record offsets and lengths for writing shx file. self.recNum = 0 self.shpNum = 0 - self._bbox = None - self._zbox = None - self._mbox = None + self._bbox: Optional[BBox] = None + self._zbox: Optional[ZBox] = None + self._mbox: Optional[MBox] = None # Use deletion flags in dbf? Default is false (0). Note: Currently has no effect, records should NOT contain deletion flags. self.deletionFlag = 0 # Encoding @@ -2680,6 +2866,8 @@ def close(self): # Flush files for attribute in (self.shp, self.shx, self.dbf): + if attribute is None: + continue if hasattr(attribute, "flush") and not getattr(attribute, "closed", False): try: attribute.flush() @@ -2695,14 +2883,12 @@ def close(self): pass self._files_to_close = [] - W = TypeVar("W", bound=BinaryWritableSeekable) - @overload - def __getFileObj(self, f: str) -> IO[bytes]: ... + def __getFileObj(self, f: str) -> WriteSeekableBinStream: ... @overload - def __getFileObj(self, f: None) -> NoReturn: ... + def __getFileObj(self, f: None) -> Never: ... @overload - def __getFileObj(self, f: W) -> W: ... + def __getFileObj(self, f: WriteSeekableBinStream) -> WriteSeekableBinStream: ... def __getFileObj(self, f): """Safety handler to verify file-like objects""" if not f: @@ -2719,22 +2905,32 @@ def __getFileObj(self, f): return f raise ShapefileException(f"Unsupported file-like object: {f}") - def __shpFileLength(self): + def __shpFileLength(self) -> int: """Calculates the file length of the shp file.""" + shp = self.__getFileObj(self.shp) + # Remember starting position - start = self.shp.tell() + + start = shp.tell() # Calculate size of all shapes - self.shp.seek(0, 2) - size = self.shp.tell() + shp.seek(0, 2) + size = shp.tell() # Calculate size as 16-bit words size //= 2 # Return to start - self.shp.seek(start) + shp.seek(start) return size - def __bbox(self, s): - x = [] - y = [] + def __bbox(self, s: Shape) -> BBox: + x: list[float] = [] + y: list[float] = [] + + if self._bbox: + x.append(self._bbox.xmin) + y.append(self._bbox.ymin) + x.append(self._bbox.xmax) + y.append(self._bbox.ymax) + if len(s.points) > 0: px, py = list(zip(*s.points))[:2] x.extend(px) @@ -2747,23 +2943,11 @@ def __bbox(self, s): "Cannot create bbox. Expected a valid shape with at least one point. " f"Got a shape of type '{s.shapeType}' and 0 points." ) - bbox = [min(x), min(y), max(x), max(y)] - # update global - if self._bbox: - # compare with existing - self._bbox = [ - min(bbox[0], self._bbox[0]), - min(bbox[1], self._bbox[1]), - max(bbox[2], self._bbox[2]), - max(bbox[3], self._bbox[3]), - ] - else: - # first time bbox is being set - self._bbox = bbox - return bbox + self._bbox = BBox(xmin=min(x), ymin=min(y), xmax=max(x), ymax=max(y)) + return self._bbox - def __zbox(self, s): - z = [] + def __zbox(self, s) -> ZBox: + z: list[float] = [] if self._zbox: z.extend(self._zbox) @@ -2777,14 +2961,14 @@ def __zbox(self, s): # Original self._zbox bounds (if any) are the first two entries. # Set zbox for the first, and all later times - self._zbox = [min(z), max(z)] + self._zbox = ZBox(zmin=min(z), zmax=max(z)) return self._zbox - def __mbox(self, s): + def __mbox(self, s) -> MBox: mpos = 3 if s.shapeType in _HasZ._shapeTypes else 2 - m = [] + m: list[float] = [] if self._mbox: - m.extend(self._mbox) + m.extend(m_bound for m_bound in self._mbox if m_bound is not None) for p in s.points: try: @@ -2801,32 +2985,32 @@ def __mbox(self, s): # Original self._mbox bounds (if any) are the first two entries. # Set mbox for the first, and all later times - self._mbox = [min(m), max(m)] + self._mbox = MBox(mmin=min(m), mmax=max(m)) return self._mbox @property def shapeTypeName(self) -> str: return SHAPETYPE_LOOKUP[self.shapeType or 0] - def bbox(self): + def bbox(self) -> Optional[BBox]: """Returns the current bounding box for the shapefile which is the lower-left and upper-right corners. It does not contain the elevation or measure extremes.""" return self._bbox - def zbox(self): + def zbox(self) -> Optional[ZBox]: """Returns the current z extremes for the shapefile.""" return self._zbox - def mbox(self): + def mbox(self) -> Optional[MBox]: """Returns the current m extremes for the shapefile.""" return self._mbox def __shapefileHeader( self, - fileObj: Optional[BinaryWritableSeekable], - headerType: str = "shp", - ): + fileObj: Optional[WriteSeekableBinStream], + headerType: Literal["shp", "dbf", "shx"] = "shp", + ) -> None: """Writes the specified header type to the specified file-like object. Several of the shapefile formats are so similar that a single generic method to read or write them is warranted.""" @@ -2853,7 +3037,7 @@ def __shapefileHeader( # In such cases of empty shapefiles, ESRI spec says the bbox values are 'unspecified'. # Not sure what that means, so for now just setting to 0s, which is the same behavior as in previous versions. # This would also make sense since the Z and M bounds are similarly set to 0 for non-Z/M type shapefiles. - bbox = [0, 0, 0, 0] + bbox = BBox(0, 0, 0, 0) f.write(pack("<4d", *bbox)) except error: raise ShapefileException( @@ -2862,25 +3046,25 @@ def __shapefileHeader( else: f.write(pack("<4d", 0, 0, 0, 0)) # Elevation - if self.shapeType in (11, 13, 15, 18): + if self.shapeType in {POINTZ} | _HasZ._shapeTypes: # Z values are present in Z type zbox = self.zbox() if zbox is None: # means we have empty shapefile/only null geoms (see commentary on bbox above) - zbox = [0, 0] + zbox = ZBox(0, 0) else: # As per the ESRI shapefile spec, the zbox for non-Z type shapefiles are set to 0s - zbox = [0, 0] + zbox = ZBox(0, 0) # Measure - if self.shapeType in (11, 13, 15, 18, 21, 23, 25, 28, 31): + if self.shapeType in {POINTM, POINTZ} | _HasM._shapeTypes: # M values are present in M or Z type mbox = self.mbox() if mbox is None: # means we have empty shapefile/only null geoms (see commentary on bbox above) - mbox = [0, 0] + mbox = MBox(0, 0) else: # As per the ESRI shapefile spec, the mbox for non-M type shapefiles are set to 0s - mbox = [0, 0] + mbox = MBox(0, 0) # Try writing try: f.write(pack("<4d", zbox[0], zbox[1], mbox[0], mbox[1])) @@ -2889,7 +3073,7 @@ def __shapefileHeader( "Failed to write shapefile elevation and measure values. Floats required." ) - def __dbfHeader(self): + def __dbfHeader(self) -> None: """Writes the dbf header and field descriptors.""" f = self.__getFileObj(self.dbf) f.seek(0) @@ -2910,7 +3094,7 @@ def __dbfHeader(self): raise ShapefileException( "Shapefile dbf header length exceeds maximum length." ) - recordLength = sum(int(field[2]) for field in fields) + 1 + recordLength = sum(field.size for field in fields) + 1 header = pack( " None: # Balance if already not balanced if self.autoBalance and self.recNum < self.shpNum: self.balance() @@ -2959,8 +3147,8 @@ def shape( if self.shx: self.__shxRecord(offset, length) - def __shpRecord(self, s): - f = self.__getFileObj(self.shp) + def __shpRecord(self, s: Shape) -> tuple[int, int]: + f: WriteSeekableBinStream = self.__getFileObj(self.shp) offset = f.tell() self.shpNum += 1 @@ -2991,7 +3179,7 @@ def __shpRecord(self, s): # or .flush is called if not using RawIOBase). # https://docs.python.org/3/library/io.html#id2 # https://docs.python.org/3/library/io.html#io.BufferedWriter - b_io = io.BytesIO() + b_io: ReadWriteSeekableBinStream = io.BytesIO() # Record number, Content length place holder b_io.write(pack(">2i", self.shpNum, -1)) @@ -3023,7 +3211,7 @@ def __shpRecord(self, s): f.write(b_io.read()) return offset, length - def __shxRecord(self, offset, length): + def __shxRecord(self, offset: int, length: int) -> None: """Writes the shx records.""" f = self.__getFileObj(self.shx) @@ -3036,8 +3224,10 @@ def __shxRecord(self, offset, length): f.write(pack(">i", length)) def record( - self, *recordList: Iterable[RecordValue], **recordDict: dict[str, RecordValue] - ): + self, + *recordList: RecordValue, + **recordDict: RecordValue, + ) -> None: """Creates a dbf attribute record. You can submit either a sequence of field values or keyword arguments of field names and values. Before adding records you must add fields for the record values using the @@ -3048,7 +3238,7 @@ def record( # Balance if already not balanced if self.autoBalance and self.recNum > self.shpNum: self.balance() - + record: list[RecordValue] fieldCount = sum(1 for field in self.fields if field[0] != "DeletionFlag") if recordList: record = list(recordList) @@ -3072,7 +3262,77 @@ def record( record = ["" for _ in range(fieldCount)] self.__dbfRecord(record) - def __dbfRecord(self, record): + @staticmethod + def _dbf_missing_placeholder( + value: RecordValue, field_type: FieldType, size: int + ) -> str: + if field_type in {FieldType.N, FieldType.F}: + return "*" * size # QGIS NULL + if field_type is FieldType.D: + return "0" * 8 # QGIS NULL for date type + if field_type is FieldType.L: + return " " + return str(value)[:size].ljust(size) + + @overload + @staticmethod + def _try_coerce_to_numeric_str(value: date, size: int, decimal: int) -> Never: ... + @overload + @staticmethod + def _try_coerce_to_numeric_str( + value: RecordValueNotDate, size: int, decimal: int + ) -> str: ... + @staticmethod + def _try_coerce_to_numeric_str(value, size, decimal): + # numeric or float: number stored as a string, + # right justified, and padded with blanks + # to the width of the field. + if not decimal: + # force to int + try: + # first try to force directly to int. + # forcing a large int to float and back to int + # will lose information and result in wrong nr. + int_val = int(value) + except ValueError: + # forcing directly to int failed, so was probably a float. + int_val = int(float(value)) + except TypeError: + raise ShapefileException(f"Could not form int from: {value}") + # length capped to the field size + return format(int_val, "d")[:size].rjust(size) + + try: + f_val = float(value) + except ValueError: + raise ShapefileException(f"Could not form float from: {value}") + # length capped to the field size + return format(f_val, f".{decimal}f")[:size].rjust(size) + + @staticmethod + def _try_coerce_to_date_str(value: RecordValue) -> str: + # date: 8 bytes - date stored as a string in the format YYYYMMDD. + if isinstance(value, date): + return f"{value.year:04d}{value.month:02d}{value.day:02d}" + if isinstance(value, (list, tuple)) and len(value) == 3: + return f"{value[0]:04d}{value[1]:02d}{value[2]:02d}" + if isinstance(value, str) and len(value) == 8: + return value # value is already a date string + + raise ShapefileException( + "Date values must be either a datetime.date object, a list/tuple, a YYYYMMDD string, or a missing value." + ) + + @staticmethod + def _try_coerce_to_logical_str(value: RecordValue) -> str: + # logical: 1 byte - initialized to 0x20 (space) otherwise T or F. + if value == 1: # True == 1 + return "T" + if value == 0: # False == 0 + return "F" + return " " # unknown is set to space + + def __dbfRecord(self, record: list[RecordValue]) -> None: """Writes the dbf records.""" f = self.__getFileObj(self.dbf) if self.recNum == 0: @@ -3087,72 +3347,39 @@ def __dbfRecord(self, record): fields = ( field for field in self.fields if field[0] != "DeletionFlag" ) # ignore deletionflag field in case it was specified - for (fieldName, fieldType, size, deci), value in zip(fields, record): + for (fieldName, type_, size, decimal), value in zip(fields, record): # write - fieldType = fieldType.upper() size = int(size) - if fieldType in ("N", "F"): - # numeric or float: number stored as a string, right justified, and padded with blanks to the width of the field. - if value in MISSING: - value = b"*" * size # QGIS NULL - elif not deci: - # force to int - try: - # first try to force directly to int. - # forcing a large int to float and back to int - # will lose information and result in wrong nr. - value = int(value) - except ValueError: - # forcing directly to int failed, so was probably a float. - value = int(float(value)) - value = format(value, "d")[:size].rjust( - size - ) # caps the size if exceeds the field size - else: - value = float(value) - value = format(value, f".{deci}f")[:size].rjust( - size - ) # caps the size if exceeds the field size - elif fieldType == "D": - # date: 8 bytes - date stored as a string in the format YYYYMMDD. - if isinstance(value, date): - value = f"{value.year:04d}{value.month:02d}{value.day:02d}" - elif isinstance(value, list) and len(value) == 3: - value = f"{value[0]:04d}{value[1]:02d}{value[2]:02d}" - elif value in MISSING: - value = b"0" * 8 # QGIS NULL for date type - elif is_string(value) and len(value) == 8: - pass # value is already a date string - else: - raise ShapefileException( - "Date values must be either a datetime.date object, a list, a YYYYMMDD string, or a missing value." - ) - elif fieldType == "L": - # logical: 1 byte - initialized to 0x20 (space) otherwise T or F. - if value in MISSING: - value = b" " # missing is set to space - elif value in [True, 1]: - value = b"T" - elif value in [False, 0]: - value = b"F" - else: - value = b" " # unknown is set to space + str_val: str + + if value in MISSING: + str_val = self._dbf_missing_placeholder(value, type_, size) + elif type_ in {FieldType.N, FieldType.F}: + str_val = self._try_coerce_to_numeric_str(value, size, decimal) + elif type_ is FieldType.D: + str_val = self._try_coerce_to_date_str(value) + elif type_ is FieldType.L: + str_val = self._try_coerce_to_logical_str(value) else: - # anything else is forced to string, truncated to the length of the field - value = b(value, self.encoding, self.encodingErrors)[:size].ljust(size) - if not isinstance(value, bytes): - # just in case some of the numeric format() and date strftime() results are still in unicode (Python 3 only) - value = b( - value, "ascii", self.encodingErrors - ) # should be default ascii encoding - if len(value) != size: + if isinstance(value, bytes): + str_val = value.decode(self.encoding, self.encodingErrors) + else: + # anything else is forced to string. + str_val = str(value) + + # Truncate or right pad to the length of the field + encoded_val = str_val.encode(self.encoding, self.encodingErrors)[ + :size + ].ljust(size) + + if len(encoded_val) != size: raise ShapefileException( - "Shapefile Writer unable to pack incorrect sized value" - f" (size {len(value)}) into field '{fieldName}' (size {size})." + f"Shapefile Writer unable to pack incorrect sized {value=!r} " + f"(size {len(encoded_val)}) into field '{fieldName}' (size {size})." ) - f.write(value) + f.write(encoded_val) - def balance(self): + def balance(self) -> None: """Adds corresponding empty attributes or null geometry records depending on which type of record was created to make sure all three files are in synch.""" @@ -3161,24 +3388,26 @@ def balance(self): while self.recNum < self.shpNum: self.record() - def null(self): + def null(self) -> None: """Creates a null shape.""" self.shape(NullShape()) - def point(self, x: float, y: float): + def point(self, x: float, y: float) -> None: """Creates a POINT shape.""" pointShape = Point() pointShape.points.append((x, y)) self.shape(pointShape) - def pointm(self, x: float, y: float, m: Optional[float] = None): + def pointm(self, x: float, y: float, m: Optional[float] = None) -> None: """Creates a POINTM shape. If the m (measure) value is not set, it defaults to NoData.""" pointShape = PointM() pointShape.points.append((x, y, m)) self.shape(pointShape) - def pointz(self, x: float, y: float, z: float = 0.0, m: Optional[float] = None): + def pointz( + self, x: float, y: float, z: float = 0.0, m: Optional[float] = None + ) -> None: """Creates a POINTZ shape. If the z (elevation) value is not set, it defaults to 0. If the m (measure) value is not set, it defaults to NoData.""" @@ -3186,20 +3415,20 @@ def pointz(self, x: float, y: float, z: float = 0.0, m: Optional[float] = None): pointShape.points.append((x, y, z, m)) self.shape(pointShape) - def multipoint(self, points: PointsT): + def multipoint(self, points: PointsT) -> None: """Creates a MULTIPOINT shape. Points is a list of xy values.""" # nest the points inside a list to be compatible with the generic shapeparts method self._shapeparts(parts=[points], polyShape=MultiPoint()) - def multipointm(self, points: PointsT): + def multipointm(self, points: PointsT) -> None: """Creates a MULTIPOINTM shape. Points is a list of xym values. If the m (measure) value is not included, it defaults to None (NoData).""" # nest the points inside a list to be compatible with the generic shapeparts method self._shapeparts(parts=[points], polyShape=MultiPointM()) - def multipointz(self, points: PointsT): + def multipointz(self, points: PointsT) -> None: """Creates a MULTIPOINTZ shape. Points is a list of xyzm values. If the z (elevation) value is not included, it defaults to 0. @@ -3207,32 +3436,32 @@ def multipointz(self, points: PointsT): # nest the points inside a list to be compatible with the generic shapeparts method self._shapeparts(parts=[points], polyShape=MultiPointZ()) - def line(self, lines: list[PointsT]): + def line(self, lines: list[PointsT]) -> None: """Creates a POLYLINE shape. Lines is a collection of lines, each made up of a list of xy values.""" self._shapeparts(parts=lines, polyShape=Polyline()) - def linem(self, lines: list[PointsT]): + def linem(self, lines: list[PointsT]) -> None: """Creates a POLYLINEM shape. Lines is a collection of lines, each made up of a list of xym values. If the m (measure) value is not included, it defaults to None (NoData).""" self._shapeparts(parts=lines, polyShape=PolylineM()) - def linez(self, lines: list[PointsT]): + def linez(self, lines: list[PointsT]) -> None: """Creates a POLYLINEZ shape. Lines is a collection of lines, each made up of a list of xyzm values. If the z (elevation) value is not included, it defaults to 0. If the m (measure) value is not included, it defaults to None (NoData).""" self._shapeparts(parts=lines, polyShape=PolylineZ()) - def poly(self, polys: list[PointsT]): + def poly(self, polys: list[PointsT]) -> None: """Creates a POLYGON shape. Polys is a collection of polygons, each made up of a list of xy values. Note that for ordinary polygons the coordinates must run in a clockwise direction. If some of the polygons are holes, these must run in a counterclockwise direction.""" self._shapeparts(parts=polys, polyShape=Polygon()) - def polym(self, polys: list[PointsT]): + def polym(self, polys: list[PointsT]) -> None: """Creates a POLYGONM shape. Polys is a collection of polygons, each made up of a list of xym values. Note that for ordinary polygons the coordinates must run in a clockwise direction. @@ -3240,7 +3469,7 @@ def polym(self, polys: list[PointsT]): If the m (measure) value is not included, it defaults to None (NoData).""" self._shapeparts(parts=polys, polyShape=PolygonM()) - def polyz(self, polys: list[PointsT]): + def polyz(self, polys: list[PointsT]) -> None: """Creates a POLYGONZ shape. Polys is a collection of polygons, each made up of a list of xyzm values. Note that for ordinary polygons the coordinates must run in a clockwise direction. @@ -3249,7 +3478,7 @@ def polyz(self, polys: list[PointsT]): If the m (measure) value is not included, it defaults to None (NoData).""" self._shapeparts(parts=polys, polyShape=PolygonZ()) - def multipatch(self, parts: list[PointsT], partTypes: list[int]): + def multipatch(self, parts: list[PointsT], partTypes: list[int]) -> None: """Creates a MULTIPATCH shape. Parts is a collection of 3D surface patches, each made up of a list of xyzm values. PartTypes is a list of types that define each of the surface patches. @@ -3276,7 +3505,7 @@ def multipatch(self, parts: list[PointsT], partTypes: list[int]): def _shapeparts( self, parts: list[PointsT], polyShape: Union[Polyline, Polygon, MultiPoint] - ): + ) -> None: """Internal method for adding a shape that has multiple collections of points (parts): lines, polygons, and multipoint shapes. """ @@ -3287,7 +3516,7 @@ def _shapeparts( # if shapeType in (5, 15, 25, 31): # This method is never actually called on a MultiPatch # so we omit its shapeType (31) for efficiency - if isinstance(polyShape, Polygon): + if compatible_with(polyShape, Polygon): for part in parts: if part[0] != part[-1]: part.append(part[0]) @@ -3305,25 +3534,20 @@ def _shapeparts( self.shape(polyShape) def field( - # Types of args should match *FieldTuple + # Types of args should match *Field self, name: str, - fieldType: str = "C", + field_type: Union[str, FieldType] = FieldType.C, size: int = 50, decimal: int = 0, - ): + ) -> None: """Adds a dbf field descriptor to the shapefile.""" - if fieldType == "D": - size = 8 - decimal = 0 - elif fieldType == "L": - size = 1 - decimal = 0 if len(self.fields) >= 2046: raise ShapefileException( "Shapefile Writer reached maximum number of fields: 2046." ) - self.fields.append((name, fieldType, size, decimal)) + field_ = Field.from_unchecked(name, field_type, size, decimal) + self.fields.append(field_) # Begin Testing @@ -3450,7 +3674,7 @@ def _test(args: list[str] = sys.argv[1:], verbosity: bool = False) -> int: new_url = _replace_remote_url(old_url) example.source = example.source.replace(old_url, new_url) - runner = doctest.DocTestRunner(verbose=verbosity) + runner = doctest.DocTestRunner(verbose=verbosity, optionflags=doctest.FAIL_FAST) if verbosity == 0: print(f"Running {len(tests.examples)} doctests...") diff --git a/test_shapefile.py b/test_shapefile.py index 2a10d3e..a2ffbff 100644 --- a/test_shapefile.py +++ b/test_shapefile.py @@ -695,7 +695,7 @@ def test_reader_fields(): field = fields[0] assert isinstance(field[0], str) # field name - assert field[1] in ["C", "N", "F", "L", "D", "M"] # field type + assert field[1].name in ["C", "N", "F", "L", "D", "M"] # field type assert isinstance(field[2], int) # field length assert isinstance(field[3], int) # decimal length