import datetime
import logging
import signal
import socket
import threading
import typing as _t
from contextlib import closing
from json import loads

import arrow
import configargparse
import redis

from . import __version__
from .datastructures import Queue, State, Task

_og = logging.getLogger("celery.monitoring.server")


def get_parser():
    p = configargparse.ArgParser(
        default_config_files=[
            "/etc/cleargrid/celery/monitoring.d/*.conf",
            "/etc/cleargrid/celery/monitoring/server.d/*.conf",
        ]
    )
    p.add_argument(
        "--version",
        action="version",
        version="%(prog)s {version}".format(version=__version__),
    )
    p.add("-f", "--config-file", is_config_file=True, help="More config files")
    p.add(
        "--queues",
        env_var="QUEUES",
        help="Comma seperated list of queue names",
        required=True,
    )
    p.add("--redis-host", env_var="REDIS_HOST", default="localhost")
    p.add("--redis-port", env_var="REDIS_PORT", type=int, default="6379")
    p.add("--redis-db", env_var="REDIS_DB", type=int, default="0")
    p.add(
        "--redis-connect-timeout",
        env_var="REDIS_CONNECT_TIMEOUT",
        type=int,
        default="2",
    )
    p.add(
        "--redis-socket-timeout",
        env_var="REDIS_SOCKET_TIMEOUT",
        type=int,
        default="5",
    )
    p.add(
        "--redis-retry-on-timeout",
        env_var="REDIS_RETRY_ON_TIMEOUT",
        type=int,
        default="1",
    )
    p.add(
        "--redis-socket-keepalive",
        env_var="REDIS_SOCKET_KEEPALIVE",
        type=int,
        default="1",
    )
    p.add(
        "--redis-keepalives-idle",
        env_var="REDIS_KEEPALIVES_IDLE",
        type=int,
        default="1",
    )
    p.add(
        "--redis-keepalives-count",
        env_var="REDIS_KEEPALIVES_COUNT",
        type=int,
        default="5",
    )
    p.add(
        "--redis-keepalives-interval",
        env_var="REDIS_KEEPALIVES_INTERVAL",
        type=int,
        default="1",
    )
    p.add(
        "--redis-user-timeout",
        env_var="REDIS_USER_TIMEOUT",
        type=int,
        default="1000",
    )

    p.add("--exit-on-connection-failure", action="store_true")
    p.add("--max-wait-time", type=int, default="5")
    p.add("--collect-interval", type=int, default="60")
    p.add("--queue-page-size", type=int, default="10")
    return p


def decode_json(j: bytes):
    try:
        return loads(j)
    except TypeError:
        _og.error(f"Bad json: {j!r}")
        return None


class CeleryRedisMonitor:
    SHUTDOWN_SIGNALS: _t.Set[signal.Signals] = {
        signal.SIGTERM,
        signal.SIGINT,
    }

    def __init__(self, args):
        self.connparams = {
            "host": args.redis_host,
            "port": args.redis_port,
            "db": args.redis_db,
            "socket_timeout": args.redis_socket_timeout,
            "socket_connect_timeout": args.redis_connect_timeout,
            "socket_keepalive": args.redis_socket_keepalive,
            "socket_keepalive_options": {
                socket.TCP_KEEPCNT: args.redis_keepalives_count,
                socket.TCP_KEEPIDLE: args.redis_keepalives_idle,
                socket.TCP_KEEPINTVL: args.redis_keepalives_interval,
                socket.TCP_USER_TIMEOUT: args.redis_user_timeout,
            },
            "retry_on_timeout": args.redis_retry_on_timeout,
        }
        self.collect_interval = datetime.timedelta(
            seconds=args.collect_interval
        )
        self.max_wait_time = args.max_wait_time
        self.shutdown = threading.Event()
        self.data: State = State()
        self.exit_on_connection_failure = args.exit_on_connection_failure
        self.queues = args.queues.split(",")
        self.queue_page_size = args.queue_page_size

    def run(self):
        def signal_shutdown(sig, frame):
            _og.info(
                "Caught %s, requesting shutdown",
                signal.Signals(sig).name,
                stack_info=True,
            )
            self.shutdown.set()

        for sig in self.SHUTDOWN_SIGNALS:
            signal.signal(sig, signal_shutdown)
        next_run = datetime.datetime.now()
        while not self.shutdown.is_set():
            now = datetime.datetime.now()
            diff = (next_run - now).total_seconds()
            _og.debug("Next scrape in %10.6f seconds", diff)
            if diff > 0:
                self.shutdown.wait(min(diff, self.max_wait_time))
                continue
            while next_run - now < -self.collect_interval:
                next_run += self.collect_interval

            _og.info("Scraping data")
            data = self.scrape_data()
            if data:
                # Do not log data if it has not been updated
                self.data = data
                self.log_data(self.data)

            next_run += self.collect_interval

    def scrape_data(self) -> _t.Optional[State]:
        try:
            with closing(redis.Redis(**self.connparams)) as conn:
                return self._scrape_data(conn)
        except redis.ConnectionError as exc:
            _og.info("ConnectionIssue %r", exc, stack_info=True)
            if self.exit_on_connection_failure:
                raise
        return None

    def _scrape_data(self, conn: redis.Redis) -> State:
        ret = State()
        ret.tasks.update(self._task_meta(conn))
        for queue in self.queues:
            tasks, in_queue = self._in_queue_tasks(conn, queue)
            ret.tasks.update(tasks)
            ret.queues.add(in_queue)
        return ret

    def _task_meta(self, conn: redis.Redis) -> _t.Set[Task]:
        ret = set()
        for task_keys in conn.scan_iter("celery-task-meta-*"):
            for task in [decode_json(j) for j in conn.mget(task_keys)]:
                if task and "task_id" in task:
                    ret.add(Task.from_json(task))
        return ret

    def _in_queue_tasks(
        self, conn: redis.Redis, queue: str
    ) -> _t.Tuple[_t.Set[Task], Queue]:
        tasks = set()
        in_queue = Queue(queue)
        for x in range(0, conn.llen(queue), self.queue_page_size):
            for task in [
                decode_json(j)
                for j in conn.lrange(queue, x, x + self.queue_page_size)
            ]:
                if task:
                    tasks.add(Task.from_json(task))
                    in_queue.add(task["headers"])
        return tasks, in_queue

    def log_data(self, data):
        """Log information about queues"""
        for queue in self.data.queues:
            self._log_data(queue)

    def _log_data(self, queue):
        logger = logging.getLogger("celery.queue.stats")
        logger.debug("%r", queue)
        for task_name, counts in queue.task_names.items():
            logger.info("%s=%s=%d", queue.name, task_name, counts)
        for task_name, counts in queue.task_names_deduplicated.items():
            logger.info("%s|%s=%d", queue.name, task_name, counts)

        latency = (arrow.get() - queue.oldest).total_seconds()
        logger.info("%s.latency=%.0f", queue.name, latency)


if __name__ == "__main__":
    import sentry_sdk
    from sentry_sdk.integrations.django import DjangoIntegration

    sentry_sdk.init(
        dsn=(
            "https://aed18fe1a71a47268ae1e47bf4b7240e@"
            "o265096.ingest.sentry.io/5287296"
        ),
        integrations=[DjangoIntegration()],
    )
    exit(CeleryRedisMonitor(get_parser().parse_args()).run())
