From dfe6ae073fb1f34d3120e6efdffd52874fe4e704 Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Wed, 30 Jul 2025 18:09:36 +0100 Subject: [PATCH 01/20] Add in dependency on typing-extensions for Python 3.9 backports of TypeIs and NotRequired --- pyproject.toml | 3 + src/shapefile.py | 225 ++++++++++++++++++++++++++++++++++------------- 2 files changed, 168 insertions(+), 60 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d3e0e89..ca8f667 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,9 @@ classifiers = [ "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", ] +dependencies = [ + "typing_extensions", +] [project.optional-dependencies] test = ["pytest"] diff --git a/src/shapefile.py b/src/shapefile.py index d0ec177..bb159a2 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -43,6 +43,8 @@ from urllib.parse import urlparse, urlunparse from urllib.request import Request, urlopen +from typing_extensions import NotRequired, TypeIs + # Create named logger logger = logging.getLogger(__name__) @@ -120,15 +122,27 @@ PointsT = list[PointT] BBox = tuple[float, float, float, float] +MBox = tuple[float, float] +ZBox = tuple[float, float] + + +class BinaryWritableSeekable(Protocol): + def write(self, bbox: bytes): ... + def seek(self, offset: int, whence: int = 0): ... # pylint: disable=unused-argument + def tell(self): ... -class BinaryWritable(Protocol): - def write(self, data: bytes): ... +class BinaryReadableSeekable(Protocol): + def seek(self, offset: int, whence: int = 0): ... # pylint: disable=unused-argument + def tell(self): ... + def read(self, size: int = -1): ... -class BinaryWritableSeekable(BinaryWritable): - def seek(self, i: int): ... # pylint: disable=unused-argument +class BinaryReadableWritableSeekable(Protocol): + def write(self, bbox: bytes): ... + def seek(self, offset: int, whence: int = 0): ... # pylint: disable=unused-argument def tell(self): ... + def read(self, size: int = -1): ... # File name, file object or anything with a read() method that returns bytes. @@ -227,13 +241,13 @@ class GeoJSONFeatureCollection(TypedDict): features: list[GeoJSONFeature] -class GeoJSONFeatureCollectionWithBBox(GeoJSONFeatureCollection, total=False): +class GeoJSONFeatureCollectionWithBBox(GeoJSONFeatureCollection): # bbox is optional # typing.NotRequired requires Python 3.11 # and we must support 3.9 (at least until October) # https://docs.python.org/3/library/typing.html#typing.Required # Is there a backport? - bbox: list[float] + bbox: NotRequired[list[float]] # Helpers @@ -603,6 +617,24 @@ class _NoShapeTypeSentinel: class Shape: shapeType: int = NULL + _shapeTypes = frozenset( + [ + NULL, + POINT, + POINTM, + POINTZ, + POLYLINE, + POLYLINEM, + POLYLINEZ, + POLYGON, + POLYGONM, + POLYGONZ, + MULTIPOINT, + MULTIPOINTM, + MULTIPOINTZ, + MULTIPATCH, + ] + ) def __init__( self, @@ -838,19 +870,40 @@ def __repr__(self): return f"Shape #{self.__oid}: {self.shapeTypeName}" +S = TypeVar("S", bound=Shape) + + +def compatible_with(s: Shape, cls: type[S]) -> TypeIs[S]: + return s.shapeType in cls._shapeTypes + + class NullShape(Shape): # Shape.shapeType = NULL already, # to preserve handling of default args in Shape.__init__ # Repeated for clarity. shapeType = NULL + _shapeTypes = frozenset([NULL]) @classmethod - def from_byte_stream(cls, b_io, next_shape, oid=None, bbox=None): # pylint: disable=unused-argument + def from_byte_stream( + cls, + b_io: BinaryReadableSeekable, + next_shape: int, + oid: Optional[int] = None, + bbox: Optional[BBox] = None, + ) -> NullShape: # pylint: disable=unused-argument # Shape.__init__ sets self.points = points or [] return cls(oid=oid) @staticmethod - def write_to_byte_stream(b_io, s, i, bbox, mbox, zbox): # pylint: disable=unused-argument + def write_to_byte_stream( + b_io: BinaryWritableSeekable, + s: Shape, + i: int, + bbox: Optional[BBox], + mbox: Optional[MBox], + zbox: Optional[ZBox], + ) -> int: # pylint: disable=unused-argument return 0 @@ -876,13 +929,21 @@ class _CanHaveBBox(Shape): ) # Not a BBox because the legacy implementation was a list, not a 4-tuple. - bbox: Optional[Sequence[float]] = None + # bbox: Optional[Sequence[float]] = None + bbox: Optional[BBox] = None - def _set_bbox_from_byte_stream(self, b_io): - self.bbox = _Array[float]("d", unpack("<4d", b_io.read(32))) + def _get_set_bbox_from_byte_stream(self, b_io: BinaryReadableSeekable) -> BBox: + self.bbox: BBox = tuple(_Array[float]("d", unpack("<4d", b_io.read(32)))) + return self.bbox @staticmethod - def _write_bbox_to_byte_stream(b_io, i, bbox): + def _write_bbox_to_byte_stream( + b_io: BinaryWritableSeekable, i: int, bbox: Optional[BBox] + ) -> int: + if not bbox or len(bbox) != 4: + raise ShapefileException( + f"Four numbers required. Got: {bbox=}" + ) try: return b_io.write(pack("<4d", *bbox)) except error: @@ -891,20 +952,24 @@ def _write_bbox_to_byte_stream(b_io, i, bbox): ) @staticmethod - def _get_npoints_from_byte_stream(b_io): + def _get_npoints_from_byte_stream(b_io: BinaryReadableSeekable) -> int: return unpack(" int: return b_io.write(pack(" int: + x_ys: list[float] = [] for point in s.points: x_ys.extend(point[:2]) try: @@ -916,31 +981,41 @@ def _write_points_to_byte_stream(b_io, s, i): # pylint: disable=unused-argument @staticmethod - def _get_nparts_from_byte_stream(b_io): - return None + def _get_nparts_from_byte_stream(b_io: BinaryReadableSeekable) -> int: + return 0 - def _set_parts_from_byte_stream(self, b_io, nParts): + def _set_parts_from_byte_stream(self, b_io: BinaryReadableSeekable, nParts: int): pass - def _set_part_types_from_byte_stream(self, b_io, nParts): + def _set_part_types_from_byte_stream( + self, b_io: BinaryReadableSeekable, nParts: int + ): pass - def _set_zs_from_byte_stream(self, b_io, nPoints): + def _set_zs_from_byte_stream(self, b_io: BinaryReadableSeekable, nPoints: int): pass - def _set_ms_from_byte_stream(self, b_io, nPoints, next_shape): + def _set_ms_from_byte_stream( + self, b_io: BinaryReadableSeekable, nPoints: int, next_shape: int + ): pass # pylint: enable=unused-argument @classmethod - def from_byte_stream(cls, b_io, next_shape, oid=None, bbox=None): + def from_byte_stream( + cls, + b_io: BinaryReadableSeekable, + next_shape: int, + oid: Optional[int] = None, + bbox: Optional[BBox] = None, + ) -> Optional[_CanHaveBBox]: # pylint: disable=unused-argument shape = cls(oid=oid) - shape._set_bbox_from_byte_stream(b_io) # pylint: disable=assignment-from-none + shape_bbox = shape._get_set_bbox_from_byte_stream(b_io) # pylint: disable=assignment-from-none # if bbox specified and no overlap, skip this shape - if bbox is not None and not bbox_overlap(bbox, tuple(shape.bbox)): # pylint: disable=no-member + if bbox is not None and not bbox_overlap(bbox, shape_bbox): # pylint: disable=no-member #type: ignore [index] # because we stop parsing this shape, caller must skip to beginning of # next shape after we return (as done in f.seek(next_shape)) return None @@ -963,33 +1038,43 @@ def from_byte_stream(cls, b_io, next_shape, oid=None, bbox=None): return shape @staticmethod - def write_to_byte_stream(b_io, s, i, bbox, mbox, zbox): + def write_to_byte_stream( + b_io: BinaryWritableSeekable, + s: Shape, + i: int, + bbox: Optional[BBox], + mbox: Optional[MBox], + zbox: Optional[ZBox], + ) -> int: # We use static methods here and below, # to support s only being an instance of a the # Shape base class (with shapeType set) # i.e. not necessarily one of our newer shape specific # sub classes. - n = _CanHaveBBox._write_bbox_to_byte_stream(b_io, i, bbox) + n = 0 + + if compatible_with(s, _CanHaveBBox): + n += _CanHaveBBox._write_bbox_to_byte_stream(b_io, i, bbox) - if s.shapeType in _CanHaveParts._shapeTypes: + if compatible_with(s, _CanHaveParts): n += _CanHaveParts._write_nparts_to_byte_stream(b_io, s) # Shape types with multiple points per record - if s.shapeType in _CanHaveBBox._shapeTypes: + if compatible_with(s, _CanHaveBBox): n += _CanHaveBBox._write_npoints_to_byte_stream(b_io, s) # Write part indexes. Includes MultiPatch - if s.shapeType in _CanHaveParts._shapeTypes: + if compatible_with(s, _CanHaveParts): n += _CanHaveParts._write_part_indices_to_byte_stream(b_io, s) - if s.shapeType == MULTIPATCH: + if compatible_with(s, MultiPatch): n += MultiPatch._write_part_types_to_byte_stream(b_io, s) # Write points for multiple-point records - if s.shapeType in _CanHaveBBox._shapeTypes: + if compatible_with(s, _CanHaveBBox): n += _CanHaveBBox._write_points_to_byte_stream(b_io, s, i) - if s.shapeType in _HasZ._shapeTypes: + if compatible_with(s, _HasZ): n += _HasZ._write_zs_to_byte_stream(b_io, s, i, zbox) - if s.shapeType in _HasM._shapeTypes: + if compatible_with(s, _HasM): n += _HasM._write_ms_to_byte_stream(b_io, s, i, mbox) return n @@ -1012,18 +1097,20 @@ class _CanHaveParts(_CanHaveBBox): ) @staticmethod - def _get_nparts_from_byte_stream(b_io): + def _get_nparts_from_byte_stream(b_io: BinaryReadableSeekable) -> int: return unpack(" int: return b_io.write(pack(f"<{len(s.parts)}i", *s.parts)) @@ -1057,7 +1144,13 @@ def _write_x_y_to_byte_stream(b_io, x, y, i): ) @classmethod - def from_byte_stream(cls, b_io, next_shape, oid=None, bbox=None): + def from_byte_stream( + cls, + b_io: BinaryReadableSeekable, + next_shape: int, + oid: Optional[int] = None, + bbox: Optional[BBox] = None, + ): # pylint: disable=unused-argument shape = cls(oid=oid) x, y = cls._x_y_from_byte_stream(b_io) @@ -1077,7 +1170,14 @@ def from_byte_stream(cls, b_io, next_shape, oid=None, bbox=None): return shape @staticmethod - def write_to_byte_stream(b_io, s, i, bbox, mbox, zbox): # pylint: disable=unused-argument + def write_to_byte_stream( + b_io: BinaryWritableSeekable, + s: Shape, + i: int, + bbox: Optional[BBox], + mbox: Optional[MBox], + zbox: Optional[ZBox], + ) -> int: # pylint: disable=unused-argument # Serialize a single point x, y = s.points[0][0], s.points[0][1] n = Point._write_x_y_to_byte_stream(b_io, x, y, i) @@ -1216,6 +1316,7 @@ def _write_zs_to_byte_stream(b_io, s, i, zbox): class MultiPatch(_HasM, _HasZ, _CanHaveParts): shapeType = MULTIPATCH + _shapeTypes = frozenset([MULTIPATCH]) def _set_part_types_from_byte_stream(self, b_io, nParts): self.partTypes = _Array[int]("i", unpack(f"<{nParts}i", b_io.read(nParts * 4))) @@ -1227,6 +1328,8 @@ def _write_part_types_to_byte_stream(b_io, s): class PointM(Point): shapeType = POINTM + _shapeTypes = frozenset([POINTM, POINTZ]) + # same default as in Writer.__shpRecord (if s.shapeType in (11, 21):) # PyShp encodes None m values as NODATA m = (None,) @@ -1297,6 +1400,8 @@ class MultiPointM(MultiPoint, _HasM): class PointZ(PointM): shapeType = POINTZ + _shapeTypes = frozenset([POINTZ]) + # same default as in Writer.__shpRecord (if s.shapeType == 11:) z: Sequence[float] = (0.0,) @@ -2070,7 +2175,7 @@ def __shape( # Read entire record into memory to avoid having to call # seek on the file afterwards - b_io = io.BytesIO(f.read(recLength_bytes)) + b_io: BinaryReadableSeekable = io.BytesIO(f.read(recLength_bytes)) b_io.seek(0) shapeType = unpack(" IO[bytes]: ... + def __getFileObj(self, f: str) -> BinaryWritableSeekable: ... @overload def __getFileObj(self, f: None) -> NoReturn: ... @overload - def __getFileObj(self, f: W) -> W: ... + def __getFileObj(self, f: BinaryWritableSeekable) -> BinaryWritableSeekable: ... def __getFileObj(self, f): """Safety handler to verify file-like objects""" if not f: @@ -2762,8 +2867,8 @@ def __bbox(self, s): self._bbox = bbox return bbox - def __zbox(self, s): - z = [] + def __zbox(self, s) -> ZBox: + z: list[float] = [] if self._zbox: z.extend(self._zbox) @@ -2777,12 +2882,12 @@ def __zbox(self, s): # Original self._zbox bounds (if any) are the first two entries. # Set zbox for the first, and all later times - self._zbox = [min(z), max(z)] + self._zbox = (min(z), max(z)) return self._zbox - def __mbox(self, s): + def __mbox(self, s) -> MBox: mpos = 3 if s.shapeType in _HasZ._shapeTypes else 2 - m = [] + m: list[float] = [] if self._mbox: m.extend(self._mbox) @@ -2801,7 +2906,7 @@ def __mbox(self, s): # Original self._mbox bounds (if any) are the first two entries. # Set mbox for the first, and all later times - self._mbox = [min(m), max(m)] + self._mbox = (min(m), max(m)) return self._mbox @property @@ -2959,8 +3064,8 @@ def shape( if self.shx: self.__shxRecord(offset, length) - def __shpRecord(self, s): - f = self.__getFileObj(self.shp) + def __shpRecord(self, s: Shape) -> tuple[int, int]: + f: BinaryWritableSeekable = self.__getFileObj(self.shp) offset = f.tell() self.shpNum += 1 @@ -2991,7 +3096,7 @@ def __shpRecord(self, s): # or .flush is called if not using RawIOBase). # https://docs.python.org/3/library/io.html#id2 # https://docs.python.org/3/library/io.html#io.BufferedWriter - b_io = io.BytesIO() + b_io: BinaryReadableWritableSeekable = io.BytesIO() # Record number, Content length place holder b_io.write(pack(">2i", self.shpNum, -1)) From ebc859df52473c8b23557f107833e85a7c337eb4 Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Wed, 30 Jul 2025 18:42:29 +0100 Subject: [PATCH 02/20] Satisfy Pylint --- src/shapefile.py | 39 ++++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/src/shapefile.py b/src/shapefile.py index bb159a2..5ead714 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -877,6 +877,9 @@ def compatible_with(s: Shape, cls: type[S]) -> TypeIs[S]: return s.shapeType in cls._shapeTypes +# pylint: disable=unused-argument +# Need unused arguments to keep the same call signature for +# different implementations of from_byte_stream and write_to_byte_stream class NullShape(Shape): # Shape.shapeType = NULL already, # to preserve handling of default args in Shape.__init__ @@ -891,7 +894,7 @@ def from_byte_stream( next_shape: int, oid: Optional[int] = None, bbox: Optional[BBox] = None, - ) -> NullShape: # pylint: disable=unused-argument + ) -> NullShape: # Shape.__init__ sets self.points = points or [] return cls(oid=oid) @@ -903,7 +906,7 @@ def write_to_byte_stream( bbox: Optional[BBox], mbox: Optional[MBox], zbox: Optional[ZBox], - ) -> int: # pylint: disable=unused-argument + ) -> int: return 0 @@ -928,8 +931,6 @@ class _CanHaveBBox(Shape): ] ) - # Not a BBox because the legacy implementation was a list, not a 4-tuple. - # bbox: Optional[Sequence[float]] = None bbox: Optional[BBox] = None def _get_set_bbox_from_byte_stream(self, b_io: BinaryReadableSeekable) -> BBox: @@ -941,9 +942,7 @@ def _write_bbox_to_byte_stream( b_io: BinaryWritableSeekable, i: int, bbox: Optional[BBox] ) -> int: if not bbox or len(bbox) != 4: - raise ShapefileException( - f"Four numbers required. Got: {bbox=}" - ) + raise ShapefileException(f"Four numbers required. Got: {bbox=}") try: return b_io.write(pack("<4d", *bbox)) except error: @@ -979,7 +978,6 @@ def _write_points_to_byte_stream( f"Failed to write points for record {i}. Expected floats." ) - # pylint: disable=unused-argument @staticmethod def _get_nparts_from_byte_stream(b_io: BinaryReadableSeekable) -> int: return 0 @@ -1000,8 +998,6 @@ def _set_ms_from_byte_stream( ): pass - # pylint: enable=unused-argument - @classmethod def from_byte_stream( cls, @@ -1009,13 +1005,13 @@ def from_byte_stream( next_shape: int, oid: Optional[int] = None, bbox: Optional[BBox] = None, - ) -> Optional[_CanHaveBBox]: # pylint: disable=unused-argument + ) -> Optional[_CanHaveBBox]: shape = cls(oid=oid) - shape_bbox = shape._get_set_bbox_from_byte_stream(b_io) # pylint: disable=assignment-from-none + shape_bbox = shape._get_set_bbox_from_byte_stream(b_io) # if bbox specified and no overlap, skip this shape - if bbox is not None and not bbox_overlap(bbox, shape_bbox): # pylint: disable=no-member #type: ignore [index] + if bbox is not None and not bbox_overlap(bbox, shape_bbox): # because we stop parsing this shape, caller must skip to beginning of # next shape after we return (as done in f.seek(next_shape)) return None @@ -1150,7 +1146,7 @@ def from_byte_stream( next_shape: int, oid: Optional[int] = None, bbox: Optional[BBox] = None, - ): # pylint: disable=unused-argument + ): shape = cls(oid=oid) x, y = cls._x_y_from_byte_stream(b_io) @@ -1177,7 +1173,7 @@ def write_to_byte_stream( bbox: Optional[BBox], mbox: Optional[MBox], zbox: Optional[ZBox], - ) -> int: # pylint: disable=unused-argument + ) -> int: # Serialize a single point x, y = s.points[0][0], s.points[0][1] n = Point._write_x_y_to_byte_stream(b_io, x, y, i) @@ -1193,6 +1189,9 @@ def write_to_byte_stream( return n +# pylint: enable=unused-argument + + class Polyline(_CanHaveParts): shapeType = POLYLINE @@ -1725,6 +1724,7 @@ def _assert_ext_is_supported(self, ext: str): assert ext in self.CONSTITUENT_FILE_EXTS def __init__( + # pylint: disable=unused-argument self, shapefile_path: Union[str, os.PathLike] = "", /, @@ -1734,7 +1734,9 @@ def __init__( shp: Union[_NoShpSentinel, Optional[BinaryFileT]] = _NoShpSentinel(), shx: Optional[BinaryFileT] = None, dbf: Optional[BinaryFileT] = None, - **kwargs, # pylint: disable=unused-argument + # Keep kwargs even though unused, to preserve PyShp 2.4 API + **kwargs, + # pylint: enable=unused-argument ): self.shp = None self.shx = None @@ -2677,6 +2679,7 @@ class Writer: W = TypeVar("W", bound=BinaryWritableSeekable) + # pylint: disable=unused-argument def __init__( self, target: Union[str, os.PathLike, None] = None, @@ -2688,7 +2691,9 @@ def __init__( shp: Optional[BinaryWritableSeekable] = None, shx: Optional[BinaryWritableSeekable] = None, dbf: Optional[BinaryWritableSeekable] = None, - **kwargs, # pylint: disable=unused-argument + # Keep kwargs even though unused, to preserve PyShp 2.4 API + **kwargs, + # pylint: enable=unused-argument ): self.target = target self.autoBalance = autoBalance From 174f589d7963e401285a27fbf7d4fab8abde079d Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Wed, 30 Jul 2025 18:57:06 +0100 Subject: [PATCH 03/20] Rename Binary Stream Protocols --- src/shapefile.py | 80 ++++++++++++++++++++++++------------------------ 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/src/shapefile.py b/src/shapefile.py index 5ead714..7597d22 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -126,20 +126,20 @@ ZBox = tuple[float, float] -class BinaryWritableSeekable(Protocol): - def write(self, bbox: bytes): ... +class WriteSeekableBinStream(Protocol): + def write(self, b: bytes): ... # pylint: disable=redefined-outer-name def seek(self, offset: int, whence: int = 0): ... # pylint: disable=unused-argument def tell(self): ... -class BinaryReadableSeekable(Protocol): +class ReadSeekableBinStream(Protocol): def seek(self, offset: int, whence: int = 0): ... # pylint: disable=unused-argument def tell(self): ... def read(self, size: int = -1): ... -class BinaryReadableWritableSeekable(Protocol): - def write(self, bbox: bytes): ... +class ReadWriteSeekableBinStream(Protocol): + def write(self, b: bytes): ... # pylint: disable=redefined-outer-name def seek(self, offset: int, whence: int = 0): ... # pylint: disable=unused-argument def tell(self): ... def read(self, size: int = -1): ... @@ -147,7 +147,7 @@ def read(self, size: int = -1): ... # File name, file object or anything with a read() method that returns bytes. BinaryFileT = Union[str, IO[bytes]] -BinaryFileStreamT = Union[IO[bytes], io.BytesIO, BinaryWritableSeekable] +BinaryFileStreamT = Union[IO[bytes], io.BytesIO, WriteSeekableBinStream] FieldTuple = tuple[str, str, int, int] RecordValue = Union[ @@ -890,7 +890,7 @@ class NullShape(Shape): @classmethod def from_byte_stream( cls, - b_io: BinaryReadableSeekable, + b_io: ReadSeekableBinStream, next_shape: int, oid: Optional[int] = None, bbox: Optional[BBox] = None, @@ -900,7 +900,7 @@ def from_byte_stream( @staticmethod def write_to_byte_stream( - b_io: BinaryWritableSeekable, + b_io: WriteSeekableBinStream, s: Shape, i: int, bbox: Optional[BBox], @@ -933,13 +933,13 @@ class _CanHaveBBox(Shape): bbox: Optional[BBox] = None - def _get_set_bbox_from_byte_stream(self, b_io: BinaryReadableSeekable) -> BBox: + def _get_set_bbox_from_byte_stream(self, b_io: ReadSeekableBinStream) -> BBox: self.bbox: BBox = tuple(_Array[float]("d", unpack("<4d", b_io.read(32)))) return self.bbox @staticmethod def _write_bbox_to_byte_stream( - b_io: BinaryWritableSeekable, i: int, bbox: Optional[BBox] + b_io: WriteSeekableBinStream, i: int, bbox: Optional[BBox] ) -> int: if not bbox or len(bbox) != 4: raise ShapefileException(f"Four numbers required. Got: {bbox=}") @@ -951,22 +951,22 @@ def _write_bbox_to_byte_stream( ) @staticmethod - def _get_npoints_from_byte_stream(b_io: BinaryReadableSeekable) -> int: + def _get_npoints_from_byte_stream(b_io: ReadSeekableBinStream) -> int: return unpack(" int: return b_io.write(pack(" int: x_ys: list[float] = [] for point in s.points: @@ -979,29 +979,29 @@ def _write_points_to_byte_stream( ) @staticmethod - def _get_nparts_from_byte_stream(b_io: BinaryReadableSeekable) -> int: + def _get_nparts_from_byte_stream(b_io: ReadSeekableBinStream) -> int: return 0 - def _set_parts_from_byte_stream(self, b_io: BinaryReadableSeekable, nParts: int): + def _set_parts_from_byte_stream(self, b_io: ReadSeekableBinStream, nParts: int): pass def _set_part_types_from_byte_stream( - self, b_io: BinaryReadableSeekable, nParts: int + self, b_io: ReadSeekableBinStream, nParts: int ): pass - def _set_zs_from_byte_stream(self, b_io: BinaryReadableSeekable, nPoints: int): + def _set_zs_from_byte_stream(self, b_io: ReadSeekableBinStream, nPoints: int): pass def _set_ms_from_byte_stream( - self, b_io: BinaryReadableSeekable, nPoints: int, next_shape: int + self, b_io: ReadSeekableBinStream, nPoints: int, next_shape: int ): pass @classmethod def from_byte_stream( cls, - b_io: BinaryReadableSeekable, + b_io: ReadSeekableBinStream, next_shape: int, oid: Optional[int] = None, bbox: Optional[BBox] = None, @@ -1035,7 +1035,7 @@ def from_byte_stream( @staticmethod def write_to_byte_stream( - b_io: BinaryWritableSeekable, + b_io: WriteSeekableBinStream, s: Shape, i: int, bbox: Optional[BBox], @@ -1093,19 +1093,19 @@ class _CanHaveParts(_CanHaveBBox): ) @staticmethod - def _get_nparts_from_byte_stream(b_io: BinaryReadableSeekable) -> int: + def _get_nparts_from_byte_stream(b_io: ReadSeekableBinStream) -> int: return unpack(" int: return b_io.write(pack(" int: return b_io.write(pack(f"<{len(s.parts)}i", *s.parts)) @@ -1142,7 +1142,7 @@ def _write_x_y_to_byte_stream(b_io, x, y, i): @classmethod def from_byte_stream( cls, - b_io: BinaryReadableSeekable, + b_io: ReadSeekableBinStream, next_shape: int, oid: Optional[int] = None, bbox: Optional[BBox] = None, @@ -1167,7 +1167,7 @@ def from_byte_stream( @staticmethod def write_to_byte_stream( - b_io: BinaryWritableSeekable, + b_io: WriteSeekableBinStream, s: Shape, i: int, bbox: Optional[BBox], @@ -2177,7 +2177,7 @@ def __shape( # Read entire record into memory to avoid having to call # seek on the file afterwards - b_io: BinaryReadableSeekable = io.BytesIO(f.read(recLength_bytes)) + b_io: ReadSeekableBinStream = io.BytesIO(f.read(recLength_bytes)) b_io.seek(0) shapeType = unpack(" BinaryWritableSeekable: ... + def __getFileObj(self, f: str) -> WriteSeekableBinStream: ... @overload def __getFileObj(self, f: None) -> NoReturn: ... @overload - def __getFileObj(self, f: BinaryWritableSeekable) -> BinaryWritableSeekable: ... + def __getFileObj(self, f: WriteSeekableBinStream) -> WriteSeekableBinStream: ... def __getFileObj(self, f): """Safety handler to verify file-like objects""" if not f: @@ -2934,7 +2934,7 @@ def mbox(self): def __shapefileHeader( self, - fileObj: Optional[BinaryWritableSeekable], + fileObj: Optional[WriteSeekableBinStream], headerType: str = "shp", ): """Writes the specified header type to the specified file-like object. @@ -3070,7 +3070,7 @@ def shape( self.__shxRecord(offset, length) def __shpRecord(self, s: Shape) -> tuple[int, int]: - f: BinaryWritableSeekable = self.__getFileObj(self.shp) + f: WriteSeekableBinStream = self.__getFileObj(self.shp) offset = f.tell() self.shpNum += 1 @@ -3101,7 +3101,7 @@ def __shpRecord(self, s: Shape) -> tuple[int, int]: # or .flush is called if not using RawIOBase). # https://docs.python.org/3/library/io.html#id2 # https://docs.python.org/3/library/io.html#io.BufferedWriter - b_io: BinaryReadableWritableSeekable = io.BytesIO() + b_io: ReadWriteSeekableBinStream = io.BytesIO() # Record number, Content length place holder b_io.write(pack(">2i", self.shpNum, -1)) From 28e2e59c202b980f836cda152295138f6a31018e Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Wed, 30 Jul 2025 19:09:47 +0100 Subject: [PATCH 04/20] Replace NoReturn with Never --- src/shapefile.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/shapefile.py b/src/shapefile.py index 7597d22..ce59bfb 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -29,7 +29,6 @@ Iterable, Iterator, Literal, - NoReturn, Optional, Protocol, Reversible, @@ -43,7 +42,7 @@ from urllib.parse import urlparse, urlunparse from urllib.request import Request, urlopen -from typing_extensions import NotRequired, TypeIs +from typing_extensions import Never, NotRequired, TypeIs, # Create named logger logger = logging.getLogger(__name__) @@ -2810,7 +2809,7 @@ def close(self): @overload def __getFileObj(self, f: str) -> WriteSeekableBinStream: ... @overload - def __getFileObj(self, f: None) -> NoReturn: ... + def __getFileObj(self, f: None) -> Never: ... @overload def __getFileObj(self, f: WriteSeekableBinStream) -> WriteSeekableBinStream: ... def __getFileObj(self, f): From 3deb12c658752190ed7a1bbf767952faffc94e55 Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Wed, 30 Jul 2025 19:37:51 +0100 Subject: [PATCH 05/20] Type hint rest of shape subclass methods --- src/shapefile.py | 84 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 56 insertions(+), 28 deletions(-) diff --git a/src/shapefile.py b/src/shapefile.py index ce59bfb..ac85b94 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -42,7 +42,7 @@ from urllib.parse import urlparse, urlunparse from urllib.request import Request, urlopen -from typing_extensions import Never, NotRequired, TypeIs, +from typing_extensions import Never, NotRequired, Self, TypeIs # Create named logger logger = logging.getLogger(__name__) @@ -893,7 +893,7 @@ def from_byte_stream( next_shape: int, oid: Optional[int] = None, bbox: Optional[BBox] = None, - ) -> NullShape: + ) -> Self: # Shape.__init__ sets self.points = points or [] return cls(oid=oid) @@ -941,7 +941,7 @@ def _write_bbox_to_byte_stream( b_io: WriteSeekableBinStream, i: int, bbox: Optional[BBox] ) -> int: if not bbox or len(bbox) != 4: - raise ShapefileException(f"Four numbers required. Got: {bbox=}") + raise ShapefileException(f"Four numbers required for bbox. Got: {bbox}") try: return b_io.write(pack("<4d", *bbox)) except error: @@ -1004,7 +1004,7 @@ def from_byte_stream( next_shape: int, oid: Optional[int] = None, bbox: Optional[BBox] = None, - ) -> Optional[_CanHaveBBox]: + ) -> Optional[Self]: shape = cls(oid=oid) shape_bbox = shape._get_set_bbox_from_byte_stream(b_io) @@ -1116,21 +1116,25 @@ class Point(Shape): shapeType = POINT _shapeTypes = frozenset([POINT, POINTM, POINTZ]) - def _set_single_point_z_from_byte_stream(self, b_io): + def _set_single_point_z_from_byte_stream(self, b_io: ReadSeekableBinStream): pass - def _set_single_point_m_from_byte_stream(self, b_io, next_shape): + def _set_single_point_m_from_byte_stream( + self, b_io: ReadSeekableBinStream, next_shape: int + ): pass @staticmethod - def _x_y_from_byte_stream(b_io): + def _x_y_from_byte_stream(b_io: ReadSeekableBinStream): # Unpack _Array too x, y = _Array[float]("d", unpack("<2d", b_io.read(16))) # Convert to tuple return x, y @staticmethod - def _write_x_y_to_byte_stream(b_io, x, y, i): + def _write_x_y_to_byte_stream( + b_io: WriteSeekableBinStream, x: float, y: float, i: int + ) -> int: try: return b_io.write(pack("<2d", x, y)) except error: @@ -1145,7 +1149,7 @@ def from_byte_stream( next_shape: int, oid: Optional[int] = None, bbox: Optional[BBox] = None, - ): + ) -> Optional[Self]: shape = cls(oid=oid) x, y = cls._x_y_from_byte_stream(b_io) @@ -1193,14 +1197,17 @@ def write_to_byte_stream( class Polyline(_CanHaveParts): shapeType = POLYLINE + _shapeTypes = frozenset([POLYLINE, POLYLINEM, POLYLINEZ]) class Polygon(_CanHaveParts): shapeType = POLYGON + _shapeTypes = frozenset([POLYGON, POLYGONM, POLYGONZ]) class MultiPoint(_CanHaveBBox): shapeType = MULTIPOINT + _shapeTypes = frozenset([MULTIPOINT, MULTIPOINTM, MULTIPOINTZ]) class _HasM(_CanHaveBBox): @@ -1218,7 +1225,9 @@ class _HasM(_CanHaveBBox): ) m: Sequence[Optional[float]] - def _set_ms_from_byte_stream(self, b_io, nPoints, next_shape): + def _set_ms_from_byte_stream( + self, b_io: ReadSeekableBinStream, nPoints: int, next_shape: int + ): if next_shape - b_io.tell() >= 16: __mmin, __mmax = unpack("<2d", b_io.read(16)) # Measure values less than -10e38 are nodata values according to the spec @@ -1233,7 +1242,11 @@ def _set_ms_from_byte_stream(self, b_io, nPoints, next_shape): self.m = [None for _ in range(nPoints)] @staticmethod - def _write_ms_to_byte_stream(b_io, s, i, mbox): + def _write_ms_to_byte_stream( + b_io: WriteSeekableBinStream, s: Shape, i: int, mbox: Optional[MBox] + ) -> int: + if not mbox or len(mbox) != 2: + raise ShapefileException(f"Two numbers required for mbox. Got: {mbox}") # Write m extremes and values # When reading a file, pyshp converts NODATA m values to None, so here we make sure to convert them back to NODATA # Note: missing m values are autoset to NODATA. @@ -1281,12 +1294,17 @@ class _HasZ(_CanHaveBBox): ) z: Sequence[float] - def _set_zs_from_byte_stream(self, b_io, nPoints): + def _set_zs_from_byte_stream(self, b_io: ReadSeekableBinStream, nPoints: int): __zmin, __zmax = unpack("<2d", b_io.read(16)) # pylint: disable=unused-private-member self.z = _Array[float]("d", unpack(f"<{nPoints}d", b_io.read(nPoints * 8))) @staticmethod - def _write_zs_to_byte_stream(b_io, s, i, zbox): + def _write_zs_to_byte_stream( + b_io: WriteSeekableBinStream, s: Shape, i: int, zbox: Optional[ZBox] + ) -> int: + if not zbox or len(zbox) != 2: + raise ShapefileException(f"Two numbers required for zbox. Got: {zbox}") + # Write z extremes and values # Note: missing z values are autoset to 0, but not sure if this is ideal. try: @@ -1316,11 +1334,13 @@ class MultiPatch(_HasM, _HasZ, _CanHaveParts): shapeType = MULTIPATCH _shapeTypes = frozenset([MULTIPATCH]) - def _set_part_types_from_byte_stream(self, b_io, nParts): + def _set_part_types_from_byte_stream( + self, b_io: ReadSeekableBinStream, nParts: int + ): self.partTypes = _Array[int]("i", unpack(f"<{nParts}i", b_io.read(nParts * 4))) @staticmethod - def _write_part_types_to_byte_stream(b_io, s): + def _write_part_types_to_byte_stream(b_io: WriteSeekableBinStream, s: Shape) -> int: return b_io.write(pack(f"<{len(s.partTypes)}i", *s.partTypes)) @@ -1332,7 +1352,9 @@ class PointM(Point): # PyShp encodes None m values as NODATA m = (None,) - def _set_single_point_m_from_byte_stream(self, b_io, next_shape): + def _set_single_point_m_from_byte_stream( + self, b_io: ReadSeekableBinStream, next_shape: int + ): if next_shape - b_io.tell() >= 8: (m,) = unpack(" int: # Write a single M value # Note: missing m values are autoset to NODATA. @@ -1386,15 +1410,19 @@ def _write_single_point_m_to_byte_stream(b_io, s, i): class PolylineM(Polyline, _HasM): shapeType = POLYLINEM + _shapeTypes = frozenset([POLYLINEM, POLYLINEZ]) class PolygonM(Polygon, _HasM): shapeType = POLYGONM + _shapeTypes = frozenset([POLYGONM, POLYGONZ]) class MultiPointM(MultiPoint, _HasM): shapeType = MULTIPOINTM + _shapeTypes = frozenset([MULTIPOINTM, MULTIPOINTZ]) + class PointZ(PointM): shapeType = POINTZ @@ -1403,21 +1431,20 @@ class PointZ(PointM): # same default as in Writer.__shpRecord (if s.shapeType == 11:) z: Sequence[float] = (0.0,) - def _set_single_point_z_from_byte_stream(self, b_io): + def _set_single_point_z_from_byte_stream(self, b_io: ReadSeekableBinStream): self.z = tuple(unpack(" int: # Note: missing z values are autoset to 0, but not sure if this is ideal. - + z: float = 0.0 # then write value if hasattr(s, "z"): # if z values are stored in attribute try: - if not s.z: - # s.z = (0,) - z = 0 - else: + if s.z: z = s.z[0] except error: raise ShapefileException( @@ -1426,10 +1453,7 @@ def _write_single_point_z_to_byte_stream(b_io, s, i): else: # if z values are stored as 3rd dimension try: - if len(s.points[0]) < 3: - # s.points[0].append(0) - z = 0 - else: + if len(s.points[0]) >= 3 and s.points[0][2] is not None: z = s.points[0][2] except error: raise ShapefileException( @@ -1441,14 +1465,18 @@ def _write_single_point_z_to_byte_stream(b_io, s, i): class PolylineZ(PolylineM, _HasZ): shapeType = POLYLINEZ + _shapeTypes = frozenset([POLYLINEZ]) class PolygonZ(PolygonM, _HasZ): shapeType = POLYGONZ + _shapeTypes = frozenset([POLYGONZ]) + class MultiPointZ(MultiPointM, _HasZ): shapeType = MULTIPOINTZ + _shapeTypes = frozenset([MULTIPOINTZ]) SHAPE_CLASS_FROM_SHAPETYPE: dict[int, type[Union[NullShape, Point, _CanHaveBBox]]] = { From a1059f383c4f480ec4762f6a9d08b70123214b53 Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Wed, 30 Jul 2025 19:48:56 +0100 Subject: [PATCH 06/20] Relax Read Write BinStream Protocols for readability --- src/shapefile.py | 70 +++++++++++++++++++++++++----------------------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/src/shapefile.py b/src/shapefile.py index ac85b94..66669d6 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -125,6 +125,14 @@ ZBox = tuple[float, float] +class WriteableBinStream(Protocol): + def write(self, b: bytes): ... # pylint: disable=redefined-outer-name + + +class ReadableBinStream(Protocol): + def read(self, size: int = -1): ... + + class WriteSeekableBinStream(Protocol): def write(self, b: bytes): ... # pylint: disable=redefined-outer-name def seek(self, offset: int, whence: int = 0): ... # pylint: disable=unused-argument @@ -899,7 +907,7 @@ def from_byte_stream( @staticmethod def write_to_byte_stream( - b_io: WriteSeekableBinStream, + b_io: WriteableBinStream, s: Shape, i: int, bbox: Optional[BBox], @@ -932,13 +940,13 @@ class _CanHaveBBox(Shape): bbox: Optional[BBox] = None - def _get_set_bbox_from_byte_stream(self, b_io: ReadSeekableBinStream) -> BBox: + def _get_set_bbox_from_byte_stream(self, b_io: ReadableBinStream) -> BBox: self.bbox: BBox = tuple(_Array[float]("d", unpack("<4d", b_io.read(32)))) return self.bbox @staticmethod def _write_bbox_to_byte_stream( - b_io: WriteSeekableBinStream, i: int, bbox: Optional[BBox] + b_io: WriteableBinStream, i: int, bbox: Optional[BBox] ) -> int: if not bbox or len(bbox) != 4: raise ShapefileException(f"Four numbers required for bbox. Got: {bbox}") @@ -950,22 +958,20 @@ def _write_bbox_to_byte_stream( ) @staticmethod - def _get_npoints_from_byte_stream(b_io: ReadSeekableBinStream) -> int: + def _get_npoints_from_byte_stream(b_io: ReadableBinStream) -> int: return unpack(" int: + def _write_npoints_to_byte_stream(b_io: WriteableBinStream, s: _CanHaveBBox) -> int: return b_io.write(pack(" int: x_ys: list[float] = [] for point in s.points: @@ -978,18 +984,16 @@ def _write_points_to_byte_stream( ) @staticmethod - def _get_nparts_from_byte_stream(b_io: ReadSeekableBinStream) -> int: + def _get_nparts_from_byte_stream(b_io: ReadableBinStream) -> int: return 0 - def _set_parts_from_byte_stream(self, b_io: ReadSeekableBinStream, nParts: int): + def _set_parts_from_byte_stream(self, b_io: ReadableBinStream, nParts: int): pass - def _set_part_types_from_byte_stream( - self, b_io: ReadSeekableBinStream, nParts: int - ): + def _set_part_types_from_byte_stream(self, b_io: ReadableBinStream, nParts: int): pass - def _set_zs_from_byte_stream(self, b_io: ReadSeekableBinStream, nPoints: int): + def _set_zs_from_byte_stream(self, b_io: ReadableBinStream, nPoints: int): pass def _set_ms_from_byte_stream( @@ -1034,7 +1038,7 @@ def from_byte_stream( @staticmethod def write_to_byte_stream( - b_io: WriteSeekableBinStream, + b_io: WriteableBinStream, s: Shape, i: int, bbox: Optional[BBox], @@ -1092,19 +1096,19 @@ class _CanHaveParts(_CanHaveBBox): ) @staticmethod - def _get_nparts_from_byte_stream(b_io: ReadSeekableBinStream) -> int: + def _get_nparts_from_byte_stream(b_io: ReadableBinStream) -> int: return unpack(" int: + def _write_nparts_to_byte_stream(b_io: WriteableBinStream, s) -> int: return b_io.write(pack(" int: return b_io.write(pack(f"<{len(s.parts)}i", *s.parts)) @@ -1116,7 +1120,7 @@ class Point(Shape): shapeType = POINT _shapeTypes = frozenset([POINT, POINTM, POINTZ]) - def _set_single_point_z_from_byte_stream(self, b_io: ReadSeekableBinStream): + def _set_single_point_z_from_byte_stream(self, b_io: ReadableBinStream): pass def _set_single_point_m_from_byte_stream( @@ -1125,7 +1129,7 @@ def _set_single_point_m_from_byte_stream( pass @staticmethod - def _x_y_from_byte_stream(b_io: ReadSeekableBinStream): + def _x_y_from_byte_stream(b_io: ReadableBinStream): # Unpack _Array too x, y = _Array[float]("d", unpack("<2d", b_io.read(16))) # Convert to tuple @@ -1133,7 +1137,7 @@ def _x_y_from_byte_stream(b_io: ReadSeekableBinStream): @staticmethod def _write_x_y_to_byte_stream( - b_io: WriteSeekableBinStream, x: float, y: float, i: int + b_io: WriteableBinStream, x: float, y: float, i: int ) -> int: try: return b_io.write(pack("<2d", x, y)) @@ -1170,7 +1174,7 @@ def from_byte_stream( @staticmethod def write_to_byte_stream( - b_io: WriteSeekableBinStream, + b_io: WriteableBinStream, s: Shape, i: int, bbox: Optional[BBox], @@ -1243,7 +1247,7 @@ def _set_ms_from_byte_stream( @staticmethod def _write_ms_to_byte_stream( - b_io: WriteSeekableBinStream, s: Shape, i: int, mbox: Optional[MBox] + b_io: WriteableBinStream, s: Shape, i: int, mbox: Optional[MBox] ) -> int: if not mbox or len(mbox) != 2: raise ShapefileException(f"Two numbers required for mbox. Got: {mbox}") @@ -1294,13 +1298,13 @@ class _HasZ(_CanHaveBBox): ) z: Sequence[float] - def _set_zs_from_byte_stream(self, b_io: ReadSeekableBinStream, nPoints: int): + def _set_zs_from_byte_stream(self, b_io: ReadableBinStream, nPoints: int): __zmin, __zmax = unpack("<2d", b_io.read(16)) # pylint: disable=unused-private-member self.z = _Array[float]("d", unpack(f"<{nPoints}d", b_io.read(nPoints * 8))) @staticmethod def _write_zs_to_byte_stream( - b_io: WriteSeekableBinStream, s: Shape, i: int, zbox: Optional[ZBox] + b_io: WriteableBinStream, s: Shape, i: int, zbox: Optional[ZBox] ) -> int: if not zbox or len(zbox) != 2: raise ShapefileException(f"Two numbers required for zbox. Got: {zbox}") @@ -1334,13 +1338,11 @@ class MultiPatch(_HasM, _HasZ, _CanHaveParts): shapeType = MULTIPATCH _shapeTypes = frozenset([MULTIPATCH]) - def _set_part_types_from_byte_stream( - self, b_io: ReadSeekableBinStream, nParts: int - ): + def _set_part_types_from_byte_stream(self, b_io: ReadableBinStream, nParts: int): self.partTypes = _Array[int]("i", unpack(f"<{nParts}i", b_io.read(nParts * 4))) @staticmethod - def _write_part_types_to_byte_stream(b_io: WriteSeekableBinStream, s: Shape) -> int: + def _write_part_types_to_byte_stream(b_io: WriteableBinStream, s: Shape) -> int: return b_io.write(pack(f"<{len(s.partTypes)}i", *s.partTypes)) @@ -1367,7 +1369,7 @@ def _set_single_point_m_from_byte_stream( @staticmethod def _write_single_point_m_to_byte_stream( - b_io: WriteSeekableBinStream, s: Shape, i: int + b_io: WriteableBinStream, s: Shape, i: int ) -> int: # Write a single M value # Note: missing m values are autoset to NODATA. @@ -1431,12 +1433,12 @@ class PointZ(PointM): # same default as in Writer.__shpRecord (if s.shapeType == 11:) z: Sequence[float] = (0.0,) - def _set_single_point_z_from_byte_stream(self, b_io: ReadSeekableBinStream): + def _set_single_point_z_from_byte_stream(self, b_io: ReadableBinStream): self.z = tuple(unpack(" int: # Note: missing z values are autoset to 0, but not sure if this is ideal. z: float = 0.0 From e1a11a72a50e7422d114376eba585931f45373f6 Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Wed, 30 Jul 2025 20:41:52 +0100 Subject: [PATCH 07/20] Type hint Reader.mbox and dbf field header code --- src/shapefile.py | 50 ++++++++++++++++++++++++++++-------------------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/src/shapefile.py b/src/shapefile.py index 66669d6..ab48210 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -2176,17 +2176,17 @@ def __shpHeader(self): shp.seek(32) self.shapeType = unpack(" NODATA: - self.mbox.append(m) - else: - self.mbox.append(None) + # Measure values less than -10e38 are nodata values according to the spec + + self.mbox: tuple[Optional[float], Optional[float]] + for i, m_bound in enumerate(unpack("<2d", shp.read(16))): + self.mbox[i] = m_bound if m_bound < NODATA else None + + def __shape( self, oid: Optional[int] = None, bbox: Optional[BBox] = None @@ -2240,13 +2240,19 @@ def __shxOffsets(self): raise ShapefileException( "Shapefile Reader requires a shapefile or file-like object. (no shx file found" ) + if self.numShapes is None: + raise ShapefileException( + "numShapes must not be None. " + " Was there a problem with .__shxHeader() ?" + f"Got: {self.numShapes=}" + ) # Jump to the first record. shx.seek(100) # Each index record consists of two nrs, we only want the first one shxRecords = _Array[int]("i", shx.read(2 * self.numShapes * 4)) if sys.byteorder != "big": shxRecords.byteswap() - self._offsets: list[int] = [2 * el for el in shxRecords[::2]] + self._offsets = [2 * el for el in shxRecords[::2]] def __shapeIndex(self, i: Optional[int] = None) -> Optional[int]: """Returns the offset in a .shp file for a shape based on information @@ -2366,18 +2372,20 @@ def __dbfHeader(self): # read fields numFields = (self.__dbfHdrLength - 33) // 32 for __field in range(numFields): - fieldDesc = list(unpack("<11sc4xBB14x", dbf.read(32))) - name = 0 - idx = 0 - if b"\x00" in fieldDesc[name]: - idx = fieldDesc[name].index(b"\x00") + encoded_field_tuple: tuple[bytes,bytes,int,int] = unpack("<11sc4xBB14x", dbf.read(32)) + encoded_name, encoded_field_type_char, size, decimal = encoded_field_tuple + + if b"\x00" in encoded_name: + idx = encoded_name.index(b"\x00") else: - idx = len(fieldDesc[name]) - 1 - fieldDesc[name] = fieldDesc[name][:idx] - fieldDesc[name] = u(fieldDesc[name], self.encoding, self.encodingErrors) - fieldDesc[name] = fieldDesc[name].lstrip() - fieldDesc[1] = u(fieldDesc[1], "ascii") - self.fields.append(fieldDesc) + idx = len(encoded_name) - 1 + encoded_name = encoded_name[:idx] + field_name = u(encoded_name, self.encoding, self.encodingErrors) + field_name = field_name.lstrip() + + field_type_char = u(encoded_field_type_char, "ascii") + + self.fields.append((field_name, field_type_char, size, decimal)) terminator = dbf.read(1) if terminator != b"\r": raise ShapefileException( From 9661cf028c4e7d634f5a4fe19d6f0906374cd968 Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Wed, 30 Jul 2025 20:55:24 +0100 Subject: [PATCH 08/20] Change output in Readme to reflect field data is now a list of tuples --- README.md | 38 +++++++++++++++++++------------------- src/shapefile.py | 10 +++++----- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index a0abc6a..ff7da81 100644 --- a/README.md +++ b/README.md @@ -553,7 +553,7 @@ in the shp geometry file and the dbf attribute file. The field names of a shapefile are available as soon as you read a shapefile. You can call the "fields" attribute of the shapefile as a Python list. Each -field is a Python list with the following information: +field is a Python tuple with the following information: * Field name: the name describing the data at this column index. * Field type: the type of data at this column index. Types can be: @@ -574,24 +574,24 @@ attribute: >>> fields = sf.fields - >>> assert fields == [("DeletionFlag", "C", 1, 0), ["AREA", "N", 18, 5], - ... ["BKG_KEY", "C", 12, 0], ["POP1990", "N", 9, 0], ["POP90_SQMI", "N", 10, 1], - ... ["HOUSEHOLDS", "N", 9, 0], - ... ["MALES", "N", 9, 0], ["FEMALES", "N", 9, 0], ["WHITE", "N", 9, 0], - ... ["BLACK", "N", 8, 0], ["AMERI_ES", "N", 7, 0], ["ASIAN_PI", "N", 8, 0], - ... ["OTHER", "N", 8, 0], ["HISPANIC", "N", 8, 0], ["AGE_UNDER5", "N", 8, 0], - ... ["AGE_5_17", "N", 8, 0], ["AGE_18_29", "N", 8, 0], ["AGE_30_49", "N", 8, 0], - ... ["AGE_50_64", "N", 8, 0], ["AGE_65_UP", "N", 8, 0], - ... ["NEVERMARRY", "N", 8, 0], ["MARRIED", "N", 9, 0], ["SEPARATED", "N", 7, 0], - ... ["WIDOWED", "N", 8, 0], ["DIVORCED", "N", 8, 0], ["HSEHLD_1_M", "N", 8, 0], - ... ["HSEHLD_1_F", "N", 8, 0], ["MARHH_CHD", "N", 8, 0], - ... ["MARHH_NO_C", "N", 8, 0], ["MHH_CHILD", "N", 7, 0], - ... ["FHH_CHILD", "N", 7, 0], ["HSE_UNITS", "N", 9, 0], ["VACANT", "N", 7, 0], - ... ["OWNER_OCC", "N", 8, 0], ["RENTER_OCC", "N", 8, 0], - ... ["MEDIAN_VAL", "N", 7, 0], ["MEDIANRENT", "N", 4, 0], - ... ["UNITS_1DET", "N", 8, 0], ["UNITS_1ATT", "N", 7, 0], ["UNITS2", "N", 7, 0], - ... ["UNITS3_9", "N", 8, 0], ["UNITS10_49", "N", 8, 0], - ... ["UNITS50_UP", "N", 8, 0], ["MOBILEHOME", "N", 7, 0]] + >>> assert fields == [("DeletionFlag", "C", 1, 0), ("AREA", "N", 18, 5), + ... ("BKG_KEY", "C", 12, 0), ("POP1990", "N", 9, 0), ("POP90_SQMI", "N", 10, 1), + ... ("HOUSEHOLDS", "N", 9, 0), + ... ("MALES", "N", 9, 0), ("FEMALES", "N", 9, 0), ("WHITE", "N", 9, 0), + ... ("BLACK", "N", 8, 0), ("AMERI_ES", "N", 7, 0), ("ASIAN_PI", "N", 8, 0), + ... ("OTHER", "N", 8, 0), ("HISPANIC", "N", 8, 0), ("AGE_UNDER5", "N", 8, 0), + ... ("AGE_5_17", "N", 8, 0), ("AGE_18_29", "N", 8, 0), ("AGE_30_49", "N", 8, 0), + ... ("AGE_50_64", "N", 8, 0), ("AGE_65_UP", "N", 8, 0), + ... ("NEVERMARRY", "N", 8, 0), ("MARRIED", "N", 9, 0), ("SEPARATED", "N", 7, 0), + ... ("WIDOWED", "N", 8, 0), ("DIVORCED", "N", 8, 0), ("HSEHLD_1_M", "N", 8, 0), + ... ("HSEHLD_1_F", "N", 8, 0), ("MARHH_CHD", "N", 8, 0), + ... ("MARHH_NO_C", "N", 8, 0), ("MHH_CHILD", "N", 7, 0), + ... ("FHH_CHILD", "N", 7, 0), ("HSE_UNITS", "N", 9, 0), ("VACANT", "N", 7, 0), + ... ("OWNER_OCC", "N", 8, 0), ("RENTER_OCC", "N", 8, 0), + ... ("MEDIAN_VAL", "N", 7, 0), ("MEDIANRENT", "N", 4, 0), + ... ("UNITS_1DET", "N", 8, 0), ("UNITS_1ATT", "N", 7, 0), ("UNITS2", "N", 7, 0), + ... ("UNITS3_9", "N", 8, 0), ("UNITS10_49", "N", 8, 0), + ... ("UNITS50_UP", "N", 8, 0), ("MOBILEHOME", "N", 7, 0)] The first field of a dbf file is always a 1-byte field called "DeletionFlag", which indicates records that have been deleted but not removed. However, diff --git a/src/shapefile.py b/src/shapefile.py index ab48210..2ee5cdf 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -2182,10 +2182,10 @@ def __shpHeader(self): # Measure # Measure values less than -10e38 are nodata values according to the spec - self.mbox: tuple[Optional[float], Optional[float]] - for i, m_bound in enumerate(unpack("<2d", shp.read(16))): - self.mbox[i] = m_bound if m_bound < NODATA else None - + self.mbox: tuple[Optional[float], Optional[float]] = tuple( + m_bound if m_bound >= NODATA else None + for m_bound in unpack("<2d", shp.read(16)) + ) def __shape( @@ -2374,7 +2374,7 @@ def __dbfHeader(self): for __field in range(numFields): encoded_field_tuple: tuple[bytes,bytes,int,int] = unpack("<11sc4xBB14x", dbf.read(32)) encoded_name, encoded_field_type_char, size, decimal = encoded_field_tuple - + if b"\x00" in encoded_name: idx = encoded_name.index(b"\x00") else: From db0e99c5e7f634d781d8f9a64ea0c66419d49315 Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Wed, 30 Jul 2025 20:57:49 +0100 Subject: [PATCH 09/20] Change rbox and mbox in Readme doctests to tuples --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ff7da81..67ecb4c 100644 --- a/README.md +++ b/README.md @@ -1375,7 +1375,7 @@ Shapefiles containing M-values can be examined in several ways: >>> r = shapefile.Reader('shapefiles/test/linem') >>> r.mbox # the lower and upper bound of M-values in the shapefile - [0.0, 3.0] + (0.0, 3.0) >>> r.shape(0).m # flat list of M-values [0.0, None, 3.0, None, 0.0, None, None] @@ -1408,7 +1408,7 @@ To examine a Z-type shapefile you can do: >>> r = shapefile.Reader('shapefiles/test/linez') >>> r.zbox # the lower and upper bound of Z-values in the shapefile - [0.0, 22.0] + (0.0, 22.0) >>> r.shape(0).z # flat list of Z-values [18.0, 20.0, 22.0, 0.0, 0.0, 0.0, 0.0, 15.0, 13.0, 14.0] From 3ae51cc9a4b6492cde12a62d6184b0872666190a Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Wed, 30 Jul 2025 21:40:55 +0100 Subject: [PATCH 10/20] Make FieldData a NamedTuple --- README.md | 4 +- src/shapefile.py | 106 ++++++++++++++++++++++++++--------------------- 2 files changed, 61 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 67ecb4c..a0e8f44 100644 --- a/README.md +++ b/README.md @@ -919,8 +919,8 @@ You can also add attributes using keyword arguments where the keys are field nam >>> w = shapefile.Writer('shapefiles/test/dtype') - >>> w.field('FIRST_FLD','C','40') - >>> w.field('SECOND_FLD','C','40') + >>> w.field('FIRST_FLD','C', 40) + >>> w.field('SECOND_FLD','C', 40) >>> w.null() >>> w.null() >>> w.record('First', 'Line') diff --git a/src/shapefile.py b/src/shapefile.py index 2ee5cdf..08f0f05 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -29,6 +29,7 @@ Iterable, Iterator, Literal, + NamedTuple, Optional, Protocol, Reversible, @@ -156,7 +157,14 @@ def read(self, size: int = -1): ... BinaryFileT = Union[str, IO[bytes]] BinaryFileStreamT = Union[IO[bytes], io.BytesIO, WriteSeekableBinStream] -FieldTuple = tuple[str, str, int, int] + +class FieldData(NamedTuple): + name: str + fieldType: str + size: int + decimal: int + + RecordValue = Union[ bool, int, float, str, date ] # A Possible value in a Shapefile record, e.g. L, N, F, C, D types @@ -1776,7 +1784,7 @@ def __init__( self.shpLength: Optional[int] = None self.numRecords: Optional[int] = None self.numShapes: Optional[int] = None - self.fields: list[FieldTuple] = [] + self.fields: list[FieldData] = [] self.__dbfHdrLength = 0 self.__fieldLookup: dict[str, int] = {} self.encoding = encoding @@ -2181,12 +2189,11 @@ def __shpHeader(self): self.zbox: ZBox = tuple(unpack("<2d", shp.read(16))) # Measure # Measure values less than -10e38 are nodata values according to the spec - - self.mbox: tuple[Optional[float], Optional[float]] = tuple( - m_bound if m_bound >= NODATA else None + m_bounds = [ + float(m_bound) if m_bound >= NODATA else None for m_bound in unpack("<2d", shp.read(16)) - ) - + ] + self.mbox = tuple(m_bounds[:2]) def __shape( self, oid: Optional[int] = None, bbox: Optional[BBox] = None @@ -2372,20 +2379,22 @@ def __dbfHeader(self): # read fields numFields = (self.__dbfHdrLength - 33) // 32 for __field in range(numFields): - encoded_field_tuple: tuple[bytes,bytes,int,int] = unpack("<11sc4xBB14x", dbf.read(32)) - encoded_name, encoded_field_type_char, size, decimal = encoded_field_tuple + encoded_field_tuple: tuple[bytes, bytes, int, int] = unpack( + "<11sc4xBB14x", dbf.read(32) + ) + encoded_name, encoded_type_char, size, decimal = encoded_field_tuple if b"\x00" in encoded_name: idx = encoded_name.index(b"\x00") else: idx = len(encoded_name) - 1 encoded_name = encoded_name[:idx] - field_name = u(encoded_name, self.encoding, self.encodingErrors) - field_name = field_name.lstrip() + name = u(encoded_name, self.encoding, self.encodingErrors) + name = name.lstrip() - field_type_char = u(encoded_field_type_char, "ascii") + type_char = u(encoded_type_char, "ascii") - self.fields.append((field_name, field_type_char, size, decimal)) + self.fields.append(FieldData(name, type_char, size, decimal)) terminator = dbf.read(1) if terminator != b"\r": raise ShapefileException( @@ -2393,7 +2402,7 @@ def __dbfHeader(self): ) # insert deletion field at start - self.fields.insert(0, ("DeletionFlag", "C", 1, 0)) + self.fields.insert(0, FieldData("DeletionFlag", "C", 1, 0)) # store all field positions for easy lookups # note: fieldLookup gives the index position of a field inside Reader.fields @@ -2434,7 +2443,7 @@ def __recordFmt(self, fields: Optional[Container[str]] = None) -> tuple[str, int def __recordFields( self, fields: Optional[Iterable[str]] = None - ) -> tuple[list[FieldTuple], dict[str, int], Struct]: + ) -> tuple[list[FieldData], dict[str, int], Struct]: """Returns the necessary info required to unpack a record's fields, restricted to a subset of fieldnames 'fields' if specified. Returns a list of field info tuples, a name-index lookup dict, @@ -2469,7 +2478,7 @@ def __recordFields( def __record( self, - fieldTuples: list[FieldTuple], + fieldTuples: list[FieldData], recLookup: dict[str, int], recStruct: Struct, oid: Optional[int] = None, @@ -2734,7 +2743,7 @@ def __init__( ): self.target = target self.autoBalance = autoBalance - self.fields: list[FieldTuple] = [] + self.fields: list[FieldData] = [] self.shapeType = shapeType self.shp: Optional[WriteSeekableBinStream] = None self.shx: Optional[WriteSeekableBinStream] = None @@ -2829,6 +2838,8 @@ def close(self): # Flush files for attribute in (self.shp, self.shx, self.dbf): + if attribute is None: + continue if hasattr(attribute, "flush") and not getattr(attribute, "closed", False): try: attribute.flush() @@ -2868,20 +2879,30 @@ def __getFileObj(self, f): def __shpFileLength(self): """Calculates the file length of the shp file.""" + shp = self.__getFileObj(self.shp) + # Remember starting position - start = self.shp.tell() + + start = shp.tell() # Calculate size of all shapes - self.shp.seek(0, 2) - size = self.shp.tell() + shp.seek(0, 2) + size = shp.tell() # Calculate size as 16-bit words size //= 2 # Return to start - self.shp.seek(start) + shp.seek(start) return size - def __bbox(self, s): - x = [] - y = [] + def __bbox(self, s: Shape): + x: list[float] = [] + y: list[float] = [] + + if self._bbox: + x.append(self._bbox[0]) + y.append(self._bbox[1]) + x.append(self._bbox[2]) + y.append(self._bbox[3]) + if len(s.points) > 0: px, py = list(zip(*s.points))[:2] x.extend(px) @@ -2894,20 +2915,8 @@ def __bbox(self, s): "Cannot create bbox. Expected a valid shape with at least one point. " f"Got a shape of type '{s.shapeType}' and 0 points." ) - bbox = [min(x), min(y), max(x), max(y)] - # update global - if self._bbox: - # compare with existing - self._bbox = [ - min(bbox[0], self._bbox[0]), - min(bbox[1], self._bbox[1]), - max(bbox[2], self._bbox[2]), - max(bbox[3], self._bbox[3]), - ] - else: - # first time bbox is being set - self._bbox = bbox - return bbox + self._bbox = (min(x), min(y), max(x), max(y)) + return self._bbox def __zbox(self, s) -> ZBox: z: list[float] = [] @@ -3057,7 +3066,7 @@ def __dbfHeader(self): raise ShapefileException( "Shapefile dbf header length exceeds maximum length." ) - recordLength = sum(int(field[2]) for field in fields) + 1 + recordLength = sum(field.size for field in fields) + 1 header = pack( " Date: Wed, 30 Jul 2025 23:08:23 +0100 Subject: [PATCH 11/20] Tackle dbf! Delete u() --- src/shapefile.py | 192 +++++++++++++++++++++++------------------------ 1 file changed, 92 insertions(+), 100 deletions(-) diff --git a/src/shapefile.py b/src/shapefile.py index 08f0f05..ed91d64 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -43,7 +43,7 @@ from urllib.parse import urlparse, urlunparse from urllib.request import Request, urlopen -from typing_extensions import Never, NotRequired, Self, TypeIs +from typing_extensions import Never, NotRequired, Self, TypeIs, Unpack # Create named logger logger = logging.getLogger(__name__) @@ -267,28 +267,12 @@ class GeoJSONFeatureCollectionWithBBox(GeoJSONFeatureCollection): # Helpers -MISSING = [None, ""] +MISSING = {None, ""} NODATA = -10e38 # as per the ESRI shapefile spec, only used for m-values. unpack_2_int32_be = Struct(">2i").unpack -def b( - v: Union[str, bytes], encoding: str = "utf-8", encodingErrors: str = "strict" -) -> bytes: - if isinstance(v, str): - # For python 3 encode str to bytes. - return v.encode(encoding, encodingErrors) - if isinstance(v, bytes): - # Already bytes. - return v - if v is None: - # Since we're dealing with text, interpret None as "" - return b"" - # Force string representation. - return str(v).encode(encoding, encodingErrors) - - def u( v: Union[str, bytes], encoding: str = "utf-8", encodingErrors: str = "strict" ) -> str: @@ -2169,7 +2153,7 @@ def __restrictIndex(self, i: int) -> int: i = range(self.numRecords)[i] return i - def __shpHeader(self): + def __shpHeader(self) -> None: """Reads the header information from a .shp file.""" if not self.shp: raise ShapefileException( @@ -2362,7 +2346,7 @@ def iterShapes(self, bbox: Optional[BBox] = None) -> Iterator[Optional[Shape]]: self.numShapes = i self._offsets = offsets - def __dbfHeader(self): + def __dbfHeader(self) -> None: """Reads a dbf header. Xbase-related code borrows heavily from ActiveState Python Cookbook Recipe 362715 by Raymond Hettinger""" if not self.dbf: @@ -2389,10 +2373,10 @@ def __dbfHeader(self): else: idx = len(encoded_name) - 1 encoded_name = encoded_name[:idx] - name = u(encoded_name, self.encoding, self.encodingErrors) + name = encoded_name.decode(self.encoding, self.encodingErrors) name = name.lstrip() - type_char = u(encoded_type_char, "ascii") + type_char = encoded_type_char.decode("ascii") self.fields.append(FieldData(name, type_char, size, decimal)) terminator = dbf.read(1) @@ -2422,14 +2406,14 @@ def __recordFmt(self, fields: Optional[Container[str]] = None) -> tuple[str, int """ if self.numRecords is None: self.__dbfHeader() - structcodes = [f"{fieldinfo[2]}s" for fieldinfo in self.fields] + structcodes = [f"{fieldinfo.size}s" for fieldinfo in self.fields] if fields is not None: # only unpack specified fields, ignore others using padbytes (x) structcodes = [ code - if fieldinfo[0] in fields - or fieldinfo[0] == "DeletionFlag" # always unpack delflag - else f"{fieldinfo[2]}x" + if fieldinfo.name in fields + or fieldinfo.name == "DeletionFlag" # always unpack delflag + else f"{fieldinfo.size}x" for fieldinfo, code in zip(self.fields, structcodes) ] fmt = "".join(structcodes) @@ -2484,11 +2468,14 @@ def __record( oid: Optional[int] = None, ) -> Optional[_Record]: """Reads and returns a dbf record row as a list of values. Requires specifying - a list of field info tuples 'fieldTuples', a record name-index dict 'recLookup', + a list of field info FieldData namedtuples 'fieldTuples', a record name-index dict 'recLookup', and a Struct instance 'recStruct' for unpacking these fields. """ f = self.__getFileObj(self.dbf) + # The only format chars in from self.__recordFmt, in recStruct from __recordFields, + # are s and x (ascii encoded str and pad byte) so everything in recordContents is bytes + # https://docs.python.org/3/library/struct.html#format-characters recordContents = recStruct.unpack(f.read(recStruct.size)) # deletion flag field is always unpacked as first value (see __recordFmt) @@ -2552,7 +2539,7 @@ def __record( value = date(y, m, d) except (TypeError, ValueError): # if invalid date, just return as unicode string so user can decide - value = u(value.strip()) + value = str(value.strip()) elif typ == "L": # logical: 1 byte - initialized to 0x20 (space) otherwise T or F. if value == b" ": @@ -2877,7 +2864,7 @@ def __getFileObj(self, f): return f raise ShapefileException(f"Unsupported file-like object: {f}") - def __shpFileLength(self): + def __shpFileLength(self) -> int: """Calculates the file length of the shp file.""" shp = self.__getFileObj(self.shp) @@ -2893,7 +2880,7 @@ def __shpFileLength(self): shp.seek(start) return size - def __bbox(self, s: Shape): + def __bbox(self, s: Shape) -> BBox: x: list[float] = [] y: list[float] = [] @@ -2964,25 +2951,25 @@ def __mbox(self, s) -> MBox: def shapeTypeName(self) -> str: return SHAPETYPE_LOOKUP[self.shapeType or 0] - def bbox(self): + def bbox(self) -> Optional[BBox]: """Returns the current bounding box for the shapefile which is the lower-left and upper-right corners. It does not contain the elevation or measure extremes.""" return self._bbox - def zbox(self): + def zbox(self) -> Optional[ZBox]: """Returns the current z extremes for the shapefile.""" return self._zbox - def mbox(self): + def mbox(self) -> Optional[MBox]: """Returns the current m extremes for the shapefile.""" return self._mbox def __shapefileHeader( self, fileObj: Optional[WriteSeekableBinStream], - headerType: str = "shp", - ): + headerType: Literal["shp", "dbf", "shx"] = "shp", + ) -> None: """Writes the specified header type to the specified file-like object. Several of the shapefile formats are so similar that a single generic method to read or write them is warranted.""" @@ -3009,7 +2996,7 @@ def __shapefileHeader( # In such cases of empty shapefiles, ESRI spec says the bbox values are 'unspecified'. # Not sure what that means, so for now just setting to 0s, which is the same behavior as in previous versions. # This would also make sense since the Z and M bounds are similarly set to 0 for non-Z/M type shapefiles. - bbox = [0, 0, 0, 0] + bbox = (0, 0, 0, 0) f.write(pack("<4d", *bbox)) except error: raise ShapefileException( @@ -3018,25 +3005,25 @@ def __shapefileHeader( else: f.write(pack("<4d", 0, 0, 0, 0)) # Elevation - if self.shapeType in (11, 13, 15, 18): + if self.shapeType in {POINTZ} | _HasZ._shapeTypes: # Z values are present in Z type zbox = self.zbox() if zbox is None: # means we have empty shapefile/only null geoms (see commentary on bbox above) - zbox = [0, 0] + zbox = (0, 0) else: # As per the ESRI shapefile spec, the zbox for non-Z type shapefiles are set to 0s - zbox = [0, 0] + zbox = (0, 0) # Measure - if self.shapeType in (11, 13, 15, 18, 21, 23, 25, 28, 31): + if self.shapeType in {POINTM, POINTZ} | _HasM._shapeTypes: # M values are present in M or Z type mbox = self.mbox() if mbox is None: # means we have empty shapefile/only null geoms (see commentary on bbox above) - mbox = [0, 0] + mbox = (0, 0) else: # As per the ESRI shapefile spec, the mbox for non-M type shapefiles are set to 0s - mbox = [0, 0] + mbox = (0, 0) # Try writing try: f.write(pack("<4d", zbox[0], zbox[1], mbox[0], mbox[1])) @@ -3045,7 +3032,7 @@ def __shapefileHeader( "Failed to write shapefile elevation and measure values. Floats required." ) - def __dbfHeader(self): + def __dbfHeader(self) -> None: """Writes the dbf header and field descriptors.""" f = self.__getFileObj(self.dbf) f.seek(0) @@ -3081,10 +3068,10 @@ def __dbfHeader(self): # Field descriptors for field in fields: name, fieldType, size, decimal = field - encoded_name = b(name, self.encoding, self.encodingErrors) + encoded_name = name.encode(self.encoding, self.encodingErrors) encoded_name = encoded_name.replace(b" ", b"_") encoded_name = encoded_name[:10].ljust(11).replace(b" ", b"\x00") - encodedFieldType = b(fieldType, "ascii") + encodedFieldType = fieldType.encode("ascii") fld = pack("<11sc4xBB14x", encoded_name, encodedFieldType, size, decimal) f.write(fld) # Terminator @@ -3093,7 +3080,7 @@ def __dbfHeader(self): def shape( self, s: Union[Shape, HasGeoInterface, dict], - ): + ) -> None: # Balance if already not balanced if self.autoBalance and self.recNum < self.shpNum: self.balance() @@ -3178,7 +3165,7 @@ def __shpRecord(self, s: Shape) -> tuple[int, int]: f.write(b_io.read()) return offset, length - def __shxRecord(self, offset, length): + def __shxRecord(self, offset: int, length: int) -> None: """Writes the shx records.""" f = self.__getFileObj(self.shx) @@ -3191,8 +3178,10 @@ def __shxRecord(self, offset, length): f.write(pack(">i", length)) def record( - self, *recordList: Iterable[RecordValue], **recordDict: dict[str, RecordValue] - ): + self, + *recordList: list[RecordValue], + **recordDict: RecordValue, + ) -> None: """Creates a dbf attribute record. You can submit either a sequence of field values or keyword arguments of field names and values. Before adding records you must add fields for the record values using the @@ -3203,10 +3192,10 @@ def record( # Balance if already not balanced if self.autoBalance and self.recNum > self.shpNum: self.balance() - + record: list[RecordValue] fieldCount = sum(1 for field in self.fields if field[0] != "DeletionFlag") if recordList: - record = list(recordList) + record = list(*recordList) while len(record) < fieldCount: record.append("") elif recordDict: @@ -3227,7 +3216,7 @@ def record( record = ["" for _ in range(fieldCount)] self.__dbfRecord(record) - def __dbfRecord(self, record): + def __dbfRecord(self, record: list[RecordValue]) -> None: """Writes the dbf records.""" f = self.__getFileObj(self.dbf) if self.recNum == 0: @@ -3246,68 +3235,69 @@ def __dbfRecord(self, record): # write fieldType = fieldType.upper() size = int(size) - if fieldType in ("N", "F"): + str_val: str + if fieldType in {"N", "F"}: # numeric or float: number stored as a string, right justified, and padded with blanks to the width of the field. if value in MISSING: - value = b"*" * size # QGIS NULL + str_val = "*" * size # QGIS NULL elif not deci: # force to int try: # first try to force directly to int. # forcing a large int to float and back to int # will lose information and result in wrong nr. - value = int(value) + int_val = int(value) except ValueError: # forcing directly to int failed, so was probably a float. - value = int(float(value)) - value = format(value, "d")[:size].rjust( + int_val = int(float(value)) + str_val = format(int_val, "d")[:size].rjust( size ) # caps the size if exceeds the field size else: - value = float(value) - value = format(value, f".{deci}f")[:size].rjust( + f_val = float(value) + str_val = format(f_val, f".{deci}f")[:size].rjust( size ) # caps the size if exceeds the field size elif fieldType == "D": # date: 8 bytes - date stored as a string in the format YYYYMMDD. - if isinstance(value, date): - value = f"{value.year:04d}{value.month:02d}{value.day:02d}" - elif isinstance(value, list) and len(value) == 3: - value = f"{value[0]:04d}{value[1]:02d}{value[2]:02d}" - elif value in MISSING: - value = b"0" * 8 # QGIS NULL for date type + if value in MISSING: + str_val = "0" * 8 # QGIS NULL for date type + elif isinstance(value, date): + str_val = f"{value.year:04d}{value.month:02d}{value.day:02d}" + elif isinstance(value, (list, tuple)) and len(value) == 3: + str_val = f"{value[0]:04d}{value[1]:02d}{value[2]:02d}" elif is_string(value) and len(value) == 8: - pass # value is already a date string + str_val = value # value is already a date string else: raise ShapefileException( - "Date values must be either a datetime.date object, a list, a YYYYMMDD string, or a missing value." + "Date values must be either a datetime.date object, a list/tuple, a YYYYMMDD string, or a missing value." ) elif fieldType == "L": # logical: 1 byte - initialized to 0x20 (space) otherwise T or F. if value in MISSING: - value = b" " # missing is set to space - elif value in [True, 1]: - value = b"T" - elif value in [False, 0]: - value = b"F" + str_val = " " # missing is set to space + elif value in {True, 1}: + str_val = "T" + elif value in {False, 0}: + str_val = "F" else: - value = b" " # unknown is set to space + str_val = " " # unknown is set to space else: # anything else is forced to string, truncated to the length of the field - value = b(value, self.encoding, self.encodingErrors)[:size].ljust(size) - if not isinstance(value, bytes): - # just in case some of the numeric format() and date strftime() results are still in unicode (Python 3 only) - value = b( - value, "ascii", self.encodingErrors - ) # should be default ascii encoding - if len(value) != size: + str_val = u(value, self.encoding, self.encodingErrors)[:size].ljust( + size + ) + + # should be default ascii encoding + encoded_val = str_val.encode("ascii", self.encodingErrors) + if len(encoded_val) != size: raise ShapefileException( "Shapefile Writer unable to pack incorrect sized value" - f" (size {len(value)}) into field '{fieldName}' (size {size})." + f" (size {len(encoded_val)}) into field '{fieldName}' (size {size})." ) - f.write(value) + f.write(encoded_val) - def balance(self): + def balance(self) -> None: """Adds corresponding empty attributes or null geometry records depending on which type of record was created to make sure all three files are in synch.""" @@ -3316,24 +3306,26 @@ def balance(self): while self.recNum < self.shpNum: self.record() - def null(self): + def null(self) -> None: """Creates a null shape.""" self.shape(NullShape()) - def point(self, x: float, y: float): + def point(self, x: float, y: float) -> None: """Creates a POINT shape.""" pointShape = Point() pointShape.points.append((x, y)) self.shape(pointShape) - def pointm(self, x: float, y: float, m: Optional[float] = None): + def pointm(self, x: float, y: float, m: Optional[float] = None) -> None: """Creates a POINTM shape. If the m (measure) value is not set, it defaults to NoData.""" pointShape = PointM() pointShape.points.append((x, y, m)) self.shape(pointShape) - def pointz(self, x: float, y: float, z: float = 0.0, m: Optional[float] = None): + def pointz( + self, x: float, y: float, z: float = 0.0, m: Optional[float] = None + ) -> None: """Creates a POINTZ shape. If the z (elevation) value is not set, it defaults to 0. If the m (measure) value is not set, it defaults to NoData.""" @@ -3341,20 +3333,20 @@ def pointz(self, x: float, y: float, z: float = 0.0, m: Optional[float] = None): pointShape.points.append((x, y, z, m)) self.shape(pointShape) - def multipoint(self, points: PointsT): + def multipoint(self, points: PointsT) -> None: """Creates a MULTIPOINT shape. Points is a list of xy values.""" # nest the points inside a list to be compatible with the generic shapeparts method self._shapeparts(parts=[points], polyShape=MultiPoint()) - def multipointm(self, points: PointsT): + def multipointm(self, points: PointsT) -> None: """Creates a MULTIPOINTM shape. Points is a list of xym values. If the m (measure) value is not included, it defaults to None (NoData).""" # nest the points inside a list to be compatible with the generic shapeparts method self._shapeparts(parts=[points], polyShape=MultiPointM()) - def multipointz(self, points: PointsT): + def multipointz(self, points: PointsT) -> None: """Creates a MULTIPOINTZ shape. Points is a list of xyzm values. If the z (elevation) value is not included, it defaults to 0. @@ -3362,32 +3354,32 @@ def multipointz(self, points: PointsT): # nest the points inside a list to be compatible with the generic shapeparts method self._shapeparts(parts=[points], polyShape=MultiPointZ()) - def line(self, lines: list[PointsT]): + def line(self, lines: list[PointsT]) -> None: """Creates a POLYLINE shape. Lines is a collection of lines, each made up of a list of xy values.""" self._shapeparts(parts=lines, polyShape=Polyline()) - def linem(self, lines: list[PointsT]): + def linem(self, lines: list[PointsT]) -> None: """Creates a POLYLINEM shape. Lines is a collection of lines, each made up of a list of xym values. If the m (measure) value is not included, it defaults to None (NoData).""" self._shapeparts(parts=lines, polyShape=PolylineM()) - def linez(self, lines: list[PointsT]): + def linez(self, lines: list[PointsT]) -> None: """Creates a POLYLINEZ shape. Lines is a collection of lines, each made up of a list of xyzm values. If the z (elevation) value is not included, it defaults to 0. If the m (measure) value is not included, it defaults to None (NoData).""" self._shapeparts(parts=lines, polyShape=PolylineZ()) - def poly(self, polys: list[PointsT]): + def poly(self, polys: list[PointsT]) -> None: """Creates a POLYGON shape. Polys is a collection of polygons, each made up of a list of xy values. Note that for ordinary polygons the coordinates must run in a clockwise direction. If some of the polygons are holes, these must run in a counterclockwise direction.""" self._shapeparts(parts=polys, polyShape=Polygon()) - def polym(self, polys: list[PointsT]): + def polym(self, polys: list[PointsT]) -> None: """Creates a POLYGONM shape. Polys is a collection of polygons, each made up of a list of xym values. Note that for ordinary polygons the coordinates must run in a clockwise direction. @@ -3395,7 +3387,7 @@ def polym(self, polys: list[PointsT]): If the m (measure) value is not included, it defaults to None (NoData).""" self._shapeparts(parts=polys, polyShape=PolygonM()) - def polyz(self, polys: list[PointsT]): + def polyz(self, polys: list[PointsT]) -> None: """Creates a POLYGONZ shape. Polys is a collection of polygons, each made up of a list of xyzm values. Note that for ordinary polygons the coordinates must run in a clockwise direction. @@ -3404,7 +3396,7 @@ def polyz(self, polys: list[PointsT]): If the m (measure) value is not included, it defaults to None (NoData).""" self._shapeparts(parts=polys, polyShape=PolygonZ()) - def multipatch(self, parts: list[PointsT], partTypes: list[int]): + def multipatch(self, parts: list[PointsT], partTypes: list[int]) -> None: """Creates a MULTIPATCH shape. Parts is a collection of 3D surface patches, each made up of a list of xyzm values. PartTypes is a list of types that define each of the surface patches. @@ -3431,7 +3423,7 @@ def multipatch(self, parts: list[PointsT], partTypes: list[int]): def _shapeparts( self, parts: list[PointsT], polyShape: Union[Polyline, Polygon, MultiPoint] - ): + ) -> None: """Internal method for adding a shape that has multiple collections of points (parts): lines, polygons, and multipoint shapes. """ @@ -3442,7 +3434,7 @@ def _shapeparts( # if shapeType in (5, 15, 25, 31): # This method is never actually called on a MultiPatch # so we omit its shapeType (31) for efficiency - if isinstance(polyShape, Polygon): + if compatible_with(polyShape, Polygon): for part in parts: if part[0] != part[-1]: part.append(part[0]) @@ -3466,7 +3458,7 @@ def field( fieldType: str = "C", size: int = 50, decimal: int = 0, - ): + ) -> None: """Adds a dbf field descriptor to the shapefile.""" if fieldType == "D": size = 8 From 532c02585954d951820abab3bec09c3ec0880a53 Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Thu, 31 Jul 2025 00:21:46 +0100 Subject: [PATCH 12/20] Need to type dbf coercer methods --- src/shapefile.py | 167 +++++++++++++++++++++++++++++------------------ 1 file changed, 102 insertions(+), 65 deletions(-) diff --git a/src/shapefile.py b/src/shapefile.py index ed91d64..d260953 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -157,10 +157,20 @@ def read(self, size: int = -1): ... BinaryFileT = Union[str, IO[bytes]] BinaryFileStreamT = Union[IO[bytes], io.BytesIO, WriteSeekableBinStream] +FieldTypeCode = Literal["N", "F", "L", "D", "C", "M"] +FIELD_TYPE_CODES: dict[str, FieldTypeCode] = { + "N": "N", + "F": "F", + "L": "L", + "D": "D", + "C": "C", + "M": "M", +} + class FieldData(NamedTuple): name: str - fieldType: str + fieldType: FieldTypeCode size: int decimal: int @@ -289,10 +299,6 @@ def u( return bytes(v).decode(encoding, encodingErrors) -def is_string(v: Any) -> bool: - return isinstance(v, str) - - @overload def fsdecode_if_pathlike(path: os.PathLike) -> str: ... @overload @@ -1776,7 +1782,7 @@ def __init__( # See if a shapefile name was passed as the first argument if shapefile_path: path = fsdecode_if_pathlike(shapefile_path) - if is_string(path): + if isinstance(path, str): if ".zip" in path: # Shapefile is inside a zipfile if path.count(".zip") > 1: @@ -2376,7 +2382,9 @@ def __dbfHeader(self) -> None: name = encoded_name.decode(self.encoding, self.encodingErrors) name = name.lstrip() - type_char = encoded_type_char.decode("ascii") + type_char: FieldTypeCode = FIELD_TYPE_CODES.get( + encoded_type_char.decode("ascii").upper(), "C" + ) self.fields.append(FieldData(name, type_char, size, decimal)) terminator = dbf.read(1) @@ -2473,8 +2481,8 @@ def __record( """ f = self.__getFileObj(self.dbf) - # The only format chars in from self.__recordFmt, in recStruct from __recordFields, - # are s and x (ascii encoded str and pad byte) so everything in recordContents is bytes + # The only format chars in from self.__recordFmt, in recStruct from __recordFields, + # are s and x (ascii encoded str and pad byte) so everything in recordContents is bytes # https://docs.python.org/3/library/struct.html#format-characters recordContents = recStruct.unpack(f.read(recStruct.size)) @@ -2552,8 +2560,7 @@ def __record( else: value = None # unknown value is set to missing else: - # anything else is forced to string/unicode - value = u(value, self.encoding, self.encodingErrors) + value = value.decode(self.encoding, self.encodingErrors) value = value.strip().rstrip( "\x00" ) # remove null-padding at end of strings @@ -2738,7 +2745,7 @@ def __init__( self._files_to_close: list[BinaryFileStreamT] = [] if target: target = fsdecode_if_pathlike(target) - if not is_string(target): + if not isinstance(target, str): raise TypeError( f"The target filepath {target!r} must be of type str/unicode or path-like, not {type(target)}." ) @@ -3216,6 +3223,65 @@ def record( record = ["" for _ in range(fieldCount)] self.__dbfRecord(record) + @staticmethod + def _dbf_missing_placeholder( + value: RecordValue, fieldType: FieldTypeCode, size: int + ) -> str: + if fieldType in {"N", "F"}: + return "*" * size # QGIS NULL + if fieldType == "D": + return "0" * 8 # QGIS NULL for date type + if fieldType == "L": + return " " + return str(value) + + @staticmethod + def _try_coerce_to_numeric_str(value: RecordValue, size: int, deci: int) -> str: + # numeric or float: number stored as a string, + # right justified, and padded with blanks + # to the width of the field. + if not deci: + # force to int + try: + # first try to force directly to int. + # forcing a large int to float and back to int + # will lose information and result in wrong nr. + int_val = int(value) + except ValueError: + # forcing directly to int failed, so was probably a float. + int_val = int(float(value)) + + str_val = format(int_val, "d") + else: + f_val = float(value) + str_val = format(f_val, f".{deci}f") + + # caps the size if exceeds the field size + return str_val[:size].rjust(size) + + @staticmethod + def _try_coerce_to_date_str(value: RecordValue) -> str: + # date: 8 bytes - date stored as a string in the format YYYYMMDD. + if isinstance(value, date): + return f"{value.year:04d}{value.month:02d}{value.day:02d}" + if isinstance(value, (list, tuple)) and len(value) == 3: + return f"{value[0]:04d}{value[1]:02d}{value[2]:02d}" + if isinstance(value, str) and len(value) == 8: + return value # value is already a date string + + raise ShapefileException( + "Date values must be either a datetime.date object, a list/tuple, a YYYYMMDD string, or a missing value." + ) + + @staticmethod + def _try_coerce_to_logical_str(value: RecordValue) -> str: + # logical: 1 byte - initialized to 0x20 (space) otherwise T or F. + if value == 1: # True == 1 + return "T" + if value == 0: # False == 0 + return "F" + return " " # unknown is set to space + def __dbfRecord(self, record: list[RecordValue]) -> None: """Writes the dbf records.""" f = self.__getFileObj(self.dbf) @@ -3233,63 +3299,31 @@ def __dbfRecord(self, record: list[RecordValue]) -> None: ) # ignore deletionflag field in case it was specified for (fieldName, fieldType, size, deci), value in zip(fields, record): # write - fieldType = fieldType.upper() + fieldType = FIELD_TYPE_CODES.get(fieldType.upper(), fieldType) size = int(size) str_val: str + + if value in MISSING: + str_val = self._dbf_missing_placeholder(value, fieldType, size) if fieldType in {"N", "F"}: - # numeric or float: number stored as a string, right justified, and padded with blanks to the width of the field. - if value in MISSING: - str_val = "*" * size # QGIS NULL - elif not deci: - # force to int - try: - # first try to force directly to int. - # forcing a large int to float and back to int - # will lose information and result in wrong nr. - int_val = int(value) - except ValueError: - # forcing directly to int failed, so was probably a float. - int_val = int(float(value)) - str_val = format(int_val, "d")[:size].rjust( - size - ) # caps the size if exceeds the field size - else: - f_val = float(value) - str_val = format(f_val, f".{deci}f")[:size].rjust( - size - ) # caps the size if exceeds the field size + str_val = self._try_coerce_to_numeric_str(value, size, deci) elif fieldType == "D": - # date: 8 bytes - date stored as a string in the format YYYYMMDD. - if value in MISSING: - str_val = "0" * 8 # QGIS NULL for date type - elif isinstance(value, date): - str_val = f"{value.year:04d}{value.month:02d}{value.day:02d}" - elif isinstance(value, (list, tuple)) and len(value) == 3: - str_val = f"{value[0]:04d}{value[1]:02d}{value[2]:02d}" - elif is_string(value) and len(value) == 8: - str_val = value # value is already a date string - else: - raise ShapefileException( - "Date values must be either a datetime.date object, a list/tuple, a YYYYMMDD string, or a missing value." - ) + str_val = self._try_coerce_to_date_str(value) elif fieldType == "L": - # logical: 1 byte - initialized to 0x20 (space) otherwise T or F. - if value in MISSING: - str_val = " " # missing is set to space - elif value in {True, 1}: - str_val = "T" - elif value in {False, 0}: - str_val = "F" - else: - str_val = " " # unknown is set to space + str_val = self._try_coerce_to_logical_str(value) else: - # anything else is forced to string, truncated to the length of the field - str_val = u(value, self.encoding, self.encodingErrors)[:size].ljust( - size - ) - + # + if isinstance(value, bytes): + decoded_val = value.decode(self.encoding, self.encodingErrors) + else: + # anything else is forced to string. + decoded_val = str(value) + # Truncate to the length of the field + str_val = decoded_val[:size].ljust(size) + # should be default ascii encoding encoded_val = str_val.encode("ascii", self.encodingErrors) + if len(encoded_val) != size: raise ShapefileException( "Shapefile Writer unable to pack incorrect sized value" @@ -3455,7 +3489,7 @@ def field( # Types of args should match *FieldData self, name: str, - fieldType: str = "C", + fieldType: FieldTypeCode = "C", size: int = 50, decimal: int = 0, ) -> None: @@ -3466,15 +3500,18 @@ def field( elif fieldType == "L": size = 1 decimal = 0 + elif fieldType not in {"C", "N", "F", "M"}: + raise ShapefileException( + "fieldType must be C,N,F,M,L or D. " f"Got: {fieldType=}. " + ) if len(self.fields) >= 2046: raise ShapefileException( "Shapefile Writer reached maximum number of fields: 2046." ) # A doctest in README.md used to pass in a string ('40') for size, so # try to be robust for incorrect types. - self.fields.append( - FieldData(str(name), str(fieldType), int(size), int(decimal)) - ) + fieldType = FIELD_TYPE_CODES.get(str(fieldType)[0].upper(), fieldType) + self.fields.append(FieldData(str(name), fieldType, int(size), int(decimal))) # Begin Testing From bd729ff9fee84bded53e56f1bfd0c10ff8aad48b Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Thu, 31 Jul 2025 15:22:40 +0100 Subject: [PATCH 13/20] Make field code an enum, and BBox, ZBox & MBox named tuples --- src/shapefile.py | 188 +++++++++++++++++++++++++++++------------------ 1 file changed, 115 insertions(+), 73 deletions(-) diff --git a/src/shapefile.py b/src/shapefile.py index d260953..92fe3a7 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -12,6 +12,7 @@ import array import doctest +import enum import io import logging import os @@ -37,6 +38,7 @@ TypedDict, TypeVar, Union, + cast, overload, ) from urllib.error import HTTPError @@ -121,9 +123,22 @@ PointT = Union[Point2D, PointMT, PointZT] PointsT = list[PointT] -BBox = tuple[float, float, float, float] -MBox = tuple[float, float] -ZBox = tuple[float, float] + +class BBox(NamedTuple): + xmin: float + ymin: float + xmax: float + ymax: float + + +class MBox(NamedTuple): + mmin: Optional[float] + mmax: Optional[float] + + +class ZBox(NamedTuple): + zmin: float + zmax: float class WriteableBinStream(Protocol): @@ -157,27 +172,57 @@ def read(self, size: int = -1): ... BinaryFileT = Union[str, IO[bytes]] BinaryFileStreamT = Union[IO[bytes], io.BytesIO, WriteSeekableBinStream] -FieldTypeCode = Literal["N", "F", "L", "D", "C", "M"] -FIELD_TYPE_CODES: dict[str, FieldTypeCode] = { - "N": "N", - "F": "F", - "L": "L", - "D": "D", - "C": "C", - "M": "M", -} + +# https://en.wikipedia.org/wiki/.dbf#Database_records +class FieldType(enum.Enum): + # Use an ascii-encoded byte of the name, to save a decoding step. + C = "Character" # (str) + D = "Date" + F = "Float" + L = "Logical" # (bool) + M = "Memo" # Legacy. (10 digit str, starting block in an .dbt file) + N = "Numeric" # (int) class FieldData(NamedTuple): name: str - fieldType: FieldTypeCode + fieldType: FieldType size: int decimal: int + @classmethod + def from_unchecked( + cls, + name: str, + fieldType: Union[str, FieldType] = FieldType.C, + size: int = 50, + decimal: int = 0, + ) -> Self: + if isinstance(fieldType, str): + try: + fieldType = FieldType[fieldType.upper()] + except: + raise ShapefileException( + "fieldType must be C,N,F,M,L,D or a FieldType enum member. " + f"Got: {fieldType=}. " + ) -RecordValue = Union[ - bool, int, float, str, date -] # A Possible value in a Shapefile record, e.g. L, N, F, C, D types + if fieldType is FieldType.D: + size = 8 + decimal = 0 + elif fieldType is FieldType.L: + size = 1 + decimal = 0 + + # A doctest in README.md previously passed in a string ('40') for size, + # so explictly convert name to str, and size and decimal to ints. + return cls(str(name), fieldType, int(size), int(decimal)) + + +RecordValueNotDate = Union[bool, int, float, str, date] + +# A Possible value in a Shapefile dbf record, i.e. L, N, M, F, C, or D types +RecordValue = Union[RecordValueNotDate, date] class HasGeoInterface(Protocol): @@ -356,7 +401,7 @@ def rewind(coords: Reversible[PointT]) -> PointsT: def ring_bbox(coords: PointsT) -> BBox: """Calculates and returns the bounding box of a ring.""" xs, ys = map(list, list(zip(*coords))[:2]) # ignore any z or m values - bbox = min(xs), min(ys), max(xs), max(ys) + bbox = BBox(xmin=min(xs), ymin=min(ys), xmax=max(xs), ymax=max(ys)) return bbox @@ -939,7 +984,7 @@ class _CanHaveBBox(Shape): bbox: Optional[BBox] = None def _get_set_bbox_from_byte_stream(self, b_io: ReadableBinStream) -> BBox: - self.bbox: BBox = tuple(_Array[float]("d", unpack("<4d", b_io.read(32)))) + self.bbox: BBox = BBox(*_Array[float]("d", unpack("<4d", b_io.read(32)))) return self.bbox @staticmethod @@ -1159,7 +1204,7 @@ def from_byte_stream( if bbox is not None: # create bounding box for Point by duplicating coordinates # skip shape if no overlap with bounding box - if not bbox_overlap(bbox, (x, y, x, y)): + if not bbox_overlap(bbox, BBox(x, y, x, y)): return None shape.points = [(x, y)] @@ -2174,16 +2219,18 @@ def __shpHeader(self) -> None: shp.seek(32) self.shapeType = unpack("= NODATA else None for m_bound in unpack("<2d", shp.read(16)) ] - self.mbox = tuple(m_bounds[:2]) + self.mbox = MBox(mmin=m_bounds[0], mmax=m_bounds[1]) def __shape( self, oid: Optional[int] = None, bbox: Optional[BBox] = None @@ -2382,11 +2429,9 @@ def __dbfHeader(self) -> None: name = encoded_name.decode(self.encoding, self.encodingErrors) name = name.lstrip() - type_char: FieldTypeCode = FIELD_TYPE_CODES.get( - encoded_type_char.decode("ascii").upper(), "C" - ) + field_type = FieldType[encoded_type_char.decode("ascii").upper()] - self.fields.append(FieldData(name, type_char, size, decimal)) + self.fields.append(FieldData(name, field_type, size, decimal)) terminator = dbf.read(1) if terminator != b"\r": raise ShapefileException( @@ -2394,7 +2439,7 @@ def __dbfHeader(self) -> None: ) # insert deletion field at start - self.fields.insert(0, FieldData("DeletionFlag", "C", 1, 0)) + self.fields.insert(0, FieldData("DeletionFlag", FieldType.C, 1, 0)) # store all field positions for easy lookups # note: fieldLookup gives the index position of a field inside Reader.fields @@ -2892,10 +2937,10 @@ def __bbox(self, s: Shape) -> BBox: y: list[float] = [] if self._bbox: - x.append(self._bbox[0]) - y.append(self._bbox[1]) - x.append(self._bbox[2]) - y.append(self._bbox[3]) + x.append(self._bbox.xmin) + y.append(self._bbox.ymin) + x.append(self._bbox.xmax) + y.append(self._bbox.ymax) if len(s.points) > 0: px, py = list(zip(*s.points))[:2] @@ -2909,7 +2954,7 @@ def __bbox(self, s: Shape) -> BBox: "Cannot create bbox. Expected a valid shape with at least one point. " f"Got a shape of type '{s.shapeType}' and 0 points." ) - self._bbox = (min(x), min(y), max(x), max(y)) + self._bbox = BBox(xmin=min(x), ymin=min(y), xmax=max(x), ymax=max(y)) return self._bbox def __zbox(self, s) -> ZBox: @@ -2927,14 +2972,14 @@ def __zbox(self, s) -> ZBox: # Original self._zbox bounds (if any) are the first two entries. # Set zbox for the first, and all later times - self._zbox = (min(z), max(z)) + self._zbox = ZBox(zmin=min(z), zmax=max(z)) return self._zbox def __mbox(self, s) -> MBox: mpos = 3 if s.shapeType in _HasZ._shapeTypes else 2 m: list[float] = [] if self._mbox: - m.extend(self._mbox) + m.extend(m_bound for m_bound in self._mbox if m_bound is not None) for p in s.points: try: @@ -2951,7 +2996,7 @@ def __mbox(self, s) -> MBox: # Original self._mbox bounds (if any) are the first two entries. # Set mbox for the first, and all later times - self._mbox = (min(m), max(m)) + self._mbox = MBox(mmin=min(m), mmax=max(m)) return self._mbox @property @@ -3003,7 +3048,7 @@ def __shapefileHeader( # In such cases of empty shapefiles, ESRI spec says the bbox values are 'unspecified'. # Not sure what that means, so for now just setting to 0s, which is the same behavior as in previous versions. # This would also make sense since the Z and M bounds are similarly set to 0 for non-Z/M type shapefiles. - bbox = (0, 0, 0, 0) + bbox = BBox(0, 0, 0, 0) f.write(pack("<4d", *bbox)) except error: raise ShapefileException( @@ -3017,20 +3062,20 @@ def __shapefileHeader( zbox = self.zbox() if zbox is None: # means we have empty shapefile/only null geoms (see commentary on bbox above) - zbox = (0, 0) + zbox = ZBox(0, 0) else: # As per the ESRI shapefile spec, the zbox for non-Z type shapefiles are set to 0s - zbox = (0, 0) + zbox = ZBox(0, 0) # Measure if self.shapeType in {POINTM, POINTZ} | _HasM._shapeTypes: # M values are present in M or Z type mbox = self.mbox() if mbox is None: # means we have empty shapefile/only null geoms (see commentary on bbox above) - mbox = (0, 0) + mbox = MBox(0, 0) else: # As per the ESRI shapefile spec, the mbox for non-M type shapefiles are set to 0s - mbox = (0, 0) + mbox = MBox(0, 0) # Try writing try: f.write(pack("<4d", zbox[0], zbox[1], mbox[0], mbox[1])) @@ -3078,7 +3123,7 @@ def __dbfHeader(self) -> None: encoded_name = name.encode(self.encoding, self.encodingErrors) encoded_name = encoded_name.replace(b" ", b"_") encoded_name = encoded_name[:10].ljust(11).replace(b" ", b"\x00") - encodedFieldType = fieldType.encode("ascii") + encodedFieldType = fieldType.name.encode("ascii") fld = pack("<11sc4xBB14x", encoded_name, encodedFieldType, size, decimal) f.write(fld) # Terminator @@ -3225,18 +3270,26 @@ def record( @staticmethod def _dbf_missing_placeholder( - value: RecordValue, fieldType: FieldTypeCode, size: int + value: RecordValue, fieldType: FieldType, size: int ) -> str: - if fieldType in {"N", "F"}: + if fieldType in {FieldType.N, FieldType.F}: return "*" * size # QGIS NULL - if fieldType == "D": + if fieldType is FieldType.D: return "0" * 8 # QGIS NULL for date type - if fieldType == "L": + if fieldType is FieldType.L: return " " return str(value) + @overload + @staticmethod + def _try_coerce_to_numeric_str(value: date, size: int, deci: int) -> Never: ... + @overload @staticmethod - def _try_coerce_to_numeric_str(value: RecordValue, size: int, deci: int) -> str: + def _try_coerce_to_numeric_str( + value: RecordValueNotDate, size: int, deci: int + ) -> str: ... + @staticmethod + def _try_coerce_to_numeric_str(value, size, deci): # numeric or float: number stored as a string, # right justified, and padded with blanks # to the width of the field. @@ -3247,17 +3300,20 @@ def _try_coerce_to_numeric_str(value: RecordValue, size: int, deci: int) -> str: # forcing a large int to float and back to int # will lose information and result in wrong nr. int_val = int(value) - except ValueError: + except (ValueError, TypeError): # forcing directly to int failed, so was probably a float. int_val = int(float(value)) + except TypeError: + raise ShapefileException(f"Could not form int from: {value}") + # length capped to the field size + return format(int_val, "d")[:size].rjust(size) - str_val = format(int_val, "d") - else: + try: f_val = float(value) - str_val = format(f_val, f".{deci}f") - - # caps the size if exceeds the field size - return str_val[:size].rjust(size) + except ValueError: + raise ShapefileException(f"Could not form float from: {value}") + # length capped to the field size + return format(f_val, f".{deci}f")[:size].rjust(size) @staticmethod def _try_coerce_to_date_str(value: RecordValue) -> str: @@ -3299,20 +3355,18 @@ def __dbfRecord(self, record: list[RecordValue]) -> None: ) # ignore deletionflag field in case it was specified for (fieldName, fieldType, size, deci), value in zip(fields, record): # write - fieldType = FIELD_TYPE_CODES.get(fieldType.upper(), fieldType) size = int(size) str_val: str if value in MISSING: str_val = self._dbf_missing_placeholder(value, fieldType, size) - if fieldType in {"N", "F"}: + elif fieldType in {FieldType.N, FieldType.F}: str_val = self._try_coerce_to_numeric_str(value, size, deci) - elif fieldType == "D": + elif fieldType is FieldType.D: str_val = self._try_coerce_to_date_str(value) - elif fieldType == "L": + elif fieldType is FieldType.L: str_val = self._try_coerce_to_logical_str(value) else: - # if isinstance(value, bytes): decoded_val = value.decode(self.encoding, self.encodingErrors) else: @@ -3489,29 +3543,17 @@ def field( # Types of args should match *FieldData self, name: str, - fieldType: FieldTypeCode = "C", + fieldType: Union[str, FieldType] = FieldType.C, size: int = 50, decimal: int = 0, ) -> None: """Adds a dbf field descriptor to the shapefile.""" - if fieldType == "D": - size = 8 - decimal = 0 - elif fieldType == "L": - size = 1 - decimal = 0 - elif fieldType not in {"C", "N", "F", "M"}: - raise ShapefileException( - "fieldType must be C,N,F,M,L or D. " f"Got: {fieldType=}. " - ) if len(self.fields) >= 2046: raise ShapefileException( "Shapefile Writer reached maximum number of fields: 2046." ) - # A doctest in README.md used to pass in a string ('40') for size, so - # try to be robust for incorrect types. - fieldType = FIELD_TYPE_CODES.get(str(fieldType)[0].upper(), fieldType) - self.fields.append(FieldData(str(name), fieldType, int(size), int(decimal))) + + self.fields.append(FieldData.from_unchecked(name, fieldType, size, decimal)) # Begin Testing From 4355660699de0c2fc1a12ecd3a4374d39d6010bd Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Thu, 31 Jul 2025 15:30:07 +0100 Subject: [PATCH 14/20] Remove unused imports and dedupe except TypeError: --- src/shapefile.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/shapefile.py b/src/shapefile.py index 92fe3a7..3e89bc6 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -38,14 +38,13 @@ TypedDict, TypeVar, Union, - cast, overload, ) from urllib.error import HTTPError from urllib.parse import urlparse, urlunparse from urllib.request import Request, urlopen -from typing_extensions import Never, NotRequired, Self, TypeIs, Unpack +from typing_extensions import Never, NotRequired, Self, TypeIs # Create named logger logger = logging.getLogger(__name__) @@ -3119,12 +3118,17 @@ def __dbfHeader(self) -> None: f.write(header) # Field descriptors for field in fields: - name, fieldType, size, decimal = field - encoded_name = name.encode(self.encoding, self.encodingErrors) + encoded_name = field.name.encode(self.encoding, self.encodingErrors) encoded_name = encoded_name.replace(b" ", b"_") encoded_name = encoded_name[:10].ljust(11).replace(b" ", b"\x00") - encodedFieldType = fieldType.name.encode("ascii") - fld = pack("<11sc4xBB14x", encoded_name, encodedFieldType, size, decimal) + encodedFieldType = field.fieldType.name.encode("ascii") + fld = pack( + "<11sc4xBB14x", + encoded_name, + encodedFieldType, + field.size, + field.decimal, + ) f.write(fld) # Terminator f.write(b"\r") @@ -3300,7 +3304,7 @@ def _try_coerce_to_numeric_str(value, size, deci): # forcing a large int to float and back to int # will lose information and result in wrong nr. int_val = int(value) - except (ValueError, TypeError): + except ValueError: # forcing directly to int failed, so was probably a float. int_val = int(float(value)) except TypeError: From 0c7dfc3be7ba2534d77baf7f63045077e836a095 Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Thu, 31 Jul 2025 20:36:37 +0100 Subject: [PATCH 15/20] Define briefer Field.__repr__. Delete u() --- src/shapefile.py | 110 +++++++++++++++++++++-------------------------- 1 file changed, 49 insertions(+), 61 deletions(-) diff --git a/src/shapefile.py b/src/shapefile.py index 3e89bc6..0c47007 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -177,15 +177,16 @@ class FieldType(enum.Enum): # Use an ascii-encoded byte of the name, to save a decoding step. C = "Character" # (str) D = "Date" - F = "Float" + F = "Floating point" L = "Logical" # (bool) M = "Memo" # Legacy. (10 digit str, starting block in an .dbt file) N = "Numeric" # (int) -class FieldData(NamedTuple): +# Use functional syntax to have an attribute named type, a Python keyword +class Field(NamedTuple): name: str - fieldType: FieldType + field_type: FieldType size: int decimal: int @@ -193,29 +194,34 @@ class FieldData(NamedTuple): def from_unchecked( cls, name: str, - fieldType: Union[str, FieldType] = FieldType.C, + field_type: Union[str, FieldType] = FieldType.C, size: int = 50, decimal: int = 0, ) -> Self: - if isinstance(fieldType, str): - try: - fieldType = FieldType[fieldType.upper()] - except: + if isinstance(field_type, str): + if field_type.upper() in FieldType.__members__: + field_type = FieldType[field_type.upper()] + else: raise ShapefileException( - "fieldType must be C,N,F,M,L,D or a FieldType enum member. " - f"Got: {fieldType=}. " + "type must be C,D,F,L,M,N, or a FieldType enum member. " + f"Got: {field_type=}. " ) - if fieldType is FieldType.D: + if field_type is FieldType.D: size = 8 decimal = 0 - elif fieldType is FieldType.L: + elif field_type is FieldType.L: size = 1 decimal = 0 # A doctest in README.md previously passed in a string ('40') for size, # so explictly convert name to str, and size and decimal to ints. - return cls(str(name), fieldType, int(size), int(decimal)) + return cls( + name=str(name), field_type=field_type, size=int(size), decimal=int(decimal) + ) + + def __repr__(self) -> str: + return f'Field(name="{self.name}", field_type=FieldType.{self.field_type.name}, size={self.size}, decimal={self.decimal})' RecordValueNotDate = Union[bool, int, float, str, date] @@ -327,22 +333,6 @@ class GeoJSONFeatureCollectionWithBBox(GeoJSONFeatureCollection): unpack_2_int32_be = Struct(">2i").unpack -def u( - v: Union[str, bytes], encoding: str = "utf-8", encodingErrors: str = "strict" -) -> str: - if isinstance(v, bytes): - # For python 3 decode bytes to str. - return v.decode(encoding, encodingErrors) - if isinstance(v, str): - # Already str. - return v - if v is None: - # Since we're dealing with text, interpret None as "" - return "" - # Force string representation. - return bytes(v).decode(encoding, encodingErrors) - - @overload def fsdecode_if_pathlike(path: os.PathLike) -> str: ... @overload @@ -1818,7 +1808,7 @@ def __init__( self.shpLength: Optional[int] = None self.numRecords: Optional[int] = None self.numShapes: Optional[int] = None - self.fields: list[FieldData] = [] + self.fields: list[Field] = [] self.__dbfHdrLength = 0 self.__fieldLookup: dict[str, int] = {} self.encoding = encoding @@ -2430,7 +2420,7 @@ def __dbfHeader(self) -> None: field_type = FieldType[encoded_type_char.decode("ascii").upper()] - self.fields.append(FieldData(name, field_type, size, decimal)) + self.fields.append(Field(name, field_type, size, decimal)) terminator = dbf.read(1) if terminator != b"\r": raise ShapefileException( @@ -2438,7 +2428,7 @@ def __dbfHeader(self) -> None: ) # insert deletion field at start - self.fields.insert(0, FieldData("DeletionFlag", FieldType.C, 1, 0)) + self.fields.insert(0, Field("DeletionFlag", FieldType.C, 1, 0)) # store all field positions for easy lookups # note: fieldLookup gives the index position of a field inside Reader.fields @@ -2479,7 +2469,7 @@ def __recordFmt(self, fields: Optional[Container[str]] = None) -> tuple[str, int def __recordFields( self, fields: Optional[Iterable[str]] = None - ) -> tuple[list[FieldData], dict[str, int], Struct]: + ) -> tuple[list[Field], dict[str, int], Struct]: """Returns the necessary info required to unpack a record's fields, restricted to a subset of fieldnames 'fields' if specified. Returns a list of field info tuples, a name-index lookup dict, @@ -2514,13 +2504,13 @@ def __recordFields( def __record( self, - fieldTuples: list[FieldData], + fieldTuples: list[Field], recLookup: dict[str, int], recStruct: Struct, oid: Optional[int] = None, ) -> Optional[_Record]: """Reads and returns a dbf record row as a list of values. Requires specifying - a list of field info FieldData namedtuples 'fieldTuples', a record name-index dict 'recLookup', + a list of field info Field namedtuples 'fieldTuples', a record name-index dict 'recLookup', and a Struct instance 'recStruct' for unpacking these fields. """ f = self.__getFileObj(self.dbf) @@ -2547,14 +2537,14 @@ def __record( # parse each value record = [] - for (__name, typ, __size, deci), value in zip(fieldTuples, recordContents): + for (__name, typ, __size, decimal), value in zip(fieldTuples, recordContents): if typ in {"N", "F"}: # numeric or float: number stored as a string, right justified, and padded with blanks to the width of the field. value = value.split(b"\0")[0] value = value.replace(b"*", b"") # QGIS NULL is all '*' chars if value == b"": value = None - elif deci: + elif decimal: try: value = float(value) except ValueError: @@ -2590,7 +2580,7 @@ def __record( y, m, d = int(value[:4]), int(value[4:6]), int(value[6:8]) value = date(y, m, d) except (TypeError, ValueError): - # if invalid date, just return as unicode string so user can decide + # if invalid date, just return as unicode string so user can decimalde value = str(value.strip()) elif typ == "L": # logical: 1 byte - initialized to 0x20 (space) otherwise T or F. @@ -2781,7 +2771,7 @@ def __init__( ): self.target = target self.autoBalance = autoBalance - self.fields: list[FieldData] = [] + self.fields: list[Field] = [] self.shapeType = shapeType self.shp: Optional[WriteSeekableBinStream] = None self.shx: Optional[WriteSeekableBinStream] = None @@ -3121,7 +3111,7 @@ def __dbfHeader(self) -> None: encoded_name = field.name.encode(self.encoding, self.encodingErrors) encoded_name = encoded_name.replace(b" ", b"_") encoded_name = encoded_name[:10].ljust(11).replace(b" ", b"\x00") - encodedFieldType = field.fieldType.name.encode("ascii") + encodedFieldType = field.field_type.name.encode("ascii") fld = pack( "<11sc4xBB14x", encoded_name, @@ -3273,31 +3263,29 @@ def record( self.__dbfRecord(record) @staticmethod - def _dbf_missing_placeholder( - value: RecordValue, fieldType: FieldType, size: int - ) -> str: - if fieldType in {FieldType.N, FieldType.F}: + def _dbf_missing_placeholder(value: RecordValue, type: FieldType, size: int) -> str: + if type in {FieldType.N, FieldType.F}: return "*" * size # QGIS NULL - if fieldType is FieldType.D: + if type is FieldType.D: return "0" * 8 # QGIS NULL for date type - if fieldType is FieldType.L: + if type is FieldType.L: return " " return str(value) @overload @staticmethod - def _try_coerce_to_numeric_str(value: date, size: int, deci: int) -> Never: ... + def _try_coerce_to_numeric_str(value: date, size: int, decimal: int) -> Never: ... @overload @staticmethod def _try_coerce_to_numeric_str( - value: RecordValueNotDate, size: int, deci: int + value: RecordValueNotDate, size: int, decimal: int ) -> str: ... @staticmethod - def _try_coerce_to_numeric_str(value, size, deci): + def _try_coerce_to_numeric_str(value, size, decimal): # numeric or float: number stored as a string, # right justified, and padded with blanks # to the width of the field. - if not deci: + if not decimal: # force to int try: # first try to force directly to int. @@ -3317,7 +3305,7 @@ def _try_coerce_to_numeric_str(value, size, deci): except ValueError: raise ShapefileException(f"Could not form float from: {value}") # length capped to the field size - return format(f_val, f".{deci}f")[:size].rjust(size) + return format(f_val, f".{decimal}f")[:size].rjust(size) @staticmethod def _try_coerce_to_date_str(value: RecordValue) -> str: @@ -3357,18 +3345,18 @@ def __dbfRecord(self, record: list[RecordValue]) -> None: fields = ( field for field in self.fields if field[0] != "DeletionFlag" ) # ignore deletionflag field in case it was specified - for (fieldName, fieldType, size, deci), value in zip(fields, record): + for (fieldName, type, size, decimal), value in zip(fields, record): # write size = int(size) str_val: str if value in MISSING: - str_val = self._dbf_missing_placeholder(value, fieldType, size) - elif fieldType in {FieldType.N, FieldType.F}: - str_val = self._try_coerce_to_numeric_str(value, size, deci) - elif fieldType is FieldType.D: + str_val = self._dbf_missing_placeholder(value, type, size) + elif type in {FieldType.N, FieldType.F}: + str_val = self._try_coerce_to_numeric_str(value, size, decimal) + elif type is FieldType.D: str_val = self._try_coerce_to_date_str(value) - elif fieldType is FieldType.L: + elif type is FieldType.L: str_val = self._try_coerce_to_logical_str(value) else: if isinstance(value, bytes): @@ -3544,10 +3532,10 @@ def _shapeparts( self.shape(polyShape) def field( - # Types of args should match *FieldData + # Types of args should match *Field self, name: str, - fieldType: Union[str, FieldType] = FieldType.C, + field_type: Union[str, FieldType] = FieldType.C, size: int = 50, decimal: int = 0, ) -> None: @@ -3556,8 +3544,8 @@ def field( raise ShapefileException( "Shapefile Writer reached maximum number of fields: 2046." ) - - self.fields.append(FieldData.from_unchecked(name, fieldType, size, decimal)) + field_ = Field.from_unchecked(name, field_type, size, decimal) + self.fields.append(field_) # Begin Testing From 5f3f0d27ab8a97aff38c3f47b86684a8f8f8c441 Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Thu, 31 Jul 2025 20:42:54 +0100 Subject: [PATCH 16/20] Don't shadow builtin type --- src/shapefile.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/shapefile.py b/src/shapefile.py index 0c47007..29fc76e 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -3263,12 +3263,14 @@ def record( self.__dbfRecord(record) @staticmethod - def _dbf_missing_placeholder(value: RecordValue, type: FieldType, size: int) -> str: - if type in {FieldType.N, FieldType.F}: + def _dbf_missing_placeholder( + value: RecordValue, field_type: FieldType, size: int + ) -> str: + if field_type in {FieldType.N, FieldType.F}: return "*" * size # QGIS NULL - if type is FieldType.D: + if field_type is FieldType.D: return "0" * 8 # QGIS NULL for date type - if type is FieldType.L: + if field_type is FieldType.L: return " " return str(value) @@ -3345,18 +3347,18 @@ def __dbfRecord(self, record: list[RecordValue]) -> None: fields = ( field for field in self.fields if field[0] != "DeletionFlag" ) # ignore deletionflag field in case it was specified - for (fieldName, type, size, decimal), value in zip(fields, record): + for (fieldName, type_, size, decimal), value in zip(fields, record): # write size = int(size) str_val: str if value in MISSING: - str_val = self._dbf_missing_placeholder(value, type, size) - elif type in {FieldType.N, FieldType.F}: + str_val = self._dbf_missing_placeholder(value, type_, size) + elif type_ in {FieldType.N, FieldType.F}: str_val = self._try_coerce_to_numeric_str(value, size, decimal) - elif type is FieldType.D: + elif type_ is FieldType.D: str_val = self._try_coerce_to_date_str(value) - elif type is FieldType.L: + elif type_ is FieldType.L: str_val = self._try_coerce_to_logical_str(value) else: if isinstance(value, bytes): From 87a1ff53bb022d54e0123b2c407cacaa098055ce Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Thu, 31 Jul 2025 21:32:17 +0100 Subject: [PATCH 17/20] Passes doctests (adjusted to reflect new API: Field, BBox, ZBox & MBox) --- README.md | 45 +++++++++++++++++---------------------------- src/shapefile.py | 32 ++++++++++++++++---------------- 2 files changed, 33 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index a0e8f44..80f0431 100644 --- a/README.md +++ b/README.md @@ -406,7 +406,7 @@ and the bounding box area the shapefile covers: >>> len(sf) 663 >>> sf.bbox - (-122.515048, 37.652916, -122.327622, 37.863433) + BBox(xmin=-122.515048, ymin=37.652916, xmax=-122.327622, ymax=37.863433) Finally, if you would prefer to work with the entire shapefile in a different format, you can convert all of it to a GeoJSON dictionary, although you may lose @@ -553,45 +553,34 @@ in the shp geometry file and the dbf attribute file. The field names of a shapefile are available as soon as you read a shapefile. You can call the "fields" attribute of the shapefile as a Python list. Each -field is a Python tuple with the following information: +field is a Python namedtuple (Field) with the following information: - * Field name: the name describing the data at this column index. - * Field type: the type of data at this column index. Types can be: + * name: the name describing the data at this column index (a string). + * field_type: a FieldType enum member determining the type of data at this column index. Names can be: * "C": Characters, text. * "N": Numbers, with or without decimals. * "F": Floats (same as "N"). * "L": Logical, for boolean True/False values. * "D": Dates. * "M": Memo, has no meaning within a GIS and is part of the xbase spec instead. - * Field length: the length of the data found at this column index. Older GIS + * size: Field length: the length of the data found at this column index. Older GIS software may truncate this length to 8 or 11 characters for "Character" fields. - * Decimal length: the number of decimal places found in "Number" fields. + * deci: Decimal length. The number of decimal places found in "Number" fields. + +A new field can be created directly from the type enum member etc., or as follows: + + >>> shapefile.Field.from_unchecked("Population", "N", 10,0) + Field(name="Population", field_type=FieldType.N, size=10, decimal=0) + +Using this method the conversion from string to enum is done automatically. To see the fields for the Reader object above (sf) call the "fields" attribute: - >>> fields = sf.fields - - >>> assert fields == [("DeletionFlag", "C", 1, 0), ("AREA", "N", 18, 5), - ... ("BKG_KEY", "C", 12, 0), ("POP1990", "N", 9, 0), ("POP90_SQMI", "N", 10, 1), - ... ("HOUSEHOLDS", "N", 9, 0), - ... ("MALES", "N", 9, 0), ("FEMALES", "N", 9, 0), ("WHITE", "N", 9, 0), - ... ("BLACK", "N", 8, 0), ("AMERI_ES", "N", 7, 0), ("ASIAN_PI", "N", 8, 0), - ... ("OTHER", "N", 8, 0), ("HISPANIC", "N", 8, 0), ("AGE_UNDER5", "N", 8, 0), - ... ("AGE_5_17", "N", 8, 0), ("AGE_18_29", "N", 8, 0), ("AGE_30_49", "N", 8, 0), - ... ("AGE_50_64", "N", 8, 0), ("AGE_65_UP", "N", 8, 0), - ... ("NEVERMARRY", "N", 8, 0), ("MARRIED", "N", 9, 0), ("SEPARATED", "N", 7, 0), - ... ("WIDOWED", "N", 8, 0), ("DIVORCED", "N", 8, 0), ("HSEHLD_1_M", "N", 8, 0), - ... ("HSEHLD_1_F", "N", 8, 0), ("MARHH_CHD", "N", 8, 0), - ... ("MARHH_NO_C", "N", 8, 0), ("MHH_CHILD", "N", 7, 0), - ... ("FHH_CHILD", "N", 7, 0), ("HSE_UNITS", "N", 9, 0), ("VACANT", "N", 7, 0), - ... ("OWNER_OCC", "N", 8, 0), ("RENTER_OCC", "N", 8, 0), - ... ("MEDIAN_VAL", "N", 7, 0), ("MEDIANRENT", "N", 4, 0), - ... ("UNITS_1DET", "N", 8, 0), ("UNITS_1ATT", "N", 7, 0), ("UNITS2", "N", 7, 0), - ... ("UNITS3_9", "N", 8, 0), ("UNITS10_49", "N", 8, 0), - ... ("UNITS50_UP", "N", 8, 0), ("MOBILEHOME", "N", 7, 0)] + >>> sf.fields + [Field(name="DeletionFlag", field_type=FieldType.C, size=1, decimal=0), Field(name="AREA", field_type=FieldType.N, size=18, decimal=5), Field(name="BKG_KEY", field_type=FieldType.C, size=12, decimal=0), Field(name="POP1990", field_type=FieldType.N, size=9, decimal=0), Field(name="POP90_SQMI", field_type=FieldType.N, size=10, decimal=1), Field(name="HOUSEHOLDS", field_type=FieldType.N, size=9, decimal=0), Field(name="MALES", field_type=FieldType.N, size=9, decimal=0), Field(name="FEMALES", field_type=FieldType.N, size=9, decimal=0), Field(name="WHITE", field_type=FieldType.N, size=9, decimal=0), Field(name="BLACK", field_type=FieldType.N, size=8, decimal=0), Field(name="AMERI_ES", field_type=FieldType.N, size=7, decimal=0), Field(name="ASIAN_PI", field_type=FieldType.N, size=8, decimal=0), Field(name="OTHER", field_type=FieldType.N, size=8, decimal=0), Field(name="HISPANIC", field_type=FieldType.N, size=8, decimal=0), Field(name="AGE_UNDER5", field_type=FieldType.N, size=8, decimal=0), Field(name="AGE_5_17", field_type=FieldType.N, size=8, decimal=0), Field(name="AGE_18_29", field_type=FieldType.N, size=8, decimal=0), Field(name="AGE_30_49", field_type=FieldType.N, size=8, decimal=0), Field(name="AGE_50_64", field_type=FieldType.N, size=8, decimal=0), Field(name="AGE_65_UP", field_type=FieldType.N, size=8, decimal=0), Field(name="NEVERMARRY", field_type=FieldType.N, size=8, decimal=0), Field(name="MARRIED", field_type=FieldType.N, size=9, decimal=0), Field(name="SEPARATED", field_type=FieldType.N, size=7, decimal=0), Field(name="WIDOWED", field_type=FieldType.N, size=8, decimal=0), Field(name="DIVORCED", field_type=FieldType.N, size=8, decimal=0), Field(name="HSEHLD_1_M", field_type=FieldType.N, size=8, decimal=0), Field(name="HSEHLD_1_F", field_type=FieldType.N, size=8, decimal=0), Field(name="MARHH_CHD", field_type=FieldType.N, size=8, decimal=0), Field(name="MARHH_NO_C", field_type=FieldType.N, size=8, decimal=0), Field(name="MHH_CHILD", field_type=FieldType.N, size=7, decimal=0), Field(name="FHH_CHILD", field_type=FieldType.N, size=7, decimal=0), Field(name="HSE_UNITS", field_type=FieldType.N, size=9, decimal=0), Field(name="VACANT", field_type=FieldType.N, size=7, decimal=0), Field(name="OWNER_OCC", field_type=FieldType.N, size=8, decimal=0), Field(name="RENTER_OCC", field_type=FieldType.N, size=8, decimal=0), Field(name="MEDIAN_VAL", field_type=FieldType.N, size=7, decimal=0), Field(name="MEDIANRENT", field_type=FieldType.N, size=4, decimal=0), Field(name="UNITS_1DET", field_type=FieldType.N, size=8, decimal=0), Field(name="UNITS_1ATT", field_type=FieldType.N, size=7, decimal=0), Field(name="UNITS2", field_type=FieldType.N, size=7, decimal=0), Field(name="UNITS3_9", field_type=FieldType.N, size=8, decimal=0), Field(name="UNITS10_49", field_type=FieldType.N, size=8, decimal=0), Field(name="UNITS50_UP", field_type=FieldType.N, size=8, decimal=0), Field(name="MOBILEHOME", field_type=FieldType.N, size=7, decimal=0)] The first field of a dbf file is always a 1-byte field called "DeletionFlag", which indicates records that have been deleted but not removed. However, @@ -1375,7 +1364,7 @@ Shapefiles containing M-values can be examined in several ways: >>> r = shapefile.Reader('shapefiles/test/linem') >>> r.mbox # the lower and upper bound of M-values in the shapefile - (0.0, 3.0) + MBox(mmin=0.0, mmax=3.0) >>> r.shape(0).m # flat list of M-values [0.0, None, 3.0, None, 0.0, None, None] @@ -1408,7 +1397,7 @@ To examine a Z-type shapefile you can do: >>> r = shapefile.Reader('shapefiles/test/linez') >>> r.zbox # the lower and upper bound of Z-values in the shapefile - (0.0, 22.0) + ZBox(zmin=0.0, zmax=22.0) >>> r.shape(0).z # flat list of Z-values [18.0, 20.0, 22.0, 0.0, 0.0, 0.0, 0.0, 15.0, 13.0, 14.0] diff --git a/src/shapefile.py b/src/shapefile.py index 29fc76e..f002bdd 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -327,7 +327,7 @@ class GeoJSONFeatureCollectionWithBBox(GeoJSONFeatureCollection): # Helpers -MISSING = {None, ""} +MISSING = (None, "") # Don't make a set, as user input may not be Hashable NODATA = -10e38 # as per the ESRI shapefile spec, only used for m-values. unpack_2_int32_be = Struct(">2i").unpack @@ -2538,7 +2538,7 @@ def __record( # parse each value record = [] for (__name, typ, __size, decimal), value in zip(fieldTuples, recordContents): - if typ in {"N", "F"}: + if typ in {FieldType.N, FieldType.F}: # numeric or float: number stored as a string, right justified, and padded with blanks to the width of the field. value = value.split(b"\0")[0] value = value.replace(b"*", b"") # QGIS NULL is all '*' chars @@ -2564,7 +2564,7 @@ def __record( except ValueError: # not parseable as int, set to None value = None - elif typ == "D": + elif typ is FieldType.D: # date: 8 bytes - date stored as a string in the format YYYYMMDD. if ( not value.replace(b"\x00", b"") @@ -2582,7 +2582,7 @@ def __record( except (TypeError, ValueError): # if invalid date, just return as unicode string so user can decimalde value = str(value.strip()) - elif typ == "L": + elif typ is FieldType.L: # logical: 1 byte - initialized to 0x20 (space) otherwise T or F. if value == b" ": value = None # space means missing or not yet set @@ -3225,7 +3225,7 @@ def __shxRecord(self, offset: int, length: int) -> None: def record( self, - *recordList: list[RecordValue], + *recordList: RecordValue, **recordDict: RecordValue, ) -> None: """Creates a dbf attribute record. You can submit either a sequence of @@ -3241,7 +3241,7 @@ def record( record: list[RecordValue] fieldCount = sum(1 for field in self.fields if field[0] != "DeletionFlag") if recordList: - record = list(*recordList) + record = list(recordList) while len(record) < fieldCount: record.append("") elif recordDict: @@ -3272,7 +3272,7 @@ def _dbf_missing_placeholder( return "0" * 8 # QGIS NULL for date type if field_type is FieldType.L: return " " - return str(value) + return str(value)[:size].ljust(size) @overload @staticmethod @@ -3362,20 +3362,20 @@ def __dbfRecord(self, record: list[RecordValue]) -> None: str_val = self._try_coerce_to_logical_str(value) else: if isinstance(value, bytes): - decoded_val = value.decode(self.encoding, self.encodingErrors) + str_val = value.decode(self.encoding, self.encodingErrors) else: # anything else is forced to string. - decoded_val = str(value) - # Truncate to the length of the field - str_val = decoded_val[:size].ljust(size) + str_val = str(value) - # should be default ascii encoding - encoded_val = str_val.encode("ascii", self.encodingErrors) + # Truncate or right pad to the length of the field + encoded_val = str_val.encode(self.encoding, self.encodingErrors)[ + :size + ].ljust(size) if len(encoded_val) != size: raise ShapefileException( - "Shapefile Writer unable to pack incorrect sized value" - f" (size {len(encoded_val)}) into field '{fieldName}' (size {size})." + f"Shapefile Writer unable to pack incorrect sized {value=!r} " + f"(size {len(encoded_val)}) into field '{fieldName}' (size {size})." ) f.write(encoded_val) @@ -3674,7 +3674,7 @@ def _test(args: list[str] = sys.argv[1:], verbosity: bool = False) -> int: new_url = _replace_remote_url(old_url) example.source = example.source.replace(old_url, new_url) - runner = doctest.DocTestRunner(verbose=verbosity) + runner = doctest.DocTestRunner(verbose=verbosity, optionflags=doctest.FAIL_FAST) if verbosity == 0: print(f"Running {len(tests.examples)} doctests...") From 5dc6257258d758ffb8045b1c66cf7f147b1e0ee1 Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Thu, 31 Jul 2025 21:35:06 +0100 Subject: [PATCH 18/20] Adjust unit test to allow for new field type enum --- test_shapefile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_shapefile.py b/test_shapefile.py index 2a10d3e..a2ffbff 100644 --- a/test_shapefile.py +++ b/test_shapefile.py @@ -695,7 +695,7 @@ def test_reader_fields(): field = fields[0] assert isinstance(field[0], str) # field name - assert field[1] in ["C", "N", "F", "L", "D", "M"] # field type + assert field[1].name in ["C", "N", "F", "L", "D", "M"] # field type assert isinstance(field[2], int) # field length assert isinstance(field[3], int) # decimal length From 435dd3a610aaefa2eb44beaba1bd3a8695e16e0b Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Thu, 31 Jul 2025 21:46:58 +0100 Subject: [PATCH 19/20] Update docs --- README.md | 28 ++++++++++++++++++++++++++-- changelog.txt | 15 +++++++++++++-- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 80f0431..1fb969f 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,8 @@ The Python Shapefile Library (PyShp) reads and writes ESRI Shapefiles in pure Py - **Author**: [Joel Lawhead](https://github.com/GeospatialPython) - **Maintainers**: [Karim Bahgat](https://github.com/karimbahgat) -- **Version**: 2.3.1 -- **Date**: 28 July, 2022 +- **Version**: 3.0.0-alpha +- **Date**: 31 July, 2025 - **License**: [MIT](https://github.com/GeospatialPython/pyshp/blob/master/LICENSE.TXT) ## Contents @@ -93,6 +93,30 @@ part of your geospatial project. # Version Changes +## 3.0.0-alpha + +### Breaking Changes: +- Python 2 and Python 3.8 support dropped. +- Field info tuple is now a namedtuple (Field) instead of a list. +- Field type codes are now FieldType enum members. +- bbox, mbox and zbox attributes are all new Namedtuples. +- Writer does not mutate shapes. +- New custom subclasses for each shape type: Null, Multipatch, Point, Polyline, + Multipoint, and Polygon, plus the latter 4's M and Z variants (Reader and + Writer are still compatible with their base class, Shape, as before). +- Shape sub classes are creatable from, and serializable to bytes streams, + as per the shapefile spec. + +### Code quality +- Statically typed, and checked with Mypy +- Checked with Ruff. +- f-strings +- Remove Python 2 specific functions. +- Run doctests against wheels. +- Testing of wheels before publishing them +- pyproject.toml src layout +- Slow test marked. + ## 2.4.0 ### Breaking Change. Support for Python 2 and Pythons <= 3.8 to be dropped. diff --git a/changelog.txt b/changelog.txt index 48a534a..d2c5f51 100644 --- a/changelog.txt +++ b/changelog.txt @@ -1,9 +1,20 @@ VERSION 3.0.0-alpha -Python 2 and Python 3.8 support dropped + Breaking Changes: + * Python 2 and Python 3.8 support dropped. + * Field info tuple is now a namedtuple (Field) instead of a list. + * Field type codes are now FieldType enum members. + * bbox, mbox and zbox attributes are all new Namedtuples. + * Writer does not mutate shapes. + * New custom subclasses for each shape type: Null, Multipatch, Point, Polyline, + Multipoint, and Polygon, plus the latter 4's M and Z variants (Reader and + Writer are still compatible with their base class, Shape, as before). + * Shape sub classes are creatable from, and serializable to bytes streams, + as per the shapefile spec. -2025-07-22 Code quality + * Statically typed and checked with Mypy + * Checked with Ruff. * Type hints * f-strings * Remove Python 2 specific functions. From c9203040c8c7455f6818d1dc41b239a33d010cea Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Thu, 31 Jul 2025 21:50:09 +0100 Subject: [PATCH 20/20] Trim trailing whitespace --- README.md | 6 +++--- changelog.txt | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 1fb969f..cd4758c 100644 --- a/README.md +++ b/README.md @@ -101,10 +101,10 @@ part of your geospatial project. - Field type codes are now FieldType enum members. - bbox, mbox and zbox attributes are all new Namedtuples. - Writer does not mutate shapes. -- New custom subclasses for each shape type: Null, Multipatch, Point, Polyline, - Multipoint, and Polygon, plus the latter 4's M and Z variants (Reader and +- New custom subclasses for each shape type: Null, Multipatch, Point, Polyline, + Multipoint, and Polygon, plus the latter 4's M and Z variants (Reader and Writer are still compatible with their base class, Shape, as before). -- Shape sub classes are creatable from, and serializable to bytes streams, +- Shape sub classes are creatable from, and serializable to bytes streams, as per the shapefile spec. ### Code quality diff --git a/changelog.txt b/changelog.txt index d2c5f51..45bfd76 100644 --- a/changelog.txt +++ b/changelog.txt @@ -6,10 +6,10 @@ VERSION 3.0.0-alpha * Field type codes are now FieldType enum members. * bbox, mbox and zbox attributes are all new Namedtuples. * Writer does not mutate shapes. - * New custom subclasses for each shape type: Null, Multipatch, Point, Polyline, - Multipoint, and Polygon, plus the latter 4's M and Z variants (Reader and + * New custom subclasses for each shape type: Null, Multipatch, Point, Polyline, + Multipoint, and Polygon, plus the latter 4's M and Z variants (Reader and Writer are still compatible with their base class, Shape, as before). - * Shape sub classes are creatable from, and serializable to bytes streams, + * Shape sub classes are creatable from, and serializable to bytes streams, as per the shapefile spec. Code quality