import json
import logging
import socket

import arrow

from ..client import GPSClient, GPSWatcher
from ..objects import Error, Position, Velocity
from .exceptions import GPSDConnectionError


class GPSDWatcher(GPSWatcher):
    logger = logging.getLogger("cleargrid.lib.gps.microhard.GPSDWatcher")

    def __init__(self, client, view=False, min_mode=0):
        super().__init__(client)
        self.view = view
        self.min_mode = min_mode

    def setup(self):
        self.logger.debug("Starting watch mode")
        self.client.send(b"w")

    def is_valid(self, data):
        if not (self.view or isinstance(data, TPVReport)):
            return False
        elif getattr(data, "mode", 0) < self.min_mode:
            return False
        return True

    def teardown(self):
        self.client.send(b"w")
        # clear the buffers
        try:
            while True:
                data = self.client.raw_read()
                self.logger.debug("Discarding %r", data)
        except socket.timeout:
            pass


class GPSDClient(GPSClient):
    decoders = {}
    logger = logging.getLogger("cleargrid.lib.gps.microhard.GPSDClient")
    host = "192.168.168.1"
    port = 2947
    watcher_class = GPSDWatcher

    @classmethod
    def register(cls, symbol):
        def inner(obj):
            cls.decoders[symbol.upper()] = obj
            return obj

        return inner

    def __init__(self, host=None, port=None, **kwargs):
        super().__init__(host, port, **kwargs)
        self.logger.debug("Connecting to GPSD @ %s", self.address)
        self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        try:
            self._socket.settimeout(1.5)
            self._socket.connect((self.host, self.port))
        except OSError as exc:
            logging.exception(
                "Error connectiung to %s:%d", self.host, self.port
            )
            raise GPSDConnectionError() from exc
        self.buffer = b""

    def send(self, data):
        return self._socket.send(data)

    def raw_read(self):
        return self._socket.recv(1024)

    def receive(self):
        while True:
            index = self.buffer.find(b"\r\n")
            if index < 0:
                data = self.raw_read()
                self.buffer += data
                self.logger.debug(
                    "Recieved %d bytes from %s", len(data), self.address
                )
            else:
                data = self.buffer[:index]
                self.buffer = self.buffer[index + 2 :]
                return self._parse(data.decode("UTF-8"))

    @classmethod
    def _parse(cls, data):
        phrases = data.split(",")
        tag = phrases.pop(0)
        if tag != "GPSD":
            cls.logger.error("Expected 'GPSD' recieved %r", tag)
            raise ValueError(data)

        return [cls._parse_phrase(phrase) for phrase in phrases]

    @classmethod
    def _parse_phrase(cls, data):
        # ValueError: b'L=3 2.37 abcdefgijklmnopqrstuvwxyz'
        if data[1] != "=":
            cls.logger.error(
                "Failed to find expected token '=', found %r", data[1]
            )
            raise ValueError(data)
        symbol = data[0].upper()
        parts = data[2:].split()
        if symbol in cls.decoders:
            return cls.decoders[symbol](*parts)
        cls.logger.warn("No decoder registered for symbol %r", symbol)
        return (symbol, parts)

    def version(self):
        with self.lock:
            self.send(b"l\n")
            data = self.receive()
        return data

    def watch(self, *, view=False, mode=0):
        return GPSDWatcher(self, view, mode)


@GPSDClient.register("x")
class GPSDObject(object):
    def __init__(self, timestamp, **kwargs):
        self._timestamp = None
        self.timestamp = timestamp

    @property
    def timestamp(self):
        return self._timestamp

    @timestamp.setter
    def timestamp(self, value):
        if value is not None:
            self._timestamp = arrow.get(float(value))

    def __repr__(self):
        fmt = "{self.timestamp}"
        return fmt.format(self=self)


@GPSDClient.register("o")
class TPVReport(GPSDObject):
    def __init__(self, *args):
        if len(args) != 15:
            msg = "Expected 15 arguments, received %d" % len(args)
            logging.error(msg)
            logging.debug(args)
            raise ValueError(msg)

        (
            sentence,
            timestamp,
            time_error,
            latitude,
            longitude,
            altitude,
            h_error,
            v_error,
            track,
            speed,
            climb,
            t_error,
            s_error,
            c_error,
            mode,
        ) = args
        super().__init__(timestamp)
        self.sentence = sentence
        self.error = Error(
            time_error, h_error, v_error, t_error, s_error, c_error
        )
        self.position = Position(latitude, longitude, altitude)
        self.velocity = Velocity(track, speed, climb)
        if mode == "?":
            self.mode = 0
        else:
            self.mode = int(mode)

    def __repr__(self):
        fmt = "{self.timestamp} {self.position} {self.velocity}"
        return fmt.format(self=self)

    @property
    def json(self):
        return json.dumps(
            {
                "timestamp": self.timestamp.isoformat(),
                "error": self.error.__dict__,
                "position": self.position.__dict__,
                "velocity": self.velocity.__dict__,
                "mode": self.mode,
            }
        )


@GPSDClient.register("y")
class Satellites(GPSDObject):
    def __init__(self, sentence, timestamp, *satellites):
        super().__init__(timestamp)
        self.satellites = satellites


@GPSDClient.register("w")
def watch(value):
    return value == "1"


if __name__ == "__main__":
    c = GPSDClient("192.168.168.1")
    print(c.version())
    with c.watch() as w:
        print(w)
        for i, p in enumerate(w.stream):
            if i > 10:
                break
            print(p)

    print()
    print("=" * 10, "view = True", "=" * 10)
    print()

    with c.watch(view=True) as w:
        print(w)
        for i, p in enumerate(w.stream):
            if i > 10:
                break
            print(p)

    print()
    print("=" * 10, "min_mode = 3", "=" * 10)
    print()

    with c.watch(mode=3) as w:
        print(w)
        for i, p in enumerate(w.stream):
            if i > 10:
                break
            print(p)
