From 1f573c10e7145243c69d1a2f212d49de273a66a6 Mon Sep 17 00:00:00 2001 From: Jochen Vothknecht <jochen3120@gmail.com> Date: Tue, 14 Dec 2021 11:07:56 +0100 Subject: [PATCH] Adding SpiderMQTT --- SiliconTorch/MQTT.py | 418 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 418 insertions(+) create mode 100644 SiliconTorch/MQTT.py diff --git a/SiliconTorch/MQTT.py b/SiliconTorch/MQTT.py new file mode 100644 index 0000000..32f5cce --- /dev/null +++ b/SiliconTorch/MQTT.py @@ -0,0 +1,418 @@ + +import yaml +import json +import random +import logging +import paho.mqtt.client as mqtt + + +######################################## +#### #### +#### Code from fxk8y/SpiderMQTT #### +#### #### +#### Will be replaced with import #### +#### when SpiderMQTT matures! #### +#### #### +######################################## + +# - Topic was taken as is +# - Message was taken as is (and will be debugged here!) +# - Executor was taken as is, but won't be used for threaded execution +# - SpiderMQTT will lose its request-feature, other bugs will be fixed here + +class Topic(str): + """A str-type for interaction with MQTT topics. + + Behaves like a normal str object except: + - can't have leading or trailing slashes + - can't contain double slashes + - addition is only defined for adding other Topic() objects""" + + def __new__(cls, data=''): + data = str(data) + + if len(data) < 1: + raise ValueError('Topic cannot be constructed from empty str') + + while '//' in data: + data = data.replace('//', '/') + + if data.startswith('/'): + data = data[1:] + + if data.endswith('/'): + data = data[:-1] + + return super().__new__(cls, data) + + def __add__(self, other): + if isinstance(other, Topic): + return Topic(str(self) + '/' + str(other)) + else: + return NotImplemented + + def containsWildcard(self): + return self.count('+') > 0 or self.count('#') > 0 + + def compare(self, other): + """Compares two topics according to the MQTT specification + + One argument may contain topic wildcards. + Can be called both as a class method or instance method. + + Arguments: + self -- The Topic object itself when called as an instance method, or any Topic/str object in case of a class method call + other -- The other Topic or string to compare to""" + + if (self.count('+') > 0 or self.count('#') > 0) and (other.count('+') > 0 or other.count('#') > 0): + raise ValueError('Only one topic may contain wildcards') + + x_parts = self.split('/') + y_parts = other.split('/') + + lx = len(x_parts) + ly = len(y_parts) + + result = True + + for i in range(min(lx, ly)): + x = x_parts[i] + y = y_parts[i] + + if x == y: + continue + elif x == '+' or y == '+': + continue + elif x == '#' or y == '#': + return True + elif x != y: + return False + else: + if lx == ly: + return True + elif lx < ly and x_parts[-1] == '#': + return True + elif ly < lx and y_parts[-1] == '#': + return True + else: + return False + + +class Message: + """Fancy MQTT message container + + Implements various conversions to python-native types. + The conversion results are cached internally. + Every conversion method takes a parameter called default whose value is returned as-is. + Default values are not cached in case of conversion failure. + """ + + class DEFAULT: + """Marker object for caching failed conversions""" + pass + + def __init__(self, topic: str, payload: bytearray): + self.cache = {} + self.topic = topic + self.payload = payload + + def raw(self): + """Get the raw payload as bytearray""" + + return self.payload + + def bool(self, default=False): + """Coerce payload to bool + + 'true', 'yes', 'ja', and '1' are treated as True. + 'false', 'no', 'nope', 'nein' and '0' are treated as False. + The conversion is case-insensitive.""" + + def convert(): + payload = self.payload.lower() + + if payload in ['true', 'yes', 'ja', '1']: + return True + elif payload in ['false', 'no', 'nope', 'nein', '0']: + return False + else: + return Message.DEFAULT + + return self._getOrElseUpdate(bool, convert, default) + + def int(self, default=0): + """Coerce payload to int""" + + def convert(): + try: + return int(self.payload) + except: + return Message.DEFAULT + + return self._getOrElseUpdate(int, convert, default) + + def float(self, default=0.0): + """Coerce payload to float""" + + def convert(): + try: + return float(self.payload) + except: + return Message.DEFAULT + + return self._getOrElseUpdate(float, convert, default) + + def complex(self, default=0j): + """Coerce payload to complex""" + + def convert(): + try: + return complex(self.payload) + except: + return Message.DEFAULT + + return self._getOrElseUpdate(complex, convert, default) + + def str(self, default=''): + """Decodes the payload as UTF-8""" + + def convert(): + try: + return self.payload.decode('utf-8') + except: + return Message.DEFAULT + + return self._getOrElseUpdate(str, convert, default) + + def json(self, default={}): + """Parses the payload as a JSON object""" + + def convert(): + try: + return json.loads(self.payload.decode('utf-8')) + except: + return Message.DEFAULT + + return self._getOrElseUpdate("json", convert, default) + + def yaml(self, default={}): + """Parses the payload as a YAML document""" + + def convert(): + try: + return yaml.safe_load(self.payload.decode('utf-8')) + except: + return Message.DEFAULT + + return self._getOrElseUpdate("yaml", convert, default) + + def _getOrElseUpdate(self, key, f, default): + if key in self.cache: + out = self.cache[key] + + if out is Message.DEFAULT: + return default + else: + return out + else: + out = f() + self.cache[key] = out + + if out is Message.DEFAULT: + return default + else: + return out + + def __str__(self): + return 'Message[topic=\'{}\' payload=\'{}\']'.format(self.topic, self.str(default='<binary garbage>')) + + +class Executor: + + __instance = None + + def __new__(cls, *args, **kwargs): + if cls.__instance is None: + cls.__instance = super(Executor, cls).__new__(cls) + return cls.__instance + + def __init__(self, callback, *args, **kwargs): + if isinstance(callback, (list, set)): + cbs = callback + else: + cbs = [callback] + + for cb in cbs: + try: + cb(*args, **kwargs) + except Exception as ex: + pass # TODO: logging!!! + + +class SpiderMQTT: + + class Request: + + def __init__(self, spider, requestTopic, responseTopic, payload, callback, pub_qos): + self.spider = spider + self.pub_qos = pub_qos + self.payload = payload + self.callback = callback + self.requestTopic = Topic(requestTopic) + self.responseTopic = Topic(responseTopic) + + self.__msgCallback = self.onMessage + self.__subCallback = self.onSubscribe + + if self.requestTopic.containsWildcard(): + raise ValueError('requestTopic mustn\'t contain any wildcards') + + if self.spider.isConnected(): + self.spider.addCallback(self.__msgCallback, self.__subCallback) + else: + pass # TODO: figure out what to do... + + def onSubscribe(self): + self.spider.publish(self.requestTopic, self.payload, qos=pub_qos, retain=False) + + def onMessage(self, msg): + self.spider.removeCallback(self.__msgCallback, self.__subCallback) + self.spider.requests.remove(self) + Executor(self.callback, msg) + + def __init__(self, broker: str, port: int = 1883, user: str = None, password: str = None, sub_qos: int = 0, + will_topic: str = None, will_payload = None, will_qos: int = 0, will_retain: bool = False, backgroundTask: bool = True): + """backgroundTask: True for run in Task, False for run blocking. TODO: write proper docstring!!!!""" + + self.sub_qos = sub_qos + self.connected = False + self.requests = set() + self.subscriptions = {} + self.pendingMessages = [] + + logging.basicConfig(format='[{asctime:s}][{levelname:s}] {name:s} in {funcName:s}(): {message:s}', datefmt='%H:%M:%S %d.%m.%Y', style='{', level=logging.DEBUG) + + self.log = logging.getLogger(__name__) + + client_id = 'SpiderMQTT[{:X}]'.format(random.randint(0x100000000000, 0xFFFFFFFFFFFF)) + self.mqtt = mqtt.Client(client_id=client_id, clean_session=True) + self.mqtt.enable_logger(self.log) + self.mqtt.reconnect_delay_set(1, 1) + + if user is not None: + self.mqtt.username_pw_set(user, password) + + if will_topic is not None: + self.mqtt.will_set(will_topic, will_payload, will_qos, will_retain) + + # self.mqtt.on_log = self.__onLog + self.mqtt.on_message = self.__onMessage + self.mqtt.on_publish = self.__onPublish + self.mqtt.on_connect = self.__onConnect + self.mqtt.on_subscribe = self.__onSubscribe + self.mqtt.on_disconnect = self.__onDisconnect + self.mqtt.on_unsubscribe = self.__onUnSubscribe + + self.mqtt.connect(broker, port) + + if backgroundTask: + def _shutdown(): + self.mqtt.loop_stop() + self.__shutdownFunc = _shutdown + + self.mqtt.loop_start() + else: + def _shutdown(): + self.running = False + self.__shutdownFunc = _shutdown + + self.running = True + while self.running: + self.mqtt.loop() + + def isConnected(self): + return self.connected + + def publish(self, topic, payload=None, qos=0, retain=False, prettyPrintYAML=False): + if isinstance(payload, str): + pl = payload.encode('utf-8') + elif isinstance(payload, (bool, int, float, complex)): + pl = str(payload).encode('utf-8') + elif isinstance(payload, (list, set, dict)): + pl = yaml.dump(payload, default_flow_style=not prettyPrintYAML).encode('utf-8') + else: + pl = payload + + if self.isConnected(): + self.mqtt.publish(topic, pl, qos, retain) + else: + msg = Message(topic, pl) + msg.qos = qos + msg.retain = retain + + self.pendingMessages += [msg] + + def request(self, requestTopic, responseTopic, payload, callback, pub_qos=0): + self.requests.add(self.Request(self, requestTopic, responseTopic, payload, callback, pub_qos)) + # TODO: further actions needed???? + + def __onMessage(self, _cl, _ud, msg): + message = Message(msg.topic, msg.payload) + + for sub in self.subscriptions.values(): + sub.onMessage(message) + + def __onConnect(self, *ignored): + self.connected = True + + for sub in self.subscriptions.values(): + sub.onConnect() + + for msg in self.pendingMessages: + msg.mid = self.mqtt.publish(msg.topic, msg.payload, msg.qos, msg.retain).mid + + def __onDisconnect(self, *ignored): + self.connected = False + + for sub in self.subscriptions.values(): + sub.onDisconnect() + + def __onSubscribe(self, _cl, _ud, mid, _gq): + for sub in self.subscriptions.values(): + sub.onSubscribe(mid) + + def __onUnSubscribe(self, _cl, _ud, mid): + for sub in self.subscriptions.values(): + sub.onUnSubscribe(mid) + + def __onPublish(self, _cl, _ud, mid): + for msg in self.pendingMessages: + if hasattr(msg, 'mid') and msg.mid == mid: + self.pendingMessages.remove(msg) + + def __onLog(): + pass + + def addCallback(self, topic, callback, subscribeCallback=None): + if topic in self.subscriptions: + self.subscriptions[topic].addCallback(callback) + self.subscriptions[topic].addSubscribeCallback(callback) + else: + self.subscriptions[topic] = Subscription(self, topic, callback, subscribeCallback) + + def subscribe(self, topic, callback): + self.addCallback(topic, callback) + + def removeCallback(self, callback, subscribeCallback=None): + for sub in self.subscriptions: + sub.removeCallback(callback) + sub.removeSubscribeCallback(subscribeCallback) + + if sub.callbacks == {}: + self.mqtt.unsubscribe(sub.topic) + del self.subscriptions[sub.topic] + + def shutdown(self): + self.mqtt.disconnect() + self.__shutdownFunc() + -- GitLab