diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 893ae1eb..b8c08caa 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -23,17 +23,17 @@ import logging import os from types import MethodType -from typing import Optional +from typing import Any, ClassVar, Optional, Type from alembic.config import Config from alembic.environment import EnvironmentContext from alembic.migration import MigrationContext from alembic.script import ScriptDirectory from alembic.util.exc import CommandError -from sqlalchemy import MetaData, create_engine, desc, inspect, text +from sqlalchemy import MetaData, Subquery, create_engine, desc, inspect, text from sqlalchemy.engine.url import URL, make_url from sqlalchemy.event import listen from sqlalchemy.exc import ArgumentError, DatabaseError, DBAPIError -from sqlalchemy.orm import Session +from sqlalchemy.orm import Query, Session from sqlalchemy.pool import NullPool, StaticPool from .compatibility import CompatibilityTransformations, compatibility_transformations from .db_mapping_base import DatabaseMappingBase, MappedItemBase, MappedTable, PublicItem @@ -49,6 +49,8 @@ ) from .helpers import ( Asterisk, + ItemType, + LegacyItemType, _create_first_spine_database, compare_schemas, copy_database_bind, @@ -119,7 +121,7 @@ class DatabaseMapping(DatabaseMappingQueryMixin, DatabaseMappingCommitMixin, Dat """ - _sq_name_by_item_type = { + _sq_name_by_item_type: ClassVar[dict[ItemType, str]] = { "alternative": "alternative_sq", "scenario": "scenario_sq", "scenario_alternative": "scenario_alternative_sq", @@ -234,25 +236,25 @@ def session(self): return self._session @staticmethod - def item_types() -> list[str]: + def item_types() -> list[ItemType]: return [x for x in DatabaseMapping._sq_name_by_item_type if not ITEM_CLASS_BY_TYPE[x].is_protected] @staticmethod - def all_item_types() -> list[str]: + def all_item_types() -> list[ItemType]: return list(DatabaseMapping._sq_name_by_item_type) @staticmethod - def item_factory(item_type): + def item_factory(item_type: ItemType) -> Type[MappedItemBase]: return ITEM_CLASS_BY_TYPE[item_type] def _query_commit_count(self) -> int: with self: return self.query(self.commit_sq).count() - def make_item(self, item_type: str, **item) -> MappedItemBase: + def make_item(self, item_type: ItemType, **item) -> MappedItemBase: return ITEM_CLASS_BY_TYPE[item_type](self, **item) - def _make_sq(self, item_type): + def _make_sq(self, item_type: ItemType) -> Subquery: sq_name = self._sq_name_by_item_type[item_type] return getattr(self, sq_name) @@ -425,35 +427,35 @@ def _receive_engine_close(self, dbapi_con, _connection_record): copy_database_bind(self._original_engine, self.engine) @staticmethod - def real_item_type(tablename): + def real_item_type(item_type: ItemType | LegacyItemType) -> ItemType: return { "object_class": "entity_class", "relationship_class": "entity_class", "object": "entity", "relationship": "entity", - }.get(tablename, tablename) + }.get(item_type, item_type) @staticmethod - def _convert_legacy(tablename, item): - if tablename in ("entity_class", "entity"): + def _convert_legacy(item_type: ItemType, item: dict) -> None: + if item_type in ("entity_class", "entity"): object_class_id_list = tuple(item.pop("object_class_id_list", ())) if object_class_id_list: item["dimension_id_list"] = object_class_id_list object_class_name_list = tuple(item.pop("object_class_name_list", ())) if object_class_name_list: item["dimension_name_list"] = object_class_name_list - if tablename == "entity": + if item_type == "entity": object_id_list = tuple(item.pop("object_id_list", ())) if object_id_list: item["element_id_list"] = object_id_list object_name_list = tuple(item.pop("object_name_list", ())) if object_name_list: item["element_name_list"] = object_name_list - if tablename in ("parameter_definition", "parameter_value"): + if item_type in ("parameter_definition", "parameter_value"): entity_class_id = item.pop("object_class_id", None) or item.pop("relationship_class_id", None) if entity_class_id: item["entity_class_id"] = entity_class_id - if tablename == "parameter_value": + if item_type == "parameter_value": entity_id = item.pop("object_id", None) or item.pop("relationship_id", None) if entity_id: item["entity_id"] = entity_id @@ -512,10 +514,10 @@ def add(self, mapped_table: MappedTable, **kwargs) -> PublicItem: checked_item.replaced_item_waiting_for_removal = existing_item return checked_item.public_item - def add_by_type(self, item_type: str, **kwargs) -> PublicItem: + def add_by_type(self, item_type: ItemType, **kwargs) -> PublicItem: return self.add(self._mapped_tables[item_type], **kwargs) - def apply_many_by_type(self, item_type: str, method_name: str, items: list[dict], **kwargs) -> None: + def apply_many_by_type(self, item_type: ItemType, method_name: str, items: list[dict], **kwargs) -> None: mapped_table = self._mapped_tables[item_type] method = getattr(self, method_name) for item in items: @@ -541,7 +543,7 @@ def item(self, mapped_table: MappedTable, **kwargs) -> PublicItem: raise SpineDBAPIError(f"{mapped_table.item_type} matching {kwargs} has been removed") return item.public_item - def get_or_add_by_type(self, item_type: str, **kwargs) -> PublicItem: + def get_or_add_by_type(self, item_type: ItemType, **kwargs) -> PublicItem: return self.get_or_add(self.mapped_table(item_type), **kwargs) def get_or_add(self, mapped_table: MappedTable, **kwargs) -> PublicItem: @@ -550,7 +552,7 @@ def get_or_add(self, mapped_table: MappedTable, **kwargs) -> PublicItem: except SpineDBAPIError: return self.add(mapped_table, **kwargs) - def item_by_type(self, item_type: str, **kwargs) -> PublicItem: + def item_by_type(self, item_type: ItemType, **kwargs) -> PublicItem: return self.item(self._mapped_tables[item_type], **kwargs) def find(self, mapped_table: MappedTable, **kwargs) -> list[PublicItem]: @@ -564,17 +566,17 @@ def find(self, mapped_table: MappedTable, **kwargs) -> list[PublicItem]: for entity in entities: print(f"{entity['name']}: {entity['description']}") """ - mapped_table.check_fields(kwargs, valid_types=(type(None),)) fetched = self._fetched.get(mapped_table.item_type, -1) == self._get_commit_count() if not kwargs: if not fetched: self.do_fetch_all(mapped_table) return [i.public_item for i in mapped_table.values() if i.is_valid()] + mapped_table.check_fields(kwargs, valid_types=(type(None),)) if not fetched: self._do_fetch_more(mapped_table, offset=0, limit=None, real_commit_count=None, **kwargs) return [i.public_item for i in mapped_table.values() if i.is_valid() and _fields_equal(i, kwargs)] - def find_by_type(self, item_type: str, **kwargs) -> list[PublicItem]: + def find_by_type(self, item_type: ItemType, **kwargs) -> list[PublicItem]: return self.find(self._mapped_tables[item_type], **kwargs) @staticmethod @@ -600,7 +602,7 @@ def update(mapped_table: MappedTable, **kwargs) -> Optional[PublicItem]: mapped_table.update_item(item_update, target_item, updated_fields) return target_item.public_item - def update_by_type(self, item_type: str, **kwargs) -> PublicItem: + def update_by_type(self, item_type: ItemType, **kwargs) -> PublicItem: return self.update(self._mapped_tables[item_type], **kwargs) def add_or_update(self, mapped_table: MappedTable, **kwargs) -> Optional[PublicItem]: @@ -614,7 +616,7 @@ def add_or_update(self, mapped_table: MappedTable, **kwargs) -> Optional[PublicI pass return self.update(mapped_table, **kwargs) - def add_or_update_by_type(self, item_type: str, **kwargs) -> PublicItem: + def add_or_update_by_type(self, item_type: ItemType, **kwargs) -> PublicItem: return self.add_or_update(self._mapped_tables[item_type], **kwargs) @staticmethod @@ -638,7 +640,7 @@ def remove(mapped_table: MappedTable, **kwargs) -> None: if not removed_item: raise SpineDBAPIError("failed to remove") - def remove_by_type(self, item_type: str, **kwargs) -> None: + def remove_by_type(self, item_type: ItemType, **kwargs) -> None: self.remove(self._mapped_tables[item_type], **kwargs) @staticmethod @@ -661,19 +663,21 @@ def restore(mapped_table: MappedTable, **kwargs) -> PublicItem: raise SpineDBAPIError("failed to restore item") return restored_item.public_item - def restore_by_type(self, item_type: str, **kwargs) -> PublicItem: + def restore_by_type(self, item_type: ItemType, **kwargs) -> PublicItem: return self.restore(self._mapped_tables[item_type], **kwargs) - def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): + def get_item( + self, item_type: ItemType | LegacyItemType, fetch: bool = True, skip_removed: bool = True, **kwargs + ) -> PublicItem | dict: """Finds and returns an item matching the arguments, or an empty dict if none found. This is legacy method. Use :meth:`item` instead. This method supports legacy item types, e.g. object and relationship_class. Args: - item_type (str): One of . - fetch (bool, optional): Whether to fetch the DB in case the item is not found in memory. - skip_removed (bool, optional): Whether to ignore removed items. + item_type: One of . + fetch: Whether to fetch the DB in case the item is not found in memory. + skip_removed: Whether to ignore removed items. **kwargs: Fields and values for one the unique keys as specified for the item type in :ref:`db_mapping_schema`. @@ -699,21 +703,23 @@ def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): return {} return item.public_item - def get_items(self, item_type, fetch=True, skip_removed=True, **kwargs): + def get_items( + self, item_type: ItemType | LegacyItemType, fetch: bool = True, skip_removed: bool = True, **kwargs + ) -> list[PublicItem]: """Finds and returns all the items of one type. This is legacy method. Use :meth:`find` instead. This method supports legacy item types, e.g. object and relationship_class. Args: - item_type (str): One of . - fetch (bool, optional): Whether to fetch the DB before returning the items. - skip_removed (bool, optional): Whether to ignore removed items. + item_type: One of . + fetch: Whether to fetch the DB before returning the items. + skip_removed: Whether to ignore removed items. **kwargs: Fields and values for one the unique keys as specified for the item type in :ref:`db_mapping_schema`. Returns: - list(:class:`PublicItem`): The items. + The items. """ item_type = self.real_item_type(item_type) mapped_table = self.mapped_table(item_type) @@ -723,7 +729,7 @@ def get_items(self, item_type, fetch=True, skip_removed=True, **kwargs): get_items = mapped_table.valid_values if skip_removed else mapped_table.values return [x.public_item for x in get_items() if all(x.get(k) == v for k, v in kwargs.items())] - def item_active_in_scenario(self, item, scenario_id): + def item_active_in_scenario(self, item: PublicItem, scenario_id: TempId) -> bool | None: """Checks if an item is active in a given scenario. Takes into account the ranks of the alternatives and figures @@ -732,12 +738,11 @@ def item_active_in_scenario(self, item, scenario_id): :meta private: Args: - item (:class:`PublicItem`): Item value to check - scenario_id (:class:`TempId`): The id of the scenario to test against + item: Item value to check + scenario_id: The id of the scenario to test against Returns: - result (bool or None): True if the item is active, False if not, - None if no entity alternatives are specified. + True if the item is active, False if not, None if no entity alternatives are specified. """ scenario_table = self._mapped_tables["scenario"] scenario = scenario_table.find_item_by_id(scenario_id) @@ -756,7 +761,10 @@ def item_active_in_scenario(self, item, scenario_id): @staticmethod def _modify_items( - function: Callable[[dict | PublicItem | MappedItemBase], tuple[list[PublicItem], list[str]]], + function: Callable[ + [dict | PublicItem | MappedItemBase | TempId | int], + tuple[PublicItem | None | tuple[PublicItem | None, PublicItem | None], str | None], + ], *items, strict: bool = False, ) -> tuple[list[PublicItem], list[str]]: @@ -771,19 +779,21 @@ def _modify_items( modified.append(item) return modified, errors - def add_item(self, item_type, check=True, **kwargs): + def add_item( + self, item_type: ItemType | LegacyItemType, check: bool = True, **kwargs + ) -> tuple[PublicItem | None, str | None]: """Adds an item to the in-memory mapping. This is legacy method. Use :meth:`add` instead. This method supports legacy item types, e.g. object and relationship_class. Args: - item_type (str): One of . - check (bool): Whether to check for data integrity. + item_type: One of . + check: Whether to check for data integrity. **kwargs: Fields and values as specified for the item type in :ref:`db_mapping_schema`. Returns: - tuple(:class:`PublicItem` or None, str): The added item and any errors. + The added item and any errors. """ item_type = self.real_item_type(item_type) self._convert_legacy(item_type, kwargs) @@ -795,41 +805,45 @@ def add_item(self, item_type, check=True, **kwargs): except SpineDBAPIError as error: return None, str(error) - def add_items(self, item_type, *items, check=True, strict=False): + def add_items( + self, item_type: ItemType | LegacyItemType, *items: dict, check: bool = True, strict: bool = False + ) -> tuple[list[PublicItem | None], list[str | None]]: """Adds many items to the in-memory mapping. This is legacy method. Use the :meth:`add_entities`, :meth:`add_entity_classes` etc. methods instead. This method supports legacy item types, e.g. object and relationship_class. Args: - item_type (str): One of . - *items (Iterable(dict)): One or more :class:`dict` objects mapping fields to values of the item type, + item_type: One of . + *items: One or more :class:`dict` objects mapping fields to values of the item type, as specified in :ref:`db_mapping_schema`. - check (bool): Whether to check for data integrity. - strict (bool): Whether the method should raise :exc:`~.exception.SpineIntegrityError` + check: Whether to check for data integrity. + strict: Whether the method should raise :exc:`~.exception.SpineIntegrityError` if the insertion of one of the items violates an integrity constraint. Returns: - tuple(list(:class:`PublicItem`),list(str)): items successfully added and found violations. + items successfully added and found violations. """ return self._modify_items(lambda x: self.add_item(item_type, check=check, **x), *items, strict=strict) - def update_item(self, item_type, check=True, **kwargs): + def update_item( + self, item_type: ItemType | LegacyItemType, check: bool = True, **kwargs + ) -> tuple[PublicItem | None, str | None]: """Updates an item in the in-memory mapping. This is legacy method. Use :meth:`update` instead. This method supports legacy item types, e.g. object and relationship_class. Args: - item_type (str): One of . - check (bool): Whether to check for data integrity and legacy item types. + item_type: One of . + check: Whether to check for data integrity and legacy item types. **kwargs: Fields to update and their new values as specified for the item type in :ref:`db_mapping_schema`. Returns: - tuple(:class:`PublicItem` or None, str): The updated item and any errors. + The updated item and any errors. """ + item_type = self.real_item_type(item_type) if check: - item_type = self.real_item_type(item_type) self._convert_legacy(item_type, kwargs) mapped_table = self.mapped_table(item_type) try: @@ -846,39 +860,42 @@ def update_item(self, item_type, check=True, **kwargs): mapped_table.update_item(candidate_item, target_item, updated_fields) return target_item.public_item, "" - def update_items(self, item_type, *items, check=True, strict=False): + def update_items( + self, item_type: ItemType | LegacyItemType, *items: dict, check: bool = True, strict: bool = False + ) -> tuple[list[PublicItem | None], list[str | None]]: """Updates many items in the in-memory mapping. This is legacy method. Use the :meth:`update_entities`, :meth:`update_entity_classes` etc. methods instead. This method supports legacy item types, e.g. object and relationship_class. Args: - item_type (str): One of . - *items (Iterable(dict)): One or more :class:`dict` objects mapping fields to values of the item type, + item_type: One of . + *items: One or more :class:`dict` objects mapping fields to values of the item type, as specified in :ref:`db_mapping_schema` and including the `id`. - check (bool): Whether to check for data integrity. - strict (bool): Whether the method should raise :exc:`~.exception.SpineIntegrityError` + check: Whether to check for data integrity. + strict: Whether the method should raise :exc:`~.exception.SpineIntegrityError` if the update of one of the items violates an integrity constraint. Returns: - tuple(list(:class:`PublicItem`),list(str)): items successfully updated and found violations. + items successfully updated and found violations. """ return self._modify_items(lambda x: self.update_item(item_type, check=check, **x), *items, strict=strict) - def add_update_item(self, item_type, check=True, **kwargs): + def add_update_item( + self, item_type: ItemType | LegacyItemType, check: bool = True, **kwargs + ) -> tuple[PublicItem | None, PublicItem | None, str | None]: """Adds an item to the in-memory mapping if it doesn't exist; otherwise updates the current one. This is legacy method. Use :meth:`add_or_update` instead. This method supports legacy item types, e.g. object and relationship_class. Args: - item_type (str): One of . - check (bool): Whether to check for data integrity. + item_type: One of . + check: Whether to check for data integrity. **kwargs: Fields and values as specified for the item type in :ref:`db_mapping_schema`. Returns: - tuple(:class:`PublicItem` or None, :class:`PublicItem` or None, str): The added item if any, - the updated item if any, and any errors. + The added item if any, the updated item if any, and any errors. """ added, add_error = self.add_item(item_type, check=check, **kwargs) if not add_error: @@ -888,7 +905,13 @@ def add_update_item(self, item_type, check=True, **kwargs): return None, updated, update_error return None, None, add_error or update_error - def add_update_items(self, item_type, *items, check=True, strict=False): + def add_update_items( + self, + item_type: ItemType | LegacyItemType, + *items: PublicItem | MappedItemBase | dict, + check: bool = True, + strict: bool = False, + ) -> tuple[list[PublicItem | None], list[PublicItem | None], list[str | None]]: """Adds or updates many items into the in-memory mapping. This is legacy method. @@ -896,16 +919,15 @@ def add_update_items(self, item_type, *items, check=True, strict=False): This method supports legacy item types, e.g. object and relationship_class. Args: - item_type (str): One of . - *items (Iterable(dict)): One or more :class:`dict` objects mapping fields to values of the item type, + item_type: One of . + *items: One or more :class:`dict` objects mapping fields to values of the item type, as specified in :ref:`db_mapping_schema`. - check (bool): Whether to check for data integrity. - strict (bool): Whether the method should raise :exc:`~.exception.SpineIntegrityError` + check: Whether to check for data integrity. + strict: Whether the method should raise :exc:`~.exception.SpineIntegrityError` if the insertion of one of the items violates an integrity constraint. Returns: - tuple(list(:class:`PublicItem`),list(:class:`PublicItem`),list(str)): items successfully added, - items successfully updated, and found violations. + items successfully added, items successfully updated, and found violations. """ def _function(item): @@ -918,16 +940,18 @@ def _function(item): updated = [x for x in updated if x] return added, updated, errors - def remove_item(self, item_type, id_, check=True): + def remove_item( + self, item_type: ItemType | LegacyItemType, id_: TempId | int, check: bool = True + ) -> tuple[PublicItem | None, str | None]: """Removes an item from the in-memory mapping. This is legacy method. Use :meth:`remove` instead. This method supports legacy item types, e.g. object and relationship_class. Args: - item_type (str): One of . - id_ (int): The id of the item to remove. - check (bool): Whether to check for data integrity. + item_type: One of . + id_: The id of the item to remove. + check: Whether to check for data integrity. Returns: tuple(:class:`PublicItem` or None, str): The removed item and any errors. @@ -942,7 +966,7 @@ def remove_item(self, item_type, id_, check=True): return (removed_item.public_item, None) if removed_item else (None, "failed to remove") def remove_items( - self, item_type: str, *ids, check: bool = True, strict: bool = False + self, item_type: ItemType | LegacyItemType, *ids: TempId | int, check: bool = True, strict: bool = False ) -> tuple[list[PublicItem], list[str]]: """Removes many items from the in-memory mapping. @@ -969,30 +993,34 @@ def remove_items( return [], [] return self._modify_items(lambda x: self.remove_item(item_type, x, check=check), *ids, strict=strict) - def cascade_remove_items(self, cache=None, **kwargs): + def cascade_remove_items(self, cache: Any | None = None, **kwargs: TempId | int): # Legacy for item_type, ids in kwargs.items(): self.remove_items(item_type, *ids) - def restore_item(self, item_type, id_): + def restore_item( + self, item_type: ItemType | LegacyItemType, id_: TempId | int + ) -> tuple[PublicItem | None, str | None]: """Restores a previously removed item into the in-memory mapping. This is legacy method. Use :meth:`restore` instead. This method supports legacy item types, e.g. object and relationship_class. Args: - item_type (str): One of . - id_ (int): The id of the item to restore. + item_type: One of . + id_: The id of the item to restore. Returns: - tuple(:class:`PublicItem` or None, str): The restored item if any and possible error. + The restored item if any and possible error. """ item_type = self.real_item_type(item_type) mapped_table = self.mapped_table(item_type) restored_item = mapped_table.restore_item(id_) return (restored_item.public_item, None) if restored_item else (None, "failed to restore item") - def restore_items(self, item_type, *ids): + def restore_items( + self, item_type: ItemType | LegacyItemType, *ids: TempId | int + ) -> tuple[list[PublicItem | None], list[str | None]]: """Restores many previously removed items into the in-memory mapping. This is legacy method. @@ -1000,15 +1028,15 @@ def restore_items(self, item_type, *ids): This method supports legacy item types, e.g. object and relationship_class. Args: - item_type (str): One of . + item_type: One of . *ids: Ids of items to be removed. Returns: - tuple(list(:class:`PublicItem`),list(str)): items successfully restored and found violations. + items successfully restored and found violations. """ return self._modify_items(lambda x: self.restore_item(item_type, x), *ids) - def purge_items(self, item_type: str) -> bool: + def purge_items(self, item_type: ItemType | LegacyItemType) -> bool: """Removes all items of one type. This is legacy method. Use :meth:`remove_entity`, :meth:`remove_entity_class` etc. @@ -1023,18 +1051,20 @@ def purge_items(self, item_type: str) -> bool: """ return bool(self.remove_items(item_type, Asterisk)[0]) - def fetch_more(self, item_type, offset=0, limit=None, **kwargs): + def fetch_more( + self, item_type: ItemType | LegacyItemType, offset: int = 0, limit: int | None = None, **kwargs + ) -> list[PublicItem]: """Fetches items from the DB into the in-memory mapping, incrementally. Args: - item_type (str): One of . - offset (int): The initial row. - limit (int): The maximum number of rows to fetch. + item_type: One of . + offset: The initial row. + limit: The maximum number of rows to fetch. **kwargs: Fields and values for one the unique keys as specified for the item type in :ref:`db_mapping_schema`. Returns: - list(:class:`PublicItem`): The items fetched. + The items fetched. """ item_type = self.real_item_type(item_type) mapped_table = self.mapped_table(item_type) @@ -1043,7 +1073,7 @@ def fetch_more(self, item_type, offset=0, limit=None, **kwargs): for x in self._do_fetch_more(mapped_table, offset=offset, limit=limit, real_commit_count=None, **kwargs) ] - def fetch_all(self, *item_types) -> list[PublicItem]: + def fetch_all(self, *item_types: ItemType | LegacyItemType) -> list[PublicItem]: """Fetches items from the DB into the in-memory mapping. Unlike :meth:`fetch_more`, this method fetches entire tables. @@ -1059,7 +1089,7 @@ def fetch_all(self, *item_types) -> list[PublicItem]: items += [item.public_item for item in self.do_fetch_all(mapped_table, commit_count)] return items - def query(self, *entities, **kwargs): + def query(self, *entities, **kwargs) -> Query: """Returns a :class:`~spinedb_api.query.Query` object to execute against the mapped DB. To perform custom ``SELECT`` statements, call this method with one or more of the documented @@ -1226,32 +1256,34 @@ def _muster_items_by_modified_status( to_add.append(item) return to_add, to_update, to_remove - def rollback_session(self): + def rollback_session(self) -> None: """Discards all the changes from the in-memory mapping.""" if not self._rollback(): raise NothingToRollback() if self._memory: self._memory_dirty = False - def has_external_commits(self): + def has_external_commits(self) -> bool: """Tests whether the database has had commits from other sources than this mapping. Returns: - bool: True if database has external commits, False otherwise + True if database has external commits, False otherwise """ return self._commit_count != self._query_commit_count() - def add_ext_entity_metadata(self, *items, **kwargs): + def add_ext_entity_metadata(self, *items, **kwargs) -> tuple[list[PublicItem | None], list[PublicItem | None]]: metadata_items = self.get_metadata_to_add_with_item_metadata_items(*items) self.add_items("metadata", *metadata_items, **kwargs) return self.add_items("entity_metadata", *items, **kwargs) - def add_ext_parameter_value_metadata(self, *items, **kwargs): + def add_ext_parameter_value_metadata( + self, *items, **kwargs + ) -> tuple[list[PublicItem | None], list[PublicItem | None]]: metadata_items = self.get_metadata_to_add_with_item_metadata_items(*items) self.add_items("metadata", *metadata_items, **kwargs) return self.add_items("parameter_value_metadata", *items, **kwargs) - def get_metadata_to_add_with_item_metadata_items(self, *items): + def get_metadata_to_add_with_item_metadata_items(self, *items) -> list[dict]: metadata_table = self._mapped_tables["metadata"] new_metadata = [] for item in items: @@ -1262,19 +1294,21 @@ def get_metadata_to_add_with_item_metadata_items(self, *items): new_metadata.append(metadata) return new_metadata - def _update_ext_item_metadata(self, tablename, *items, **kwargs): + def _update_ext_item_metadata( + self, tablename: ItemType | LegacyItemType, *items, **kwargs + ) -> tuple[list[PublicItem | None], list[str | None]]: metadata_items = self.get_metadata_to_add_with_item_metadata_items(*items) added, errors = self.add_items("metadata", *metadata_items, **kwargs) updated, more_errors = self.update_items(tablename, *items, **kwargs) return added + updated, errors + more_errors - def update_ext_entity_metadata(self, *items, **kwargs): + def update_ext_entity_metadata(self, *items, **kwargs) -> tuple[list[PublicItem | None], list[str | None]]: return self._update_ext_item_metadata("entity_metadata", *items, **kwargs) - def update_ext_parameter_value_metadata(self, *items, **kwargs): + def update_ext_parameter_value_metadata(self, *items, **kwargs) -> tuple[list[PublicItem | None], list[str | None]]: return self._update_ext_item_metadata("parameter_value_metadata", *items, **kwargs) - def remove_unused_metadata(self): + def remove_unused_metadata(self) -> None: used_metadata_ids = set() for x in self._mapped_tables["entity_metadata"].valid_values(): used_metadata_ids.add(x["metadata_id"]) @@ -1287,7 +1321,7 @@ def get_filter_configs(self) -> list[dict]: """Returns the config dicts of filters applied to this database mapping.""" return self.filter_configs - def close(self): + def close(self) -> None: self._original_engine.dispose() self._original_engine = None self.engine = None diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 1cddd1df..b1a4ea13 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -14,16 +14,21 @@ from contextlib import suppress from dataclasses import dataclass from difflib import SequenceMatcher -from typing import Any, ClassVar, Optional, Type, TypedDict, Union +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Type, TypedDict, Union +from sqlalchemy import Subquery +from sqlalchemy.orm import Query from .exception import SpineDBAPIError -from .helpers import Asterisk, AsteriskType +from .helpers import Asterisk, AsteriskType, ItemType from .mapped_item_status import Status from .temp_id import TempId, resolve +if TYPE_CHECKING: + from .db_mapping import DatabaseMapping + @dataclass(frozen=True) class DirtyItems: - item_type: str + item_type: ItemType to_add: list[MappedItemBase] to_update: list[MappedItemBase] to_remove: list[MappedItemBase] @@ -42,10 +47,10 @@ def __init__(self): self._closed = False self._context_open_count = 0 self._mapped_tables = {item_type: MappedTable(self, item_type) for item_type in self.all_item_types()} - self._fetched = {} - self._commit_count = None + self._fetched: dict[ItemType, int] = {} + self._commit_count: int | None = None item_types = self.item_types() - self._sorted_item_types = [] + self._sorted_item_types: list[ItemType] = [] while item_types: item_type = item_types.pop(0) if not self.item_factory(item_type).ref_types().isdisjoint(item_types): @@ -62,54 +67,27 @@ def close(self): self._closed = True @staticmethod - def item_types(): - """Returns a list of public item types from the DB mapping schema (equivalent to the table names). - - :meta private: - - Returns: - list(str) - """ + def item_types() -> list[ItemType]: + """Returns a list of public item types from the DB mapping schema (equivalent to the table names).""" raise NotImplementedError() @staticmethod - def all_item_types(): - """Returns a list of all item types from the DB mapping schema (equivalent to the table names). - - :meta private: - - Returns: - list(str) - """ + def all_item_types() -> list[ItemType]: + """Returns a list of all item types from the DB mapping schema (equivalent to the table names).""" raise NotImplementedError() @staticmethod - def item_factory(item_type): - """Returns a subclass of :class:`.MappedItemBase` to make items of given type. - - :meta private: - - Args: - item_type (str) - - Returns: - function - """ + def item_factory(item_type: ItemType) -> Type[MappedItemBase]: + """Returns a subclass of :class:`.MappedItemBase` to make items of given type.""" raise NotImplementedError() - def _make_sq(self, item_type): + def _make_sq(self, item_type: ItemType) -> Subquery: """Returns a :class:`~sqlalchemy.sql.expression.Alias` object representing a subquery to collect items of given type. - - Args: - item_type (str) - - Returns: - :class:`~sqlalchemy.sql.expression.Alias` """ raise NotImplementedError() - def _make_query(self, sq): + def _make_query(self, sq: Subquery) -> Query: """Returns a :class:`~sqlalchemy.orm.query.Query` object from given subquery. Args: @@ -128,10 +106,10 @@ def _query_commit_count(self): """ raise NotImplementedError() - def make_item(self, item_type: str, **item) -> MappedItemBase: + def make_item(self, item_type: ItemType, **item) -> MappedItemBase: raise NotImplementedError - def dirty_ids(self, item_type): + def dirty_ids(self, item_type: ItemType): return { item["id"] for item in self._mapped_tables[item_type].valid_values() @@ -223,7 +201,7 @@ def refresh_session(self): self._commit_count = None self._fetched.clear() - def mapped_table(self, item_type: str) -> MappedTable: + def mapped_table(self, item_type: ItemType) -> MappedTable: """Returns mapped table for given item type.""" try: return self._mapped_tables[item_type] @@ -255,7 +233,7 @@ def reset_purging(self): for mapped_table in self._mapped_tables.values(): mapped_table.reset_purging() - def _add_descendants(self, item_types): + def _add_descendants(self, item_types: set[ItemType]) -> None: while True: changed = False for item_type in set(self.item_types()) - item_types: @@ -266,20 +244,16 @@ def _add_descendants(self, item_types): break def _get_commit_count(self) -> int: - """Returns current commit count. - - Returns: - int - """ + """Returns current commit count.""" if self._commit_count is None: self._commit_count = self._query_commit_count() return self._commit_count - def _do_make_query(self, item_type, **kwargs): + def _do_make_query(self, item_type: ItemType, **kwargs) -> Query: """Returns a :class:`~spinedb_api.query.Query` object to fetch items of given type. Args: - item_type (str): item type + item_type: item type **kwargs: query filters Returns: @@ -303,11 +277,11 @@ def _do_make_query(self, item_type, **kwargs): pass return qry - def _get_next_chunk(self, item_type, offset, limit, **kwargs): + def _get_next_chunk(self, item_type: ItemType, offset: int, limit: int, **kwargs) -> list[dict]: """Gets chunk of items from the DB. Returns: - list(dict): list of dictionary items. + list of dictionary items. """ with self: qry = self._do_make_query(item_type, **kwargs) @@ -380,7 +354,7 @@ def item(self, mapped_table: MappedTable, **kwargs) -> PublicItem: class MappedTable(dict): - def __init__(self, db_map: DatabaseMappingBase, item_type: str, *args, **kwargs): + def __init__(self, db_map: DatabaseMapping, item_type: ItemType, *args, **kwargs): """ Args: db_map: the DB mapping where this mapped table belongs. @@ -402,7 +376,7 @@ def purged(self) -> bool: def purged(self, purged: bool) -> None: self.wildcard_item.status = Status.to_remove if purged else Status.committed - def get(self, id_, default=None): + def get(self, id_: TempId | int, default: MappedItemBase | None = None) -> MappedItemBase | None: id_ = self._temp_id_lookup.get(id_, id_) return super().get(id_, default) @@ -440,7 +414,7 @@ def find_item(self, item: dict, fetch: bool = True) -> MappedItemBase: return self.find_item_by_id(id_, fetch=fetch) return self.find_item_by_unique_key(item, fetch=fetch) - def find_item_by_id(self, id_: TempId, fetch: bool = True) -> MappedItemBase: + def find_item_by_id(self, id_: TempId | int, fetch: bool = True) -> MappedItemBase: current_item = self.get(id_) if current_item is None and fetch: self._db_map.do_fetch_all(self) @@ -518,7 +492,7 @@ def _prepare_item( if unique_item is not current_item and unique_item.is_valid(): raise SpineDBAPIError(f"there's already a {self.item_type} with {dict(zip(key, value))}") - def item_to_remove(self, id_: TempId | AsteriskType) -> MappedItemBase: + def item_to_remove(self, id_: TempId | int | AsteriskType) -> MappedItemBase: if id_ is Asterisk: return self.wildcard_item return self.find_item_by_id(id_) @@ -646,7 +620,7 @@ def remove_item(self, item: Optional[MappedItemBase]) -> Optional[MappedItemBase item.cascade_remove() return item - def restore_item(self, id_: TempId) -> Optional[MappedItemBase]: + def restore_item(self, id_: TempId | int) -> Optional[MappedItemBase]: if id_ is Asterisk: self.purged = False for current_item in self.values(): @@ -678,7 +652,7 @@ class FieldDict(TypedDict): class MappedItemBase(dict): """A dictionary that represents a db item.""" - item_type: ClassVar[str] = "not implemented" + item_type: ClassVar[ItemType] = NotImplemented fields: ClassVar[dict[str, FieldDict]] = {} """A dictionary mapping fields to a another dict mapping "type" to a Python type, "value" to a description of the value for the key, and "optional" to a bool.""" @@ -689,7 +663,7 @@ class MappedItemBase(dict): """A tuple where each element is itself a tuple of fields corresponding to a unique key.""" required_key_combinations: ClassVar[tuple[tuple[str, ...], ...]] = () """Tuple containing tuples of required keys and their possible alternatives.""" - _references: ClassVar[dict[str, str]] = {} + _references: ClassVar[dict[str, ItemType]] = {} """A dictionary mapping source fields to reference item type. Used to access external fields. """ @@ -704,7 +678,7 @@ class MappedItemBase(dict): When accessing fields in _external_fields, we first find the reference pointed at by the source field, and then return the target field of that reference. """ - _alt_references: ClassVar[dict[tuple[str, ...], tuple[str, tuple[str, ...]]]] = {} + _alt_references: ClassVar[dict[tuple[str, ...], tuple[ItemType, tuple[str, ...]]]] = {} """A dictionary mapping source fields, to a tuple of reference item type and reference fields. Used only to resolve internal fields at item creation. """ @@ -721,7 +695,7 @@ class MappedItemBase(dict): fields_not_requiring_cascade_update: ClassVar[set[str]] = set() is_protected: ClassVar[bool] = False - def __init__(self, db_map: DatabaseMappingBase, **kwargs): + def __init__(self, db_map: DatabaseMapping, **kwargs): """ Args: db_map: the DB where this item belongs. @@ -760,7 +734,7 @@ def handle_refetch(self) -> None: self._valid = None @classmethod - def ref_types(cls) -> set[str]: + def ref_types(cls) -> set[ItemType]: """Returns a set of item types that this class refers.""" return set(cls._references.values()) @@ -1234,7 +1208,7 @@ def __init__(self, mapped_item: MappedItemBase): self._mapped_item = mapped_item @property - def item_type(self) -> str: + def item_type(self) -> ItemType: return self._mapped_item.item_type @property @@ -1242,7 +1216,7 @@ def mapped_item(self) -> MappedItemBase: return self._mapped_item @property - def db_map(self) -> DatabaseMappingBase: + def db_map(self) -> DatabaseMapping: return self._mapped_item.db_map def __getitem__(self, key): diff --git a/spinedb_api/export_mapping/group_functions.py b/spinedb_api/export_mapping/group_functions.py index edfb7e76..53021e6b 100644 --- a/spinedb_api/export_mapping/group_functions.py +++ b/spinedb_api/export_mapping/group_functions.py @@ -9,26 +9,17 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### -""" -Contains functions to group values in pivot tables with hidden columns or rows. - -""" +""" Contains functions to group values in pivot tables with hidden columns or rows. """ +from typing import Any, ClassVar, Type import numpy as np class GroupFunction: - NAME = NotImplemented - DISPLAY_NAME = NotImplemented - - def __call__(self, items): - """Performs the grouping. Reduces the given list of items into a single value. - - Args: - items (list or None) + NAME: ClassVar[str] = NotImplemented + DISPLAY_NAME: ClassVar[str] = NotImplemented - Returns: - Any - """ + def __call__(self, items: list | None) -> Any: + """Performs the grouping. Reduces the given list of items into a single value.""" raise NotImplementedError @@ -116,28 +107,36 @@ def __call__(self, items): return items[0] -_classes = (NoGroup, GroupSum, GroupMean, GroupMin, GroupMax, GroupConcat, GroupOneOrNone) +_classes: tuple[Type[GroupFunction], ...] = ( + NoGroup, + GroupSum, + GroupMean, + GroupMin, + GroupMax, + GroupConcat, + GroupOneOrNone, +) -GROUP_FUNCTION_DISPLAY_NAMES = [klass.DISPLAY_NAME for klass in _classes] +GROUP_FUNCTION_DISPLAY_NAMES: list[str] = [klass.DISPLAY_NAME for klass in _classes] -def group_function_name_from_display(display_name): +def group_function_name_from_display(display_name: str) -> str: return {klass.DISPLAY_NAME: klass.NAME for klass in _classes}.get(display_name, NoGroup.NAME) -def group_function_display_from_name(name): +def group_function_display_from_name(name: str) -> str: return {klass.NAME: klass.DISPLAY_NAME for klass in _classes}.get(name, NoGroup.DISPLAY_NAME) -def from_str(name): +def from_str(name: str | None) -> GroupFunction: """ Creates group function from name. Args: - name (str, NoneType): group function name or None if no aggregation wanted. + name: group function name or None if no aggregation wanted. Returns: - GroupFunction or NoneType + GroupFunction; NoGroup if no aggregation wanted """ constructor = {klass.NAME: klass for klass in _classes}.get(name, NoGroup) return constructor() diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index a9703837..1c046740 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -17,7 +17,7 @@ import json from operator import itemgetter import os -from typing import Any +from typing import Any, Literal, TypeAlias import warnings from alembic.config import Config from alembic.environment import EnvironmentContext @@ -59,6 +59,31 @@ from sqlalchemy.sql.selectable import SelectBase from .exception import SpineDBAPIError, SpineDBVersionError +ItemType: TypeAlias = Literal[ + "alternative", + "commit", + "display_mode", + "entity", + "entity_alternative", + "entity_class_display_mode", + "entity_class", + "entity_group", + "entity_location", + "entity_metadata", + "list_value", + "metadata", + "parameter_definition", + "parameter_type", + "parameter_value", + "parameter_value_list", + "parameter_value_metadata", + "scenario", + "scenario_alternative", + "superclass_subclass", +] + +LegacyItemType = Literal["object", "object_class", "relationship", "relationship_class"] + SUPPORTED_DIALECTS = { "mysql": "pymysql", "sqlite": "sqlite3", diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index f5223276..d7c94477 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -15,11 +15,11 @@ import inspect from operator import itemgetter import re -from typing import ClassVar, Optional, Union +from typing import ClassVar, Literal, Optional, Type, Union from . import arrow_value from .db_mapping_base import DatabaseMappingBase, MappedItemBase, MappedTable from .exception import SpineDBAPIError -from .helpers import DisplayStatus, name_from_dimensions, name_from_elements +from .helpers import DisplayStatus, ItemType, name_from_dimensions, name_from_elements from .parameter_value import ( RANK_1_TYPES, UNPARSED_NULL_VALUE, @@ -1448,13 +1448,17 @@ class EntityLocationItem(MappedItemBase): } -ITEM_CLASSES = tuple( +ITEM_CLASSES: tuple[Type[MappedItemBase], ...] = tuple( x for x in tuple(locals().values()) if inspect.isclass(x) and issubclass(x, MappedItemBase) and x != MappedItemBase ) -ITEM_CLASS_BY_TYPE = {klass.item_type: klass for klass in ITEM_CLASSES} +ITEM_CLASS_BY_TYPE: dict[ItemType, Type[MappedItemBase]] = {klass.item_type: klass for klass in ITEM_CLASSES} -def _byname_iter(item: Union[EntityClassItem, EntityItem], id_list_name: str, table: MappedTable) -> Iterator[str]: +def _byname_iter( + item: Union[EntityClassItem, EntityItem], + id_list_name: Literal["dimension_id_list", "element_id_list"], + table: MappedTable, +) -> Iterator[str]: id_list = item[id_list_name] if not id_list: yield item["name"] diff --git a/spinedb_api/mapping.py b/spinedb_api/mapping.py index d67a7430..4628cecf 100644 --- a/spinedb_api/mapping.py +++ b/spinedb_api/mapping.py @@ -12,10 +12,11 @@ """Base class for import and export mappings.""" +from __future__ import annotations from enum import Enum, unique from itertools import takewhile import re -from typing import Optional +from typing import Any, ClassVar, Optional from spinedb_api import InvalidMapping _TABLEFUL_FIXED_POSITION_RE = re.compile(r"^\s*(.+):\s*(\d+)\s*,\s*(\d+)\s*$") @@ -33,26 +34,26 @@ class Position(Enum): fixed = "fixed" -def is_pivoted(position): +def is_pivoted(position: Position | int) -> bool: """Checks if position is pivoted. Args: - position (Position or int): position + position: position Returns: - bool: True if position is pivoted, False otherwise + True if position is pivoted, False otherwise """ return isinstance(position, int) and position < 0 -def is_regular(position): +def is_regular(position: Position | int) -> bool: """Checks if position is column index. Args: - position (Position or int): position + position: position Returns: - bool: True if position is a column index, False otherwise + True if position is a column index, False otherwise """ return isinstance(position, int) and position >= 0 @@ -76,67 +77,67 @@ class Mapping: """Base class for import/export item mappings. Attributes: - position (int or Position): defines where the data is written/read in the output table. + position: defines where the data is written/read in the output table. Nonnegative numbers are columns, negative numbers are pivot rows, and then there are some special cases in the Position enum. - parent (Mapping or None): Another mapping that's the 'parent' of this one. + parent: Another mapping that's the 'parent' of this one. Used to determine if a mapping is root, in which case it needs to yield the header. """ - MAP_TYPE = None + MAP_TYPE: ClassVar[str] = NotImplemented """Mapping type identifier for serialization.""" - def __init__(self, position, value=None, filter_re=""): + def __init__(self, position: Position | int, value: Any = None, filter_re: str = ""): """ Args: - position (int or Position): column index or Position - value (Any): fixed value - filter_re (str): regular expression for filtering + position: column index or Position + value: fixed value + filter_re: regular expression for filtering """ - self._child = None + self._child: Mapping | None = None self._value = None self._unfixed_value_data = self._data self._filter_re = None - self.parent = None + self.parent: Mapping | None = None self.position = position self.value = value self.filter_re = filter_re @property - def child(self): + def child(self) -> Mapping | None: return self._child @child.setter - def child(self, child): + def child(self, child: Mapping | None) -> None: self._child = child if isinstance(child, Mapping): child.parent = self @property - def value(self): + def value(self) -> Any: """Fixed value.""" return self._value @value.setter - def value(self, value): + def value(self, value: Any) -> None: self._value = value self._set_fixed_value_data() @property - def filter_re(self): + def filter_re(self) -> str: return self._filter_re.pattern if self._filter_re is not None else "" @filter_re.setter def filter_re(self, filter_re): self._filter_re = re.compile(filter_re) if filter_re else None - def _data(self, row): + def _data(self, row: int) -> Any: raise NotImplementedError() - def _fixed_value_data(self, _row): + def _fixed_value_data(self, _row: int) -> Any: return self._value - def _set_fixed_value_data(self): + def _set_fixed_value_data(self) -> None: if self._value is None: self._data = self._unfixed_value_data return @@ -152,50 +153,50 @@ def __eq__(self, other): and self._filter_re == other._filter_re ) - def tail_mapping(self): + def tail_mapping(self) -> Mapping: """Returns the last mapping in the chain. Returns: - Mapping: last child mapping + last child mapping """ if self._child is None: return self return self._child.tail_mapping() - def count_mappings(self): + def count_mappings(self) -> int: """ Counts this and child mappings. Returns: - int: number of mappings + number of mappings """ return 1 + (self.child.count_mappings() if self.child is not None else 0) - def flatten(self): + def flatten(self) -> list[Mapping]: """ Flattens the mapping tree. Returns: - list of Mapping: mappings in parent-child-grand child-etc order + mappings in parent-child-grand child-etc order """ return [self] + (self.child.flatten() if self.child is not None else []) - def is_effective_leaf(self): + def is_effective_leaf(self) -> bool: """Tests if mapping is effectively the leaf mapping. Returns: - bool: True if mapping is effectively the last child, False otherwise + True if mapping is effectively the last child, False otherwise """ return self._child is None or all( child.position in (Position.hidden, Position.table_name) for child in self._child.flatten()[:-1] ) - def is_pivoted(self): + def is_pivoted(self) -> bool: """ Queries recursively if export items are pivoted. Returns: - bool: True if any of the items is pivoted, False otherwise + True if any of the items is pivoted, False otherwise """ if is_pivoted(self.position): return True @@ -203,15 +204,15 @@ def is_pivoted(self): return False return self.child.is_pivoted() - def non_pivoted_width(self, parent_is_pivoted=False): + def non_pivoted_width(self, parent_is_pivoted: bool = False) -> int: """ Calculates columnar width of non-pivoted data. Args: - parent_is_pivoted (bool): True if a parent item is pivoted, False otherwise + parent_is_pivoted: True if a parent item is pivoted, False otherwise Returns: - int: non-pivoted data width + non-pivoted data width """ if self.child is None: if is_regular(self.position) and not parent_is_pivoted: @@ -220,14 +221,14 @@ def non_pivoted_width(self, parent_is_pivoted=False): width = self.position + 1 if is_regular(self.position) else 0 return max(width, self.child.non_pivoted_width(parent_is_pivoted or is_pivoted(self.position))) - def non_pivoted_columns(self, parent_is_pivoted=False): + def non_pivoted_columns(self, parent_is_pivoted: bool = False) -> list[int]: """Gathers non-pivoted columns from mappings. Args: - parent_is_pivoted (bool): True if a parent item is pivoted, False otherwise + parent_is_pivoted: True if a parent item is pivoted, False otherwise Returns: - list of int: indexes of non-pivoted columns + indexes of non-pivoted columns """ if self._child is None: if is_regular(self.position) and not parent_is_pivoted: @@ -238,30 +239,30 @@ def non_pivoted_columns(self, parent_is_pivoted=False): parent_is_pivoted or pivoted ) - def last_pivot_row(self): + def last_pivot_row(self) -> int: return max( (-(m.position + 1) for m in self.flatten() if isinstance(m.position, int) and m.position < 0), default=-1 ) - def query_parents(self, what): + def query_parents(self, what: str) -> Any: """Queries parent mapping for specific information. Args: - what (str): query identifier + what: query identifier Returns: - Any: query result or None if no parent recognized the identifier + query result or None if no parent recognized the identifier """ if self.parent is None: return None return self.parent.query_parents(what) - def to_dict(self): + def to_dict(self) -> dict: """ Serializes mapping into dict. Returns: - dict: serialized mapping + serialized mapping """ position = self.position.value if isinstance(self.position, Position) else self.position mapping_dict = {"map_type": self.MAP_TYPE, "position": position} @@ -272,15 +273,15 @@ def to_dict(self): return mapping_dict -def unflatten(mappings): +def unflatten(mappings: list[Mapping]) -> Mapping: """ Builds a mapping hierarchy from flattened mappings. Args: - mappings (Iterable of Mapping): flattened mappings + mappings: flattened mappings Returns: - Mapping: root mapping + root mapping """ root = None current = None @@ -294,15 +295,15 @@ def unflatten(mappings): return root -def value_index(flattened_mappings): +def value_index(flattened_mappings: list[Mapping]) -> int: """ Returns index of last non-hidden mapping in flattened mapping list. Args: - flattened_mappings (list of Mapping): flattened mappings + flattened_mappings: flattened mappings Returns: - int: value mapping index + value mapping index """ return ( len(flattened_mappings) @@ -311,14 +312,14 @@ def value_index(flattened_mappings): ) -def to_dict(root_mapping): +def to_dict(root_mapping: Mapping) -> list[dict]: """ Serializes mappings into JSON compatible data structure. Args: - root_mapping (Mapping): root mapping + root_mapping: root mapping Returns: - list: serialized mappings + serialized mappings """ return list(mapping.to_dict() for mapping in root_mapping.flatten()) diff --git a/spinedb_api/spine_io/exporters/writer.py b/spinedb_api/spine_io/exporters/writer.py index 5dbb6c1b..e261ab27 100644 --- a/spinedb_api/spine_io/exporters/writer.py +++ b/spinedb_api/spine_io/exporters/writer.py @@ -10,28 +10,37 @@ # this program. If not, see . ###################################################################################################################### """ Module contains the :class:`Writer` base class and functions to write tabular data. """ +from __future__ import annotations from contextlib import contextmanager from copy import copy from sqlalchemy.exc import OperationalError -from spinedb_api import SpineDBAPIError +from spinedb_api import DatabaseMapping, SpineDBAPIError from spinedb_api.export_mapping import rows, titles -from spinedb_api.export_mapping.export_mapping import drop_non_positioned_tail +from spinedb_api.export_mapping.export_mapping import ExportMapping, drop_non_positioned_tail from spinedb_api.export_mapping.group_functions import NoGroup -def write(db_map, writer, *mappings, empty_data_header=True, max_tables=None, max_rows=None, group_fns=NoGroup.NAME): +def write( + db_map: DatabaseMapping, + writer: Writer, + *mappings: ExportMapping, + empty_data_header: bool | list[bool] = True, + max_tables: int | None = None, + max_rows: int | None = None, + group_fns: str | list[str] = NoGroup.NAME, +): """ Writes given mapping. Args: - db_map (DatabaseMapping): database map - writer (Writer): target writer - mappings (Mapping): root mappings - empty_data_header (bool or Iterable of bool): True to write at least header rows even if there is no data, + db_map: database map + writer: target writer + mappings: root mappings + empty_data_header: True to write at least header rows even if there is no data, False to write nothing; a list of booleans applies to each mapping individually - max_tables (int, optional): maximum number of tables to write - max_rows (int, optional): maximum number of rows/table to write - group_fns (str or Iterable of str): group function names for each mappings + max_table: maximum number of tables to write + max_rows: maximum number of rows/table to write + group_fns: group function names for each mappings """ if isinstance(empty_data_header, bool): empty_data_header = len(mappings) * [empty_data_header] diff --git a/spinedb_api/spine_io/importers/json_reader.py b/spinedb_api/spine_io/importers/json_reader.py index eaf70553..5f4031ea 100644 --- a/spinedb_api/spine_io/importers/json_reader.py +++ b/spinedb_api/spine_io/importers/json_reader.py @@ -32,7 +32,7 @@ class JSONReader(Reader): # File extensions for modal widget that that returns source object and action (OK, CANCEL) FILE_EXTENSIONS = "*.json" - def __init__(self, settings): + def __init__(self, settings: dict | None): super().__init__(settings) self._filename = None self._root_prefix = None diff --git a/spinedb_api/spine_io/importers/reader.py b/spinedb_api/spine_io/importers/reader.py index ec08174e..35830a8a 100644 --- a/spinedb_api/spine_io/importers/reader.py +++ b/spinedb_api/spine_io/importers/reader.py @@ -11,19 +11,26 @@ ###################################################################################################################### """ Contains a base class for a data source readers used in importing. """ - +from collections.abc import Callable, Iterator from dataclasses import dataclass, field from itertools import islice -from typing import Any +from typing import Any, ClassVar, Type from spinedb_api import DateTime, Duration, ParameterValueFormatError from spinedb_api.exception import InvalidMappingComponent, ReaderError from spinedb_api.import_mapping.generator import get_mapped_data, identity +from spinedb_api.import_mapping.import_mapping import ImportMapping from spinedb_api.import_mapping.import_mapping_compat import parse_named_mapping_spec from spinedb_api.mapping import Position, parse_fixed_position_value -TYPE_STRING_TO_CLASS = {"string": str, "datetime": DateTime, "duration": Duration, "float": float, "boolean": bool} +TYPE_STRING_TO_CLASS: dict[str, Type] = { + "string": str, + "datetime": DateTime, + "duration": Duration, + "float": float, + "boolean": bool, +} -TYPE_CLASS_TO_STRING = {type_class: string for string, type_class in TYPE_STRING_TO_CLASS.items()} +TYPE_CLASS_TO_STRING: dict[Type, str] = {type_class: string for string, type_class in TYPE_STRING_TO_CLASS.items()} @dataclass @@ -35,11 +42,11 @@ class Reader: """A base class to read data.""" # name of data source, ex: "Text/CSV" - DISPLAY_NAME = "unnamed source" + DISPLAY_NAME: ClassVar[str] = "unnamed source" # dict with option specification for source. - OPTIONS = {} - BASE_OPTIONS = { + OPTIONS: ClassVar[dict] = {} + BASE_OPTIONS: ClassVar[dict[str, dict[str, Any]]] = { "max_rows": { "type": int, "label": "Max rows", @@ -51,24 +58,24 @@ class Reader: } # File extensions for modal widget that that returns action (OK, CANCEL) and source object - FILE_EXTENSIONS = NotImplemented + FILE_EXTENSIONS: ClassVar[str] = NotImplemented - def __init__(self, settings): + def __init__(self, settings: dict | None): """ Args: - settings (dict, optional): connector specific settings or None + settings: connector specific settings or None """ - def connect_to_source(self, source, **extras): + def connect_to_source(self, source: str, **extras) -> None: """Connects to source, ex: connecting to a database where source is a connection string. Args: - source (str): file path or URL to connect to + source: file path or URL to connect to **extras: additional source specific connection data """ raise NotImplementedError() - def disconnect(self): + def disconnect(self) -> None: """Disconnect from connected source.""" raise NotImplementedError() @@ -76,10 +83,8 @@ def get_tables_and_properties(self) -> dict[str, TableProperties]: """Returns table names and properties.""" raise NotImplementedError() - def get_data_iterator(self, table, options, max_rows=-1): - """ - Returns a data iterator and data header. - """ + def get_data_iterator(self, table: str, options: dict, max_rows: int = -1) -> tuple[Iterator[list], list[str]]: + """Returns a data iterator and data header.""" raise NotImplementedError() def get_table_cell(self, table: str, row: int, column: int, options: dict) -> Any: @@ -95,7 +100,7 @@ def get_table_cell(self, table: str, row: int, column: int, options: dict) -> An raise ReaderError(f"{table} doesn't have column {column}") @staticmethod - def _resolve_max_rows(options, max_rows=-1): + def _resolve_max_rows(options: dict, max_rows: int = -1) -> int: options_max_rows = options.get("max_rows", -1) if options_max_rows == -1: return max_rows @@ -103,7 +108,7 @@ def _resolve_max_rows(options, max_rows=-1): return options_max_rows return min(max_rows, options_max_rows) - def get_data(self, table, options, max_rows=-1, start=0): + def get_data(self, table: str, options: dict, max_rows: int = -1, start: int = 0) -> tuple[list, list[str]]: """ Return data read from data source table in table. If max_rows is specified only that number of rows. @@ -114,7 +119,12 @@ def get_data(self, table, options, max_rows=-1, start=0): data = list(data_iter) return data, header - def resolve_values_for_fixed_position_mappings(self, tables_mappings, table_options, reader_error_is_fatal): + def resolve_values_for_fixed_position_mappings( + self, + tables_mappings: dict[str, list[tuple[str, ImportMapping]]], + table_options: dict, + reader_error_is_fatal: bool, + ) -> dict[str, list[tuple[str, ImportMapping]]]: for table, named_mappings in tables_mappings.items(): parsed_mappings = [] for mapping_name, root_mapping in named_mappings: @@ -139,30 +149,30 @@ def resolve_values_for_fixed_position_mappings(self, tables_mappings, table_opti def get_mapped_data( self, - tables_mappings, - table_options, - table_column_convert_specs, - table_default_column_convert_fns, - table_row_convert_specs, - unparse_value=identity, - max_rows=-1, - ): + tables_mappings: dict[str, list[tuple[str, ImportMapping]]], + table_options: dict, + table_column_convert_specs: dict[str, dict], + table_default_column_convert_fns: dict[str, Callable[[Any], Any]], + table_row_convert_specs: dict[str, dict], + unparse_value: Callable[[Any], tuple[bytes, str]] = identity, + max_rows: int = -1, + ) -> tuple[dict[str, list], list[str | tuple[str, str]]]: """ Reads all mappings in dict tables_mappings, where key is name of table and value is the mappings for that table. Args: - tables_mappings (dict): mapping from table name to list of import mappings - table_options (dict): mapping from table name to table-specific import options - table_column_convert_specs (dict): mapping from table name to column data type conversion settings - table_default_column_convert_fns (dict): mapping from table name to + tables_mappings: mapping from table name to list of import mappings + table_options: mapping from table name to table-specific import options + table_column_convert_specs: mapping from table name to column data type conversion settings + table_default_column_convert_fns: mapping from table name to default column data type converter - table_row_convert_specs (dict): mapping from table name to row data type conversion settings - unparse_value (Callable): callable that converts imported values to database representation - max_rows (int): maximum number of source rows to map + table_row_convert_specs: mapping from table name to row data type conversion settings + unparse_value: callable that converts imported values to database representation + max_rows: maximum number of source rows to map Returns: - tuple: mapped data and a list of errors, if any + mapped data and a list of errors, if any """ mapped_data = {} errors = [] diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index 2efa08b0..8b427f38 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -9,13 +9,14 @@ # this program. If not, see . ###################################################################################################################### from __future__ import annotations -from typing import Any, Optional +from typing import Any, ClassVar, Optional +from spinedb_api.helpers import ItemType class TempId: - _next_id = {} + _next_id: ClassVar[dict[ItemType, int]] = {} - def __init__(self, id_: int, item_type: str, temp_id_lookup: dict[int, TempId]): + def __init__(self, id_: int, item_type: ItemType, temp_id_lookup: dict[int, TempId]): super().__init__() self._id = id_ self._item_type = item_type @@ -24,13 +25,13 @@ def __init__(self, id_: int, item_type: str, temp_id_lookup: dict[int, TempId]): self._temp_id_lookup[self._id] = self @staticmethod - def new_unique(item_type: str, temp_id_lookup: dict[int, TempId]) -> TempId: + def new_unique(item_type: ItemType, temp_id_lookup: dict[int, TempId]) -> TempId: id_ = TempId._next_id.get(item_type, -1) TempId._next_id[item_type] = id_ - 1 return TempId(id_, item_type, temp_id_lookup) @property - def item_type(self) -> str: + def item_type(self) -> ItemType: return self._item_type @property