diff --git a/CHANGES.md b/CHANGES.md index 7be4d0c..ca4cfcc 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,8 @@ # Change Log +# 1.2.0 +- Added `RabbitMQBoundQueueSensor` to support advanced routing options for queues + # 1.1.1 - Updated pip dependency to pika `1.3.x` to support python >= 3.7 diff --git a/config.schema.yaml b/config.schema.yaml index c2f4e6c..1656eb8 100644 --- a/config.schema.yaml +++ b/config.schema.yaml @@ -41,3 +41,77 @@ sensor_config: - "json" - "pickle" required: false +sensor_binding_config: + description: "RabbitMQ Sensor with advanced routing settings" + type: "object" + required: false + additionalProperties: false + properties: + host: + description: "RabbitMQ host to connect to" + type: "string" + required: true + username: + description: "Optional username for RabbitMQ" + type: "string" + password: + description: "Optional password for RabbitMQ" + type: "string" + secret: true + deserialization_method: + description: "Method used to de-serialize body. Default is to leave body as-is" + type: "string" + enum: + - "json" + - "pickle" + required: false + exchanges: + description: "A list of exchanges, with queues and routing rules" + type: "array" + required: true + items: + type: "object" + additionalProperties: false + properties: + name: + description: "Exchange name" + type: "string" + required: true + exchange_type: + description: "Exchange type" + type: "string" + enum: + - "direct" + - "fanout" + - "topic" + - "headers" + required: false + queues: + description: "A list of queues to monitor for this exchange" + type: "array" + required: true + items: + type: "object" + additionalProperties: false + properties: + name: + description: "Queue name" + type: "string" + required: true + bindings: + description: "A list of bindings for queue to have to this exchange" + type: "array" + required: false + items: + type: "object" + additionalProperties: false + properties: + routing_key: + description: "Optional Routing key to bind queue to exchange with" + type: "string" + required: false + arguments: + description: "Optional arguments to provide when binding queue to exhange" + type: "object" + required: false + additionalProperties: true diff --git a/pack.yaml b/pack.yaml index 8882577..5b42a6c 100644 --- a/pack.yaml +++ b/pack.yaml @@ -9,7 +9,7 @@ keywords: - aqmp - stomp - message broker -version: 1.1.1 +version: 1.2.0 python_versions: - "3" author: StackStorm, Inc. diff --git a/rabbitmq.yaml.example b/rabbitmq.yaml.example index 2f7103e..b2a1583 100644 --- a/rabbitmq.yaml.example +++ b/rabbitmq.yaml.example @@ -7,3 +7,42 @@ sensor_config: queues: - "queue1" deserialization_method: "json" +sensor_binding_config: + host: "10.0.0.100" + username: "guest" + password: "guest" + exchanges: + - name: "my_direct_exchange" + queues: + - name: "images" + bindings: + - routing_key: "app.images" + - name: "movies" + bindings: + - routing_key: "app.movies" + - name: "my_fanout_exchange" + exchange_type: "fanout" + queues: + - name: "fanout_1" + - name: "fanout_2" + - name: "fanout_3" + - name: "my_topic_exchange" + exchange_type: "topic" + queues: + - name: "topic_a" + bindings: + - routing_key: "*" + arguments: + x-queue-type: "quorum" + - name: "topic_b" + bindings: + - routing_key: "a.b.c" + - name: "my_headers_exchange" + exchange_type: "headers" + queues: + - name: "headers_1" + bindings: + - arguments: + x-match: "any" + my_header: "val" + other_header: "other_val" diff --git a/sensors/bound_queues_sensor.py b/sensors/bound_queues_sensor.py new file mode 100644 index 0000000..a8cd012 --- /dev/null +++ b/sensors/bound_queues_sensor.py @@ -0,0 +1,311 @@ +import copy +import functools +import json +import pickle +import time + +import pika +import pika.exchange_type + +from st2reactor.sensor.base import Sensor + + +class RabbitMQBoundQueueSensor(Sensor): + TRIGGER = "rabbitmq.routed_message" + + def __init__(self, sensor_service, config=None) -> None: + super(RabbitMQBoundQueueSensor, self).__init__( + sensor_service=sensor_service, config=config + ) + self._logger = self.sensor_service.get_logger(name=self.__class__.__name__) + self._config = config + self.sensor_config = self._config["sensor_binding_config"] + self._message_dispatch = functools.partial( + self._sensor_service.dispatch, trigger=self.TRIGGER + ) + self._consumer = ReconnectingConsumer( + self._logger, self.sensor_config, self._message_dispatch + ) + self._reconnect_delay = 0 + + def run(self): + while True: + self._consumer.run() + self._maybe_reconnect() + + def cleanup(self): + self._consumer.stop() + + def setup(self): + self._consumer.connect() + + def add_trigger(self, trigger): + pass + + def update_trigger(self, trigger): + pass + + def remove_trigger(self, trigger): + pass + + def _maybe_reconnect(self): + if self._consumer.should_reconnect: + self._consumer.stop() + reconnect_delay = self._get_reconnect_delay() + self._logger.info("Reconnecting after %d seconds", reconnect_delay) + time.sleep(reconnect_delay) + self._consumer = ReconnectingConsumer( + self._logger, self.sensor_config, self._message_dispatch + ) + self._consumer.connect() + + def _get_reconnect_delay(self): + if self._consumer.was_consuming: + self._reconnect_delay = 0 + else: + self._reconnect_delay += 1 + if self._reconnect_delay > 30: + self._reconnect_delay = 30 + return self._reconnect_delay + + +class ReconnectingConsumer: + AMQP_PREFETCH = 1 + DEFAULT_USERNAME = pika.ConnectionParameters.DEFAULT_USERNAME + DEFAULT_PASSWORD = pika.ConnectionParameters.DEFAULT_PASSWORD + DEFAULT_EXCHANGE = pika.exchange_type.ExchangeType.direct + DESERIALIZATION_FUNCTIONS = {"json": json.loads, "pickle": pickle.loads} + + def __init__(self, logger, config, message_dispatch) -> None: + self.should_reconnect = False + self.was_consuming = False + + self._logger = logger + self._config = config + self._dispatch = message_dispatch + self._host = self._config["host"] + self._username = self._config.get("username", self.DEFAULT_USERNAME) + self._password = self._config.get("password", self.DEFAULT_PASSWORD) + self._conn = None + self._channel = None + self._closing = False + self._consuming = False + self._consumer_tags = set() + + self._deserialization_method = self._config.get( + "deserialization_method", "json" + ) + if self._deserialization_method not in self.DESERIALIZATION_FUNCTIONS: + raise ValueError( + "Invalid deserialization method specified: %s" + % (self._deserialization_method) + ) + + def run(self): + self._conn.ioloop.start() + + def stop(self): + if not self._closing: + self._closing = True + self._logger.info("Stopping") + if self._consuming: + self._stop_consuming() + self._conn.ioloop.start() + else: + self._conn.ioloop.stop() + self._logger.info("Stopped") + + def connect(self): + credentials = pika.PlainCredentials(self._username, self._password) + connection_params = pika.ConnectionParameters( + host=self._host, credentials=credentials + ) + self._conn = self._open_connection(connection_params) + + def _reconnect(self): + self.should_reconnect = True + self.stop() + + def _open_connection(self, params): + return pika.SelectConnection( + params, + on_open_callback=self._on_connection_open, + on_open_error_callback=self._on_connection_open_error, + on_close_callback=self._on_connection_close, + ) + + def _close_connection(self): + self._consuming = False + if self._conn.is_closing or self._conn.is_closed: + self._logger.debug("Connection is closing or already closed") + else: + self._logger.debug("Closing connection") + self._conn.close() + + def _on_connection_open(self, connection): + self._logger.debug("Connection opened") + self._open_channel() + + def _on_connection_open_error(self, connection, err): + self._logger.error("Connection open failed: %s", err) + self._reconnect() + + def _on_connection_close(self, connection, reason): + self._channel = None + if self._closing: + self._conn.ioloop.stop() + else: + self._logger.warning("Connection closed, reconnect necessary: %s", reason) + self._reconnect() + + def _open_channel(self): + self._logger.debug("Creating new channel") + self._conn.channel(on_open_callback=self._on_channel_open) + + def _close_channel(self): + self._logger.debug("Closing channel") + self._channel.close() + + def _on_channel_open(self, channel): + self._logger.debug("Channel opened") + self._channel = channel + + self._logger.debug("Adding channel close callback") + self._channel.add_on_close_callback(self._on_channel_closed) + + self._logger.debug("Setting channel prefetch: %s", self.AMQP_PREFETCH) + self._channel.basic_qos( + prefetch_count=self.AMQP_PREFETCH, callback=self._on_basic_qos_ok + ) + + def _on_channel_closed(self, channel, reason): + self._logger.warning("Channel %i was closed: %s", channel, reason) + self._close_connection() + + def _on_basic_qos_ok(self, frame): + self._logger.debug("QOS set to: %s", self.AMQP_PREFETCH) + self._setup_exchanges() + + def _setup_exchanges(self): + self._logger.debug("Setting up configured exchanges") + for exchange in self._config.get("exchanges", list()): + name = exchange["name"] + exchange_type = exchange.get("exchange_type", self.DEFAULT_EXCHANGE) + queues = exchange.get("queues", list()) + self._logger.debug("Declaring exchange: %s", name) + + # pass config object as extra argument in callback + callback = functools.partial( + self._on_exchange_declare_ok, userdata=(queues, name) + ) + self._channel.exchange_declare( + exchange=name, exchange_type=exchange_type, callback=callback + ) + + def _on_exchange_declare_ok(self, frame, userdata): + queues, exchange = userdata + self._logger.debug("Exchange declared: %s", exchange) + self._setup_queues(queues, exchange) + + def _setup_queues(self, queues, exchange): + self._logger.debug("Setting up configured queues for exchange %s", exchange) + for queue in queues: + self._logger.debug("Declaring queue: %s", queue["name"]) + callback = functools.partial( + self._on_queue_declare_ok, userdata=(queue, exchange) + ) + self._channel.queue_declare(queue=queue["name"], callback=callback) + + def _on_queue_declare_ok(self, frame, userdata): + queue, exchange = userdata + self._logger.debug("Queue declared: %s", queue["name"]) + self._setup_bindings(queue, exchange) + + def _setup_bindings(self, queue, exchange): + queue, bindings = queue["name"], queue.get("bindings", list()) + self._logger.debug("Setting up bindings for queue %s", queue) + for binding in bindings: + self._logger.debug("Declaring binding for queue: %s (%s)", queue, bindings) + routing_key, arguments = binding.get("routing_key"), binding.get( + "arguments" + ) + callback = functools.partial(self._on_bind_ok, userdata=(binding, queue)) + self._channel.queue_bind( + queue=queue, + exchange=exchange, + routing_key=routing_key, + arguments=arguments, + callback=callback, + ) + + def _on_bind_ok(self, frame, userdata): + binding, queue = userdata + self._logger.debug("Binding ok for queue: %s (%s)", queue, binding) + self._start_consuming(queue) + + def _start_consuming(self, queue): + self._logger.debug("Issuing consumer related RPC commands") + callback = functools.partial(self._on_message, userdata=queue) + consumer_tag = self._channel.basic_consume(queue, callback) + self._consumer_tags.add(consumer_tag) + if not self._consuming: + # only add callback once + self._logger.debug("Adding consumer cancellation callback") + self._channel.add_on_cancel_callback(self._on_consumer_cancelled) + + self._consuming = True + self.was_consuming = True + + def _stop_consuming(self): + if self._channel: + consumers_copy = copy.deepcopy(self._consumer_tags) + for consumer_tag in consumers_copy: + self._logger.debug( + "Sending a Basic.Cancel RPC command to RabbitMQ for consumer %s", + consumer_tag, + ) + callback = functools.partial( + self._on_consumer_cancelled_ok, userdata=consumer_tag + ) + self._channel.basic_cancel(consumer_tag, callback) + + def _on_consumer_cancelled(self, frame): + consumer_tag = frame.method.consumer_tag + self._logger.warning("Consumer was cancelled remotely: %s", consumer_tag) + self._consumer_tags.discard(consumer_tag) + + def _on_consumer_cancelled_ok(self, frame, userdata): + self._logger.debug( + "RabbitMQ acknowledged the cancellation of the consumer: %s", userdata + ) + self._consumer_tags.discard(userdata) + if not self._consumer_tags: + # we either are shutting down or all consumers have been cancelled + self._close_channel() + + def _on_message(self, channel, basic_deliver, properties, body, userdata): + body = body.decode("utf-8") + self._logger.debug("Received message for queue %s with body %s", userdata, body) + + body = self._deserialize_body(body) + payload = {"queue": userdata, "body": body} + + try: + self._dispatch(payload=payload) + finally: + self._channel.basic_ack(basic_deliver.delivery_tag) + + def _deserialize_body(self, body): + if not self._deserialization_method: + return body + + deserialization_func = self.DESERIALIZATION_FUNCTIONS[ + self._deserialization_method + ] + + try: + body = deserialization_func(body) + except json.JSONDecodeError: + pass + + return body diff --git a/sensors/bound_queues_sensor.yaml b/sensors/bound_queues_sensor.yaml new file mode 100644 index 0000000..ba9eb3b --- /dev/null +++ b/sensors/bound_queues_sensor.yaml @@ -0,0 +1,16 @@ +--- +class_name: "RabbitMQBoundQueueSensor" +entry_point: "bound_queues_sensor.py" +description: "Sensor which monitors a RabbitMQ queue with bindings for new messages" +trigger_types: + - name: "routed_message" + description: "Trigger which indicates that a new message has arrived" + payload_schema: + type: "object" + properties: + queue: + type: "string" + body: + anyOf: + - type: "string" + - type: "object"