From 1f573c10e7145243c69d1a2f212d49de273a66a6 Mon Sep 17 00:00:00 2001
From: Jochen Vothknecht <jochen3120@gmail.com>
Date: Tue, 14 Dec 2021 11:07:56 +0100
Subject: [PATCH] Adding SpiderMQTT

---
 SiliconTorch/MQTT.py | 418 +++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 418 insertions(+)
 create mode 100644 SiliconTorch/MQTT.py

diff --git a/SiliconTorch/MQTT.py b/SiliconTorch/MQTT.py
new file mode 100644
index 0000000..32f5cce
--- /dev/null
+++ b/SiliconTorch/MQTT.py
@@ -0,0 +1,418 @@
+
+import yaml
+import json
+import random
+import logging
+import paho.mqtt.client as mqtt
+
+
+########################################
+####                                ####
+####   Code from fxk8y/SpiderMQTT   ####
+####                                ####
+####  Will be replaced with import  ####
+####    when SpiderMQTT matures!    ####
+####                                ####
+########################################
+
+# - Topic was taken as is
+# - Message was taken as is  (and will be debugged here!)
+# - Executor was taken as is, but won't be used for threaded execution
+# - SpiderMQTT will lose its request-feature, other bugs will be fixed here 
+
+class Topic(str):
+  """A str-type for interaction with MQTT topics.
+
+  Behaves like a normal str object except:
+    - can't have leading or trailing slashes
+    - can't contain double slashes
+    - addition is only defined for adding other Topic() objects"""
+
+  def __new__(cls, data=''):
+    data = str(data)
+
+    if len(data) < 1:
+      raise ValueError('Topic cannot be constructed from empty str')
+
+    while '//' in data:
+      data = data.replace('//', '/')
+
+    if data.startswith('/'):
+      data = data[1:]
+
+    if data.endswith('/'):
+      data = data[:-1]
+
+    return super().__new__(cls, data)
+
+  def __add__(self, other):
+    if isinstance(other, Topic):
+      return Topic(str(self) + '/' + str(other))
+    else:
+      return NotImplemented
+
+  def containsWildcard(self):
+    return self.count('+') > 0 or self.count('#') > 0
+
+  def compare(self, other):
+    """Compares two topics according to the MQTT specification
+
+    One argument may contain topic wildcards.
+    Can be called both as a class method or instance method.
+
+    Arguments:
+    self -- The Topic object itself when called as an instance method, or any Topic/str object in case of a class method call
+    other -- The other Topic or string to compare to"""
+
+    if (self.count('+') > 0 or self.count('#') > 0) and (other.count('+') > 0 or other.count('#') > 0):
+      raise ValueError('Only one topic may contain wildcards')
+
+    x_parts = self.split('/')
+    y_parts = other.split('/')
+
+    lx = len(x_parts)
+    ly = len(y_parts)
+
+    result = True
+
+    for i in range(min(lx, ly)):
+      x = x_parts[i]
+      y = y_parts[i]
+
+      if x == y:
+        continue
+      elif x == '+' or y == '+':
+        continue
+      elif x == '#' or y == '#':
+        return True
+      elif x != y:
+        return False
+    else:
+      if lx == ly:
+        return True
+      elif lx < ly and x_parts[-1] == '#':
+        return True
+      elif ly < lx and y_parts[-1] == '#':
+        return True
+      else:
+        return False
+
+
+class Message:
+  """Fancy MQTT message container
+  
+  Implements various conversions to python-native types.
+  The conversion results are cached internally.
+  Every conversion method takes a parameter called default whose value is returned as-is.
+  Default values are not cached in case of conversion failure.
+  """
+
+  class DEFAULT:
+    """Marker object for caching failed conversions"""
+    pass
+
+  def __init__(self, topic: str, payload: bytearray):
+    self.cache = {}
+    self.topic = topic
+    self.payload = payload
+
+  def raw(self):
+    """Get the raw payload as bytearray"""
+
+    return self.payload
+
+  def bool(self, default=False):
+    """Coerce payload to bool
+
+    'true', 'yes', 'ja', and '1' are treated as True.
+    'false', 'no', 'nope', 'nein' and '0' are treated as False.
+    The conversion is case-insensitive."""
+
+    def convert():
+      payload = self.payload.lower()
+
+      if payload in ['true', 'yes', 'ja', '1']:
+        return True
+      elif payload in ['false', 'no', 'nope', 'nein', '0']:
+        return False
+      else:
+        return Message.DEFAULT
+
+    return self._getOrElseUpdate(bool, convert, default)
+
+  def int(self, default=0):
+    """Coerce payload to int"""
+
+    def convert():
+      try:
+        return int(self.payload)
+      except:
+        return Message.DEFAULT
+
+    return self._getOrElseUpdate(int, convert, default)
+
+  def float(self, default=0.0):
+    """Coerce payload to float"""
+
+    def convert():
+      try:
+        return float(self.payload)
+      except:
+        return Message.DEFAULT
+
+    return self._getOrElseUpdate(float, convert, default)
+
+  def complex(self, default=0j):
+    """Coerce payload to complex"""
+
+    def convert():
+      try:
+        return complex(self.payload)
+      except:
+        return Message.DEFAULT
+
+    return self._getOrElseUpdate(complex, convert, default)
+
+  def str(self, default=''):
+    """Decodes the payload as UTF-8"""
+
+    def convert():
+      try:
+        return self.payload.decode('utf-8')
+      except:
+        return Message.DEFAULT
+
+    return self._getOrElseUpdate(str, convert, default)
+
+  def json(self, default={}):
+    """Parses the payload as a JSON object"""
+
+    def convert():
+      try:
+        return json.loads(self.payload.decode('utf-8'))
+      except:
+        return Message.DEFAULT
+
+    return self._getOrElseUpdate("json", convert, default)
+
+  def yaml(self, default={}):
+    """Parses the payload as a YAML document"""
+
+    def convert():
+      try:
+        return yaml.safe_load(self.payload.decode('utf-8'))
+      except:
+        return Message.DEFAULT
+
+    return self._getOrElseUpdate("yaml", convert, default)
+
+  def _getOrElseUpdate(self, key, f, default):
+    if key in self.cache:
+      out = self.cache[key]
+
+      if out is Message.DEFAULT:
+        return default
+      else:
+        return out
+    else:
+      out = f()
+      self.cache[key] = out
+
+      if out is Message.DEFAULT:
+        return default
+      else:
+        return out
+
+  def __str__(self):
+    return 'Message[topic=\'{}\' payload=\'{}\']'.format(self.topic, self.str(default='<binary garbage>'))
+
+
+class Executor:
+
+  __instance = None
+
+  def __new__(cls, *args, **kwargs):
+    if cls.__instance is None:
+      cls.__instance = super(Executor, cls).__new__(cls)
+    return cls.__instance
+
+  def __init__(self, callback, *args, **kwargs):
+    if isinstance(callback, (list, set)):
+      cbs = callback
+    else:
+      cbs = [callback]
+
+    for cb in cbs:
+      try:
+        cb(*args, **kwargs)
+      except Exception as ex:
+        pass  # TODO: logging!!!
+
+
+class SpiderMQTT:
+
+  class Request:
+
+    def __init__(self, spider, requestTopic, responseTopic, payload, callback, pub_qos):
+      self.spider = spider
+      self.pub_qos = pub_qos
+      self.payload = payload
+      self.callback = callback
+      self.requestTopic = Topic(requestTopic)
+      self.responseTopic = Topic(responseTopic)
+
+      self.__msgCallback = self.onMessage
+      self.__subCallback = self.onSubscribe
+
+      if self.requestTopic.containsWildcard():
+        raise ValueError('requestTopic mustn\'t contain any wildcards')
+
+      if self.spider.isConnected():
+        self.spider.addCallback(self.__msgCallback, self.__subCallback)
+      else:
+        pass  # TODO: figure out what to do...
+
+    def onSubscribe(self):
+      self.spider.publish(self.requestTopic, self.payload, qos=pub_qos, retain=False)
+
+    def onMessage(self, msg):
+      self.spider.removeCallback(self.__msgCallback, self.__subCallback)
+      self.spider.requests.remove(self)
+      Executor(self.callback, msg)
+
+  def __init__(self, broker: str, port: int = 1883, user: str = None, password: str = None, sub_qos: int = 0,
+               will_topic: str = None, will_payload = None, will_qos: int = 0, will_retain: bool = False, backgroundTask: bool = True):
+    """backgroundTask: True for run in Task, False for run blocking. TODO: write proper docstring!!!!"""
+
+    self.sub_qos = sub_qos
+    self.connected = False
+    self.requests = set()
+    self.subscriptions = {}
+    self.pendingMessages = []
+
+    logging.basicConfig(format='[{asctime:s}][{levelname:s}] {name:s} in {funcName:s}(): {message:s}', datefmt='%H:%M:%S %d.%m.%Y', style='{', level=logging.DEBUG)
+
+    self.log = logging.getLogger(__name__)
+
+    client_id = 'SpiderMQTT[{:X}]'.format(random.randint(0x100000000000, 0xFFFFFFFFFFFF))
+    self.mqtt = mqtt.Client(client_id=client_id, clean_session=True)
+    self.mqtt.enable_logger(self.log)
+    self.mqtt.reconnect_delay_set(1, 1)
+
+    if user is not None:
+      self.mqtt.username_pw_set(user, password)
+
+    if will_topic is not None:
+      self.mqtt.will_set(will_topic, will_payload, will_qos, will_retain)
+
+    # self.mqtt.on_log = self.__onLog
+    self.mqtt.on_message = self.__onMessage
+    self.mqtt.on_publish = self.__onPublish
+    self.mqtt.on_connect = self.__onConnect
+    self.mqtt.on_subscribe = self.__onSubscribe
+    self.mqtt.on_disconnect = self.__onDisconnect
+    self.mqtt.on_unsubscribe = self.__onUnSubscribe
+
+    self.mqtt.connect(broker, port)
+
+    if backgroundTask:
+      def _shutdown():
+        self.mqtt.loop_stop()
+      self.__shutdownFunc = _shutdown
+
+      self.mqtt.loop_start()
+    else:
+      def _shutdown():
+        self.running = False
+      self.__shutdownFunc = _shutdown
+
+      self.running = True
+      while self.running:
+        self.mqtt.loop()
+
+  def isConnected(self):
+    return self.connected
+
+  def publish(self, topic, payload=None, qos=0, retain=False, prettyPrintYAML=False):
+    if isinstance(payload, str):
+      pl = payload.encode('utf-8')
+    elif isinstance(payload, (bool, int, float, complex)):
+      pl = str(payload).encode('utf-8')
+    elif isinstance(payload, (list, set, dict)):
+      pl = yaml.dump(payload, default_flow_style=not prettyPrintYAML).encode('utf-8')
+    else:
+      pl = payload
+
+    if self.isConnected():
+      self.mqtt.publish(topic, pl, qos, retain)
+    else:
+      msg = Message(topic, pl)
+      msg.qos = qos
+      msg.retain = retain
+
+      self.pendingMessages += [msg]
+
+  def request(self, requestTopic, responseTopic, payload, callback, pub_qos=0):
+    self.requests.add(self.Request(self, requestTopic, responseTopic, payload, callback, pub_qos))
+    # TODO: further actions needed????
+
+  def __onMessage(self, _cl, _ud, msg):
+    message = Message(msg.topic, msg.payload)
+
+    for sub in self.subscriptions.values():
+      sub.onMessage(message)
+
+  def __onConnect(self, *ignored):
+    self.connected = True
+
+    for sub in self.subscriptions.values():
+      sub.onConnect()
+
+    for msg in self.pendingMessages:
+      msg.mid = self.mqtt.publish(msg.topic, msg.payload, msg.qos, msg.retain).mid
+
+  def __onDisconnect(self, *ignored):
+    self.connected = False
+
+    for sub in self.subscriptions.values():
+      sub.onDisconnect()
+
+  def __onSubscribe(self, _cl, _ud, mid, _gq):
+    for sub in self.subscriptions.values():
+      sub.onSubscribe(mid)
+
+  def __onUnSubscribe(self, _cl, _ud, mid):
+    for sub in self.subscriptions.values():
+      sub.onUnSubscribe(mid)
+
+  def __onPublish(self, _cl, _ud, mid):
+    for msg in self.pendingMessages:
+      if hasattr(msg, 'mid') and msg.mid == mid:
+        self.pendingMessages.remove(msg)
+
+  def __onLog():
+    pass
+
+  def addCallback(self, topic, callback, subscribeCallback=None):
+    if topic in self.subscriptions:
+      self.subscriptions[topic].addCallback(callback)
+      self.subscriptions[topic].addSubscribeCallback(callback)
+    else:
+      self.subscriptions[topic] = Subscription(self, topic, callback, subscribeCallback)
+
+  def subscribe(self, topic, callback):
+    self.addCallback(topic, callback)
+
+  def removeCallback(self, callback, subscribeCallback=None):
+    for sub in self.subscriptions:
+      sub.removeCallback(callback)
+      sub.removeSubscribeCallback(subscribeCallback)
+
+      if sub.callbacks == {}:
+        self.mqtt.unsubscribe(sub.topic)
+        del self.subscriptions[sub.topic]
+
+  def shutdown(self):
+    self.mqtt.disconnect()
+    self.__shutdownFunc()
+
-- 
GitLab