import struct
import time
import logging

import matplotlib.pyplot as plt
import numpy as np
from numba import jit

import cleargrid.r900.galois as gf

from .exceptions import *

logger = logging.getLogger(__name__)


@jit(nopython=True)
def trim_buffer(buffer, start=0, end=0, base=0):
    buffer_len = len(buffer)
    start = 0 if start is None else start
    end = 0 if end is None else end
    if start:
        # logger.debug("start_sample defined")
        if (base + buffer_len) < start:
            # logger.debug(
            #     "Ignoring block as %r+%r<%r", base, buffer_len, start
            # )
            return None, None
        elif base >= start:
            # logger.debug("processing block as %r>=%r", base, start)
            pass
        else:
            # logger.debug("trimming block")
            return buffer[start - base :], start
    elif end:
        if base > end:
            return None, -1
    return buffer, base


trim_buffer(np.array([1]), 0, 0, 0)


def symbolize(buffer, chip_length):
    _cumsum = np.cumsum(buffer)
    c0 = _cumsum[0 : -chip_length * 4]
    c1 = 2 * _cumsum[chip_length : -chip_length * 3]
    c2 = 2 * _cumsum[chip_length * 2 : -chip_length * 2]
    c3 = 2 * _cumsum[chip_length * 3 : -chip_length]
    c4 = _cumsum[chip_length * 4 :]
    filtered = np.array(
        [c2 - c4 - c0, c1 - c2 + c3 - c4 - c0, c1 - c3 + c4 - c0]
    ).T
    indicies = np.argmax(np.abs(filtered), axis=1)
    values = np.squeeze(
        np.take_along_axis(filtered, np.expand_dims(indicies, axis=1), axis=1)
    )
    _symbols = indicies + np.where(values > 0, 3, 0,)
    return _symbols


symbolize(np.array([1]), 1)


@jit(nopython=True)
def array_equal(a, b):
    if a[0] != b[0]:
        return False
    if a[-1] != b[-1]:
        return False
    return np.array_equal(a, b)


array_equal(np.array([1]), np.array([1]))


@jit(nopython=True)
def search(preamble, buffer, chip_length, offset=0):
    preamble_len = len(preamble)
    buffer_len = len(buffer)
    search_end = buffer_len - (preamble_len * chip_length * 4)
    for ndx in range(offset, search_end):
        candidate = buffer[ndx :: chip_length * 4]
        if array_equal(preamble, candidate[:preamble_len]):
            return ndx
    return -1


search(np.array([1]), np.array([1]), 1)

from .input_filters import get_input_filter


class DataStream:
    BAUD = 32768
    SAMPLE_RATE = 2359296
    BLOCK_SIZE = SAMPLE_RATE // 100
    PREAMBLE = "1111111145222521"

    PAYLOAD_SYMBOLS = 21

    def __init__(
        self,
        filter_spec="cleargrid.r900.input_filter.source_detect",
        source=None,
        sample_rate=SAMPLE_RATE,
        blocksize=BLOCK_SIZE,
        preamble=PREAMBLE,
        start_sample=None,
        end_sample=None,
        chip_length=None,
    ):
        self.input = get_input_filter(filter_spec, source)
        self.sample_rate = sample_rate
        self.blocksize = blocksize
        self.start_sample = start_sample
        self.end_sample = end_sample
        self.chip_length = chip_length or int(
            np.round(self.sample_rate / self.BAUD)
        )
        self.field = gf.Field(32, 37, 2)
        self._mag = np.full((self.blocksize * 2,), 0, dtype="I")
        self._cumsum = None
        self._symbols = None
        self._base_index = -2 * self.blocksize
        self._block_start = 0
        self._preamble = preamble

    def _trim(self):
        buffer, base = trim_buffer(
            self._mag,
            start=self.start_sample,
            end=self.end_sample,
            base=self._base_index,
        )
        if buffer is not None:
            self._mag, self._base_index = buffer, base
        elif base is None:
            raise NoData()
        elif base == -1:
            raise DataEnd()

    # @jit(nopython=True)
    def process(self, preamble=None):
        if not preamble:
            preamble = self._preamble
        preamble = np.array([int(x) for x in preamble])

        while self._read():
            start = time.monotonic()
            try:
                self._trim()
            except NoData:
                continue
            except DataEnd:
                break
            self._symbolize()
            yield from self._search(preamble)
            elapsed = time.monotonic() - start
            pct = elapsed / (self.blocksize / self.sample_rate)
            logger.debug(
                f"Block complete in {elapsed*10e6:.2f}us [{pct:0.2%} realtime]"
            )

    def plot(self, filtered):
        _sample_no = np.arange(len(self._mag)) + self._base_index
        fig = plt.figure(figsize=(10, 20), dpi=200)

        ax = fig.add_subplot("211")
        ax.plot(_sample_no, self._mag, linewidth=0.2)
        ax.axvline(self._block_start, c="r", linewidth=0.2)
        ax = ax.twinx()
        ax.plot(_sample_no, self._cumsum, "g", linewidth=0.2)

        data_end = -self.chip_length * 4
        ax = fig.add_subplot("212")
        ax.plot(
            _sample_no[:data_end], filtered[:, 0], linewidth=0.2,
        )
        ax.plot(
            _sample_no[:data_end], filtered[:, 1], linewidth=0.2,
        )
        ax.plot(
            _sample_no[:data_end], filtered[:, 2], linewidth=0.2,
        )
        ax.axvline(self._block_start, c="r", linewidth=0.2)
        ax = ax.twinx()
        ax.scatter(_sample_no[:data_end], self._symbols, c="r", s=0.2)

        plt.show()

    def _symbolize(self):
        # self.plot(filtered)
        self._symbols = symbolize(self._mag, self.chip_length)

    def bytes(self, symbols, length=21):
        bits = ""

        try:
            symbol_length = length * 2
            if len(symbols) < symbol_length:
                raise R900Error()
            interim = (
                symbols[0:symbol_length:2] * 6 + symbols[1:symbol_length:2]
            )
            _bytes = b""
            syndromes = []
            bits = ""

            symbol_length = length * 2
            interim = (
                symbols[0:symbol_length:2] * 6 + symbols[1:symbol_length:2]
            )

            for symbol in interim[:length]:
                if symbol > 32:
                    raise BadSymbol("Encountered bad symbol")
                bits = f"{bits}{symbol:05b}"

            _bytes = b"".join(
                [
                    int(bits[b : b + 8], 2).to_bytes(1, "big")
                    for b in range(0, len(bits), 8)
                ]
            )

            rsbuf = [0 for i in range(31)]
            rsbuf[: length - 5] = interim[: length - 5]
            rsbuf[26:] = interim[length - 5 : length]
            try:
                syndromes = self.field.syndrome(rsbuf, 5, 29)
            except Exception as e:
                logger.debug("Error while calculating checksum", exc_info=e)
                raise ChecksumError()

            if any(syndromes):
                raise ChecksumError()

            return _bytes
        except R900Error as e:
            logger.debug("Encountered Error extracting bytes", exc_info=e)
            raise e
        finally:
            pass
            # logger.debug("length = %r", length)
            # logger.debug("symbol_length = %r", symbol_length)
            # logger.debug("symbols = %r", symbols[:symbol_length])
            # logger.debug("interim = %r", interim)
            # logger.debug("syndromes = %r", syndromes)
            # logger.debug("bits = %r", bits)
            # logger.debug("bytes = %r", _bytes.hex())
            # logger.debug("=" * 20)

    def _search(self, preamble):
        ndx = 0
        while 1:
            ndx = search(preamble, self._symbols, self.chip_length, offset=ndx)
            if ndx < 0:
                break
            try:
                yield self._base_index + ndx, preamble, self.bytes(
                    self._symbols[ndx :: self.chip_length * 4][len(preamble) :]
                )
            except R900Error:
                logger.debug("Unable to extract signal")
            except Exception as e:
                logger.debug("Error while extrating bytes", exc_info=e)
                raise
            ndx += 1

    def timestamp(self, offset):
        return ((self._base_index + offset) / self.sample_rate) * 1e6

    def _read(self):
        mag = self.input(self.blocksize)

        prev = self._mag[-self.blocksize :]
        step = len(self._mag) - len(prev)

        self._mag = np.concatenate((prev, mag), axis=None)

        self._base_index += step
        self._block_start = self._base_index + len(prev)
        logger.debug(self._base_index)
        return len(mag)
