391 lines
14 KiB
Python
391 lines
14 KiB
Python
|
import threading
|
||
|
try:
|
||
|
from enum import StrEnum
|
||
|
except ImportError:
|
||
|
from ..backports.enum import StrEnum
|
||
|
from enum import IntEnum
|
||
|
from .security import LocalSecurity, MSGTYPE_HANDSHAKE_REQUEST, MSGTYPE_ENCRYPTED_REQUEST
|
||
|
from .packet_builder import PacketBuilder
|
||
|
from .message import MessageType, MessageQuerySubtype, MessageSubtypeResponse, MessageQuestCustom
|
||
|
import socket
|
||
|
import logging
|
||
|
import time
|
||
|
|
||
|
_LOGGER = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
class AuthException(Exception):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class ResponseException(Exception):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class RefreshFailed(Exception):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class DeviceAttributes(StrEnum):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class ParseMessageResult(IntEnum):
|
||
|
SUCCESS = 0
|
||
|
PADDING = 1
|
||
|
ERROR = 99
|
||
|
|
||
|
|
||
|
class MiedaDevice(threading.Thread):
|
||
|
def __init__(self,
|
||
|
name: str,
|
||
|
device_id: int,
|
||
|
device_type: int,
|
||
|
ip_address: str,
|
||
|
port: int,
|
||
|
token: str,
|
||
|
key: str,
|
||
|
protocol: int,
|
||
|
model: str,
|
||
|
attributes: dict):
|
||
|
threading.Thread.__init__(self)
|
||
|
self._attributes = attributes if attributes else {}
|
||
|
self._socket = None
|
||
|
self._ip_address = ip_address
|
||
|
self._port = port
|
||
|
self._security = LocalSecurity()
|
||
|
self._token = bytes.fromhex(token) if token else None
|
||
|
self._key = bytes.fromhex(key) if key else None
|
||
|
self._buffer = b""
|
||
|
self._device_name = name
|
||
|
self._device_id = device_id
|
||
|
self._device_type = device_type
|
||
|
self._protocol = protocol
|
||
|
self._model = model
|
||
|
self._updates = []
|
||
|
self._unsupported_protocol = []
|
||
|
self._is_run = False
|
||
|
self._available = True
|
||
|
self._device_protocol_version = 0
|
||
|
self._sub_type = None
|
||
|
self._sn = None
|
||
|
self._refresh_interval = 30
|
||
|
self._heartbeat_interval = 10
|
||
|
self._default_refresh_interval = 30
|
||
|
|
||
|
@property
|
||
|
def name(self):
|
||
|
return self._device_name
|
||
|
|
||
|
@property
|
||
|
def available(self):
|
||
|
return self._available
|
||
|
|
||
|
@property
|
||
|
def device_id(self):
|
||
|
return self._device_id
|
||
|
|
||
|
@property
|
||
|
def device_type(self):
|
||
|
return self._device_type
|
||
|
|
||
|
@property
|
||
|
def model(self):
|
||
|
return self._model
|
||
|
|
||
|
@property
|
||
|
def sub_type(self):
|
||
|
return self._sub_type if self._sub_type else 0
|
||
|
|
||
|
@staticmethod
|
||
|
def fetch_v2_message(msg):
|
||
|
result = []
|
||
|
while len(msg) > 0:
|
||
|
factual_msg_len = len(msg)
|
||
|
if factual_msg_len < 6:
|
||
|
break
|
||
|
alleged_msg_len = msg[4] + (msg[5] << 8)
|
||
|
if factual_msg_len >= alleged_msg_len:
|
||
|
result.append(msg[:alleged_msg_len])
|
||
|
msg = msg[alleged_msg_len:]
|
||
|
else:
|
||
|
break
|
||
|
return result, msg
|
||
|
|
||
|
def connect(self, refresh_status=True):
|
||
|
try:
|
||
|
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||
|
self._socket.settimeout(10)
|
||
|
_LOGGER.debug(f"[{self._device_id}] Connecting to {self._ip_address}:{self._port}")
|
||
|
self._socket.connect((self._ip_address, self._port))
|
||
|
_LOGGER.debug(f"[{self._device_id}] Connected")
|
||
|
if self._protocol == 3:
|
||
|
self.authenticate()
|
||
|
_LOGGER.debug(f"[{self._device_id}] Authentication success")
|
||
|
if refresh_status:
|
||
|
self.refresh_status(wait_response=True)
|
||
|
self.enable_device(True)
|
||
|
return True
|
||
|
except socket.timeout:
|
||
|
_LOGGER.debug(f"[{self._device_id}] Connection timed out")
|
||
|
except socket.error:
|
||
|
_LOGGER.debug(f"[{self._device_id}] Connection error")
|
||
|
except AuthException:
|
||
|
_LOGGER.debug(f"[{self._device_id}] Authentication failed")
|
||
|
except ResponseException:
|
||
|
_LOGGER.debug(f"[{self._device_id}] Unexpected response received")
|
||
|
except RefreshFailed:
|
||
|
_LOGGER.debug(f"[{self._device_id}] Refresh status is timed out")
|
||
|
except Exception as e:
|
||
|
_LOGGER.error(f"[{self._device_id}] Unknown error: {e.__traceback__.tb_frame.f_globals['__file__']}, "
|
||
|
f"{e.__traceback__.tb_lineno}, {repr(e)}")
|
||
|
self.enable_device(False)
|
||
|
return False
|
||
|
|
||
|
def authenticate(self):
|
||
|
request = self._security.encode_8370(
|
||
|
self._token, MSGTYPE_HANDSHAKE_REQUEST)
|
||
|
_LOGGER.debug(f"[{self._device_id}] Handshaking")
|
||
|
self._socket.send(request)
|
||
|
response = self._socket.recv(512)
|
||
|
if len(response) < 20:
|
||
|
raise AuthException()
|
||
|
response = response[8: 72]
|
||
|
self._security.tcp_key(response, self._key)
|
||
|
|
||
|
def send_message(self, data):
|
||
|
if self._protocol == 3:
|
||
|
self.send_message_v3(data, msg_type=MSGTYPE_ENCRYPTED_REQUEST)
|
||
|
else:
|
||
|
self.send_message_v2(data)
|
||
|
|
||
|
def send_message_v2(self, data):
|
||
|
if self._socket is not None:
|
||
|
self._socket.send(data)
|
||
|
else:
|
||
|
_LOGGER.debug(f"[{self._device_id}] Send failure, device disconnected, data: {data.hex()}")
|
||
|
|
||
|
def send_message_v3(self, data, msg_type=MSGTYPE_ENCRYPTED_REQUEST):
|
||
|
data = self._security.encode_8370(data, msg_type)
|
||
|
self.send_message_v2(data)
|
||
|
|
||
|
def build_send(self, cmd):
|
||
|
data = cmd.serialize()
|
||
|
_LOGGER.debug(f"[{self._device_id}] Sending: {cmd}")
|
||
|
msg = PacketBuilder(self._device_id, data).finalize()
|
||
|
self.send_message(msg)
|
||
|
|
||
|
def refresh_status(self, wait_response=False):
|
||
|
cmds = self.build_query()
|
||
|
if self._sub_type is None:
|
||
|
cmds = [MessageQuerySubtype(self.device_type)] + cmds
|
||
|
error_count = 0
|
||
|
for cmd in cmds:
|
||
|
if cmd.__class__.__name__ not in self._unsupported_protocol:
|
||
|
self.build_send(cmd)
|
||
|
if wait_response:
|
||
|
try:
|
||
|
while True:
|
||
|
msg = self._socket.recv(512)
|
||
|
if len(msg) == 0:
|
||
|
raise socket.error
|
||
|
result = self.parse_message(msg)
|
||
|
if result == ParseMessageResult.SUCCESS:
|
||
|
break
|
||
|
elif result == ParseMessageResult.PADDING:
|
||
|
continue
|
||
|
else:
|
||
|
raise ResponseException
|
||
|
except socket.timeout:
|
||
|
error_count += 1
|
||
|
self._unsupported_protocol.append(cmd.__class__.__name__)
|
||
|
_LOGGER.debug(f"[{self._device_id}] Does not supports "
|
||
|
f"the protocol {cmd.__class__.__name__}, ignored")
|
||
|
except ResponseException:
|
||
|
error_count += 1
|
||
|
else:
|
||
|
error_count += 1
|
||
|
if error_count == len(cmds):
|
||
|
raise RefreshFailed
|
||
|
|
||
|
def set_subtype(self):
|
||
|
pass
|
||
|
|
||
|
def pre_process_message(self, msg):
|
||
|
if msg[9] == MessageType.querySubtype:
|
||
|
message = MessageSubtypeResponse(msg)
|
||
|
_LOGGER.debug(f"[{self.device_id}] Received: {message}")
|
||
|
self._sub_type = message.sub_type
|
||
|
self.set_subtype()
|
||
|
self._device_protocol_version = message.device_protocol_version
|
||
|
_LOGGER.debug(f"[{self._device_id}] Subtype: {self._sub_type}. "
|
||
|
f"Device protocol version: {self._device_protocol_version}")
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
def parse_message(self, msg):
|
||
|
if self._protocol == 3:
|
||
|
messages, self._buffer = self._security.decode_8370(self._buffer + msg)
|
||
|
else:
|
||
|
messages, self._buffer = self.fetch_v2_message(self._buffer + msg)
|
||
|
if len(messages) == 0:
|
||
|
return ParseMessageResult.PADDING
|
||
|
for message in messages:
|
||
|
if message == b"ERROR":
|
||
|
return ParseMessageResult.ERROR
|
||
|
payload_len = message[4] + (message[5] << 8) - 56
|
||
|
payload_type = message[2] + (message[3] << 8)
|
||
|
if payload_type in [0x1001, 0x0001]:
|
||
|
# Heartbeat detected
|
||
|
pass
|
||
|
elif len(message) > 56:
|
||
|
cryptographic = message[40:-16]
|
||
|
if payload_len % 16 == 0:
|
||
|
decrypted = self._security.aes_decrypt(cryptographic)
|
||
|
if self.pre_process_message(decrypted):
|
||
|
try:
|
||
|
status = self.process_message(decrypted)
|
||
|
if len(status) > 0:
|
||
|
self.update_all(status)
|
||
|
else:
|
||
|
_LOGGER.debug(f"[{self._device_id}] Unidentified protocol")
|
||
|
except Exception as e:
|
||
|
_LOGGER.error(f"[{self._device_id}] Error in process message, msg = {decrypted.hex()}")
|
||
|
else:
|
||
|
_LOGGER.warning(
|
||
|
f"[{self._device_id}] Illegal payload, "
|
||
|
f"original message = {msg.hex()}, buffer = {self._buffer.hex()}, "
|
||
|
f"8370 decoded = {message.hex()}, payload type = {payload_type}, "
|
||
|
f"alleged payload length = {payload_len}, factual payload length = {len(cryptographic)}"
|
||
|
)
|
||
|
else:
|
||
|
_LOGGER.warning(
|
||
|
f"[{self._device_id}] Illegal message, "
|
||
|
f"original message = {msg.hex()}, buffer = {self._buffer.hex()}, "
|
||
|
f"8370 decoded = {message.hex()}, payload type = {payload_type}, "
|
||
|
f"alleged payload length = {payload_len}, message length = {len(message)}, "
|
||
|
)
|
||
|
return ParseMessageResult.SUCCESS
|
||
|
|
||
|
def build_query(self):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def process_message(self, msg):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def send_command(self, cmd_type, cmd_body: bytearray):
|
||
|
cmd = MessageQuestCustom(self._device_type, cmd_type, cmd_body)
|
||
|
try:
|
||
|
self.build_send(cmd)
|
||
|
except socket.error as e:
|
||
|
_LOGGER.debug(f"[{self._device_id}] Interface send_command failure, {repr(e)}, "
|
||
|
f"cmd_type: {cmd_type}, cmd_body: {cmd_body.hex()}")
|
||
|
|
||
|
def send_heartbeat(self):
|
||
|
msg = PacketBuilder(self._device_id, bytearray([0x00])).finalize(msg_type=0)
|
||
|
self.send_message(msg)
|
||
|
|
||
|
def register_update(self, update):
|
||
|
self._updates.append(update)
|
||
|
|
||
|
def update_all(self, status):
|
||
|
_LOGGER.debug(f"[{self._device_id}] Status update: {status}")
|
||
|
for update in self._updates:
|
||
|
update(status)
|
||
|
|
||
|
def enable_device(self, available=True):
|
||
|
self._available = available
|
||
|
status = {"available": available}
|
||
|
self.update_all(status)
|
||
|
|
||
|
def open(self):
|
||
|
if not self._is_run:
|
||
|
self._is_run = True
|
||
|
threading.Thread.start(self)
|
||
|
|
||
|
def close(self):
|
||
|
if self._is_run:
|
||
|
self._is_run = False
|
||
|
self.close_socket()
|
||
|
|
||
|
def close_socket(self):
|
||
|
self._unsupported_protocol = []
|
||
|
self._buffer = b""
|
||
|
if self._socket:
|
||
|
self._socket.close()
|
||
|
self._socket = None
|
||
|
|
||
|
def set_ip_address(self, ip_address):
|
||
|
if self._ip_address != ip_address:
|
||
|
_LOGGER.debug(f"[{self._device_id}] Update IP address to {ip_address}")
|
||
|
self._ip_address = ip_address
|
||
|
self.close_socket()
|
||
|
|
||
|
def set_refresh_interval(self, refresh_interval):
|
||
|
self._refresh_interval = refresh_interval
|
||
|
|
||
|
def run(self):
|
||
|
while self._is_run:
|
||
|
while self._socket is None:
|
||
|
if self.connect(refresh_status=True) is False:
|
||
|
if not self._is_run:
|
||
|
return
|
||
|
self.close_socket()
|
||
|
time.sleep(5)
|
||
|
timeout_counter = 0
|
||
|
start = time.time()
|
||
|
previous_refresh = start
|
||
|
previous_heartbeat = start
|
||
|
self._socket.settimeout(1)
|
||
|
while True:
|
||
|
try:
|
||
|
now = time.time()
|
||
|
if 0 < self._refresh_interval <= now - previous_refresh:
|
||
|
self.refresh_status()
|
||
|
previous_refresh = now
|
||
|
if now - previous_heartbeat >= self._heartbeat_interval:
|
||
|
self.send_heartbeat()
|
||
|
previous_heartbeat = now
|
||
|
msg = self._socket.recv(512)
|
||
|
msg_len = len(msg)
|
||
|
if msg_len == 0:
|
||
|
raise socket.error("Connection closed by peer")
|
||
|
result = self.parse_message(msg)
|
||
|
if result == ParseMessageResult.ERROR:
|
||
|
_LOGGER.debug(f"[{self._device_id}] Message 'ERROR' received")
|
||
|
self.close_socket()
|
||
|
break
|
||
|
elif result == ParseMessageResult.SUCCESS:
|
||
|
timeout_counter = 0
|
||
|
except socket.timeout:
|
||
|
timeout_counter = timeout_counter + 1
|
||
|
if timeout_counter >= 120:
|
||
|
_LOGGER.debug(f"[{self._device_id}] Heartbeat timed out")
|
||
|
self.close_socket()
|
||
|
break
|
||
|
except socket.error as e:
|
||
|
_LOGGER.debug(f"[{self._device_id}] Socket error {repr(e)}")
|
||
|
self.close_socket()
|
||
|
break
|
||
|
except Exception as e:
|
||
|
_LOGGER.error(f"[{self._device_id}] Unknown error :{e.__traceback__.tb_frame.f_globals['__file__']}, "
|
||
|
f"{e.__traceback__.tb_lineno}, {repr(e)}")
|
||
|
self.close_socket()
|
||
|
break
|
||
|
|
||
|
def set_attribute(self, attr, value):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def get_attribute(self, attr):
|
||
|
return self._attributes.get(attr)
|
||
|
|
||
|
def set_customize(self, customize):
|
||
|
pass
|
||
|
|
||
|
@property
|
||
|
def attributes(self):
|
||
|
ret = {}
|
||
|
for status in self._attributes.keys():
|
||
|
ret[str(status)] = self._attributes[status]
|
||
|
return ret
|