import socket
import sys
import socketserver
import threading
import logging

from django.core.cache import cache
from commlib.proto import message_pb2
from google.protobuf.message import DecodeError
from functools import partial
from commlib.crypto_lob import (
    create_null_signature,
    create_sha2_signature,
    create_sha3_signature,
    create_signing_string,
    tskdf,
    Encryptor,
)
from commlib.settings import (
    SECRET_KEY,
    TSKDF_LENGTH,
    UDP_IP_ADDRESS,
    UDP_PORT_NO,
    MC_URL,
)
from commlib.UDP.payloadconstructor import PayloadConstructor
from commlib.UDP.messagedeconstructor import is_valid_signature, decrypt_payload
from commlib.util import collector_to_signing_key


logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG, format="%(message)s")

dispatch = {}


def register_dispatch(func):
    """
        Register a function to the dispatch
        Args:
            payload_type (string):
                -  "ping": ping,
                -  "location_message": location_message,
                -  "no_location_message": no_location_message,
                -  "location_sync_request": location_sync_request,
                -  "location_sync_message": location_sync_message,
                -  "mission_status_sync_message": mission_status_sync_message,
                -  "mission_status_sync_request": mission_status_sync_request,
                -  "mission_status_update_message": mission_status_update_message,
                -  "prepare_mission_message": prepare_mission_message,
                -  "complete_mission_message": complete_mission_message,
                -  "mission_state_message": mission_state_message,
    """
    if func in dispatch:
        logger.warning("%r already registered in DISPATCH", func)
    else:
        dispatch[func.__name__] = func

    def wrap(func):
        return func

    return wrap


class InvalidSignature(Exception):
    """Base class for other exceptions"""

    pass


class SocketServerClass(threading.Thread):
    def __init__(self, udp_ip_address=UDP_IP_ADDRESS, udp_port_no=UDP_PORT_NO):
        super(SocketServerClass, self).__init__()
        # Create a TCP/IP socket
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        # Bind the socket to the port
        self.server_address = (udp_ip_address, udp_port_no)
        self.payload_constructor = PayloadConstructor()
        self.encryptor = Encryptor()

        handler = partial(
            MyUDPRequestHandler, self.encryptor, self.dispatch_payload_handler
        )

        self.UDPServerObject = socketserver.ThreadingUDPServer(
            self.server_address, handler
        )

    def dispatch_payload_handler(self, message, collector_id, wfile):
        """
        Dispatch a payload handler from server based on payload type
        """
        payload = message_pb2.Payload()
        payload.ParseFromString(message.payload)

        # No Payload
        if payload.ByteSize() == 0:
            dispatch["empty_handler"](message, None, collector_id, self, wfile)
        else:
            for k, dispatch_func in dispatch.items():
                if hasattr(payload, k):
                    dispatch_payload = getattr(payload, k)
                    if dispatch_payload.ByteSize() != 0:
                        dispatch_func(message, payload, collector_id, self, wfile)
        return None

    def run(self):
        # Make the server wait forever serving connections
        self.UDPServerObject.serve_forever()

    def stop(self):
        self.UDPServerObject.server_close()


class MyUDPRequestHandler(socketserver.DatagramRequestHandler):
    def __init__(self, encryptor, dispatch_payload_handler, *args, **kwargs):
        self.encryptor = encryptor
        self.dispatch_payload_handler = dispatch_payload_handler

        super().__init__(*args, **kwargs)

    def handle(self):
        try:
            data = self.request[0].strip()
            socket = self.request[1]
            # print("{} wrote: {}".format(self.client_address[0], data))
            new_message = message_pb2.Message()
            message = new_message.FromString(data)

            signing_info = cache.get(
                collector_to_signing_key(message.collector),
                default={"secret_key": SECRET_KEY, "tskdf_length": TSKDF_LENGTH},
            )

            secret_key = signing_info["secret_key"]
            tskdf_length = signing_info["tskdf_length"]

            # print("We have received some signing information ")
            # print(secret_key)
            # print(tskdf_length)
            if signing_info is not None:
                if is_valid_signature(message, secret_key, tskdf_length):
                    decrypted_payload = decrypt_payload(
                        message, secret_key, tskdf_length, self.encryptor
                    )
                    if decrypted_payload is not None:
                        message.payload = decrypted_payload
                        self.dispatch_payload_handler(
                            message, message.collector, self.wfile
                        )

        except InvalidSignature:
            logger.warning("Invalid Signature Received")
