import hmac
import unittest
import time
from binascii import unhexlify
from math import ceil

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.kbkdf import KBKDFHMAC, CounterLocation, Mode


class TSKDF(object):
    def __init__(self, key, length, digest, label=b"TSKDF", backend=None):
        """
        digest: SHA3_512 | SHA512
        """
        self.key = bytes(key, "utf-8")  # The Key Derivation Key
        self.digest = getattr(hashes, digest)()
        self.label = label
        self.length = length
        self.hmac_key = None
        if backend:
            self.backend = backend
        else:
            self.backend = default_backend()

    @staticmethod
    def derive_timestep():
        seconds = time.time()
        return int(seconds / 60)

    def derive(self, timestep, sequence):
        # https://cryptography.io/en/latest/hazmat/primitives/key-derivation-functions/#cryptography.hazmat.primitives.kdf.kbkdf.KBKDFHMAC
        kdf_iv = KBKDFHMAC(
            algorithm=self.digest,
            mode=Mode.CounterMode,
            length=16,
            rlen=4,
            llen=4,
            location=CounterLocation.BeforeFixed,
            label=b"iv for " + self.label,
            context="T:{:d} S:{:d}".format(timestep, sequence).encode("utf-8"),
            fixed=None,  # use label and context for fixed
            backend=self.backend,
        )
        kdf_key = KBKDFHMAC(
            algorithm=self.digest,
            mode=Mode.CounterMode,
            length=self.length,
            rlen=4,
            llen=4,
            location=CounterLocation.BeforeFixed,
            label=b"key for " + self.label,
            context="T:{:d} S:{:d}".format(timestep, sequence).encode("utf-8"),
            fixed=None,  # use label and context for fixed
            backend=self.backend,
        )
        return kdf_key.derive(self.key), kdf_iv.derive(self.key)

    def derive_hmac(self):
        kdf = KBKDFHMAC(
            algorithm=self.digest,
            mode=Mode.CounterMode,
            length=self.length,
            rlen=4,
            llen=4,
            location=CounterLocation.BeforeFixed,
            label=b"hmac for " + self.label,
            context=b"",
            fixed=None,  # use label and context for fixed
            backend=self.backend,
        )
        if self.hmac_key is not None:
            return self.hmac_key
        else:
            self.hmac_key = kdf.derive(self.key)
            return self.hmac_key

    def derive_encryption(self):

        kdf = KBKDFHMAC(
            algorithm=self.digest,
            mode=Mode.CounterMode,
            length=self.length,
            rlen=4,
            llen=4,
            location=CounterLocation.BeforeFixed,
            label=b"encryption for " + self.label,
            context=b"",
            fixed=None,  # use label and context for fixed
            backend=self.backend,
        )
        return kdf.derive(self.key)
