Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 140 additions & 106 deletions spinedb_api/db_mapping.py

Large diffs are not rendered by default.

106 changes: 40 additions & 66 deletions spinedb_api/db_mapping_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,21 @@
from contextlib import suppress
from dataclasses import dataclass
from difflib import SequenceMatcher
from typing import Any, ClassVar, Optional, Type, TypedDict, Union
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Type, TypedDict, Union
from sqlalchemy import Subquery
from sqlalchemy.orm import Query
from .exception import SpineDBAPIError
from .helpers import Asterisk, AsteriskType
from .helpers import Asterisk, AsteriskType, ItemType
from .mapped_item_status import Status
from .temp_id import TempId, resolve

if TYPE_CHECKING:
from .db_mapping import DatabaseMapping


@dataclass(frozen=True)
class DirtyItems:
item_type: str
item_type: ItemType
to_add: list[MappedItemBase]
to_update: list[MappedItemBase]
to_remove: list[MappedItemBase]
Expand All @@ -42,10 +47,10 @@ def __init__(self):
self._closed = False
self._context_open_count = 0
self._mapped_tables = {item_type: MappedTable(self, item_type) for item_type in self.all_item_types()}
self._fetched = {}
self._commit_count = None
self._fetched: dict[ItemType, int] = {}
self._commit_count: int | None = None
item_types = self.item_types()
self._sorted_item_types = []
self._sorted_item_types: list[ItemType] = []
while item_types:
item_type = item_types.pop(0)
if not self.item_factory(item_type).ref_types().isdisjoint(item_types):
Expand All @@ -62,54 +67,27 @@ def close(self):
self._closed = True

@staticmethod
def item_types():
"""Returns a list of public item types from the DB mapping schema (equivalent to the table names).

:meta private:

Returns:
list(str)
"""
def item_types() -> list[ItemType]:
"""Returns a list of public item types from the DB mapping schema (equivalent to the table names)."""
raise NotImplementedError()

@staticmethod
def all_item_types():
"""Returns a list of all item types from the DB mapping schema (equivalent to the table names).

:meta private:

Returns:
list(str)
"""
def all_item_types() -> list[ItemType]:
"""Returns a list of all item types from the DB mapping schema (equivalent to the table names)."""
raise NotImplementedError()

@staticmethod
def item_factory(item_type):
"""Returns a subclass of :class:`.MappedItemBase` to make items of given type.

:meta private:

Args:
item_type (str)

Returns:
function
"""
def item_factory(item_type: ItemType) -> Type[MappedItemBase]:
"""Returns a subclass of :class:`.MappedItemBase` to make items of given type."""
raise NotImplementedError()

def _make_sq(self, item_type):
def _make_sq(self, item_type: ItemType) -> Subquery:
"""Returns a :class:`~sqlalchemy.sql.expression.Alias` object representing a subquery
to collect items of given type.

Args:
item_type (str)

Returns:
:class:`~sqlalchemy.sql.expression.Alias`
"""
raise NotImplementedError()

def _make_query(self, sq):
def _make_query(self, sq: Subquery) -> Query:
"""Returns a :class:`~sqlalchemy.orm.query.Query` object from given subquery.

Args:
Expand All @@ -128,10 +106,10 @@ def _query_commit_count(self):
"""
raise NotImplementedError()

def make_item(self, item_type: str, **item) -> MappedItemBase:
def make_item(self, item_type: ItemType, **item) -> MappedItemBase:
raise NotImplementedError

def dirty_ids(self, item_type):
def dirty_ids(self, item_type: ItemType):
return {
item["id"]
for item in self._mapped_tables[item_type].valid_values()
Expand Down Expand Up @@ -223,7 +201,7 @@ def refresh_session(self):
self._commit_count = None
self._fetched.clear()

def mapped_table(self, item_type: str) -> MappedTable:
def mapped_table(self, item_type: ItemType) -> MappedTable:
"""Returns mapped table for given item type."""
try:
return self._mapped_tables[item_type]
Expand Down Expand Up @@ -255,7 +233,7 @@ def reset_purging(self):
for mapped_table in self._mapped_tables.values():
mapped_table.reset_purging()

def _add_descendants(self, item_types):
def _add_descendants(self, item_types: set[ItemType]) -> None:
while True:
changed = False
for item_type in set(self.item_types()) - item_types:
Expand All @@ -266,20 +244,16 @@ def _add_descendants(self, item_types):
break

def _get_commit_count(self) -> int:
"""Returns current commit count.

Returns:
int
"""
"""Returns current commit count."""
if self._commit_count is None:
self._commit_count = self._query_commit_count()
return self._commit_count

def _do_make_query(self, item_type, **kwargs):
def _do_make_query(self, item_type: ItemType, **kwargs) -> Query:
"""Returns a :class:`~spinedb_api.query.Query` object to fetch items of given type.

Args:
item_type (str): item type
item_type: item type
**kwargs: query filters

Returns:
Expand All @@ -303,11 +277,11 @@ def _do_make_query(self, item_type, **kwargs):
pass
return qry

def _get_next_chunk(self, item_type, offset, limit, **kwargs):
def _get_next_chunk(self, item_type: ItemType, offset: int, limit: int, **kwargs) -> list[dict]:
"""Gets chunk of items from the DB.

Returns:
list(dict): list of dictionary items.
list of dictionary items.
"""
with self:
qry = self._do_make_query(item_type, **kwargs)
Expand Down Expand Up @@ -380,7 +354,7 @@ def item(self, mapped_table: MappedTable, **kwargs) -> PublicItem:


class MappedTable(dict):
def __init__(self, db_map: DatabaseMappingBase, item_type: str, *args, **kwargs):
def __init__(self, db_map: DatabaseMapping, item_type: ItemType, *args, **kwargs):
"""
Args:
db_map: the DB mapping where this mapped table belongs.
Expand All @@ -402,7 +376,7 @@ def purged(self) -> bool:
def purged(self, purged: bool) -> None:
self.wildcard_item.status = Status.to_remove if purged else Status.committed

def get(self, id_, default=None):
def get(self, id_: TempId | int, default: MappedItemBase | None = None) -> MappedItemBase | None:
id_ = self._temp_id_lookup.get(id_, id_)
return super().get(id_, default)

Expand Down Expand Up @@ -440,7 +414,7 @@ def find_item(self, item: dict, fetch: bool = True) -> MappedItemBase:
return self.find_item_by_id(id_, fetch=fetch)
return self.find_item_by_unique_key(item, fetch=fetch)

def find_item_by_id(self, id_: TempId, fetch: bool = True) -> MappedItemBase:
def find_item_by_id(self, id_: TempId | int, fetch: bool = True) -> MappedItemBase:
current_item = self.get(id_)
if current_item is None and fetch:
self._db_map.do_fetch_all(self)
Expand Down Expand Up @@ -518,7 +492,7 @@ def _prepare_item(
if unique_item is not current_item and unique_item.is_valid():
raise SpineDBAPIError(f"there's already a {self.item_type} with {dict(zip(key, value))}")

def item_to_remove(self, id_: TempId | AsteriskType) -> MappedItemBase:
def item_to_remove(self, id_: TempId | int | AsteriskType) -> MappedItemBase:
if id_ is Asterisk:
return self.wildcard_item
return self.find_item_by_id(id_)
Expand Down Expand Up @@ -646,7 +620,7 @@ def remove_item(self, item: Optional[MappedItemBase]) -> Optional[MappedItemBase
item.cascade_remove()
return item

def restore_item(self, id_: TempId) -> Optional[MappedItemBase]:
def restore_item(self, id_: TempId | int) -> Optional[MappedItemBase]:
if id_ is Asterisk:
self.purged = False
for current_item in self.values():
Expand Down Expand Up @@ -678,7 +652,7 @@ class FieldDict(TypedDict):
class MappedItemBase(dict):
"""A dictionary that represents a db item."""

item_type: ClassVar[str] = "not implemented"
item_type: ClassVar[ItemType] = NotImplemented
fields: ClassVar[dict[str, FieldDict]] = {}
"""A dictionary mapping fields to a another dict mapping "type" to a Python type,
"value" to a description of the value for the key, and "optional" to a bool."""
Expand All @@ -689,7 +663,7 @@ class MappedItemBase(dict):
"""A tuple where each element is itself a tuple of fields corresponding to a unique key."""
required_key_combinations: ClassVar[tuple[tuple[str, ...], ...]] = ()
"""Tuple containing tuples of required keys and their possible alternatives."""
_references: ClassVar[dict[str, str]] = {}
_references: ClassVar[dict[str, ItemType]] = {}
"""A dictionary mapping source fields to reference item type.
Used to access external fields.
"""
Expand All @@ -704,7 +678,7 @@ class MappedItemBase(dict):
When accessing fields in _external_fields, we first find the reference pointed at by the source field,
and then return the target field of that reference.
"""
_alt_references: ClassVar[dict[tuple[str, ...], tuple[str, tuple[str, ...]]]] = {}
_alt_references: ClassVar[dict[tuple[str, ...], tuple[ItemType, tuple[str, ...]]]] = {}
"""A dictionary mapping source fields, to a tuple of reference item type and reference fields.
Used only to resolve internal fields at item creation.
"""
Expand All @@ -721,7 +695,7 @@ class MappedItemBase(dict):
fields_not_requiring_cascade_update: ClassVar[set[str]] = set()
is_protected: ClassVar[bool] = False

def __init__(self, db_map: DatabaseMappingBase, **kwargs):
def __init__(self, db_map: DatabaseMapping, **kwargs):
"""
Args:
db_map: the DB where this item belongs.
Expand Down Expand Up @@ -760,7 +734,7 @@ def handle_refetch(self) -> None:
self._valid = None

@classmethod
def ref_types(cls) -> set[str]:
def ref_types(cls) -> set[ItemType]:
"""Returns a set of item types that this class refers."""
return set(cls._references.values())

Expand Down Expand Up @@ -1234,15 +1208,15 @@ def __init__(self, mapped_item: MappedItemBase):
self._mapped_item = mapped_item

@property
def item_type(self) -> str:
def item_type(self) -> ItemType:
return self._mapped_item.item_type

@property
def mapped_item(self) -> MappedItemBase:
return self._mapped_item

@property
def db_map(self) -> DatabaseMappingBase:
def db_map(self) -> DatabaseMapping:
return self._mapped_item.db_map

def __getitem__(self, key):
Expand Down
43 changes: 21 additions & 22 deletions spinedb_api/export_mapping/group_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,17 @@
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################
"""
Contains functions to group values in pivot tables with hidden columns or rows.

"""
""" Contains functions to group values in pivot tables with hidden columns or rows. """
from typing import Any, ClassVar, Type
import numpy as np


class GroupFunction:
NAME = NotImplemented
DISPLAY_NAME = NotImplemented

def __call__(self, items):
"""Performs the grouping. Reduces the given list of items into a single value.

Args:
items (list or None)
NAME: ClassVar[str] = NotImplemented
DISPLAY_NAME: ClassVar[str] = NotImplemented

Returns:
Any
"""
def __call__(self, items: list | None) -> Any:
"""Performs the grouping. Reduces the given list of items into a single value."""
raise NotImplementedError


Expand Down Expand Up @@ -116,28 +107,36 @@ def __call__(self, items):
return items[0]


_classes = (NoGroup, GroupSum, GroupMean, GroupMin, GroupMax, GroupConcat, GroupOneOrNone)
_classes: tuple[Type[GroupFunction], ...] = (
NoGroup,
GroupSum,
GroupMean,
GroupMin,
GroupMax,
GroupConcat,
GroupOneOrNone,
)

GROUP_FUNCTION_DISPLAY_NAMES = [klass.DISPLAY_NAME for klass in _classes]
GROUP_FUNCTION_DISPLAY_NAMES: list[str] = [klass.DISPLAY_NAME for klass in _classes]


def group_function_name_from_display(display_name):
def group_function_name_from_display(display_name: str) -> str:
return {klass.DISPLAY_NAME: klass.NAME for klass in _classes}.get(display_name, NoGroup.NAME)


def group_function_display_from_name(name):
def group_function_display_from_name(name: str) -> str:
return {klass.NAME: klass.DISPLAY_NAME for klass in _classes}.get(name, NoGroup.DISPLAY_NAME)


def from_str(name):
def from_str(name: str | None) -> GroupFunction:
"""
Creates group function from name.

Args:
name (str, NoneType): group function name or None if no aggregation wanted.
name: group function name or None if no aggregation wanted.

Returns:
GroupFunction or NoneType
GroupFunction; NoGroup if no aggregation wanted
"""
constructor = {klass.NAME: klass for klass in _classes}.get(name, NoGroup)
return constructor()
27 changes: 26 additions & 1 deletion spinedb_api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import json
from operator import itemgetter
import os
from typing import Any
from typing import Any, Literal, TypeAlias
import warnings
from alembic.config import Config
from alembic.environment import EnvironmentContext
Expand Down Expand Up @@ -59,6 +59,31 @@
from sqlalchemy.sql.selectable import SelectBase
from .exception import SpineDBAPIError, SpineDBVersionError

ItemType: TypeAlias = Literal[
"alternative",
"commit",
"display_mode",
"entity",
"entity_alternative",
"entity_class_display_mode",
"entity_class",
"entity_group",
"entity_location",
"entity_metadata",
"list_value",
"metadata",
"parameter_definition",
"parameter_type",
"parameter_value",
"parameter_value_list",
"parameter_value_metadata",
"scenario",
"scenario_alternative",
"superclass_subclass",
]

LegacyItemType = Literal["object", "object_class", "relationship", "relationship_class"]

SUPPORTED_DIALECTS = {
"mysql": "pymysql",
"sqlite": "sqlite3",
Expand Down
Loading