import struct
import logging
from google.protobuf.message import DecodeError
from commlib.proto import message_pb2

from commlib.crypto_lob import (
    create_sha2_signature,
    create_sha3_signature,
    create_signing_string,
    tskdf,
    Encryptor,
)


logger = logging.getLogger(__name__)


class InvalidSignature(Exception):
    pass


def is_valid_signature(message, secret_key, tskdf_length):
    message_signed = message.mac
    if message_signed == message_pb2.Message.HMAC_SHA2_256_KDF_HMAC_SHA2_512:
        logger.info("SHA2 SIGNED")
        tskdf_sha2 = tskdf.TSKDF(secret_key, tskdf_length, "SHA512")
        signing_string = create_signing_string(message)
        derived_key = tskdf_sha2.derive_hmac()
        comparator_sig = create_sha2_signature(derived_key, signing_string)
        if comparator_sig != message.signature.decode("utf-8"):
            raise InvalidSignature()
        return True
    elif message_signed == message_pb2.Message.HMAC_SHA3_256_KDF_HMAC_SHA3_512:
        tskdf_sha3 = tskdf.TSKDF(secret_key, tskdf_length, "SHA3_512")
        signing_string = create_signing_string(message)
        derived_key = tskdf_sha3.derive_hmac()
        comparator_sig = create_sha3_signature(derived_key, signing_string)
        if comparator_sig != message.signature.decode("utf-8"):
            raise InvalidSignature()
        logger.info("SHA3 SIGNED")
        return True
    elif message_signed == message_pb2.Message.UNSIGNED:
        logger.info("UNSIGNED")
        return True


def decrypt_payload(message, secret_key, tskdf_length, encryptor):
    """Returns Protocol with a decrypted payload"""
    message_encrypted = message.cipher

    if message_encrypted == message_pb2.Message.AES_256_CTR_TSKDF_HMAC_SHA2_512:
        logger.info("SHA2 ENCRYPTED")
        tskdf_sha2 = tskdf.TSKDF(secret_key, tskdf_length, "SHA512")
        timestep = tskdf_sha2.derive_timestep()
        try:
            kdf_key, kdf_iv = tskdf_sha2.derive(timestep, message.sequence)
            decrypted_payload = encryptor.decrypt_payload(
                kdf_key, kdf_iv, message.payload
            )
            payload = message_pb2.Payload()
            payload.ParseFromString(decrypted_payload)
            return decrypted_payload
        except DecodeError:
            logger.info(
                "It was not decrypted successfully. Try using a Previous Timestep"
            )
            pass
        try:
            kdf_key, kdf_iv = tskdf_sha2.derive(timestep - 1, message.sequence)
            decrypted_payload = encryptor.decrypt_payload(
                kdf_key, kdf_iv, message.payload
            )
            payload = message_pb2.Payload()
            payload.ParseFromString(decrypted_payload)
            return decrypted_payload
        except DecodeError:
            return None

    elif message_encrypted == message_pb2.Message.AES_256_CTR_TSKDF_HMAC_SHA3_512:
        logger.info("SHA3 ENCRYPTED")
        tskdf_sha3 = tskdf.TSKDF(secret_key, tskdf_length, "SHA3_512")
        timestep = tskdf_sha3.derive_timestep()
        try:
            kdf_key, kdf_iv = tskdf_sha3.derive(timestep, message.sequence)
            decrypted_payload = encryptor.decrypt_payload(
                kdf_key, kdf_iv, message.payload
            )
            payload = message_pb2.Payload()
            payload.ParseFromString(decrypted_payload)
            return decrypted_payload
        except DecodeError:
            logger.info(
                "It was not decrypted successfully. Try using a Previous Timestep"
            )
            pass
        try:
            kdf_key, kdf_iv = tskdf_sha3.derive(timestep, message.sequence)
            decrypted_payload = encryptor.decrypt_payload(
                kdf_key, kdf_iv, message.payload
            )
            payload = message_pb2.Payload()
            payload.ParseFromString(decrypted_payload)
            return decrypted_payload
        except DecodeError:
            return None

    elif message_encrypted == message_pb2.Message.NULL:
        logger.info("MESSAGE NOT ENCRYPTED")
        return message.payload
