Midea/device.py

409 lines
15 KiB
Python

import threading
try:
from enum import StrEnum
except ImportError:
from ..backports.enum import StrEnum
from enum import IntEnum
from security import LocalSecurity, MSGTYPE_HANDSHAKE_REQUEST, MSGTYPE_ENCRYPTED_REQUEST
from packet_builder import PacketBuilder
from message import MessageType, MessageQuerySubtype, MessageSubtypeResponse, MessageQuestCustom
import socket
import logging
import time
_LOGGER = logging.getLogger(__name__)
class AuthException(Exception):
pass
class ResponseException(Exception):
pass
class RefreshFailed(Exception):
pass
class DeviceAttributes(StrEnum):
pass
class ParseMessageResult(IntEnum):
SUCCESS = 0
PADDING = 1
ERROR = 99
class MiedaDevice(threading.Thread):
def __init__(self,
name: str,
device_id: int,
device_type: int,
ip_address: str,
port: int,
token: str,
key: str,
protocol: int,
model: str,
attributes: dict):
threading.Thread.__init__(self)
self._attributes = attributes if attributes else {}
self._socket = None
self._ip_address = ip_address
self._port = port
self._security = LocalSecurity()
self._token = bytes.fromhex(token) if token else None
self._key = bytes.fromhex(key) if key else None
self._buffer = b""
self._device_name = name
self._device_id = device_id
self._device_type = device_type
self._protocol = protocol
self._model = model
self._updates = []
self._unsupported_protocol = []
self._is_run = False
self._available = True
self._device_protocol_version = 0
self._sub_type = None
self._sn = None
self._refresh_interval = 30
self._heartbeat_interval = 10
self._default_refresh_interval = 30
k = 0
@property
def name(self):
return self._device_name
@property
def available(self):
return self._available
@property
def device_id(self):
return self._device_id
@property
def device_type(self):
return self._device_type
@property
def model(self):
return self._model
@property
def sub_type(self):
return self._sub_type if self._sub_type else 0
@staticmethod
def fetch_v2_message(msg):
result = []
while len(msg) > 0:
factual_msg_len = len(msg)
if factual_msg_len < 6:
break
alleged_msg_len = msg[4] + (msg[5] << 8)
if factual_msg_len >= alleged_msg_len:
result.append(msg[:alleged_msg_len])
msg = msg[alleged_msg_len:]
else:
break
return result, msg
def connect(self, refresh_status=True):
try:
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.settimeout(10)
_LOGGER.debug(f"[{self._device_id}] Connecting to {self._ip_address}:{self._port}")
self._socket.connect((self._ip_address, self._port))
_LOGGER.debug(f"[{self._device_id}] Connected")
if self._protocol == 3:
self.authenticate()
_LOGGER.debug(f"[{self._device_id}] Authentication success")
if refresh_status:
self.refresh_status(wait_response=True)
self.enable_device(True)
return True
except socket.timeout:
_LOGGER.debug(f"[{self._device_id}] Connection timed out")
except socket.error:
_LOGGER.debug(f"[{self._device_id}] Connection error")
except AuthException:
_LOGGER.debug(f"[{self._device_id}] Authentication failed")
except ResponseException:
_LOGGER.debug(f"[{self._device_id}] Unexpected response received")
except RefreshFailed:
_LOGGER.debug(f"[{self._device_id}] Refresh status is timed out")
except Exception as e:
_LOGGER.error(f"[{self._device_id}] Unknown error: {e.__traceback__.tb_frame.f_globals['__file__']}, "
f"{e.__traceback__.tb_lineno}, {repr(e)}")
self.enable_device(False)
return False
def authenticate(self):
request = self._security.encode_8370(
self._token, MSGTYPE_HANDSHAKE_REQUEST)
_LOGGER.debug(f"[{self._device_id}] Handshaking")
self._socket.send(request)
response = self._socket.recv(512)
if len(response) < 20:
raise AuthException()
response = response[8: 72]
self._security.tcp_key(response, self._key)
def send_message(self, data):
if self._protocol == 3:
self.send_message_v3(data, msg_type=MSGTYPE_ENCRYPTED_REQUEST)
else:
self.send_message_v2(data)
def send_message_v2(self, data):
if self._socket is not None:
self._socket.send(data)
else:
_LOGGER.debug(f"[{self._device_id}] Send failure, device disconnected, data: {data.hex()}")
def send_message_v3(self, data, msg_type=MSGTYPE_ENCRYPTED_REQUEST):
data = self._security.encode_8370(data, msg_type)
print("=== encrypted ===")
print(data)
self.send_message_v2(data)
def build_send(self, cmd):
data = cmd.serialize()
print("=== data ===")
print(data)
print
_LOGGER.debug(f"[{self._device_id}] Sending: {cmd}")
msg = PacketBuilder(self._device_id, data).finalize()
print("=== msg ===")
print(msg)
print
self.send_message(msg)
def refresh_status(self, wait_response=False):
cmds = self.build_query()
if self._sub_type is None:
cmds = [MessageQuerySubtype(self.device_type)] + cmds
error_count = 0
for cmd in cmds:
if cmd.__class__.__name__ not in self._unsupported_protocol:
self.build_send(cmd)
if wait_response:
try:
while True:
msg = self._socket.recv(512)
if len(msg) == 0:
raise socket.error
result = self.parse_message(msg)
if result == ParseMessageResult.SUCCESS:
break
elif result == ParseMessageResult.PADDING:
continue
else:
raise ResponseException
except socket.timeout:
error_count += 1
self._unsupported_protocol.append(cmd.__class__.__name__)
_LOGGER.debug(f"[{self._device_id}] Does not supports "
f"the protocol {cmd.__class__.__name__}, ignored")
except ResponseException:
error_count += 1
else:
error_count += 1
if error_count == len(cmds):
raise RefreshFailed
def set_subtype(self):
pass
def pre_process_message(self, msg):
if msg[9] == MessageType.querySubtype:
message = MessageSubtypeResponse(msg)
_LOGGER.debug(f"[{self.device_id}] Received: {message}")
self._sub_type = message.sub_type
self.set_subtype()
self._device_protocol_version = message.device_protocol_version
_LOGGER.debug(f"[{self._device_id}] Subtype: {self._sub_type}. "
f"Device protocol version: {self._device_protocol_version}")
return False
return True
def parse_message(self, msg):
if self._protocol == 3:
messages, self._buffer = self._security.decode_8370(self._buffer + msg)
else:
messages, self._buffer = self.fetch_v2_message(self._buffer + msg)
if len(messages) == 0:
return ParseMessageResult.PADDING
for message in messages:
if message == b"ERROR":
return ParseMessageResult.ERROR
payload_len = message[4] + (message[5] << 8) - 56
payload_type = message[2] + (message[3] << 8)
if payload_type in [0x1001, 0x0001]:
# Heartbeat detected
pass
elif len(message) > 56:
cryptographic = message[40:-16]
if payload_len % 16 == 0:
decrypted = self._security.aes_decrypt(cryptographic)
if self.pre_process_message(decrypted):
try:
status = self.process_message(decrypted)
if len(status) > 0:
self.update_all(status)
else:
_LOGGER.debug(f"[{self._device_id}] Unidentified protocol")
except Exception as e:
_LOGGER.error(f"[{self._device_id}] Error in process message, msg = {decrypted.hex()}")
else:
_LOGGER.warning(
f"[{self._device_id}] Illegal payload, "
f"original message = {msg.hex()}, buffer = {self._buffer.hex()}, "
f"8370 decoded = {message.hex()}, payload type = {payload_type}, "
f"alleged payload length = {payload_len}, factual payload length = {len(cryptographic)}"
)
else:
_LOGGER.warning(
f"[{self._device_id}] Illegal message, "
f"original message = {msg.hex()}, buffer = {self._buffer.hex()}, "
f"8370 decoded = {message.hex()}, payload type = {payload_type}, "
f"alleged payload length = {payload_len}, message length = {len(message)}, "
)
return ParseMessageResult.SUCCESS
def build_query(self):
raise NotImplementedError
def process_message(self, msg):
raise NotImplementedError
def send_command(self, cmd_type, cmd_body: bytearray):
cmd = MessageQuestCustom(self._device_type, cmd_type, cmd_body)
try:
self.build_send(cmd)
except socket.error as e:
_LOGGER.debug(f"[{self._device_id}] Interface send_command failure, {repr(e)}, "
f"cmd_type: {cmd_type}, cmd_body: {cmd_body.hex()}")
def send_heartbeat(self):
msg = PacketBuilder(self._device_id, bytearray([0x00])).finalize(msg_type=0)
self.send_message(msg)
def register_update(self, update):
self._updates.append(update)
def update_all(self, status):
_LOGGER.debug(f"[{self._device_id}] Status update: {status}")
for update in self._updates:
update(status)
def enable_device(self, available=True):
self._available = available
status = {"available": available}
self.update_all(status)
def open(self):
if not self._is_run:
self._is_run = True
threading.Thread.start(self)
def close(self):
if self._is_run:
self._is_run = False
self.close_socket()
def close_socket(self):
self._unsupported_protocol = []
self._buffer = b""
if self._socket:
self._socket.close()
self._socket = None
def set_ip_address(self, ip_address):
if self._ip_address != ip_address:
_LOGGER.debug(f"[{self._device_id}] Update IP address to {ip_address}")
self._ip_address = ip_address
self.close_socket()
def set_refresh_interval(self, refresh_interval):
self._refresh_interval = refresh_interval
def run(self):
while self._is_run:
while self._socket is None:
if self.connect(refresh_status=True) is False:
if not self._is_run:
return
self.close_socket()
time.sleep(5)
timeout_counter = 0
start = time.time()
previous_refresh = start
previous_heartbeat = start
self._socket.settimeout(1)
while True:
try:
now = time.time()
if 0 < self._refresh_interval <= now - previous_refresh:
self.refresh_status()
previous_refresh = now
if now - previous_heartbeat >= self._heartbeat_interval:
self.send_heartbeat()
previous_heartbeat = now
msg = self._socket.recv(512)
msg_len = len(msg)
if msg_len == 0:
raise socket.error("Connection closed by peer")
result = self.parse_message(msg)
if result == ParseMessageResult.ERROR:
_LOGGER.debug(f"[{self._device_id}] Message 'ERROR' received")
self.close_socket()
break
elif result == ParseMessageResult.SUCCESS:
timeout_counter = 0
except socket.timeout:
timeout_counter = timeout_counter + 1
if timeout_counter >= 120:
_LOGGER.debug(f"[{self._device_id}] Heartbeat timed out")
self.close_socket()
break
except socket.error as e:
_LOGGER.debug(f"[{self._device_id}] Socket error {repr(e)}")
self.close_socket()
break
except Exception as e:
_LOGGER.error(f"[{self._device_id}] Unknown error :{e.__traceback__.tb_frame.f_globals['__file__']}, "
f"{e.__traceback__.tb_lineno}, {repr(e)}")
self.close_socket()
break
# def set_attribute(self, attr, value):
# raise NotImplementedError
def get_attribute(self, attr):
return self._attributes.get(attr)
def set_customize(self, customize):
pass
@property
def attributes(self):
ret = {}
for status in self._attributes.keys():
ret[str(status)] = self._attributes[status]
return ret