import numpy as np
import importlib
import sys
import logging

logger = logging.getLogger(__name__)

# In general an input filter is provided as a string referencing a factory
# function returning a callable. The factory will be called with a `source`
# parameter and is expected to return a callable which accept the number of
# samples to be read (`blocksize`) the return from the call shall be
# `blocksize` magnitude values


def get_input_filter(filter_spec: str, source: str):
    try:
        module_path, factory = filter_spec.rsplit(".", 1)
        return getattr(importlib.import_module(module_path), factory)(source)
    except ModuleNotFoundError:
        logger.error(f"Unable to load {filter_spec!r}")


def file_source(source: str):
    if source == "-":
        # we use `.buffer` for access to raw binary data
        return sys.stdin.buffer
    return open(source, "rb")


class NoMatch(ValueError):
    pass


class SourceDetect:
    def __init__(self):
        self.registry = {}

    def register(self, tag: str = None):
        def fn(factory):
            nonlocal tag
            if tag is None:
                tag = factory.__name__

            self.registry[tag] = factory
            return factory

        return fn

    def __call__(self, source: str):
        _, ext = source.rsplit(".", 1)
        try:
            factory = self.registry[ext]
            logger.info("Loading samples using %r", factory.__qualname__)
            return factory(source)
        except KeyError:
            raise NoMatch(f"Unable to find filter matching {ext!r}")


source_detect = SourceDetect()

# Two example input filters are defined below.


@source_detect.register()
def cu8(source: str):
    datafile = file_source(source)

    def filter(samples: int):
        byte_count = samples * 2
        block = np.frombuffer(
            datafile.read(byte_count),
            np.dtype([("I", np.uint8), ("Q", np.uint8)]),
        )
        return np.sqrt(
            (block["I"].astype(int)) ** 2 + (block["Q"].astype(int)) ** 2
        )

    return filter


@source_detect.register()
def cs16(source: str):
    datafile = file_source(source)

    def filter(samples: int):
        byte_count = samples * 4
        block = np.frombuffer(
            datafile.read(byte_count),
            np.dtype([("I", np.int16), ("Q", np.int16)]),
        )
        return np.sqrt(
            (block["I"].astype(int)) ** 2 + (block["Q"].astype(int)) ** 2
        )

    return filter
