import threading
try:
    from enum import StrEnum
except ImportError:
    from ..backports.enum import StrEnum
from enum import IntEnum
from .security import LocalSecurity, MSGTYPE_HANDSHAKE_REQUEST, MSGTYPE_ENCRYPTED_REQUEST
from .packet_builder import PacketBuilder
from .message import MessageType, MessageQuerySubtype, MessageSubtypeResponse, MessageQuestCustom
import socket
import logging
import time

_LOGGER = logging.getLogger(__name__)


class AuthException(Exception):
    pass


class ResponseException(Exception):
    pass


class RefreshFailed(Exception):
    pass


class DeviceAttributes(StrEnum):
    pass


class ParseMessageResult(IntEnum):
    SUCCESS = 0
    PADDING = 1
    ERROR = 99


class MiedaDevice(threading.Thread):
    def __init__(self,
                 name: str,
                 device_id: int,
                 device_type: int,
                 ip_address: str,
                 port: int,
                 token: str,
                 key: str,
                 protocol: int,
                 model: str,
                 attributes: dict):
        threading.Thread.__init__(self)
        self._attributes = attributes if attributes else {}
        self._socket = None
        self._ip_address = ip_address
        self._port = port
        self._security = LocalSecurity()
        self._token = bytes.fromhex(token) if token else None
        self._key = bytes.fromhex(key) if key else None
        self._buffer = b""
        self._device_name = name
        self._device_id = device_id
        self._device_type = device_type
        self._protocol = protocol
        self._model = model
        self._updates = []
        self._unsupported_protocol = []
        self._is_run = False
        self._available = True
        self._device_protocol_version = 0
        self._sub_type = None
        self._sn = None
        self._refresh_interval = 30
        self._heartbeat_interval = 10
        self._default_refresh_interval = 30

    @property
    def name(self):
        return self._device_name

    @property
    def available(self):
        return self._available

    @property
    def device_id(self):
        return self._device_id

    @property
    def device_type(self):
        return self._device_type

    @property
    def model(self):
        return self._model

    @property
    def sub_type(self):
        return self._sub_type if self._sub_type else 0

    @staticmethod
    def fetch_v2_message(msg):
        result = []
        while len(msg) > 0:
            factual_msg_len = len(msg)
            if factual_msg_len < 6:
                break
            alleged_msg_len = msg[4] + (msg[5] << 8)
            if factual_msg_len >= alleged_msg_len:
                result.append(msg[:alleged_msg_len])
                msg = msg[alleged_msg_len:]
            else:
                break
        return result, msg

    def connect(self, refresh_status=True):
        try:
            self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self._socket.settimeout(10)
            _LOGGER.debug(f"[{self._device_id}] Connecting to {self._ip_address}:{self._port}")
            self._socket.connect((self._ip_address, self._port))
            _LOGGER.debug(f"[{self._device_id}] Connected")
            if self._protocol == 3:
                self.authenticate()
            _LOGGER.debug(f"[{self._device_id}] Authentication success")
            if refresh_status:
                self.refresh_status(wait_response=True)
            self.enable_device(True)
            return True
        except socket.timeout:
            _LOGGER.debug(f"[{self._device_id}] Connection timed out")
        except socket.error:
            _LOGGER.debug(f"[{self._device_id}] Connection error")
        except AuthException:
            _LOGGER.debug(f"[{self._device_id}] Authentication failed")
        except ResponseException:
            _LOGGER.debug(f"[{self._device_id}] Unexpected response received")
        except RefreshFailed:
            _LOGGER.debug(f"[{self._device_id}] Refresh status is timed out")
        except Exception as e:
            _LOGGER.error(f"[{self._device_id}] Unknown error: {e.__traceback__.tb_frame.f_globals['__file__']}, "
                          f"{e.__traceback__.tb_lineno}, {repr(e)}")
        self.enable_device(False)
        return False

    def authenticate(self):
        request = self._security.encode_8370(
            self._token, MSGTYPE_HANDSHAKE_REQUEST)
        _LOGGER.debug(f"[{self._device_id}] Handshaking")
        self._socket.send(request)
        response = self._socket.recv(512)
        if len(response) < 20:
            raise AuthException()
        response = response[8: 72]
        self._security.tcp_key(response, self._key)

    def send_message(self, data):
        if self._protocol == 3:
            self.send_message_v3(data, msg_type=MSGTYPE_ENCRYPTED_REQUEST)
        else:
            self.send_message_v2(data)

    def send_message_v2(self, data):
        if self._socket is not None:
            self._socket.send(data)
        else:
            _LOGGER.debug(f"[{self._device_id}] Send failure, device disconnected, data: {data.hex()}")

    def send_message_v3(self, data, msg_type=MSGTYPE_ENCRYPTED_REQUEST):
        data = self._security.encode_8370(data, msg_type)
        self.send_message_v2(data)

    def build_send(self, cmd):
        data = cmd.serialize()
        _LOGGER.debug(f"[{self._device_id}] Sending: {cmd}")
        msg = PacketBuilder(self._device_id, data).finalize()
        self.send_message(msg)

    def refresh_status(self, wait_response=False):
        cmds = self.build_query()
        if self._sub_type is None:
            cmds = [MessageQuerySubtype(self.device_type)] + cmds
        error_count = 0
        for cmd in cmds:
            if cmd.__class__.__name__ not in self._unsupported_protocol:
                self.build_send(cmd)
                if wait_response:
                    try:
                        while True:
                            msg = self._socket.recv(512)
                            if len(msg) == 0:
                                raise socket.error
                            result = self.parse_message(msg)
                            if result == ParseMessageResult.SUCCESS:
                                break
                            elif result == ParseMessageResult.PADDING:
                                continue
                            else:
                                raise ResponseException
                    except socket.timeout:
                        error_count += 1
                        self._unsupported_protocol.append(cmd.__class__.__name__)
                        _LOGGER.debug(f"[{self._device_id}] Does not supports "
                                      f"the protocol {cmd.__class__.__name__}, ignored")
                    except ResponseException:
                        error_count += 1
            else:
                error_count += 1
        if error_count == len(cmds):
            raise RefreshFailed

    def set_subtype(self):
        pass

    def pre_process_message(self, msg):
        if msg[9] == MessageType.querySubtype:
            message = MessageSubtypeResponse(msg)
            _LOGGER.debug(f"[{self.device_id}] Received: {message}")
            self._sub_type = message.sub_type
            self.set_subtype()
            self._device_protocol_version = message.device_protocol_version
            _LOGGER.debug(f"[{self._device_id}] Subtype: {self._sub_type}. "
                          f"Device protocol version: {self._device_protocol_version}")
            return False
        return True

    def parse_message(self, msg):
        if self._protocol == 3:
            messages, self._buffer = self._security.decode_8370(self._buffer + msg)
        else:
            messages, self._buffer = self.fetch_v2_message(self._buffer + msg)
        if len(messages) == 0:
            return ParseMessageResult.PADDING
        for message in messages:
            if message == b"ERROR":
                return ParseMessageResult.ERROR
            payload_len = message[4] + (message[5] << 8) - 56
            payload_type = message[2] + (message[3] << 8)
            if payload_type in [0x1001, 0x0001]:
                # Heartbeat detected
                pass
            elif len(message) > 56:
                cryptographic = message[40:-16]
                if payload_len % 16 == 0:
                    decrypted = self._security.aes_decrypt(cryptographic)
                    if self.pre_process_message(decrypted):
                        try:
                            status = self.process_message(decrypted)
                            if len(status) > 0:
                                self.update_all(status)
                            else:
                                _LOGGER.debug(f"[{self._device_id}] Unidentified protocol")
                        except Exception as e:
                            _LOGGER.error(f"[{self._device_id}] Error in process message, msg = {decrypted.hex()}")
                else:
                    _LOGGER.warning(
                        f"[{self._device_id}] Illegal payload, "
                        f"original message = {msg.hex()}, buffer = {self._buffer.hex()}, "
                        f"8370 decoded = {message.hex()}, payload type = {payload_type}, "
                        f"alleged payload length = {payload_len}, factual payload length = {len(cryptographic)}"
                    )
            else:
                _LOGGER.warning(
                    f"[{self._device_id}] Illegal message, "
                    f"original message = {msg.hex()}, buffer = {self._buffer.hex()}, "
                    f"8370 decoded = {message.hex()}, payload type = {payload_type}, "
                    f"alleged payload length = {payload_len}, message length = {len(message)}, "
                )
        return ParseMessageResult.SUCCESS

    def build_query(self):
        raise NotImplementedError

    def process_message(self, msg):
        raise NotImplementedError

    def send_command(self, cmd_type, cmd_body: bytearray):
        cmd = MessageQuestCustom(self._device_type, cmd_type, cmd_body)
        try:
            self.build_send(cmd)
        except socket.error as e:
            _LOGGER.debug(f"[{self._device_id}] Interface send_command failure, {repr(e)}, "
                          f"cmd_type: {cmd_type}, cmd_body: {cmd_body.hex()}")

    def send_heartbeat(self):
        msg = PacketBuilder(self._device_id, bytearray([0x00])).finalize(msg_type=0)
        self.send_message(msg)

    def register_update(self, update):
        self._updates.append(update)

    def update_all(self, status):
        _LOGGER.debug(f"[{self._device_id}] Status update: {status}")
        for update in self._updates:
            update(status)

    def enable_device(self, available=True):
        self._available = available
        status = {"available": available}
        self.update_all(status)

    def open(self):
        if not self._is_run:
            self._is_run = True
            threading.Thread.start(self)

    def close(self):
        if self._is_run:
            self._is_run = False
            self.close_socket()

    def close_socket(self):
        self._unsupported_protocol = []
        self._buffer = b""
        if self._socket:
            self._socket.close()
            self._socket = None

    def set_ip_address(self, ip_address):
        if self._ip_address != ip_address:
            _LOGGER.debug(f"[{self._device_id}] Update IP address to {ip_address}")
            self._ip_address = ip_address
            self.close_socket()

    def set_refresh_interval(self, refresh_interval):
        self._refresh_interval = refresh_interval

    def run(self):
        while self._is_run:
            while self._socket is None:
                if self.connect(refresh_status=True) is False:
                    if not self._is_run:
                        return
                    self.close_socket()
                    time.sleep(5)
            timeout_counter = 0
            start = time.time()
            previous_refresh = start
            previous_heartbeat = start
            self._socket.settimeout(1)
            while True:
                try:
                    now = time.time()
                    if 0 < self._refresh_interval <= now - previous_refresh:
                        self.refresh_status()
                        previous_refresh = now
                    if now - previous_heartbeat >= self._heartbeat_interval:
                        self.send_heartbeat()
                        previous_heartbeat = now
                    msg = self._socket.recv(512)
                    msg_len = len(msg)
                    if msg_len == 0:
                        raise socket.error("Connection closed by peer")
                    result = self.parse_message(msg)
                    if result == ParseMessageResult.ERROR:
                        _LOGGER.debug(f"[{self._device_id}] Message 'ERROR' received")
                        self.close_socket()
                        break
                    elif result == ParseMessageResult.SUCCESS:
                        timeout_counter = 0
                except socket.timeout:
                    timeout_counter = timeout_counter + 1
                    if timeout_counter >= 120:
                        _LOGGER.debug(f"[{self._device_id}] Heartbeat timed out")
                        self.close_socket()
                        break
                except socket.error as e:
                    _LOGGER.debug(f"[{self._device_id}] Socket error {repr(e)}")
                    self.close_socket()
                    break
                except Exception as e:
                    _LOGGER.error(f"[{self._device_id}] Unknown error :{e.__traceback__.tb_frame.f_globals['__file__']}, "
                                  f"{e.__traceback__.tb_lineno}, {repr(e)}")
                    self.close_socket()
                    break

    def set_attribute(self, attr, value):
        raise NotImplementedError

    def get_attribute(self, attr):
        return self._attributes.get(attr)

    def set_customize(self, customize):
        pass

    @property
    def attributes(self):
        ret = {}
        for status in self._attributes.keys():
            ret[str(status)] = self._attributes[status]
        return ret