From c8541428056ceedd62d4ac701dbb67610b7a034d Mon Sep 17 00:00:00 2001 From: Yusuf Ali Date: Sat, 11 Jan 2025 19:15:38 -0500 Subject: [PATCH] feat: upgrade to latest servc spec --- .github/workflows/servc.yml | 2 +- README.md | 12 ++--- main.py | 42 ++++----------- servc/server.py | 36 ++++--------- servc/svc/__init__.py | 10 ++-- servc/svc/com/bus/__init__.py | 43 +++++++++++---- servc/svc/com/bus/rabbitmq.py | 4 +- servc/svc/com/cache/__init__.py | 11 ++-- servc/svc/com/cache/redis.py | 7 +++ servc/svc/com/http/__init__.py | 49 ++++++++--------- servc/svc/com/storage/__init__.py | 26 ++------- servc/svc/com/storage/iceberg.py | 34 +++++++++--- servc/svc/com/storage/lake.py | 17 +++--- servc/svc/com/worker/__init__.py | 64 ++++++++--------------- servc/svc/com/worker/hooks/__init__.py | 15 ++---- servc/svc/com/worker/hooks/parallelize.py | 17 +++--- servc/svc/com/worker/types.py | 13 ++++- servc/svc/config/__init__.py | 2 + tests/hooks/test_complete.py | 15 +++--- tests/hooks/test_parallelize.py | 48 +++++++++-------- tests/lake/test_iceberg.py | 3 +- tests/svc/test_rabbitmq.py | 7 ++- tests/svc/test_redis.py | 2 +- 23 files changed, 218 insertions(+), 261 deletions(-) diff --git a/.github/workflows/servc.yml b/.github/workflows/servc.yml index 2b6efd5..c18c8be 100644 --- a/.github/workflows/servc.yml +++ b/.github/workflows/servc.yml @@ -4,7 +4,7 @@ on: push: env: - SERVC_VERSION: 0.4.2 + SERVC_VERSION: 0.5.0 permissions: contents: write diff --git a/README.md b/README.md index 3139e14..3da2689 100644 --- a/README.md +++ b/README.md @@ -13,21 +13,15 @@ Serv-C implmentation for Python. Documentation can be found [here][1] Here is the most simple example of use, starting a server to handle requests at the route `my-route`; ```python -from typing import Any, List +from typing import Any -from servc.svc import Middleware from servc.server import start_server -from servc.svc.com.bus import BusComponent -from servc.svc.com.cache import CacheComponent -from servc.svc.com.worker.types import EMIT_EVENT, RESOLVER_RETURN_TYPE +from servc.svc.com.worker.types import RESOLVER_CONTEXT, RESOLVER_RETURN_TYPE def inputProcessor( messageId: str, - bus: BusComponent, - cache: CacheComponent, payload: Any, - components: List[Middleware], - emitEvent: EMIT_EVENT, + context: RESOLVER_CONTEXT, ) -> RESOLVER_RETURN_TYPE: return True diff --git a/main.py b/main.py index b10e9b3..1ed90d3 100755 --- a/main.py +++ b/main.py @@ -1,25 +1,17 @@ #!/usr/bin/env python import os -from typing import Any, List +from typing import Any from servc.server import start_server -from servc.svc import Middleware from servc.svc.client.send import sendMessage -from servc.svc.com.bus import BusComponent -from servc.svc.com.cache import CacheComponent -from servc.svc.com.worker.types import EMIT_EVENT, RESOLVER_RETURN_TYPE +from servc.svc.com.worker.types import RESOLVER_CONTEXT, RESOLVER_RETURN_TYPE from servc.svc.idgen.simple import simple from servc.svc.io.input import InputType def test_resolver( - id: str, - bus: BusComponent, - cache: CacheComponent, - payload: str | list[str], - _c: List[Middleware], - emitEvent: EMIT_EVENT, + id: str, payload: Any, context: RESOLVER_CONTEXT ) -> RESOLVER_RETURN_TYPE: if not isinstance(payload, list): sendMessage( @@ -34,8 +26,8 @@ def test_resolver( "inputs": payload, }, }, - bus, - cache, + context["bus"], + context["cache"], simple, ) return False @@ -43,32 +35,18 @@ def test_resolver( if not isinstance(x, str): return False - emitEvent( + context["bus"].emitEvent( os.getenv("EVENT", "my-event"), payload, ) return True -def test_hook( - id: str, - _b: BusComponent, - _c: CacheComponent, - p: List[Any], - _ch: List[Middleware], - _e: EMIT_EVENT, -) -> RESOLVER_RETURN_TYPE: - return [x for x in p] +def test_hook(id: str, payload: Any, context: RESOLVER_CONTEXT) -> RESOLVER_RETURN_TYPE: + return [x for x in payload] -def fail( - id: str, - _b: BusComponent, - _c: CacheComponent, - _p: Any, - _ch: List[Middleware], - _e: EMIT_EVENT, -) -> RESOLVER_RETURN_TYPE: +def fail(id: str, payload: Any, context: RESOLVER_CONTEXT) -> RESOLVER_RETURN_TYPE: raise Exception("This is a test exception") @@ -77,7 +55,7 @@ def main(): resolver={ "test": test_resolver, "fail": fail, - "hook": lambda id, _b, _c, p, _ch, _e: len(p), + "hook": lambda id, p, _c: len(p), "hook_part": test_hook, }, ) diff --git a/servc/server.py b/servc/server.py index c0861ac..f7155e4 100644 --- a/servc/server.py +++ b/servc/server.py @@ -1,5 +1,5 @@ from multiprocessing import Process -from typing import Any, List, Tuple +from typing import List from servc.svc import Middleware from servc.svc.com.bus import BusComponent, OnConsuming @@ -15,14 +15,7 @@ def blankOnConsuming(route: str): print("Consuming on route", route, flush=True) -COMPONENT_ARRAY = List[Tuple[type[Middleware], List[Any]]] - - -def compose_components(component_list: COMPONENT_ARRAY) -> List[Middleware]: - components: List[Middleware] = [] - for [componentClass, args] in component_list: - components.append(componentClass(*args)) - return components +COMPONENT_ARRAY = List[type[Middleware]] def start_consumer( @@ -38,15 +31,10 @@ def start_consumer( ): config = configClass() config.setAll(configDictionary) - bus = busClass( - config.get("conf.bus.url"), - config.get("conf.bus.routemap"), - config.get("conf.bus.prefix"), - ) - cache = cacheClass(config.get("conf.cache.url")) + bus = busClass(config.get(f"conf.{busClass.name}")) + cache = cacheClass(config.get(f"conf.{cacheClass.name}")) + consumer = workerClass( - config.get("conf.bus.route"), - config.get("conf.instanceid"), resolver, eventResolver, onConsuming, @@ -54,7 +42,7 @@ def start_consumer( busClass, cache, config, - compose_components(components), + [X(config.get(f"conf.{X.name}")) for X in components], ) consumer.connect() @@ -93,18 +81,12 @@ def start_server( ) consumer.start() - bus = busClass( - config.get("conf.bus.url"), - config.get("conf.bus.routemap"), - config.get("conf.bus.prefix"), - ) - cache = cacheClass(config.get("conf.cache.url")) + bus = busClass(config.get(f"conf.{busClass.name}")) + cache = cacheClass(config.get(f"conf.{cacheClass.name}")) http = httpClass( - int(config.get("conf.http.port")), + config.get(f"conf.{httpClass.name}"), bus, cache, - config.get("conf.bus.route"), - config.get("conf.instanceid"), consumer, resolver, eventResolver, diff --git a/servc/svc/__init__.py b/servc/svc/__init__.py index 7ae3f0f..a042d4c 100644 --- a/servc/svc/__init__.py +++ b/servc/svc/__init__.py @@ -3,6 +3,8 @@ from enum import Enum from typing import Callable, List +from servc.svc.config import Config + class ComponentType(Enum): BUS = "bus" @@ -16,7 +18,7 @@ class ComponentType(Enum): class Middleware: _children: List[Middleware] - _name: str + name: str _isReady: bool @@ -28,15 +30,11 @@ class Middleware: _close: Callable[..., bool] - def __init__(self): + def __init__(self, _config: Config): self._children = [] self._isReady = False self._isOpen = False - @property - def name(self) -> str: - return self._name - @property def isReady(self) -> bool: isReadyCheck = self._isReady diff --git a/servc/svc/com/bus/__init__.py b/servc/svc/com/bus/__init__.py index 18fbb81..b8a1eac 100644 --- a/servc/svc/com/bus/__init__.py +++ b/servc/svc/com/bus/__init__.py @@ -1,6 +1,7 @@ -from typing import Any, Callable, Union +from typing import Any, Callable, Dict, Union from servc.svc import ComponentType, Middleware +from servc.svc.config import Config from servc.svc.io.input import EventPayload, InputPayload, InputType from servc.svc.io.output import StatusCode @@ -10,20 +11,40 @@ class BusComponent(Middleware): + name: str = "bus" + _type: ComponentType = ComponentType.BUS _url: str - _routeMap: dict + _routeMap: Dict[str, str] _prefix: str - def __init__(self, url: str, routeMap: dict, prefix: str): - super().__init__() + _instanceId: str + + _route: str + + def __init__(self, config: Config): + super().__init__(config) + + self._url = str(config.get("url")) + self._prefix = str(config.get("prefix")) + self._instanceId = str(config.get("instanceid")) + self._route = str(config.get("route")) + + routemap = config.get("routemap") + if routemap is None or not isinstance(routemap, dict): + routemap = {} + self._routeMap = routemap + + @property + def instanceId(self) -> str: + return self._instanceId - self._url = url - self._routeMap = routeMap - self._prefix = prefix + @property + def route(self) -> str: + return self._route def getRoute(self, route: str) -> str: if route in self._routeMap: @@ -33,7 +54,7 @@ def getRoute(self, route: str) -> str: def publishMessage(self, route: str, message: InputPayload | EventPayload) -> bool: return True - def emitEvent(self, event: str, instanceId: str, details: Any) -> bool: + def emitEvent(self, event: str, details: Any) -> bool: return self.publishMessage( self.getRoute(event), { @@ -41,11 +62,11 @@ def emitEvent(self, event: str, instanceId: str, details: Any) -> bool: "route": self.getRoute(event), "event": event, "details": details, - "instanceId": instanceId, + "instanceId": self._instanceId, }, ) - def create_queue(self, queue: str, bindEventExchange: bool = True) -> bool: + def create_queue(self, queue: str, bindEventExchange: bool) -> bool: return False def delete_queue(self, queue: str) -> bool: @@ -59,6 +80,6 @@ def subscribe( route: str, inputProcessor: InputProcessor, onConsuming: OnConsuming | None, - bindEventExchange: bool = True, + bindEventExchange: bool, ) -> bool: return True diff --git a/servc/svc/com/bus/rabbitmq.py b/servc/svc/com/bus/rabbitmq.py index b5d2d1c..0d5c8c9 100644 --- a/servc/svc/com/bus/rabbitmq.py +++ b/servc/svc/com/bus/rabbitmq.py @@ -108,7 +108,7 @@ def get_channel(self, method: Callable | None, args: Tuple | None): on_open_callback=lambda c: on_channel_open(c, method, args) ) - def create_queue(self, queue: str, bindEventExchange: bool = False, channel: pika.channel.Channel | None = None) -> bool: # type: ignore + def create_queue(self, queue: str, bindEventExchange: bool, channel: pika.channel.Channel | None = None) -> bool: # type: ignore if not self.isReady: return self._connect(self.create_queue, (queue, bindEventExchange)) if not channel: @@ -180,7 +180,7 @@ def subscribe( # type: ignore route: str, inputProcessor: InputProcessor, onConsuming: OnConsuming | None, - bindEventExchange: bool = True, + bindEventExchange: bool, channel: pika.channel.Channel | None = None, ) -> bool: if not self.isReady: diff --git a/servc/svc/com/cache/__init__.py b/servc/svc/com/cache/__init__.py index ad29574..62c1394 100644 --- a/servc/svc/com/cache/__init__.py +++ b/servc/svc/com/cache/__init__.py @@ -1,19 +1,18 @@ from typing import Any from servc.svc import ComponentType, Middleware +from servc.svc.config import Config from servc.svc.io.output import StatusCode from servc.svc.io.response import generateResponseArtifact class CacheComponent(Middleware): - _type: ComponentType = ComponentType.CACHE - - _url: str + name: str = "cache" - def __init__(self, url: str): - super().__init__() + _type: ComponentType = ComponentType.CACHE - self._url = url + def __init__(self, config: Config): + super().__init__(config) def setKey(self, id: str, value: Any) -> str: return "" diff --git a/servc/svc/com/cache/redis.py b/servc/svc/com/cache/redis.py index 4d6cce1..54264be 100644 --- a/servc/svc/com/cache/redis.py +++ b/servc/svc/com/cache/redis.py @@ -7,6 +7,7 @@ from redis import Redis from servc.svc.com.cache import CacheComponent +from servc.svc.config import Config def decimal_default(obj: Any) -> None | str | float: @@ -20,6 +21,12 @@ def decimal_default(obj: Any) -> None | str | float: class CacheRedis(CacheComponent): _redisClient: Redis + _url: str + + def __init__(self, config: Config): + super().__init__(config) + self._url = str(config.get("url")) + @property def conn(self): return self._redisClient diff --git a/servc/svc/com/http/__init__.py b/servc/svc/com/http/__init__.py index 5bd7d0a..f18e9ec 100644 --- a/servc/svc/com/http/__init__.py +++ b/servc/svc/com/http/__init__.py @@ -1,6 +1,6 @@ import os from multiprocessing import Process -from typing import Dict, Tuple, TypedDict +from typing import Dict, List, Tuple, TypedDict from flask import Flask, jsonify, request # type: ignore @@ -9,6 +9,7 @@ from servc.svc.com.bus import BusComponent from servc.svc.com.cache import CacheComponent from servc.svc.com.worker import RESOLVER_MAPPING +from servc.svc.config import Config from servc.svc.idgen.simple import simple from servc.svc.io.input import InputPayload, InputType from servc.svc.io.output import StatusCode @@ -29,6 +30,8 @@ def methodGrabber(m: RESOLVER_MAPPING) -> Dict[str, Tuple[str, ...]]: class HTTPInterface(Middleware): + name: str = "http" + _type: ComponentType = ComponentType.INTERFACE _port: int @@ -41,39 +44,30 @@ class HTTPInterface(Middleware): _consumer: Process - _route: str - - _instanceId: str - _info: ServiceInformation def __init__( self, - port: int, + config: Config, bus: BusComponent, cache: CacheComponent, - route: str, - instanceId: str, consumerthread: Process, resolvers: RESOLVER_MAPPING, eventResolvers: RESOLVER_MAPPING, ): - super().__init__() - self._port = port + super().__init__(config) + self._port = int(config.get("port")) self._server = Flask(__name__) self._bus = bus self._cache = cache self._children.append(self._bus) self._children.append(self._cache) - - self._route = route - self._instanceId = instanceId self._consumer = consumerthread self._info = { - "instanceId": self._instanceId, - "queue": self._bus.getRoute(self._route), + "instanceId": self._bus.instanceId, + "queue": self._bus.route, "methods": methodGrabber(resolvers), "eventHandlers": methodGrabber(eventResolvers), } @@ -106,6 +100,8 @@ def _health(self): except AssertionError: pid = self._consumer.pid try: + if not pid: + raise Exception("No PID") os.kill(pid, 0) except OSError: consumerAlive = False @@ -126,29 +122,28 @@ def _postMessage(self): return self._getInformation() if content_type == "application/json": body = request.json + if not body: + return "bad request", StatusCode.INVALID_INPUTS.value # compatibility patch - if "inputs" in body and "argument" not in body: + if "inputs" in body and "argument" not in body and body["inputs"]: body["argument"] = body["inputs"] - must_have_keys = ("type",) + must_have_keys: List[str] = ["type"] for key in must_have_keys: if key not in body: return f"missing key {key}", StatusCode.INVALID_INPUTS.value if body["type"] == InputType.EVENT.value: - must_have_keys = ("event", "details") + must_have_keys = ["event", "details"] for key in must_have_keys: if key not in body: return f"missing key {key}", StatusCode.INVALID_INPUTS.value - instanceId = ( - body["instanceId"] if "instanceId" in body else self._instanceId - ) - id = self._bus.emitEvent(body["event"], instanceId, body["details"]) + id = self._bus.emitEvent(body["event"], body["details"]) return id elif body["type"] == InputType.INPUT.value: - must_have_keys = ("route", "argument") + must_have_keys = ["route", "argument"] for key in must_have_keys: if key not in body: return f"missing key {key}", StatusCode.INVALID_INPUTS.value @@ -161,16 +156,16 @@ def _postMessage(self): } if "instanceId" in body: payload["instanceId"] = body["instanceId"] + force: bool = True if "force" in body and body["force"] else False - id = sendMessage( + res_id = sendMessage( payload, self._bus, self._cache, simple, - True if "force" in body and body["force"] else False, - [], + force=force, ) - return id + return res_id else: return "bad request", StatusCode.INVALID_INPUTS.value diff --git a/servc/svc/com/storage/__init__.py b/servc/svc/com/storage/__init__.py index e973ad0..52684a4 100644 --- a/servc/svc/com/storage/__init__.py +++ b/servc/svc/com/storage/__init__.py @@ -1,29 +1,9 @@ -from typing import Any, Dict - from servc.svc import ComponentType, Middleware -from servc.svc.io.output import StatusCode +from servc.svc.config import Config class StorageComponent(Middleware): _type: ComponentType = ComponentType.STORAGE - _config: Dict[str, str] - - def __init__(self, config: Dict[str, Any] | None): - super().__init__() - - if config is None: - config = {} - self._config = config - - # def list(self, path: str) -> Tuple[str]: - # return tuple([]) - - def delete(self, path: str) -> bool: - return False - - def get(self, path: str) -> Any: - return None - - def upload(self, path: str, data: Any) -> StatusCode: - return StatusCode.OK + def __init__(self, config: Config): + super().__init__(config) diff --git a/servc/svc/com/storage/iceberg.py b/servc/svc/com/storage/iceberg.py index 3de97a8..b32e285 100644 --- a/servc/svc/com/storage/iceberg.py +++ b/servc/svc/com/storage/iceberg.py @@ -12,23 +12,26 @@ from pyiceberg.types import NestedField from servc.svc.com.storage.lake import Lake, LakeTable +from servc.svc.config import Config class IceBerg(Lake): - # _config # _table _catalog: Catalog _ice: Table | None - def __init__(self, config: Dict[str, Any] | None, table: LakeTable | str): + def __init__(self, config: Config, table: LakeTable | str): super().__init__(config, table) - if not config: - raise Exception("Config is required") + catalog_name = str(config.get("catalog_name")) + catalog_properties_raw = config.get("catalog_properties") + if not isinstance(catalog_properties_raw, dict): + catalog_properties_raw = {} + catalog_properties: Dict = catalog_properties_raw self._catalog = load_catalog( - config.get("catalog_name", None), - **{**config.get("catalog_properties", {})}, + catalog_name, + **{**catalog_properties}, ) self._ice = None @@ -60,7 +63,7 @@ def _connect(self): ) partitionSpec: PartitionSpec = PartitionSpec(*partitions) - self._catalog.create_namespace_if_not_exists(self._getDatabase()) + self._catalog.create_namespace_if_not_exists(self._database) self._ice = self._catalog.create_table( tableName, self._table["schema"], @@ -73,7 +76,7 @@ def _connect(self): self._isReady = self._table is not None self._isOpen = self._table is not None - return None + return self._table is not None def _close(self): if self._isOpen: @@ -83,6 +86,8 @@ def _close(self): return False def getPartitions(self) -> Dict[str, List[Any]] | None: + if not self._isOpen: + self._connect() if self._ice is None: raise Exception("Table not connected") @@ -96,20 +101,27 @@ def getPartitions(self) -> Dict[str, List[Any]] | None: return partitions def getSchema(self) -> Schema | None: + if not self._isOpen: + self._connect() if self._ice is None: raise Exception("Table not connected") return self._ice.schema().as_arrow() def getCurrentVersion(self) -> str | None: + if not self._isOpen: + self._connect() if self._ice is None: raise Exception("Table not connected") + snapshot = self._ice.current_snapshot() if snapshot is None: return None return str(snapshot.snapshot_id) def getVersions(self) -> List[str] | None: + if not self._isOpen: + self._connect() if self._ice is None: raise Exception("Table not connected") @@ -118,6 +130,8 @@ def getVersions(self) -> List[str] | None: return [str(x) for x in chunked.to_pylist()] def insert(self, data: List[Any]) -> bool: + if not self._isOpen: + self._connect() if self._ice is None: raise Exception("Table not connected") @@ -127,6 +141,8 @@ def insert(self, data: List[Any]) -> bool: def overwrite( self, data: List[Any], partitions: Dict[str, List[Any]] | None = None ) -> bool: + if not self._isOpen: + self._connect() if self._ice is None: raise Exception("Table not connected") @@ -154,6 +170,8 @@ def readRaw( version: str | None = None, options: Any | None = None, ) -> DataScan: + if not self._isOpen: + self._connect() if self._ice is None: raise Exception("Table not connected") diff --git a/servc/svc/com/storage/lake.py b/servc/svc/com/storage/lake.py index 386de82..ee765fe 100644 --- a/servc/svc/com/storage/lake.py +++ b/servc/svc/com/storage/lake.py @@ -4,6 +4,7 @@ from pyarrow import RecordBatchReader, Schema, Table from servc.svc.com.storage import StorageComponent +from servc.svc.config import Config class Medallion(Enum): @@ -21,19 +22,23 @@ class LakeTable(TypedDict): class Lake(StorageComponent): + name: str = "lake" + _table: Any - def __init__(self, config: Dict[str, Any] | None, table: LakeTable | str): + _database: str + + def __init__(self, config: Config, table: LakeTable | str): super().__init__(config) + self._table = table + self._database = str(config.get("database")) + if not isinstance(self._table, str) and "options" not in self._table: self._table["options"] = {} - def _getDatabase(self) -> str: - return self._config.get("database", "default") - def _get_table_name(self) -> str: - schema: str = self._getDatabase() + schema: str = self._database name_w_medallion: str = "" if isinstance(self._table, str): @@ -46,7 +51,7 @@ def _get_table_name(self) -> str: return ".".join([schema, name_w_medallion]) @property - def name(self) -> str: + def tablename(self) -> str: return self._get_table_name() def getPartitions(self) -> Dict[str, List[Any]] | None: diff --git a/servc/svc/com/worker/__init__.py b/servc/svc/com/worker/__init__.py index bcb539b..aecd98a 100644 --- a/servc/svc/com/worker/__init__.py +++ b/servc/svc/com/worker/__init__.py @@ -4,7 +4,7 @@ from servc.svc.com.bus import BusComponent, OnConsuming from servc.svc.com.cache import CacheComponent from servc.svc.com.worker.hooks import evaluate_post_hooks, evaluate_pre_hooks -from servc.svc.com.worker.types import EMIT_EVENT, RESOLVER_MAPPING +from servc.svc.com.worker.types import RESOLVER_CONTEXT, RESOLVER_MAPPING from servc.svc.config import Config from servc.svc.io.input import InputType from servc.svc.io.output import ( @@ -17,21 +17,17 @@ from servc.svc.io.response import getAnswerArtifact, getErrorArtifact -def HEALTHZ( - _id: str, bus: BusComponent, cache: CacheComponent, _any: Any, c: List[Middleware] -) -> StatusCode: - for component in [bus, cache, *c]: +def HEALTHZ(_id: str, _any: Any, c: RESOLVER_CONTEXT) -> StatusCode: + for component in [c["bus"], c["cache"], *c["middlewares"]]: if not component.isReady: return StatusCode.SERVER_ERROR return StatusCode.OK class WorkerComponent(Middleware): - _type: ComponentType = ComponentType.WORKER - - _route: str + name: str = "worker" - _instanceId: str + _type: ComponentType = ComponentType.WORKER _resolvers: RESOLVER_MAPPING @@ -51,8 +47,6 @@ class WorkerComponent(Middleware): def __init__( self, - route: str, - instanceId: str, resolvers: RESOLVER_MAPPING, eventResolvers: RESOLVER_MAPPING, onConsuming: OnConsuming, @@ -62,9 +56,7 @@ def __init__( config: Config, otherComponents: List[Middleware] = [], ): - super().__init__() - self._route = route - self._instanceId = instanceId + super().__init__(config) self._resolvers = resolvers self._eventResolvers = eventResolvers self._onConsuming = onConsuming @@ -73,7 +65,7 @@ def __init__( self._cache = cache self._config = config self._bindToEventExchange = ( - config.get("conf.bus.bindtoeventexchange") + config.get(f"conf.{self.name}.bindtoeventexchange") if len(self._eventResolvers.keys()) > 0 else False ) @@ -97,30 +89,30 @@ def connect(self): super().connect() print("Consumer now Subscribing", flush=True) - print(" Route:", self._route, flush=True) - print(" InstanceId:", self._instanceId, flush=True) + print(" Route:", self._bus.route, flush=True) + print(" InstanceId:", self._bus.instanceId, flush=True) print(" Resolvers:", self._resolvers.keys(), flush=True) print(" Event Resolvers:", self._eventResolvers.keys(), flush=True) print(" Bind to Event Exchange:", self._bindToEventExchange, flush=True) self._bus.subscribe( - self._route, + self._bus.route, self.inputProcessor, self._onConsuming, bindEventExchange=self._bindToEventExchange, ) - def emitEvent(self, bus: BusComponent, eventName: str, details: Any): - bus.emitEvent(eventName, self._instanceId, details) - def inputProcessor(self, message: Any) -> StatusCode: bus = self._busClass( - self._config.get("conf.bus.url"), - self._config.get("conf.bus.routemap"), - self._config.get("conf.bus.prefix"), + self._config.get(f"conf.{self._bus.name}"), ) cache = self._cache - emitEvent: EMIT_EVENT = lambda x, y: self.emitEvent(bus, x, y) + context: RESOLVER_CONTEXT = { + "bus": bus, + "cache": cache, + "middlewares": self._children, + "config": self._config, + } if "type" not in message or "route" not in message: return StatusCode.INVALID_INPUTS @@ -134,14 +126,7 @@ def inputProcessor(self, message: Any) -> StatusCode: return StatusCode.INVALID_INPUTS if message["event"] not in self._eventResolvers: return StatusCode.METHOD_NOT_FOUND - self._eventResolvers[message["event"]]( - "", - bus, - cache, - {**message}, - self._children, - emitEvent, - ) + self._eventResolvers[message["event"]]("", {**message}, context) return StatusCode.OK if message["type"] in [InputType.INPUT.value, InputType.INPUT]: @@ -157,7 +142,7 @@ def inputProcessor(self, message: Any) -> StatusCode: ), ) return StatusCode.INVALID_INPUTS - if "instanceId" in message and message["instanceId"] != self._instanceId: + if "instanceId" in message and message["instanceId"] != bus.instanceId: return StatusCode.NO_PROCESSING if message["argumentId"] in ["raw", "plain"] and message["inputs"]: @@ -184,14 +169,10 @@ def inputProcessor(self, message: Any) -> StatusCode: return StatusCode.METHOD_NOT_FOUND continueExecution = evaluate_pre_hooks( - self._route, self._resolvers, - bus, - cache, message, artifact, - self._children, - emitEvent, + context, ) if not continueExecution: return StatusCode.OK @@ -200,11 +181,8 @@ def inputProcessor(self, message: Any) -> StatusCode: try: response = self._resolvers[artifact["method"]]( message["id"], - bus, - cache, artifact["inputs"], - self._children, - emitEvent, + context, ) cache.setKey(message["id"], getAnswerArtifact(message["id"], response)) except NotAuthorizedException as e: diff --git a/servc/svc/com/worker/hooks/__init__.py b/servc/svc/com/worker/hooks/__init__.py index 038f172..e2ef16f 100644 --- a/servc/svc/com/worker/hooks/__init__.py +++ b/servc/svc/com/worker/hooks/__init__.py @@ -1,6 +1,3 @@ -from typing import List - -from servc.svc import Middleware from servc.svc.com.bus import BusComponent from servc.svc.com.cache import CacheComponent from servc.svc.com.worker.hooks.oncomplete import process_complete_hook @@ -8,7 +5,7 @@ evaluate_part_pre_hook, process_post_part_hook, ) -from servc.svc.com.worker.types import EMIT_EVENT, RESOLVER_MAPPING +from servc.svc.com.worker.types import RESOLVER_CONTEXT, RESOLVER_MAPPING from servc.svc.io.hooks import Hooks, OnCompleteHook, PartHook from servc.svc.io.input import ArgumentArtifact, InputPayload @@ -45,14 +42,10 @@ def evaluate_post_hooks( def evaluate_pre_hooks( - route: str, resolvers: RESOLVER_MAPPING, - bus: BusComponent, - cache: CacheComponent, message: InputPayload, artifact: ArgumentArtifact, - children: List[Middleware], - emit: EMIT_EVENT, + context: RESOLVER_CONTEXT, ) -> bool: hooks: Hooks = {} if "hooks" in artifact and isinstance(artifact["hooks"], dict): @@ -61,9 +54,7 @@ def evaluate_pre_hooks( return True for prehook in (evaluate_part_pre_hook,): - continueExecution = prehook( - route, resolvers, bus, cache, message, artifact, children, emit - ) + continueExecution = prehook(resolvers, message, artifact, context) if not continueExecution: return False diff --git a/servc/svc/com/worker/hooks/parallelize.py b/servc/svc/com/worker/hooks/parallelize.py index 3582c2b..9845463 100644 --- a/servc/svc/com/worker/hooks/parallelize.py +++ b/servc/svc/com/worker/hooks/parallelize.py @@ -1,10 +1,9 @@ from typing import List -from servc.svc import Middleware from servc.svc.client.send import sendMessage from servc.svc.com.bus import BusComponent from servc.svc.com.cache import CacheComponent -from servc.svc.com.worker.types import EMIT_EVENT, RESOLVER_MAPPING +from servc.svc.com.worker.types import RESOLVER_CONTEXT, RESOLVER_MAPPING from servc.svc.idgen.simple import simple as idGenerator from servc.svc.io.hooks import Hooks, OnCompleteHook, PartHook from servc.svc.io.input import ArgumentArtifact, InputPayload, InputType @@ -42,23 +41,21 @@ def process_post_part_hook( def evaluate_part_pre_hook( - route: str, resolvers: RESOLVER_MAPPING, - bus: BusComponent, - cache: CacheComponent, message: InputPayload, artifact: ArgumentArtifact, - children: List[Middleware], - emit: EMIT_EVENT, + context: RESOLVER_CONTEXT, ) -> bool: + bus = context["bus"] + cache = context["cache"] + route = bus.route hooks: Hooks = artifact.get("hooks", {}) method = artifact["method"] part_method = f"{method}_part" if part_method not in resolvers: return True - jobs = resolvers[part_method]( - message["id"], bus, cache, artifact, children, emit) + jobs = resolvers[part_method](message["id"], artifact, context) if not isinstance(jobs, list): print(f"Resolver {part_method} did not return a list") return True @@ -88,7 +85,7 @@ def evaluate_part_pre_hook( # create task queue task_queue = f"part.{route}-{method}-{message['id']}" - bus.create_queue(task_queue) + bus.create_queue(task_queue, False) # publish messages to part queue payload_template: InputPayload = { diff --git a/servc/svc/com/worker/types.py b/servc/svc/com/worker/types.py index aa8ea05..7de58d8 100644 --- a/servc/svc/com/worker/types.py +++ b/servc/svc/com/worker/types.py @@ -1,16 +1,25 @@ -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, List, TypedDict, Union from servc.svc import Middleware from servc.svc.com.bus import BusComponent from servc.svc.com.cache import CacheComponent +from servc.svc.config import Config from servc.svc.io.output import StatusCode EMIT_EVENT = Callable[[str, Any], None] RESOLVER_RETURN_TYPE = Union[StatusCode, Any, None] + +class RESOLVER_CONTEXT(TypedDict): + bus: BusComponent + cache: CacheComponent + middlewares: List[Middleware] + config: Config + + RESOLVER = Callable[ - [str, BusComponent, CacheComponent, Any, List[Middleware], EMIT_EVENT], + [str, Any, RESOLVER_CONTEXT], RESOLVER_RETURN_TYPE, ] diff --git a/servc/svc/config/__init__.py b/servc/svc/config/__init__.py index f634e49..823ea60 100644 --- a/servc/svc/config/__init__.py +++ b/servc/svc/config/__init__.py @@ -52,6 +52,8 @@ def __init__(self, config_path: str | None = None): ): self.setValue(key.replace("__", ".").lower(), value) + self.setValue("conf.bus.instanceid", self.get("conf.instanceid")) + # validate conf.file matches config_path. Otherwise, raise an exception because we are not able to load the configuration file if self.get("conf.file") != config_path: raise Exception("Configuration file does not match the configuration path") diff --git a/tests/hooks/test_complete.py b/tests/hooks/test_complete.py index 29217e4..23b2bca 100644 --- a/tests/hooks/test_complete.py +++ b/tests/hooks/test_complete.py @@ -2,13 +2,13 @@ import pika -from tests.hooks import get_route_message from servc.svc.com.bus.rabbitmq import BusRabbitMQ from servc.svc.com.cache.redis import CacheRedis from servc.svc.com.worker.hooks.oncomplete import process_complete_hook from servc.svc.config import Config from servc.svc.io.hooks import CompleteHookType from servc.svc.io.input import ArgumentArtifact, InputPayload, InputType +from tests.hooks import get_route_message message: InputPayload = { "id": "123", @@ -35,8 +35,8 @@ class TestCompleteHook(unittest.TestCase): @classmethod def setUpClass(cls) -> None: config = Config() - cls.bus = BusRabbitMQ(config.get("conf.bus.url"), {}, "") - cls.cache = CacheRedis(config.get("conf.cache.url")) + cls.bus = BusRabbitMQ(config.get("conf.bus")) + cls.cache = CacheRedis(config.get("conf.cache")) params = pika.URLParameters(config.get("conf.bus.url")) cls.conn = pika.BlockingConnection(params) @@ -51,14 +51,15 @@ def tearDownClass(cls) -> None: cls.conn.close() def setUp(self): - self.bus.create_queue("random") + self.bus.create_queue("random", False) def tearDown(self): self.bus.delete_queue("random") def test_complete_hook_simple(self): - res = process_complete_hook(self.bus, self.cache, message, - art, art["hooks"]["on_complete"][0]) + res = process_complete_hook( + self.bus, self.cache, message, art, art["hooks"]["on_complete"][0] + ) body, _ = get_route_message(self.channel, self.cache, "random") self.assertTrue(body["argument"]["inputs"]["inputs"], art["inputs"]) @@ -71,7 +72,7 @@ def test_w_hook_override(self): "route": "random", "inputs": True, } - res = process_complete_hook(self.bus, self.cache, message, art, hook) + res = process_complete_hook(self.bus, self.cache, message, art, hook) body, _ = get_route_message(self.channel, self.cache, "random") self.assertTrue(body["argument"]["inputs"], True) diff --git a/tests/hooks/test_parallelize.py b/tests/hooks/test_parallelize.py index 2cf92d6..8a16599 100644 --- a/tests/hooks/test_parallelize.py +++ b/tests/hooks/test_parallelize.py @@ -34,8 +34,8 @@ partHook: PartHook = art["hooks"]["part"] testMapping: RESOLVER_MAPPING = { - "mymethod": lambda _m, _b, _c, p, *y: len(p), - "mymethod_part": lambda _m, _b, _c, p, *y: [x for x in p], + "mymethod": lambda _m, p, _c: len(p), + "mymethod_part": lambda _m, p, _c: [x for x in p], "myothermethod": lambda *z: 1, } @@ -46,8 +46,14 @@ class TestParallelize(unittest.TestCase): @classmethod def setUpClass(cls) -> None: config = Config() - cls.bus = BusRabbitMQ(config.get("conf.bus.url"), {}, "") - cls.cache = CacheRedis(config.get("conf.cache.url")) + cls.bus = BusRabbitMQ(config.get("conf.bus")) + cls.cache = CacheRedis(config.get("conf.cache")) + cls.context = { + "bus": cls.bus, + "cache": cls.cache, + "middlewares": [], + "config": config, + } params = pika.URLParameters(config.get("conf.bus.url")) cls.conn = pika.BlockingConnection(params) @@ -69,14 +75,14 @@ def test_part_queue(self): self.assertFalse(res) def test_existing_part_queue(self): - self.bus.create_queue("test_part") + self.bus.create_queue("test_part", False) self.bus.publishMessage("test_part", "test") res = process_post_part_hook(self.bus, self.cache, message, art, partHook) self.assertTrue(res) def test_greater_than_total_parts(self): - self.bus.create_queue("test_part") + self.bus.create_queue("test_part", False) self.bus.publishMessage("test_part", "test") self.bus.publishMessage("test_part", "test") @@ -85,46 +91,44 @@ def test_greater_than_total_parts(self): def test_pre_hook_method_check(self): # continue because there is no part method - res = evaluate_part_pre_hook( - "test", testMapping, self.bus, self.cache, message, art, [], emit - ) + self.bus._route = "test" + res = evaluate_part_pre_hook(testMapping, message, art, self.context) self.assertTrue(res) def test_check_w_part_method(self): + self.bus._route = "test" art2 = json.loads(json.dumps(art)) art2["method"] = "mymethod" del art2["hooks"]["part"] # not because there is a part method - res = evaluate_part_pre_hook( - "test", testMapping, self.bus, self.cache, message, art2, [], emit - ) + res = evaluate_part_pre_hook(testMapping, message, art2, self.context) self.assertFalse(res) def test_check_w_part_method_and_hook(self): + self.bus._route = "test" # true because the hook exists already - res = evaluate_part_pre_hook( - "test", testMapping, self.bus, self.cache, message, art, [], emit - ) + res = evaluate_part_pre_hook(testMapping, message, art, self.context) self.assertTrue(res) def test_non_list_partifier(self): + self.bus._route = "test" art2 = json.loads(json.dumps(art)) art2["method"] = "mymethod" del art2["hooks"]["part"] new_mapping: RESOLVER_MAPPING = { - "mymethod": lambda _m, _b, _c, p, *y: len(p), - "mymethod_part": lambda _m, _b, _c, p, *y: len(p), + "mymethod": lambda _m, p, _c: len(p), + "mymethod_part": lambda _m, p, _c: len(p), "myothermethod": lambda *z: 1, } # true because resolver did not return a list - res = evaluate_part_pre_hook( - "test", new_mapping, self.bus, self.cache, message, art2, [], emit - ) + res = evaluate_part_pre_hook(new_mapping, message, art2, self.context) self.assertTrue(res) def test_w_on_complete_hook(self): + self.bus._route = "test" + art2 = json.loads(json.dumps(art)) art2["method"] = "mymethod" del art2["hooks"]["part"] @@ -132,9 +136,7 @@ def test_w_on_complete_hook(self): {"type": "test", "route": "test", "method": "test"} ] - res = evaluate_part_pre_hook( - "test", testMapping, self.bus, self.cache, message, art2, [], emit - ) + res = evaluate_part_pre_hook(testMapping, message, art2, self.context) self.assertFalse(res) diff --git a/tests/lake/test_iceberg.py b/tests/lake/test_iceberg.py index 25ccc8e..e94796b 100644 --- a/tests/lake/test_iceberg.py +++ b/tests/lake/test_iceberg.py @@ -26,6 +26,7 @@ } config = { + "database": "default", "catalog_name": "default", "catalog_properties": { "type": "sql", @@ -47,7 +48,7 @@ def test_connect(self): self.assertTrue(self.iceberg.isOpen) def test_name(self): - self.assertEqual(self.iceberg.name, "default.bronze-test") + self.assertEqual(self.iceberg.tablename, "default.bronze-test") def test_insert(self): self.iceberg.overwrite([]) diff --git a/tests/svc/test_rabbitmq.py b/tests/svc/test_rabbitmq.py index 38c86f9..87a7d96 100644 --- a/tests/svc/test_rabbitmq.py +++ b/tests/svc/test_rabbitmq.py @@ -11,7 +11,7 @@ class TestRabbitMQ(unittest.TestCase): @classmethod def setUpClass(cls) -> None: config = Config() - cls.bus = BusRabbitMQ(config.get("conf.bus.url"), {}, "") + cls.bus = BusRabbitMQ(config.get("conf.bus")) params = pika.URLParameters(config.get("conf.bus.url")) cls.conn = pika.BlockingConnection(params) @@ -41,8 +41,7 @@ def test_get_route(self): self.assertTrue(self.bus.getRoute(route).startswith(prefix)) self.bus._routeMap = mapPrefix - self.assertEqual(self.bus.getRoute(route), - "".join([prefix, mapPrefix[route]])) + self.assertEqual(self.bus.getRoute(route), "".join([prefix, mapPrefix[route]])) self.assertEqual( self.bus.getRoute("fake_route"), "".join([prefix, "fake_route"]) ) @@ -108,7 +107,7 @@ def test_nonexistent_queue_length(self): def test_existent_queue_length(self): route = "test_route" self.bus.delete_queue(route) - self.bus.create_queue(route) + self.bus.create_queue(route, False) self.assertEqual(self.bus.get_queue_length(route), 0) diff --git a/tests/svc/test_redis.py b/tests/svc/test_redis.py index 979ed3d..08cea9d 100644 --- a/tests/svc/test_redis.py +++ b/tests/svc/test_redis.py @@ -10,7 +10,7 @@ class TestRedis(unittest.TestCase): @classmethod def setUpClass(cls) -> None: config = Config() - cls.cache = CacheRedis(config.get("conf.cache.url")) + cls.cache = CacheRedis(config.get("conf.cache")) @classmethod def tearDownClass(cls) -> None: