diff --git a/iotc/__init__.py b/iotc/__init__.py index dca23eb..d74aa75 100644 --- a/iotc/__init__.py +++ b/iotc/__init__.py @@ -1,8 +1,9 @@ from iotc.constants import * -from iotc.provision import ProvisioningClient +from iotc.provision import ProvisioningClient, Credentials +from sys import exit import ure import json -from utime import time,sleep +from utime import time, sleep import gc try: from umqtt.robust import MQTTClient @@ -12,27 +13,39 @@ upip.install('micropython-umqtt.robust') from umqtt.robust import MQTTClient gc.collect() -class Command(): - def __init__(self, cmd_name, request_id): - self._cmd_name = cmd_name - self._request_id = request_id + + +class Command(object): + def __init__(self, command_name, command_value, component_name=None): + self._command_name = command_name + self._command_value = command_value + if component_name is not None: + self._component_name = component_name + else: + self._component_name = None + self.reply = None @property def name(self): - return self._cmd_name - @property - def payload(self): - return self._payload + return self._command_name - @payload.setter - def payload(self,value): - self._payload=value + @property + def value(self): + return self._command_value @property - def request_id(self): - return self._request_id + def component_name(self): + return self._component_name + + class IoTCClient(): - def __init__(self, id_scope, device_id, credentials_type: IoTCConnectType, credentials, logger=None): + def __init__(self, + id_scope, + device_id, + credentials_type: IoTCConnectType, + credentials, + logger=None, + storage=None): self._device_id = device_id self._id_scope = id_scope self._credentials_type = credentials_type @@ -40,6 +53,7 @@ def __init__(self, id_scope, device_id, credentials_type: IoTCConnectType, crede self._content_encoding = 'utf-8' self._connected = False self._credentials = credentials + self._storage = storage self._events = {} self._model_id = None if logger is not None: @@ -54,11 +68,12 @@ def set_content_type(self, content_type): def set_content_encoding(self, content_encoding): self._content_encoding = content_encoding - def set_log_level(self,log_level:IoTCLogLevel): + def set_log_level(self, log_level: IoTCLogLevel): self._logger.set_log_level(log_level) - + def _on_message(self, topic, message): topic = topic.decode('utf-8') + self._logger.debug(topic) if topic == HubTopics.TWIN_RES.format(200, self._twin_request_id): self._logger.info('Received twin: {}'.format(message)) @@ -71,54 +86,132 @@ def _on_message(self, topic, message): elif topic.startswith(HubTopics.COMMANDS): # commands - self._logger.info( - 'Received command {} with message: {}'.format(topic, message)) match = self._commands_regex.match(topic) if match is not None: - if all(m is not None for m in [match.group(1), match.group(2)]): + if all(m is not None for m in [match.group(1), + match.group(2)]): command_name = match.group(1) command_req = match.group(2) - command = Command(command_name, command_req) - if message is not None: - command.payload = message - self._on_commands(command) + self.command_req = command_req + command = Command(command_name, message) + try: + command_name_with_components = command_name.split("*") + + if len(command_name_with_components) > 1: + # In a component + self._logger.debug("Command in a component") + command = Command( + command_name_with_components[1], + message, + component_name=command_name_with_components[0], + ) - elif topic.startswith(HubTopics.ENQUEUED_COMMANDS.format(self._device_id)): + def reply_fn(): + self._logger.debug( + 'Acknowledging command {}'.format( + command.name)) + self._mqtt_client.publish( + '$iothub/methods/res/{}/?$rid={}'.format( + 200, command_req).encode('utf-8'), '') + if command.component_name is not None: + self.send_property({ + "{}".format(command.component_name): { + "{}".format(command.name): { + "value": command.value, + "requestId": command_req + } + } + }) + else: + self.send_property({ + "{}".format(command.name): { + "value": command.value, + "requestId": command_req + } + }) + + command.reply = reply_fn + self._on_commands(command) + sleep(0.1) + except: + pass + + elif topic.startswith( + HubTopics.ENQUEUED_COMMANDS.format(self._device_id)): params = topic.split( - "devices/{}/messages/devicebound/".format(self._device_id), 1)[1].split('&') + "devices/{}/messages/devicebound/".format(self._device_id), + 1)[1].split('&') for param in params: p = param.split('=') if p[0] == "method-name": - command_name = p[1].split("Commands%3A")[1] + command_name = decode_uri_component(p[1]) + command = Command(command_name, message) + try: + command_name_with_components = command_name.split("*") - self._logger.info( - 'Received enqueued command {} with message: {}'.format(command_name, message)) - command = Command(command_name, None) - if message is not None: - command.payload = message - self._on_enqueued_commands(command) - - def connect(self): - prov = ProvisioningClient( - self._id_scope, self._device_id, self._credentials_type,self._credentials,self._logger,self._model_id) - creds = prov.register() - self._mqtt_client = MQTTClient(self._device_id, creds.host, 8883, creds.user.encode( - 'utf-8'), creds.password.encode('utf-8'), ssl=True, keepalive=60) + if len(command_name_with_components) > 1: + # In a component + self._logger.debug("Command in a component") + command = Command( + command_name_with_components[1], + message, + component_name=command_name_with_components[0], + ) + except: + pass + + self._logger.debug( + 'Received enqueued command {} with message: {}'.format( + command.name, command.value)) + self._on_enqueued_commands(command) + + def connect(self, force_dps=False): + creds = None + + if force_dps: + self._logger.info("Refreshing credentials...") + + if self._storage is not None and force_dps is False: + creds = self._storage.retrieve() + + if creds is None: + prov = ProvisioningClient(self._id_scope, self._device_id, + self._credentials_type, + self._credentials, self._logger, + self._model_id) + creds = prov.register() + + self._mqtt_client = MQTTClient(self._device_id, + creds.host, + 8883, + creds.user, + creds.password, + ssl=True, + keepalive=60) self._commands_regex = ure.compile( '\$iothub\/methods\/POST\/(.+)\/\?\$rid=(.+)') - self._mqtt_client.connect(False) - self._connected = True - self._logger.info('Device connected!') - self._mqtt_client.set_callback(self._on_message) - self._mqtt_client.subscribe(HubTopics.TWIN) - self._mqtt_client.subscribe('{}/#'.format(HubTopics.PROPERTIES)) - self._mqtt_client.subscribe('{}/#'.format(HubTopics.COMMANDS)) - self._mqtt_client.subscribe( - '{}/#'.format(HubTopics.ENQUEUED_COMMANDS.format(self._device_id))) - - self._logger.debug(self._twin_request_id) - self._mqtt_client.publish( - HubTopics.TWIN_REQ.format(self._twin_request_id).encode('utf-8'), '{{}}') + try: + self._mqtt_client.connect(False) + self._connected = True + self._logger.info('Device connected!') + if self._storage: + self._storage.persist(creds) + self._mqtt_client.set_callback(self._on_message) + self._mqtt_client.subscribe(HubTopics.TWIN) + self._mqtt_client.subscribe('{}/#'.format(HubTopics.PROPERTIES)) + self._mqtt_client.subscribe('{}/#'.format(HubTopics.COMMANDS)) + self._mqtt_client.subscribe('{}/#'.format( + HubTopics.ENQUEUED_COMMANDS.format(self._device_id))) + + self._logger.debug(self._twin_request_id) + self._mqtt_client.publish( + HubTopics.TWIN_REQ.format( + self._twin_request_id).encode('utf-8'), '{{}}') + except: + self._logger.info("ERROR: Failed to connect to Hub") + if force_dps is True: + exit(1) + self.connect(True) def is_connected(self): if self._connected == True: @@ -131,19 +224,18 @@ def set_model_id(self, model): def send_property(self, payload): self._logger.debug('Sending properties {}'.format(json.dumps(payload))) self._mqtt_client.publish( - HubTopics.PROP_REPORT.format(time()).encode('utf-8'), json.dumps(payload)) + HubTopics.PROP_REPORT.format(time()).encode('utf-8'), + json.dumps(payload)) def send_telemetry(self, payload, properties=None): topic = 'devices/{}/messages/events/?$.ct={}&$.ce={}'.format( self._device_id, self._content_type, self._content_encoding) if properties is not None: for prop in properties: - topic += '{}={}&'.format(encode_uri_component(prop), - encode_uri_component(properties[prop])) + topic += '&{}={}'.format(prop, properties[prop]) - topic = topic[:-1] - self._mqtt_client.publish(topic.encode( - 'utf-8'), json.dumps(payload).encode('utf-8')) + self._mqtt_client.publish(topic.encode('utf-8'), + json.dumps(payload).encode('utf-8')) def on(self, event, callback): self._events[event] = callback @@ -153,45 +245,104 @@ def listen(self): return self._mqtt_client.ping() self._mqtt_client.wait_msg() - sleep(1) + sleep(0.5) + + def _handle_property_ack( + self, + callback, + property_name, + property_value, + property_version, + component_name=None, + ): + if callback is not None: + ret = callback(property_name, property_value, component_name) + else: + ret = True + if ret: + if component_name is not None: + self._logger.debug("Acknowledging {}".format(property_name)) + self.send_property({ + "{}".format(component_name): { + "{}".format(property_name): { + "ac": 200, + "ad": "Property received", + "av": property_version, + "value": property_value, + } + } + }) + else: + self._logger.debug("Acknowledging {}".format(property_name)) + self.send_property({ + "{}".format(property_name): { + "ac": 200, + "ad": "Property received", + "av": property_version, + "value": property_value, + } + }) + else: + self._logger.debug( + 'Property "{}" unsuccessfully processed'.format(property_name)) def on_properties_update(self, patch): try: prop_cb = self._events[IoTCEvents.PROPERTIES] except: return + # Set component at false by default + is_component = False for prop in patch: - if prop == '$version': + is_component = False + if prop == "$version": continue - ret = prop_cb(prop, patch[prop]['value']) - if ret: - self._logger.debug('Acknowledging {}'.format(prop)) - self.send_property({ - '{}'.format(prop): { - "value": patch[prop]["value"], - 'status': 'completed', - 'desiredVersion': patch['$version'], - 'message': 'Property received'} - }) + + # check if component + try: + is_component = patch[prop]["__t"] + except KeyError: + pass + if is_component: + for component_prop in patch[prop]: + if component_prop == "__t": + continue + self._logger.debug( + 'In component "{}" for property "{}"'.format( + prop, component_prop)) + self._handle_property_ack( + prop_cb, + component_prop, + patch[prop][component_prop]["value"], + patch["$version"], + prop, + ) else: - self._logger.debug( - 'Property "{}" unsuccessfully processed'.format(prop)) + self._handle_property_ack(prop_cb, prop, patch[prop]["value"], + patch["$version"]) def _cmd_resp(self, command: Command, value): - self._logger.debug( - 'Responding to command "{}" request'.format(command.name)) + self._logger.debug('Responding to command "{}" request'.format( + command.name)) self.send_property({ '{}'.format(command.name): { 'value': value, - 'requestId': command.request_id + 'requestId': self.command_req } }) def _cmd_ack(self, command: Command): self._logger.debug('Acknowledging command {}'.format(command.name)) + output = { + "status": 201, + "payload": { + "command": command.value.decode() + } + } self._mqtt_client.publish( - '$iothub/methods/res/{}/?$rid={}'.format(200, command.request_id).encode('utf-8'), '') + '$iothub/methods/res/{}/?$rid={}'.format( + 200, self.command_req).encode('utf-8'), json.dumps(output)) def _on_commands(self, command: Command): try: @@ -199,10 +350,10 @@ def _on_commands(self, command: Command): except KeyError: return - self._logger.debug( - 'Received command {}'.format(command.name)) + self._logger.debug('Received command {}'.format(command.name)) + self._logger.debug('Received command {}'.format(command.value)) + print(command.value.decode()) self._cmd_ack(command) - cmd_cb(command, self._cmd_resp) def _on_enqueued_commands(self, command: Command): @@ -211,8 +362,6 @@ def _on_enqueued_commands(self, command: Command): except KeyError: return - self._logger.debug( - 'Received enqueued command {}'.format(command.name)) - self._cmd_ack(command) + self._logger.debug('Received enqueued command {}'.format(command.name)) - cmd_cb(command) + cmd_cb(command) \ No newline at end of file diff --git a/iotc/constants.py b/iotc/constants.py index bebf40e..2b27ae1 100644 --- a/iotc/constants.py +++ b/iotc/constants.py @@ -1,17 +1,23 @@ +from itertools import islice + + class IoTCLogLevel: DISABLED = 1 API_ONLY = 2 ALL = 3 + class IoTCConnectType: SYMM_KEY = 1 DEVICE_KEY = 2 + class IoTCEvents: PROPERTIES = 1 COMMANDS = 2 ENQUEUED_COMMANDS = 3 + class HubTopics: TWIN = '$iothub/twin/res/#' TWIN_REQ = '$iothub/twin/GET/?$rid={}' @@ -21,6 +27,7 @@ class HubTopics: COMMANDS = '$iothub/methods/POST' ENQUEUED_COMMANDS = 'devices/{}/messages/devicebound' + class ConsoleLogger: def __init__(self, log_level=IoTCLogLevel.API_ONLY): self._log_level = log_level @@ -39,6 +46,7 @@ def debug(self, message): def set_log_level(self, log_level): self._log_level = log_level + unsafe = { '?': '%3F', ' ': '%20', @@ -51,13 +59,54 @@ def set_log_level(self, log_level): ';': '%3B', '+': '%2B', '=': '%3D', - '@': '%40' + '@': '%40', + '*': '%2A' } + def encode_uri_component(string): ret = '' for char in string: if char in unsafe: char = unsafe[char] ret = '{}{}'.format(ret, char) - return ret \ No newline at end of file + return ret + + +def window(seq, width): + it = iter(seq) + result = tuple(islice(it, width)) + if len(result) == width: + yield result + for elem in it: + result = result[1:] + (elem,) + yield result + + +def decode_uri_component(string): + res = "" + skip = 0 + for chars in window(string, 3): + if skip > 0: + skip -= 1 + continue + if chars[0] == '%': + unescaped = None + char_code = "{}{}{}".format(chars[0], chars[1], chars[2]) + for k, v in unsafe.items(): + if v.lower() == char_code.lower(): + unescaped = k + if unescaped: + res = "{}{}".format(res, unescaped) + skip = 2 + continue + + res = "{}{}".format(res, chars[0]) + + # add last two characters which are skipped from the loop + if skip == 1: + res = "{}{}".format(res, string[len(string)-1]) + elif skip == 0: + res = "{}{}{}".format( + res, string[len(string)-2], string[len(string)-1]) + return res diff --git a/iotc/provision.py b/iotc/provision.py index efebc67..3763487 100644 --- a/iotc/provision.py +++ b/iotc/provision.py @@ -1,11 +1,11 @@ +from iotc.constants import IoTCConnectType, encode_uri_component, ConsoleLogger, IoTCLogLevel +from iotc.hmac import new as hmac +import hashlib +import ubinascii +import json import sys import gc gc.collect() -import json -from iotc.constants import IoTCConnectType,encode_uri_component,ConsoleLogger,IoTCLogLevel -import ubinascii -import hashlib -from iotc.hmac import new as hmac gc.collect() try: from utime import time, sleep @@ -25,10 +25,15 @@ class Credentials: + @classmethod + def create_from_json_string(cls, cred_str): + cred_obj = json.loads(cred_str) + return cls(cred_obj['host'], cred_obj['user'], cred_obj['password']) + def __init__(self, host, user, password): self._host = host - self._user = user - self._password = password + self._user = user.encode('utf-8') + self._password = password.encode('utf-8') @property def host(self): @@ -42,8 +47,11 @@ def user(self): def password(self): return self._password + def to_json_string(self): + return json.dumps({"host": self._host, "user": self.user, "password": self.password}) + def __str__(self): - return 'Host={};User={};Password={}'.format(self._host,self._user,self._password) + return 'Host={};User={};Password={}'.format(self._host, self._user, self._password) class ProvisioningClient(): @@ -55,9 +63,9 @@ def __init__(self, scope_id, registration_id, credentials_type: IoTCConnectType, self._credentials_type = credentials_type self._api_version = '2019-01-15' if logger is not None: - self._logger=logger + self._logger = logger else: - self._logger=ConsoleLogger(IoTCLogLevel.DISABLED) + self._logger = ConsoleLogger(IoTCLogLevel.DISABLED) if model_id is not None: self._model_id = model_id @@ -90,7 +98,7 @@ def __init__(self, scope_id, registration_id, credentials_type: IoTCConnectType, gc.collect() except: pass - + expiry = time() + 946706400 # 6 hours from now in epoch signature = encode_uri_component(self._compute_key( self._device_key, '{}\n{}'.format(resource_uri, expiry)))