diff --git a/README.md b/README.md index cd4758c..66951be 100644 --- a/README.md +++ b/README.md @@ -430,7 +430,7 @@ and the bounding box area the shapefile covers: >>> len(sf) 663 >>> sf.bbox - BBox(xmin=-122.515048, ymin=37.652916, xmax=-122.327622, ymax=37.863433) + (-122.515048, 37.652916, -122.327622, 37.863433) Finally, if you would prefer to work with the entire shapefile in a different format, you can convert all of it to a GeoJSON dictionary, although you may lose @@ -1388,7 +1388,7 @@ Shapefiles containing M-values can be examined in several ways: >>> r = shapefile.Reader('shapefiles/test/linem') >>> r.mbox # the lower and upper bound of M-values in the shapefile - MBox(mmin=0.0, mmax=3.0) + (0.0, 3.0) >>> r.shape(0).m # flat list of M-values [0.0, None, 3.0, None, 0.0, None, None] @@ -1421,7 +1421,7 @@ To examine a Z-type shapefile you can do: >>> r = shapefile.Reader('shapefiles/test/linez') >>> r.zbox # the lower and upper bound of Z-values in the shapefile - ZBox(zmin=0.0, zmax=22.0) + (0.0, 22.0) >>> r.shape(0).z # flat list of Z-values [18.0, 20.0, 22.0, 0.0, 0.0, 0.0, 0.0, 15.0, 13.0, 14.0] diff --git a/src/shapefile.py b/src/shapefile.py index 65ffce2..2c741d0 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -12,7 +12,6 @@ import array import doctest -import enum import io import logging import os @@ -26,6 +25,7 @@ IO, Any, Container, + Final, Generic, Iterable, Iterator, @@ -38,6 +38,7 @@ TypedDict, TypeVar, Union, + cast, overload, ) from urllib.error import HTTPError @@ -122,22 +123,25 @@ PointT = Union[Point2D, PointMT, PointZT] PointsT = list[PointT] +BBox = tuple[float, float, float, float] +MBox = tuple[float, float] +ZBox = tuple[float, float] -class BBox(NamedTuple): - xmin: float - ymin: float - xmax: float - ymax: float +# class BBox(NamedTuple): +# xmin: float +# ymin: float +# xmax: float +# ymax: float -class MBox(NamedTuple): - mmin: Optional[float] - mmax: Optional[float] +# class MBox(NamedTuple): +# mmin: Optional[float] +# mmax: Optional[float] -class ZBox(NamedTuple): - zmin: float - zmax: float +# class ZBox(NamedTuple): +# zmin: float +# zmax: float class WriteableBinStream(Protocol): @@ -171,22 +175,49 @@ def read(self, size: int = -1): ... BinaryFileT = Union[str, IO[bytes]] BinaryFileStreamT = Union[IO[bytes], io.BytesIO, WriteSeekableBinStream] +FieldTypeT = Literal["C", "D", "F", "L", "M", "N"] + # https://en.wikipedia.org/wiki/.dbf#Database_records -class FieldType(enum.Enum): - # Use an ascii-encoded byte of the name, to save a decoding step. - C = "Character" # (str) - D = "Date" - F = "Floating point" - L = "Logical" # (bool) - M = "Memo" # Legacy. (10 digit str, starting block in an .dbt file) - N = "Numeric" # (int) +class FieldType: + """A bare bones 'enum', as the enum library noticeably slows performance.""" + + # __slots__ = ["C", "D", "F", "L", "M", "N", "__members__"] + + C: Final = "C" # "Character" # (str) + D: Final = "D" # "Date" + F: Final = "F" # "Floating point" + L: Final = "L" # "Logical" # (bool) + M: Final = "M" # "Memo" # Legacy. (10 digit str, starting block in an .dbt file) + N: Final = "N" # "Numeric" # (int) + __members__: set[FieldTypeT] = { + "C", + "D", + "F", + "L", + "M", + "N", + } # set(__slots__) - {"__members__"} + + # def raise_if_invalid(field_type: Hashable): + # if field_type not in FieldType.__members__: + # raise ShapefileException( + # f"field_type must be in {{FieldType.__members__}}. Got: {field_type=}. " + # ) + + +FIELD_TYPE_ALIASES: dict[Union[str, bytes], FieldTypeT] = {} +for c in FieldType.__members__: + FIELD_TYPE_ALIASES[c.upper()] = c + FIELD_TYPE_ALIASES[c.lower()] = c + FIELD_TYPE_ALIASES[c.encode("ascii").lower()] = c + FIELD_TYPE_ALIASES[c.encode("ascii").upper()] = c # Use functional syntax to have an attribute named type, a Python keyword class Field(NamedTuple): name: str - field_type: FieldType + field_type: FieldTypeT size: int decimal: int @@ -194,34 +225,32 @@ class Field(NamedTuple): def from_unchecked( cls, name: str, - field_type: Union[str, FieldType] = FieldType.C, + field_type: Union[str, bytes, FieldTypeT] = "C", size: int = 50, decimal: int = 0, ) -> Self: - if isinstance(field_type, str): - if field_type.upper() in FieldType.__members__: - field_type = FieldType[field_type.upper()] - else: - raise ShapefileException( - "type must be C,D,F,L,M,N, or a FieldType enum member. " - f"Got: {field_type=}. " - ) + try: + type_ = FIELD_TYPE_ALIASES[field_type] + except KeyError: + raise ShapefileException( + f"field_type must be in {{FieldType.__members__}}. Got: {field_type=}. " + ) - if field_type is FieldType.D: + if type_ is FieldType.D: size = 8 decimal = 0 - elif field_type is FieldType.L: + elif type_ is FieldType.L: size = 1 decimal = 0 # A doctest in README.md previously passed in a string ('40') for size, # so explictly convert name to str, and size and decimal to ints. return cls( - name=str(name), field_type=field_type, size=int(size), decimal=int(decimal) + name=str(name), field_type=type_, size=int(size), decimal=int(decimal) ) def __repr__(self) -> str: - return f'Field(name="{self.name}", field_type=FieldType.{self.field_type.name}, size={self.size}, decimal={self.decimal})' + return f'Field(name="{self.name}", field_type=FieldType.{self.field_type}, size={self.size}, decimal={self.decimal})' RecordValueNotDate = Union[bool, int, float, str, date] @@ -390,8 +419,9 @@ def rewind(coords: Reversible[PointT]) -> PointsT: def ring_bbox(coords: PointsT) -> BBox: """Calculates and returns the bounding box of a ring.""" xs, ys = map(list, list(zip(*coords))[:2]) # ignore any z or m values - bbox = BBox(xmin=min(xs), ymin=min(ys), xmax=max(xs), ymax=max(ys)) - return bbox + # bbox = BBox(xmin=min(xs), ymin=min(ys), xmax=max(xs), ymax=max(ys)) + # return bbox + return min(xs), min(ys), max(xs), max(ys) def bbox_overlap(bbox1: BBox, bbox2: BBox) -> bool: @@ -697,8 +727,8 @@ def __init__( # Preserve previous behaviour for anyone who set self.shapeType = None if not isinstance(shapeType, _NoShapeTypeSentinel): self.shapeType = shapeType - self.points = points or [] - self.parts = parts or [] + self.points: PointsT = points or [] + self.parts: Sequence[int] = parts or [] if partTypes: self.partTypes = partTypes @@ -973,7 +1003,7 @@ class _CanHaveBBox(Shape): bbox: Optional[BBox] = None def _get_set_bbox_from_byte_stream(self, b_io: ReadableBinStream) -> BBox: - self.bbox: BBox = BBox(*_Array[float]("d", unpack("<4d", b_io.read(32)))) + self.bbox: BBox = unpack("<4d", b_io.read(32)) return self.bbox @staticmethod @@ -1132,7 +1162,7 @@ def _get_nparts_from_byte_stream(b_io: ReadableBinStream) -> int: return unpack(" int: + def _write_nparts_to_byte_stream(b_io: WriteableBinStream, s: _CanHaveParts) -> int: return b_io.write(pack("= nPoints * 8: self.m = [] - for m in _Array[float]("d", unpack(f"<{nPoints}d", b_io.read(nPoints * 8))): + for m in unpack(f"<{nPoints}d", b_io.read(nPoints * 8)): if m > NODATA: self.m.append(m) else: @@ -1293,20 +1327,20 @@ def _write_ms_to_byte_stream( f"Failed to write measure extremes for record {i}. Expected floats" ) try: - if hasattr(s, "m"): + if getattr(s, "m", False): # if m values are stored in attribute - ms = [m if m is not None else NODATA for m in s.m] + ms = [m if m is not None else NODATA for m in cast(_HasM, s).m] else: # if m values are stored as 3rd/4th dimension # 0-index position of m value is 3 if z type (x,y,z,m), or 2 if m type (x,y,m) mpos = 3 if s.shapeType in _HasZ._shapeTypes else 2 - ms = [] - for p in s.points: - if len(p) > mpos and p[mpos] is not None: - ms.append(p[mpos]) - else: - ms.append(NODATA) + ms = [ + cast(float, p[mpos]) + if len(p) > mpos and p[mpos] is not None + else NODATA + for p in s.points + ] num_bytes_written += b_io.write(pack(f"<{len(ms)}d", *ms)) @@ -1330,6 +1364,10 @@ class _HasZ(_CanHaveBBox): ) z: Sequence[float] + def __init__(self, *args, **kwargs): + self.z = [] + super().__init__(*args, **kwargs) + def _set_zs_from_byte_stream(self, b_io: ReadableBinStream, nPoints: int): __zmin, __zmax = unpack("<2d", b_io.read(16)) # pylint: disable=unused-private-member self.z = _Array[float]("d", unpack(f"<{nPoints}d", b_io.read(nPoints * 8))) @@ -1350,12 +1388,12 @@ def _write_zs_to_byte_stream( f"Failed to write elevation extremes for record {i}. Expected floats." ) try: - if hasattr(s, "z"): + if getattr(s, "z", False): # if z values are stored in attribute - zs = s.z + zs = cast(_HasZ, s).z else: # if z values are stored as 3rd dimension - zs = [p[2] if len(p) > 2 else 0 for p in s.points] + zs = [cast(float, p[2]) if len(p) > 2 else 0 for p in s.points] num_bytes_written += b_io.write(pack(f"<{len(zs)}d", *zs)) except error: @@ -1406,16 +1444,14 @@ def _write_single_point_m_to_byte_stream( # Write a single M value # Note: missing m values are autoset to NODATA. - if hasattr(s, "m"): + if getattr(s, "m", False): # if m values are stored in attribute try: # if not s.m or s.m[0] is None: # s.m = (NODATA,) # m = s.m[0] - if s.m and s.m[0] is not None: - m = s.m[0] - else: - m = NODATA + s = cast(_HasM, s) + m = s.m[0] if s.m and s.m[0] is not None else NODATA except error: raise ShapefileException( f"Failed to write measure value for record {i}. Expected floats." @@ -1432,7 +1468,7 @@ def _write_single_point_m_to_byte_stream( # s.points[0][mpos] = NODATA m = NODATA else: - m = s.points[0][mpos] + m = cast(float, s.points[0][mpos]) except error: raise ShapefileException( @@ -1977,7 +2013,7 @@ def __seek_0_on_file_obj_wrap_or_open_from_name( if hasattr(file_, "read"): # Copy if required try: - file_.seek(0) # type: ignore + file_.seek(0) return file_ except (NameError, io.UnsupportedOperation): return io.BytesIO(file_.read()) @@ -2208,18 +2244,23 @@ def __shpHeader(self) -> None: shp.seek(32) self.shapeType = unpack("= NODATA else None for m_bound in unpack("<2d", shp.read(16)) ] - self.mbox = MBox(mmin=m_bounds[0], mmax=m_bounds[1]) + # self.mbox = MBox(mmin=m_bounds[0], mmax=m_bounds[1]) + self.mbox: tuple[Optional[float], Optional[float]] = (m_bounds[0], m_bounds[1]) def __shape( self, oid: Optional[int] = None, bbox: Optional[BBox] = None @@ -2418,7 +2459,7 @@ def __dbfHeader(self) -> None: name = encoded_name.decode(self.encoding, self.encodingErrors) name = name.lstrip() - field_type = FieldType[encoded_type_char.decode("ascii").upper()] + field_type = FIELD_TYPE_ALIASES[encoded_type_char] self.fields.append(Field(name, field_type, size, decimal)) terminator = dbf.read(1) @@ -2538,7 +2579,7 @@ def __record( # parse each value record = [] for (__name, typ, __size, decimal), value in zip(fieldTuples, recordContents): - if typ in {FieldType.N, FieldType.F}: + if typ is FieldType.N or typ is FieldType.F: # numeric or float: number stored as a string, right justified, and padded with blanks to the width of the field. value = value.split(b"\0")[0] value = value.replace(b"*", b"") # QGIS NULL is all '*' chars @@ -2632,7 +2673,9 @@ def records(self, fields: Optional[list[str]] = None) -> list[_Record]: f = self.__getFileObj(self.dbf) f.seek(self.__dbfHdrLength) fieldTuples, recLookup, recStruct = self.__recordFields(fields) - for i in range(self.numRecords): # type: ignore + # self.__dbfHeader() sets self.numRecords, so it's fine to cast it to int + # (to tell mypy it's not None). + for i in range(cast(int, self.numRecords)): r = self.__record( oid=i, fieldTuples=fieldTuples, recLookup=recLookup, recStruct=recStruct ) @@ -2922,71 +2965,79 @@ def __shpFileLength(self) -> int: return size def __bbox(self, s: Shape) -> BBox: - x: list[float] = [] - y: list[float] = [] + xs: list[float] = [] + ys: list[float] = [] - if self._bbox: - x.append(self._bbox.xmin) - y.append(self._bbox.ymin) - x.append(self._bbox.xmax) - y.append(self._bbox.ymax) - - if len(s.points) > 0: - px, py = list(zip(*s.points))[:2] - x.extend(px) - y.extend(py) - else: + if not s.points: # this should not happen. # any shape that is not null should have at least one point, and only those should be sent here. # could also mean that earlier code failed to add points to a non-null shape. - raise ValueError( + raise ShapefileException( "Cannot create bbox. Expected a valid shape with at least one point. " - f"Got a shape of type '{s.shapeType}' and 0 points." + f"Got a shape of type {s.shapeType=} and 0 points." ) - self._bbox = BBox(xmin=min(x), ymin=min(y), xmax=max(x), ymax=max(y)) - return self._bbox - def __zbox(self, s) -> ZBox: - z: list[float] = [] - if self._zbox: - z.extend(self._zbox) - - for p in s.points: - try: - z.append(p[2]) - except IndexError: - # point did not have z value - # setting it to 0 is probably ok, since it means all are on the same elevation - z.append(0) - - # Original self._zbox bounds (if any) are the first two entries. - # Set zbox for the first, and all later times - self._zbox = ZBox(zmin=min(z), zmax=max(z)) - return self._zbox + for point in s.points: + xs.append(point[0]) + ys.append(point[1]) - def __mbox(self, s) -> MBox: - mpos = 3 if s.shapeType in _HasZ._shapeTypes else 2 - m: list[float] = [] - if self._mbox: - m.extend(m_bound for m_bound in self._mbox if m_bound is not None) + shape_bbox = (min(xs), min(ys), max(xs), max(ys)) + # update global + if self._bbox: + # compare with existing + self._bbox = ( + min(shape_bbox[0], self._bbox[0]), + min(shape_bbox[1], self._bbox[1]), + max(shape_bbox[2], self._bbox[2]), + max(shape_bbox[3], self._bbox[3]), + ) + else: + # first time bbox is being set + self._bbox = shape_bbox + return shape_bbox + + def __zbox(self, s: Union[_HasZ, PointZ]) -> ZBox: + shape_zs: list[float] = [] + if s.z: + shape_zs.extend(s.z) + else: + for p in s.points: + # On a ShapeZ type, M is at index 4, and the point can be a 3-tuple or 4-tuple. + z = p[2] if len(p) >= 3 and p[2] is not None else 0 + shape_zs.append(z) + zbox = (min(shape_zs), max(shape_zs)) + # update global + if self._zbox: + # compare with existing + self._zbox = (min(zbox[0], self._zbox[0]), max(zbox[1], self._zbox[1])) + else: + # first time zbox is being set + self._zbox = zbox + return zbox + + def __mbox(self, s: Union[_HasM, PointM]) -> MBox: + mpos = 3 if s.shapeType in _HasZ._shapeTypes | PointZ._shapeTypes else 2 + shape_ms: list[float] = [] + if s.m: + shape_ms.extend(m for m in s.m if m is not None) + else: + for p in s.points: + m = p[mpos] if len(p) >= mpos + 1 else None + if m is not None: + shape_ms.append(m) - for p in s.points: - try: - if p[mpos] is not None: - # mbox should only be calculated on valid m values - m.append(p[mpos]) - except IndexError: - # point did not have m value so is missing - # mbox should only be calculated on valid m values - pass - if not m: + if not shape_ms: # only if none of the shapes had m values, should mbox be set to missing m values - m.append(NODATA) - - # Original self._mbox bounds (if any) are the first two entries. - # Set mbox for the first, and all later times - self._mbox = MBox(mmin=min(m), mmax=max(m)) - return self._mbox + shape_ms.append(NODATA) + mbox = (min(shape_ms), max(shape_ms)) + # update global + if self._mbox: + # compare with existing + self._mbox = (min(mbox[0], self._mbox[0]), max(mbox[1], self._mbox[1])) + else: + # first time mbox is being set + self._mbox = mbox + return mbox @property def shapeTypeName(self) -> str: @@ -3037,7 +3088,8 @@ def __shapefileHeader( # In such cases of empty shapefiles, ESRI spec says the bbox values are 'unspecified'. # Not sure what that means, so for now just setting to 0s, which is the same behavior as in previous versions. # This would also make sense since the Z and M bounds are similarly set to 0 for non-Z/M type shapefiles. - bbox = BBox(0, 0, 0, 0) + # bbox = BBox(0, 0, 0, 0) + bbox = (0, 0, 0, 0) f.write(pack("<4d", *bbox)) except error: raise ShapefileException( @@ -3046,25 +3098,29 @@ def __shapefileHeader( else: f.write(pack("<4d", 0, 0, 0, 0)) # Elevation - if self.shapeType in {POINTZ} | _HasZ._shapeTypes: + if self.shapeType in PointZ._shapeTypes | _HasZ._shapeTypes: # Z values are present in Z type zbox = self.zbox() if zbox is None: # means we have empty shapefile/only null geoms (see commentary on bbox above) - zbox = ZBox(0, 0) + # zbox = ZBox(0, 0) + zbox = (0, 0) else: # As per the ESRI shapefile spec, the zbox for non-Z type shapefiles are set to 0s - zbox = ZBox(0, 0) + # zbox = ZBox(0, 0) + zbox = (0, 0) # Measure - if self.shapeType in {POINTM, POINTZ} | _HasM._shapeTypes: + if self.shapeType in PointM._shapeTypes | _HasM._shapeTypes: # M values are present in M or Z type mbox = self.mbox() if mbox is None: # means we have empty shapefile/only null geoms (see commentary on bbox above) - mbox = MBox(0, 0) + # mbox = MBox(0, 0) + mbox = (0, 0) else: # As per the ESRI shapefile spec, the mbox for non-M type shapefiles are set to 0s - mbox = MBox(0, 0) + # mbox = MBox(0, 0) + mbox = (0, 0) # Try writing try: f.write(pack("<4d", zbox[0], zbox[1], mbox[0], mbox[1])) @@ -3111,7 +3167,7 @@ def __dbfHeader(self) -> None: encoded_name = field.name.encode(self.encoding, self.encodingErrors) encoded_name = encoded_name.replace(b" ", b"_") encoded_name = encoded_name[:10].ljust(11).replace(b" ", b"\x00") - encodedFieldType = field.field_type.name.encode("ascii") + encodedFieldType = field.field_type.encode("ascii") fld = pack( "<11sc4xBB14x", encoded_name, @@ -3155,7 +3211,7 @@ def __shpRecord(self, s: Shape) -> tuple[int, int]: # Shape Type if self.shapeType is None and s.shapeType != NULL: self.shapeType = s.shapeType - if s.shapeType not in {NULL, self.shapeType}: + if s.shapeType not in (NULL, self.shapeType): raise ShapefileException( f"The shape's type ({s.shapeType}) must match " f"the type of the shapefile ({self.shapeType})." @@ -3163,15 +3219,17 @@ def __shpRecord(self, s: Shape) -> tuple[int, int]: # For both single point and multiple-points non-null shapes, # update bbox, mbox and zbox of the whole shapefile - new_bbox = self.__bbox(s) if s.shapeType != NULL else None - new_mbox = ( - self.__mbox(s) - if s.shapeType in {POINTM, POINTZ} | _HasM._shapeTypes - else None - ) - new_zbox = ( - self.__zbox(s) if s.shapeType in {POINTZ} | _HasZ._shapeTypes else None - ) + shape_bbox = self.__bbox(s) if s.shapeType != NULL else None + + if s.shapeType in PointM._shapeTypes | _HasM._shapeTypes: + shape_mbox = self.__mbox(cast(Union[_HasM, PointM], s)) + else: + shape_mbox = None + + if s.shapeType in PointZ._shapeTypes | _HasZ._shapeTypes: + shape_zbox = self.__zbox(cast(Union[_HasZ, PointZ], s)) + else: + shape_zbox = None # Create an in-memory binary buffer to avoid # unnecessary seeks to files on disk @@ -3194,9 +3252,9 @@ def __shpRecord(self, s: Shape) -> tuple[int, int]: b_io=b_io, s=s, i=self.shpNum, - bbox=new_bbox, - mbox=new_mbox, - zbox=new_zbox, + bbox=shape_bbox, + mbox=shape_mbox, + zbox=shape_zbox, ) # Finalize record length as 16-bit words @@ -3262,77 +3320,7 @@ def record( record = ["" for _ in range(fieldCount)] self.__dbfRecord(record) - @staticmethod - def _dbf_missing_placeholder( - value: RecordValue, field_type: FieldType, size: int - ) -> str: - if field_type in {FieldType.N, FieldType.F}: - return "*" * size # QGIS NULL - if field_type is FieldType.D: - return "0" * 8 # QGIS NULL for date type - if field_type is FieldType.L: - return " " - return str(value)[:size].ljust(size) - - @overload - @staticmethod - def _try_coerce_to_numeric_str(value: date, size: int, decimal: int) -> Never: ... - @overload - @staticmethod - def _try_coerce_to_numeric_str( - value: RecordValueNotDate, size: int, decimal: int - ) -> str: ... - @staticmethod - def _try_coerce_to_numeric_str(value, size, decimal): - # numeric or float: number stored as a string, - # right justified, and padded with blanks - # to the width of the field. - if not decimal: - # force to int - try: - # first try to force directly to int. - # forcing a large int to float and back to int - # will lose information and result in wrong nr. - int_val = int(value) - except ValueError: - # forcing directly to int failed, so was probably a float. - int_val = int(float(value)) - except TypeError: - raise ShapefileException(f"Could not form int from: {value}") - # length capped to the field size - return format(int_val, "d")[:size].rjust(size) - - try: - f_val = float(value) - except ValueError: - raise ShapefileException(f"Could not form float from: {value}") - # length capped to the field size - return format(f_val, f".{decimal}f")[:size].rjust(size) - - @staticmethod - def _try_coerce_to_date_str(value: RecordValue) -> str: - # date: 8 bytes - date stored as a string in the format YYYYMMDD. - if isinstance(value, date): - return f"{value.year:04d}{value.month:02d}{value.day:02d}" - if isinstance(value, (list, tuple)) and len(value) == 3: - return f"{value[0]:04d}{value[1]:02d}{value[2]:02d}" - if isinstance(value, str) and len(value) == 8: - return value # value is already a date string - - raise ShapefileException( - "Date values must be either a datetime.date object, a list/tuple, a YYYYMMDD string, or a missing value." - ) - - @staticmethod - def _try_coerce_to_logical_str(value: RecordValue) -> str: - # logical: 1 byte - initialized to 0x20 (space) otherwise T or F. - if value == 1: # True == 1 - return "T" - if value == 0: # False == 0 - return "F" - return " " # unknown is set to space - - def __dbfRecord(self, record: list[RecordValue]) -> None: + def __dbfRecord(self, record): """Writes the dbf records.""" f = self.__getFileObj(self.dbf) if self.recNum == 0: @@ -3347,37 +3335,83 @@ def __dbfRecord(self, record: list[RecordValue]) -> None: fields = ( field for field in self.fields if field[0] != "DeletionFlag" ) # ignore deletionflag field in case it was specified - for (fieldName, type_, size, decimal), value in zip(fields, record): + for (fieldName, fieldType, size, deci), value in zip(fields, record): # write - size = int(size) - str_val: str - - if value in MISSING: - str_val = self._dbf_missing_placeholder(value, type_, size) - elif type_ in {FieldType.N, FieldType.F}: - str_val = self._try_coerce_to_numeric_str(value, size, decimal) - elif type_ is FieldType.D: - str_val = self._try_coerce_to_date_str(value) - elif type_ is FieldType.L: - str_val = self._try_coerce_to_logical_str(value) - else: - if isinstance(value, bytes): - str_val = value.decode(self.encoding, self.encodingErrors) - else: - # anything else is forced to string. - str_val = str(value) - - # Truncate or right pad to the length of the field - encoded_val = str_val.encode(self.encoding, self.encodingErrors)[ - :size - ].ljust(size) + # fieldName, fieldType, size and deci were already checked + # when their Field instance was created and added to self.fields + str_val: Optional[str] = None - if len(encoded_val) != size: + if fieldType in ("N", "F"): + # numeric or float: number stored as a string, right justified, and padded with blanks to the width of the field. + if value in MISSING: + str_val = "*" * size # QGIS NULL + elif not deci: + # force to int + try: + # first try to force directly to int. + # forcing a large int to float and back to int + # will lose information and result in wrong nr. + num_val = int(value) + except ValueError: + # forcing directly to int failed, so was probably a float. + num_val = int(float(value)) + str_val = format(num_val, "d")[:size].rjust( + size + ) # caps the size if exceeds the field size + else: + f_val = float(value) + str_val = format(f_val, f".{deci}f")[:size].rjust( + size + ) # caps the size if exceeds the field size + elif fieldType == "D": + # date: 8 bytes - date stored as a string in the format YYYYMMDD. + if isinstance(value, date): + str_val = f"{value.year:04d}{value.month:02d}{value.day:02d}" + elif isinstance(value, list) and len(value) == 3: + str_val = f"{value[0]:04d}{value[1]:02d}{value[2]:02d}" + elif value in MISSING: + str_val = "0" * 8 # QGIS NULL for date type + elif isinstance(value, str) and len(value) == 8: + pass # value is already a date string + else: + raise ShapefileException( + "Date values must be either a datetime.date object, a list, a YYYYMMDD string, or a missing value." + ) + elif fieldType == "L": + # logical: 1 byte - initialized to 0x20 (space) otherwise T or F. + if value in MISSING: + str_val = " " # missing is set to space + elif value in [True, 1]: + str_val = "T" + elif value in [False, 0]: + str_val = "F" + else: + str_val = " " # unknown is set to space + + if str_val is None: + # Types C and M, and anything else, value is forced to string, + # encoded by the codec specified to the Writer (utf-8 by default), + # then the resulting bytes are padded and truncated to the length + # of the field + encoded = ( + str(value) + .encode(self.encoding, self.encodingErrors)[:size] + .ljust(size) + ) + else: + # str_val was given a not-None string value + # under the checks for fieldTypes "N", "F", "D", or "L" above + # Numeric, logical, and date numeric types are ascii already, but + # for Shapefile or dbf spec reasons + # "should be default ascii encoding" + encoded = str_val.encode("ascii", self.encodingErrors) + + if len(encoded) != size: raise ShapefileException( - f"Shapefile Writer unable to pack incorrect sized {value=!r} " - f"(size {len(encoded_val)}) into field '{fieldName}' (size {size})." + f"Shapefile Writer unable to pack incorrect sized {value=}" + f" (encoded as {len(encoded)}B) into field '{fieldName}' ({size}B)." ) - f.write(encoded_val) + f.write(encoded) def balance(self) -> None: """Adds corresponding empty attributes or null geometry records depending @@ -3537,7 +3571,7 @@ def field( # Types of args should match *Field self, name: str, - field_type: Union[str, FieldType] = FieldType.C, + field_type: FieldTypeT = "C", size: int = 50, decimal: int = 0, ) -> None: diff --git a/test_shapefile.py b/test_shapefile.py index a2ffbff..2a10d3e 100644 --- a/test_shapefile.py +++ b/test_shapefile.py @@ -695,7 +695,7 @@ def test_reader_fields(): field = fields[0] assert isinstance(field[0], str) # field name - assert field[1].name in ["C", "N", "F", "L", "D", "M"] # field type + assert field[1] in ["C", "N", "F", "L", "D", "M"] # field type assert isinstance(field[2], int) # field length assert isinstance(field[3], int) # decimal length