import socket
import sys
import commlib.settings
import struct
import threading
import time
import requests
import logging
import json
import time
from requests.auth import HTTPBasicAuth
from commlib.settings import (
    SECRET_KEY,
    TSKDF_LENGTH,
    UDP_IP_ADDRESS,
    UDP_PORT_NO,
    MC_URL,
    COLLECTOR_PK,
    MC_USERNAME,
    MC_PASSWORD,
)
from commlib.crypto_lob import (
    create_sha2_signature,
    create_sha3_signature,
    tskdf,
    Encryptor,
)
from commlib.proto import message_pb2
from commlib.UDP.messageconstructor import (
    construct_message,
    create_sha2_signature,
    create_sha3_signature,
    create_signing_string,
    encrypt_and_sign_message,
)
from commlib.UDP.messagedeconstructor import is_valid_signature, decrypt_payload
from google.protobuf.internal.decoder import _DecodeVarint32

from datetime import datetime, timezone

logger = logging.getLogger(__name__)


dispatch = {}


def register_dispatch(func):
    """
        Register a function to the dispatch
        Args:
            payload_type (string):
                -  "ping": ping,
                -  "location_message": location_message,
                -  "no_location_message": no_location_message,
                -  "location_sync_request": location_sync_request,
                -  "location_sync_message": location_sync_message,
                -  "mission_status_sync_message": mission_status_sync_message,
                -  "mission_status_sync_request": mission_status_sync_request,
                -  "mission_status_update_message": mission_status_update_message,
                -  "prepare_mission_message": prepare_mission_message,
                -  "complete_mission_message": complete_mission_message,
                -  "mission_state_message": mission_state_message,
    """
    if func in dispatch:
        logger.warning("%r already registered in DISPATCH", func)
    else:
        dispatch[func.__name__] = func

    def wrap(func):
        return func

    return wrap


def get_mission_key(mission_pk):
    return "mission_" + str(mission_pk)


class Client(threading.Thread):
    def __init__(
        self,
        secret_key=SECRET_KEY,
        tskdf_length=TSKDF_LENGTH,
        udp_ip_address=UDP_IP_ADDRESS,
        udp_port_no=UDP_PORT_NO,
        collector_pk=COLLECTOR_PK,
        username=MC_USERNAME,
        password=MC_PASSWORD,
        outgoing_meters_q=None,
        outgoing_loc_q=None,
        comm_service_db=None,
        outgoing_pb_q=None,
        itron_dir=r"C:\Export",
        deduping_cooldown=30,
    ):
        super(Client, self).__init__()
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.server_address = (udp_ip_address, udp_port_no)
        self.sock.connect(self.server_address)
        self.collector_pk = collector_pk
        self.mission_pks = None
        self.mission_states = {}
        self.mc_url = None
        self.outgoing_pb_q = outgoing_pb_q
        self.outgoing_meters_q = outgoing_meters_q
        self.outgoing_loc_q = outgoing_loc_q
        self.itron_dir = itron_dir
        self.deduping_cooldown = deduping_cooldown
        self.username = username
        self.password = password
        self.comm_service_db = comm_service_db

        # We still retain this for decrypting incoming Received Messages
        self.secret_key = secret_key
        self.tskdf_length = tskdf_length
        self.encryptor = Encryptor()

    def register_mc_url(self, mc_url):
        self.mc_url = mc_url

    def get_mission_pk(self):
        mission_url = (
            self.mc_url
            + "/api/v2/payload/"
            + str(self.collector_pk)
            + "/request_missions/"
        )

        while self.mission_pks is None:
            response = requests.get(
                mission_url, auth=HTTPBasicAuth(self.username, self.password)
            )
            if response.status_code == 200:
                response = response.json()
                self.mission_pks = response["missions"]
            else:
                time.sleep(5)
                logger.warn(
                    f"We did not receive a response from {mission_url}. Retrying in 5s..."
                )

    def initialize_device_sync(self):
        for mission_pk in self.mission_pks:
            device_sync_url = (
                self.mc_url + "/api/v2/mission_status?mission_pk=" + str(mission_pk)
            )

            while not bool(self.mission_states):
                response = requests.get(
                    device_sync_url, auth=HTTPBasicAuth(self.username, self.password)
                )

                if response.status_code == 200:
                    response = response.json()
                    # Example initialization response
                    # {
                    #     "data": '{"type": "FeatureCollection", "features": [{"geometry": {"type": "Point", "coordinates": [-114.3741464, 51.1027555]}, "properties": {"device_id": 13202, "device_state": "unheard", "mission_pk": 1, "read_at": null}}, {"geometry": {"type": "Point", "coordinates": [-114.3741464, 51.1027555]}, "properties": {"device_id": 13201, "device_state": "unheard", "mission_pk": 1, "read_at": null}}]}',
                    #     "status": 200,
                    # }
                    data = json.loads(response["data"])
                    if response["data"] != "null":
                        features = data["features"]
                        devices = [item["properties"] for item in features]
                        self.mission_states[get_mission_key(mission_pk)] = devices
                        logger.debug("Initialized sync for mission %s", mission_pk)
                    else:
                        logger.warn(
                            f"We did not receive a response from {device_sync_url}. Retrying in 5s..."
                        )
                        time.sleep(5)

                else:
                    logger.warn(
                        f"We did not receive a response from {device_sync_url}. Retrying in 5s..."
                    )
                    time.sleep(5)

    def dispatch_payload_handler(self, message, collector_id):
        """
        Dispatch a payload handler from client based on payload type
        """
        payload = message_pb2.Payload()
        payload.ParseFromString(message.payload)
        if payload.ByteSize() == 0:
            dispatch["empty_handler"](message, None, collector_id, self)
        else:
            for k, dispatch_func in dispatch.items():
                if hasattr(payload, k):
                    dispatch_payload = getattr(payload, k)
                    if dispatch_payload.ByteSize() != 0:
                        dispatch_func(message, payload, collector_id, self)

        return None

    def send_message(self, data):
        self.sock.sendall(data)

    def send_empty_message(self, collector_id, sequence, tskdf, cipher, mac):
        """
        ARGS: 
            collector_id (int)
            sequence (int)
            cipher: Message.AES_256_CTR_TSKDF_HMAC_SHA2_512 | Message.AES_256_CTR_TSKDF_HMAC_SHA3_512
            mac: Message.HMAC_SHA2_256_KDF_HMAC_SHA2_512 | Message.HMAC_SHA3_256_KDF_HMAC_SHA3_512
        """
        message = construct_message(collector_id, sequence, cipher, mac)
        if cipher == message_pb2.Message.AES_256_CTR_TSKDF_HMAC_SHA2_512:
            create_signature = create_sha2_signature
        elif cipher == message_pb2.Message.AES_256_CTR_TSKDF_HMAC_SHA3_512:
            create_signature = create_sha3_signature

        data = encrypt_and_sign_message(
            tskdf, message, None, self.encryptor, create_signature, cipher, mac
        )
        self.sock.sendall(data)

    def run(self):
        while True:
            print("Waiting to receive")
            data = self.sock.recv(4096)

            new_message = message_pb2.Message()
            message = new_message.FromString(data)
            if message.ByteSize() != 0:
                if is_valid_signature(message, self.secret_key, self.tskdf_length):
                    decrypted_payload = decrypt_payload(
                        message, self.secret_key, self.tskdf_length, self.encryptor
                    )
                    if decrypted_payload is not None:
                        message.payload = decrypted_payload
                        self.dispatch_payload_handler(message, message.collector)

    def stop(self):
        self.sock.close()

