Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ simplejson==3.20.1
flask==3.1.0
pyyaml==6.0.2
pyiceberg[sql-sqlite,pyarrow]==0.8.1
deltalake==0.25.4
1 change: 1 addition & 0 deletions servc/svc/com/bus/rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import simplejson
from pika.adapters.asyncio_connection import AsyncioConnection # type: ignore
from pika.adapters.blocking_connection import BlockingConnection # type: ignore

from servc.svc.com.bus import BusComponent, InputProcessor, OnConsuming
from servc.svc.com.cache.redis import decimal_default
from servc.svc.io.input import EventPayload, InputPayload, InputType
Expand Down
188 changes: 188 additions & 0 deletions servc/svc/com/storage/delta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import os
from typing import Any, Dict, List, Tuple

import pyarrow as pa
from deltalake import DeltaTable, write_deltalake
from pyarrow import Schema, Table

from servc.svc.com.storage.lake import Lake, LakeTable
from servc.svc.config import Config


class Delta(Lake[DeltaTable]):
_storageOptions: Dict[str, str] = {}

_location_prefix: str

_table: LakeTable

def __init__(self, config: Config, table: LakeTable):
super().__init__(config, table)

self._table = table

catalog_properties_raw = config.get("catalog_properties")
if not isinstance(catalog_properties_raw, dict):
catalog_properties_raw = {}

# TODO: make generic for all storage types
if catalog_properties_raw.get("type") == "local":
self._location_prefix = str(
catalog_properties_raw.get("location", "/tmp/delta")
)
self._storageOptions = {}
else:
self._location_prefix = os.path.join(
str(catalog_properties_raw.get("warehouse")),
str(catalog_properties_raw.get("s3.access-key-id")),
)
self._storageOptions = {
"AWS_ACCESS_KEY_ID": str(
catalog_properties_raw.get("s3.access-key-id")
),
"AWS_SECRET_ACCESS_KEY": str(
catalog_properties_raw.get("s3.secret-access-key")
),
"AWS_ENDPOINT_URL": str(catalog_properties_raw.get("s3.endpoint")),
"AWS_ALLOW_HTTP": "true",
"aws_conditional_put": "etag",
}

def _connect(self):
if self.isOpen:
return None

tablename = self._get_table_name()
uri = os.path.join(self._location_prefix, tablename)
self._conn = DeltaTable.create(
table_uri=uri,
name=tablename,
schema=self._table["schema"],
partition_by=self._table["partitions"],
mode="ignore",
storage_options=self._storageOptions,
)

return super()._connect()

def optimize(self):
table = self.getConn()

print("Optimizing", self._get_table_name(), flush=True)
table.optimize.compact()
table.vacuum()
table.cleanup_metadata()
table.create_checkpoint()

def getPartitions(self) -> Dict[str, List[Any]] | None:
table = self.getConn()

partitions: Dict[str, List[Any]] = {}
for obj in table.partitions():
for key, value in obj.items():
if key not in partitions:
partitions[key] = []
if value not in partitions[key]:
partitions[key].append(value)

return partitions

def getCurrentVersion(self) -> str | None:
table = self.getConn()
return str(table.version())

def getVersions(self) -> List[str] | None:
return [str(self.getCurrentVersion())]

def insert(self, data: List[Any]) -> bool:
table = self.getConn()
write_deltalake(
table,
data=pa.Table.from_pylist(data, self.getSchema()),
storage_options=self._storageOptions,
mode="append",
)
return True

def _filters(
self,
partitions: Dict[str, List[Any]] | None = None,
) -> List[Tuple[str, str, Any]] | None:
filters: List[Tuple[str, str, Any]] = []
if partitions is None:
return None
for key, value in partitions.items():
if len(value) == 1:
filters.append((key, "=", value[0]))
else:
filters.append((key, "in", value))
return filters if len(filters) > 0 else None

def overwrite(
self, data: List[Any], partitions: Dict[str, List[Any]] | None = None
) -> bool:
table = self.getConn()

predicate: str | None = None
filter = self._filters(partitions)
if filter is not None:
predicate = " & ".join([" ".join(x) for x in filter])

write_deltalake(
table,
data=pa.Table.from_pylist(data, self.getSchema()),
storage_options=self._storageOptions,
mode="overwrite",
predicate=predicate,
engine="rust",
)
return True

def readRaw(
self,
columns: List[str],
partitions: Dict[str, List[Any]] | None = None,
version: str | None = None,
options: Any | None = None,
) -> Table:
table = self.getConn()
if version is not None:
table.load_as_version(int(version))

if options is None or not isinstance(options, dict):
options = {}

rcolumns = columns if columns[0] != "*" else None

if options.get("filter", None) is not None:
return table.to_pyarrow_dataset(
partitions=self._filters(partitions),
).to_table(
filter=options.get("filter"),
columns=rcolumns,
)
return table.to_pyarrow_table(
columns=rcolumns,
partitions=self._filters(partitions),
)

def read(
self,
columns: List[str],
partitions: Dict[str, List[Any]] | None = None,
version: str | None = None,
options: Any | None = None,
) -> Table:
return self.readRaw(columns, partitions, version, options)

def getSchema(self) -> Schema | None:
table = self.getConn()

return table.schema().to_pyarrow()

def _close(self):
if self._isOpen:
self._isReady = False
self._isOpen = False
return True
return False
63 changes: 19 additions & 44 deletions servc/svc/com/storage/iceberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
from servc.svc.config import Config


class IceBerg(Lake):
class IceBerg(Lake[Table]):
# _table
_catalog: Catalog
_ice: Table | None

def __init__(self, config: Config, table: LakeTable | str):
super().__init__(config, table)
Expand All @@ -33,7 +32,6 @@ def __init__(self, config: Config, table: LakeTable | str):
catalog_name,
**{**catalog_properties},
)
self._ice = None

def _connect(self):
if self.isOpen:
Expand All @@ -47,7 +45,7 @@ def _connect(self):
except:
doesExist = False
if doesExist:
self._ice = self._catalog.load_table(tableName)
self._conn = self._catalog.load_table(tableName)

elif not doesExist and isinstance(self._table, str):
raise Exception(f"Table {tableName} does not exist")
Expand All @@ -72,7 +70,7 @@ def _connect(self):
self._catalog.create_namespace_if_not_exists(self._database)

# TODO: undo this garbage when rest catalog works
self._ice = self._catalog.create_table_if_not_exists(
self._conn = self._catalog.create_table_if_not_exists(
tableName,
self._table["schema"],
partition_spec=partitionSpec,
Expand All @@ -82,9 +80,7 @@ def _connect(self):
properties=self._table["options"].get("properties", {}),
)

self._isReady = self._table is not None
self._isOpen = self._table is not None
return self._table is not None
return super()._connect()

def _close(self):
if self._isOpen:
Expand All @@ -94,13 +90,10 @@ def _close(self):
return False

def getPartitions(self) -> Dict[str, List[Any]] | None:
if not self._isOpen:
self._connect()
if self._ice is None:
raise Exception("Table not connected")
table = self.getConn()

partitions: Dict[str, List[Any]] = {}
for obj in self._ice.inspect.partitions().to_pylist():
for obj in table.inspect.partitions().to_pylist():
for key, value in obj["partition"].items():
field = key.replace("_partition", "")
if field not in partitions:
Expand All @@ -109,54 +102,39 @@ def getPartitions(self) -> Dict[str, List[Any]] | None:
return partitions

def getSchema(self) -> Schema | None:
if not self._isOpen:
self._connect()
if self._ice is None:
raise Exception("Table not connected")
table = self.getConn()

return self._ice.schema().as_arrow()
return table.schema().as_arrow()

def getCurrentVersion(self) -> str | None:
if not self._isOpen:
self._connect()
if self._ice is None:
raise Exception("Table not connected")
table = self.getConn()

snapshot = self._ice.current_snapshot()
snapshot = table.current_snapshot()
if snapshot is None:
return None
return str(snapshot.snapshot_id)

def getVersions(self) -> List[str] | None:
if not self._isOpen:
self._connect()
if self._ice is None:
raise Exception("Table not connected")
table = self.getConn()

snapshots: paTable = self._ice.inspect.snapshots()
snapshots: paTable = table.inspect.snapshots()
chunked = snapshots.column("snapshot_id")
return [str(x) for x in chunked.to_pylist()]

def insert(self, data: List[Any]) -> bool:
if not self._isOpen:
self._connect()
if self._ice is None:
raise Exception("Table not connected")
table = self.getConn()

self._ice.append(pa.Table.from_pylist(data, self.getSchema()))
table.append(pa.Table.from_pylist(data, self.getSchema()))
return True

def overwrite(
self, data: List[Any], partitions: Dict[str, List[Any]] | None = None
) -> bool:
if not self._isOpen:
self._connect()
if self._ice is None:
raise Exception("Table not connected")
table = self.getConn()

df = pa.Table.from_pylist(data, self.getSchema())
if partitions is None or len(partitions) == 0:
self._ice.overwrite(df)
table.overwrite(df)
return True

# when partitions are provided, we need to filter the data
Expand All @@ -168,7 +146,7 @@ def overwrite(
for i in range(1, len(boolPartition)):
right_side = And(right_side, boolPartition[i])

self._ice.overwrite(df, overwrite_filter=right_side)
table.overwrite(df, overwrite_filter=right_side)
return True

def readRaw(
Expand All @@ -178,10 +156,7 @@ def readRaw(
version: str | None = None,
options: Any | None = None,
) -> DataScan:
if not self._isOpen:
self._connect()
if self._ice is None:
raise Exception("Table not connected")
table = self.getConn()

if options is None:
options = {}
Expand All @@ -197,7 +172,7 @@ def readRaw(
options.get("row_filter", AlwaysTrue()), right_side
)

return self._ice.scan(
return table.scan(
row_filter=options.get("row_filter", AlwaysTrue()),
selected_fields=tuple(columns),
limit=options.get("limit", None),
Expand Down
Loading