From be2d04e8b0982e0889bae00a580656fbdfc04363 Mon Sep 17 00:00:00 2001 From: Yusuf Ali Date: Sun, 23 Mar 2025 23:02:35 -0400 Subject: [PATCH] feat: adding in delta support --- requirements.txt | 1 + servc/svc/com/bus/rabbitmq.py | 1 + servc/svc/com/storage/delta.py | 188 +++++++++++++++++++++++++++++++ servc/svc/com/storage/iceberg.py | 63 ++++------- servc/svc/com/storage/lake.py | 23 +++- servc/svc/com/storage/tenant.py | 27 +++++ tests/test_delta.py | 147 ++++++++++++++++++++++++ 7 files changed, 403 insertions(+), 47 deletions(-) create mode 100644 servc/svc/com/storage/delta.py create mode 100644 servc/svc/com/storage/tenant.py create mode 100644 tests/test_delta.py diff --git a/requirements.txt b/requirements.txt index 0be7070..66fa967 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ simplejson==3.20.1 flask==3.1.0 pyyaml==6.0.2 pyiceberg[sql-sqlite,pyarrow]==0.8.1 +deltalake==0.25.4 \ No newline at end of file diff --git a/servc/svc/com/bus/rabbitmq.py b/servc/svc/com/bus/rabbitmq.py index 2727780..5ed8d35 100644 --- a/servc/svc/com/bus/rabbitmq.py +++ b/servc/svc/com/bus/rabbitmq.py @@ -9,6 +9,7 @@ import simplejson from pika.adapters.asyncio_connection import AsyncioConnection # type: ignore from pika.adapters.blocking_connection import BlockingConnection # type: ignore + from servc.svc.com.bus import BusComponent, InputProcessor, OnConsuming from servc.svc.com.cache.redis import decimal_default from servc.svc.io.input import EventPayload, InputPayload, InputType diff --git a/servc/svc/com/storage/delta.py b/servc/svc/com/storage/delta.py new file mode 100644 index 0000000..9f606d5 --- /dev/null +++ b/servc/svc/com/storage/delta.py @@ -0,0 +1,188 @@ +import os +from typing import Any, Dict, List, Tuple + +import pyarrow as pa +from deltalake import DeltaTable, write_deltalake +from pyarrow import Schema, Table + +from servc.svc.com.storage.lake import Lake, LakeTable +from servc.svc.config import Config + + +class Delta(Lake[DeltaTable]): + _storageOptions: Dict[str, str] = {} + + _location_prefix: str + + _table: LakeTable + + def __init__(self, config: Config, table: LakeTable): + super().__init__(config, table) + + self._table = table + + catalog_properties_raw = config.get("catalog_properties") + if not isinstance(catalog_properties_raw, dict): + catalog_properties_raw = {} + + # TODO: make generic for all storage types + if catalog_properties_raw.get("type") == "local": + self._location_prefix = str( + catalog_properties_raw.get("location", "/tmp/delta") + ) + self._storageOptions = {} + else: + self._location_prefix = os.path.join( + str(catalog_properties_raw.get("warehouse")), + str(catalog_properties_raw.get("s3.access-key-id")), + ) + self._storageOptions = { + "AWS_ACCESS_KEY_ID": str( + catalog_properties_raw.get("s3.access-key-id") + ), + "AWS_SECRET_ACCESS_KEY": str( + catalog_properties_raw.get("s3.secret-access-key") + ), + "AWS_ENDPOINT_URL": str(catalog_properties_raw.get("s3.endpoint")), + "AWS_ALLOW_HTTP": "true", + "aws_conditional_put": "etag", + } + + def _connect(self): + if self.isOpen: + return None + + tablename = self._get_table_name() + uri = os.path.join(self._location_prefix, tablename) + self._conn = DeltaTable.create( + table_uri=uri, + name=tablename, + schema=self._table["schema"], + partition_by=self._table["partitions"], + mode="ignore", + storage_options=self._storageOptions, + ) + + return super()._connect() + + def optimize(self): + table = self.getConn() + + print("Optimizing", self._get_table_name(), flush=True) + table.optimize.compact() + table.vacuum() + table.cleanup_metadata() + table.create_checkpoint() + + def getPartitions(self) -> Dict[str, List[Any]] | None: + table = self.getConn() + + partitions: Dict[str, List[Any]] = {} + for obj in table.partitions(): + for key, value in obj.items(): + if key not in partitions: + partitions[key] = [] + if value not in partitions[key]: + partitions[key].append(value) + + return partitions + + def getCurrentVersion(self) -> str | None: + table = self.getConn() + return str(table.version()) + + def getVersions(self) -> List[str] | None: + return [str(self.getCurrentVersion())] + + def insert(self, data: List[Any]) -> bool: + table = self.getConn() + write_deltalake( + table, + data=pa.Table.from_pylist(data, self.getSchema()), + storage_options=self._storageOptions, + mode="append", + ) + return True + + def _filters( + self, + partitions: Dict[str, List[Any]] | None = None, + ) -> List[Tuple[str, str, Any]] | None: + filters: List[Tuple[str, str, Any]] = [] + if partitions is None: + return None + for key, value in partitions.items(): + if len(value) == 1: + filters.append((key, "=", value[0])) + else: + filters.append((key, "in", value)) + return filters if len(filters) > 0 else None + + def overwrite( + self, data: List[Any], partitions: Dict[str, List[Any]] | None = None + ) -> bool: + table = self.getConn() + + predicate: str | None = None + filter = self._filters(partitions) + if filter is not None: + predicate = " & ".join([" ".join(x) for x in filter]) + + write_deltalake( + table, + data=pa.Table.from_pylist(data, self.getSchema()), + storage_options=self._storageOptions, + mode="overwrite", + predicate=predicate, + engine="rust", + ) + return True + + def readRaw( + self, + columns: List[str], + partitions: Dict[str, List[Any]] | None = None, + version: str | None = None, + options: Any | None = None, + ) -> Table: + table = self.getConn() + if version is not None: + table.load_as_version(int(version)) + + if options is None or not isinstance(options, dict): + options = {} + + rcolumns = columns if columns[0] != "*" else None + + if options.get("filter", None) is not None: + return table.to_pyarrow_dataset( + partitions=self._filters(partitions), + ).to_table( + filter=options.get("filter"), + columns=rcolumns, + ) + return table.to_pyarrow_table( + columns=rcolumns, + partitions=self._filters(partitions), + ) + + def read( + self, + columns: List[str], + partitions: Dict[str, List[Any]] | None = None, + version: str | None = None, + options: Any | None = None, + ) -> Table: + return self.readRaw(columns, partitions, version, options) + + def getSchema(self) -> Schema | None: + table = self.getConn() + + return table.schema().to_pyarrow() + + def _close(self): + if self._isOpen: + self._isReady = False + self._isOpen = False + return True + return False diff --git a/servc/svc/com/storage/iceberg.py b/servc/svc/com/storage/iceberg.py index cf1e4b1..c8e53da 100644 --- a/servc/svc/com/storage/iceberg.py +++ b/servc/svc/com/storage/iceberg.py @@ -15,10 +15,9 @@ from servc.svc.config import Config -class IceBerg(Lake): +class IceBerg(Lake[Table]): # _table _catalog: Catalog - _ice: Table | None def __init__(self, config: Config, table: LakeTable | str): super().__init__(config, table) @@ -33,7 +32,6 @@ def __init__(self, config: Config, table: LakeTable | str): catalog_name, **{**catalog_properties}, ) - self._ice = None def _connect(self): if self.isOpen: @@ -47,7 +45,7 @@ def _connect(self): except: doesExist = False if doesExist: - self._ice = self._catalog.load_table(tableName) + self._conn = self._catalog.load_table(tableName) elif not doesExist and isinstance(self._table, str): raise Exception(f"Table {tableName} does not exist") @@ -72,7 +70,7 @@ def _connect(self): self._catalog.create_namespace_if_not_exists(self._database) # TODO: undo this garbage when rest catalog works - self._ice = self._catalog.create_table_if_not_exists( + self._conn = self._catalog.create_table_if_not_exists( tableName, self._table["schema"], partition_spec=partitionSpec, @@ -82,9 +80,7 @@ def _connect(self): properties=self._table["options"].get("properties", {}), ) - self._isReady = self._table is not None - self._isOpen = self._table is not None - return self._table is not None + return super()._connect() def _close(self): if self._isOpen: @@ -94,13 +90,10 @@ def _close(self): return False def getPartitions(self) -> Dict[str, List[Any]] | None: - if not self._isOpen: - self._connect() - if self._ice is None: - raise Exception("Table not connected") + table = self.getConn() partitions: Dict[str, List[Any]] = {} - for obj in self._ice.inspect.partitions().to_pylist(): + for obj in table.inspect.partitions().to_pylist(): for key, value in obj["partition"].items(): field = key.replace("_partition", "") if field not in partitions: @@ -109,54 +102,39 @@ def getPartitions(self) -> Dict[str, List[Any]] | None: return partitions def getSchema(self) -> Schema | None: - if not self._isOpen: - self._connect() - if self._ice is None: - raise Exception("Table not connected") + table = self.getConn() - return self._ice.schema().as_arrow() + return table.schema().as_arrow() def getCurrentVersion(self) -> str | None: - if not self._isOpen: - self._connect() - if self._ice is None: - raise Exception("Table not connected") + table = self.getConn() - snapshot = self._ice.current_snapshot() + snapshot = table.current_snapshot() if snapshot is None: return None return str(snapshot.snapshot_id) def getVersions(self) -> List[str] | None: - if not self._isOpen: - self._connect() - if self._ice is None: - raise Exception("Table not connected") + table = self.getConn() - snapshots: paTable = self._ice.inspect.snapshots() + snapshots: paTable = table.inspect.snapshots() chunked = snapshots.column("snapshot_id") return [str(x) for x in chunked.to_pylist()] def insert(self, data: List[Any]) -> bool: - if not self._isOpen: - self._connect() - if self._ice is None: - raise Exception("Table not connected") + table = self.getConn() - self._ice.append(pa.Table.from_pylist(data, self.getSchema())) + table.append(pa.Table.from_pylist(data, self.getSchema())) return True def overwrite( self, data: List[Any], partitions: Dict[str, List[Any]] | None = None ) -> bool: - if not self._isOpen: - self._connect() - if self._ice is None: - raise Exception("Table not connected") + table = self.getConn() df = pa.Table.from_pylist(data, self.getSchema()) if partitions is None or len(partitions) == 0: - self._ice.overwrite(df) + table.overwrite(df) return True # when partitions are provided, we need to filter the data @@ -168,7 +146,7 @@ def overwrite( for i in range(1, len(boolPartition)): right_side = And(right_side, boolPartition[i]) - self._ice.overwrite(df, overwrite_filter=right_side) + table.overwrite(df, overwrite_filter=right_side) return True def readRaw( @@ -178,10 +156,7 @@ def readRaw( version: str | None = None, options: Any | None = None, ) -> DataScan: - if not self._isOpen: - self._connect() - if self._ice is None: - raise Exception("Table not connected") + table = self.getConn() if options is None: options = {} @@ -197,7 +172,7 @@ def readRaw( options.get("row_filter", AlwaysTrue()), right_side ) - return self._ice.scan( + return table.scan( row_filter=options.get("row_filter", AlwaysTrue()), selected_fields=tuple(columns), limit=options.get("limit", None), diff --git a/servc/svc/com/storage/lake.py b/servc/svc/com/storage/lake.py index 254f37b..60752ae 100644 --- a/servc/svc/com/storage/lake.py +++ b/servc/svc/com/storage/lake.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict, List, NotRequired, TypedDict +from typing import Any, Dict, Generic, List, NotRequired, TypedDict, TypeVar from pyarrow import RecordBatchReader, Schema, Table @@ -21,13 +21,18 @@ class LakeTable(TypedDict): options: NotRequired[Dict[str, Any]] -class Lake(StorageComponent): +T = TypeVar("T") + + +class Lake(Generic[T], StorageComponent): name: str = "lake" - _table: Any + _table: LakeTable | str _database: str + _conn: T | None = None + def __init__(self, config: Config, table: LakeTable | str): super().__init__(config) @@ -37,6 +42,13 @@ def __init__(self, config: Config, table: LakeTable | str): if not isinstance(self._table, str) and "options" not in self._table: self._table["options"] = {} + def getConn(self) -> T: + if not self._isOpen: + self._connect() + if self._conn is None: + raise Exception("Table not connected") + return self._conn + def _get_table_name(self) -> str: schema: str = self._database @@ -58,6 +70,11 @@ def table(self) -> LakeTable | str: def tablename(self) -> str: return self._get_table_name() + def _connect(self): + self._isReady = self._conn is not None + self._isOpen = self._conn is not None + return self._conn is not None + def getPartitions(self) -> Dict[str, List[Any]] | None: return None diff --git a/servc/svc/com/storage/tenant.py b/servc/svc/com/storage/tenant.py new file mode 100644 index 0000000..f52018c --- /dev/null +++ b/servc/svc/com/storage/tenant.py @@ -0,0 +1,27 @@ +from servc.svc.com.storage.lake import Lake, LakeTable +from servc.svc.config import Config + + +class TenantTable(Lake): + _tenant_name: str + + _table: LakeTable + + def __init__(self, config: Config, table: LakeTable, tenant_name: str): + super().__init__(config, table) + self._tenant_name = tenant_name + self._table = table + + def _get_table_name(self) -> str: + schema: str = self._database + + name_w_medallion = "".join( + [ + self._tenant_name, + self._table["medallion"].value, + "_", + self._table["name"], + ] + ) + + return ".".join([schema, name_w_medallion]) diff --git a/tests/test_delta.py b/tests/test_delta.py new file mode 100644 index 0000000..60f435d --- /dev/null +++ b/tests/test_delta.py @@ -0,0 +1,147 @@ +import unittest + +import pyarrow as pa +import pyarrow.dataset as ds + +from servc.svc.com.storage.delta import Delta +from servc.svc.com.storage.lake import LakeTable, Medallion + +schema = pa.schema( + [ + ("date", pa.string()), + ("some_int", pa.int64()), + ] +) + +mytable: LakeTable = { + "name": "test", + "partitions": ["date"], + "medallion": Medallion.BRONZE, + "schema": pa.schema( + [ # type: ignore + pa.field("date", pa.string(), nullable=False), + pa.field("some_int", pa.int64(), nullable=False), + ] + ), +} + +config = { + "database": "default", + "catalog_name": "default", + "catalog_properties": { + "type": "local", + "location": "/tmp/delta", + }, +} + + +class TestLakeDelta(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.iceberg = Delta(config, mytable) + + def test_connect(self): + self.iceberg._connect() + self.assertTrue(self.iceberg.isOpen) + + def test_name(self): + self.assertEqual(self.iceberg.tablename, "default.bronze_test") + + def test_insert(self): + self.iceberg.overwrite([]) + self.iceberg.insert([{"date": "2021-01-01", "some_int": 1}]) + data = self.iceberg.read(["date"]).to_pylist() + self.assertEqual(len(data), 1) + self.assertEqual(data, [{"date": "2021-01-01"}]) + + def test_overwrite(self): + self.iceberg.overwrite([]) + self.iceberg.insert([{"date": "2021-01-01", "some_int": 1}]) + self.iceberg.insert([{"date": "2021-01-02", "some_int": 1}]) + self.iceberg.insert([{"date": "2021-01-02", "some_int": 3}]) + + data = self.iceberg.read(["date"]).to_pylist() + self.assertEqual(len(data), 3) + + self.iceberg.overwrite([], {"date": ["'2021-01-02'"]}) + data = self.iceberg.read(["date"]).to_pylist() + self.assertEqual(len(data), 1) + + def test_reading_partitions(self): + self.iceberg.overwrite([]) + self.iceberg.insert([{"date": "2021-01-01", "some_int": 1}]) + self.iceberg.insert([{"date": "2021-01-02", "some_int": 1}]) + self.iceberg.insert([{"date": "2021-01-02", "some_int": 3}]) + + data = self.iceberg.read( + ["date"], partitions={"date": ["2021-01-01"]} + ).to_pylist() + self.assertEqual(len(data), 1) + + data = self.iceberg.read( + ["date"], partitions={"date": ["2021-01-02"]} + ).to_pylist() + self.assertEqual(len(data), 2) + + data = self.iceberg.read( + ["date"], partitions={"date": ["2021-01-02", "2021-01-01"]} + ).to_pylist() + self.assertEqual(len(data), 3) + + data = self.iceberg.read( + ["date"], + partitions={"date": ["2021-01-02"]}, + options={"filter": (ds.field("some_int") == 3)}, + ).to_pylist() + self.assertEqual(len(data), 1) + + data = self.iceberg.read( + ["date"], + partitions={"date": ["2021-01-02"]}, + options={"filter": (ds.field("some_int") == 3)}, + ).to_pylist() + self.assertEqual(len(data), 1) + + def test_version_travel(self): + self.iceberg.insert([{"date": "2021-01-01", "some_int": 1}]) + orig_data = self.iceberg.read(["date"]).to_pylist() + currentVersion = self.iceberg.getCurrentVersion() + + versions = self.iceberg.getVersions() + self.assertGreater(len(versions), 0) + self.assertIn(currentVersion, versions) + + self.iceberg.insert([{"date": "2021-01-02", "some_int": 1}]) + new_version = self.iceberg.getCurrentVersion() + self.assertNotEqual(currentVersion, new_version) + + data = self.iceberg.read(["date"], version=currentVersion).to_pylist() + self.assertEqual(len(data), len(orig_data)) + self.assertEqual(data, orig_data) + + def test_partitions(self): + self.iceberg.overwrite([]) + self.iceberg.insert([{"date": "2021-01-01", "some_int": 1}]) + self.iceberg.insert([{"date": "2021-01-02", "some_int": 1}]) + self.iceberg.insert([{"date": "2021-01-02", "some_int": 3}]) + + partitions = self.iceberg.getPartitions() + self.assertEqual(list(partitions.keys()), ["date"]) + self.assertEqual(len(partitions["date"]), 2) + self.assertIn("2021-01-01", partitions["date"]) + self.assertIn("2021-01-02", partitions["date"]) + + def test_schema(self): + schema = self.iceberg.getSchema() + self.assertIsInstance(schema, pa.Schema) + self.assertEqual(len(schema.names), 2) + self.assertEqual(schema.names, ["date", "some_int"]) + + def test_close(self): + self.iceberg.close() + self.iceberg.connect() + self.iceberg.close() + + +if __name__ == "__main__": + unittest.main()