#!/usr/bin/env python3

import os
import sys

BASE_DIR = "/opt/rbta/ad/mgmtportal/api/core"
if BASE_DIR not in sys.path:
    sys.path.insert(0, BASE_DIR)

os.environ["DJANGO_SETTINGS_MODULE"] = "project.settings"

import json
import subprocess
import time
from argparse import ArgumentParser, RawDescriptionHelpFormatter
from traceback import format_exc
from typing import Any, Dict, Iterable, List, Set, Tuple

import ldap
from aldpro_logging import SyslogLogger
from aldpro_om2 import Dn
from django.conf import settings
from ldap.controls import SimplePagedResultsControl

from directory_service.models.ntp_dnsrecord import NtpDnsrecord
from directory_service.models.ntp_server import NtpServer

logger = SyslogLogger("aldpro-ntp-management")

BASE_DN = settings.BASE_DN
LDAP_SERVER = settings.LDAP_SERVER
LDAP_SSL = settings.LDAP_SSL
LDAP_PROTO = "ldaps" if LDAP_SSL else "ldap"
LDAP_PORT = 636 if LDAP_SSL else 389
LDAP_USER = f"uid={settings.LDAP_USER},cn=sysaccounts,cn=etc,{BASE_DN}"
LDAP_PASSWORD = settings.LDAP_PASSWORD

MODULE_NAME = __name__.split(".")[-1]
SLEEP_PERIOD = 120


class NtpException(Exception):
    def __init__(self, message):
        self.message = message
        super().__init__(self.message)


class DaemonBaseLDAPTransport:
    def __new__(cls, *args, **kwargs):
        if not hasattr(cls, "instance"):
            cls.instance = super(DaemonBaseLDAPTransport, cls).__new__(cls)
        return cls.instance


class DaemonLDAPTransport(DaemonBaseLDAPTransport):
    def __init__(self, ldap_server: str):
        self.ldap_server = ldap_server
        self.__initialize_ldap = None

    def __ldap_disconnect(self) -> None:
        """Закрыть соединение к ldap

        Args:
            initialize_ldap (ldap.ldapobject.SimpleLDAPObject): объект соединения
        """
        if self.__initialize_ldap:
            self.__initialize_ldap.unbind()

    def __ldap_connect(self, domain_controller: str) -> ldap.ldapobject.SimpleLDAPObject:
        """Открыть соединенеие к ldap

        Args:
            domain_controller (str): Адрес контролера домена, к которому подсодениться

        Returns:
            ldap.ldapobject.SimpleLDAPObject: объект соединения
        """
        transport_obj = None
        ldap_url = f"{LDAP_PROTO}://{domain_controller}:{LDAP_PORT}"
        ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_NEVER)
        initialize_ldap = ldap.initialize(ldap_url)
        initialize_ldap.set_option(ldap.OPT_REFERRALS, 0)
        initialize_ldap.protocol_version = ldap.VERSION3
        try:
            initialize_ldap.simple_bind_s(
                LDAP_USER,
                LDAP_PASSWORD,
            )
        except ldap.LDAPError as e:
            logger.error(f"Ошибка соединения с ldap {e}")
            logger.error(f"\n{format_exc()}")
        else:
            transport_obj = initialize_ldap
        return transport_obj

    def _get_data(
        self, target_dn, filterstr="(objectClass=*)", attrlist=[], scope=ldap.SCOPE_SUBTREE
    ) -> Iterable[Dict[str, List[bytes]]]:
        """Получение данных из LDAP

        Args:
            target_dn (str): DN базовой записи
            filterstr (str): LDAP фильтр для поиска записей
            attrlist (list): Список атрибутов для получения
            scope: Уровень поиска записей

        Returns:
            Dict[str, List[bytes]]: Перечень атрибутов найденных записей
        """
        for entry in self.get_response(search_base=target_dn, filter_str=filterstr, attrlist=attrlist, scope=scope):
            for attrs in entry:
                yield attrs[1]

    def _get_record(self, target_dn: str) -> Dict[str, List[bytes]]:
        """Получение конкретной записи из LDAP

        Args:
            target_dn (str): DN базовой записи

        Returns:
            Dict[str, List[bytes]]: Перечень атрибутов найденной записи
        """
        try:
            raw_record = self.__initialize_ldap.search_s(target_dn, scope=ldap.SCOPE_BASE)
        except ldap.NO_SUCH_OBJECT:
            raw_record = dict()
        raw_record_data = raw_record[-1][-1] if raw_record else raw_record
        return raw_record_data

    def get_response(
        self,
        search_base: str,
        filter_str: str = "(objectClass=*)",
        attrlist=None,
        scope=ldap.SCOPE_ONELEVEL,
    ) -> Iterable[List[Tuple[str, Dict[str, List[bytes]]]]]:
        """Выполняет поиск в контейнере

        Args:
            search_base (str): Контейнер для поиска
            filter_str (str): Фильтр для поиска, по умолчанию - (objectClass=*) - вернет все объекты в контейнере
            attrlist (list): Список атрибутов для получения
            scope: Уровень поиска записей

        Returns:
            List[Tuple[str, Dict[str, List[bytes]]]]: Перечень объектов и их cn
        """
        PAGE_CONTROL: SimplePagedResultsControl = SimplePagedResultsControl(True, size=150, cookie="")
        if attrlist is None:
            attrlist = [
                "cn",
            ]
        try:
            message_id = self.__initialize_ldap.search_ext(
                base=f"{search_base}{BASE_DN}",
                scope=scope,
                filterstr=filter_str,
                attrlist=attrlist,
                serverctrls=[PAGE_CONTROL],
            )

            pages = 0
            while True:
                pages += 1
                _, rdata, _, serverctrls = self.__initialize_ldap.result3(message_id)
                yield rdata
                controls = [
                    control for control in serverctrls if control.controlType == SimplePagedResultsControl.controlType
                ]
                if not controls:
                    logger.error("The server ignores RFC 2696 control")
                    break
                if not controls[0].cookie:
                    break
                PAGE_CONTROL.cookie = controls[0].cookie
                message_id = self.__initialize_ldap.search_ext(
                    base=f"{search_base}{BASE_DN}",
                    scope=scope,
                    filterstr=filter_str,
                    attrlist=attrlist,
                    serverctrls=[PAGE_CONTROL],
                )
        except ldap.LDAPError as e:
            raise Exception(f"Ошибка поиска сущностей {e}")

    def is_connected(self) -> bool:
        """Получение статуса подключения к LDAP-серверу

        Returns:
            bool: статус подключения
        """
        return True if self.__initialize_ldap else False

    def __enter__(self):
        self.__initialize_ldap = self.__ldap_connect(settings.LDAP_SERVER)
        return self

    def __exit__(self, *args):
        self.__ldap_disconnect()


class NtpManagement:
    def __init__(self, path, transport):
        self.path = path
        self.transport = transport
        self.ldif_data = str()

    def _full_dn(self, *args: Dn) -> Dn:
        """Получение полного DN записи
        Args:
            *args (Dn): относительный путь до записи

        Returns:
            Dn: Полный путь до записи с доменным именем
        """
        return Dn(*args, settings.BASE_DN)

    def _get_dcs(self) -> Iterable[Dict[str, str]]:
        """Получение всех доступных контроллеров домена (КД)

        Returns:
            List[Dict[str, str]]: Метаинформация о каждом КД
        """
        dcs = self.transport._get_data(
            target_dn="cn=computers,cn=accounts,",
            filterstr="(&(objectClass=rbta-subsystem-dc)(rbtaSubsystemRole=dc))",
            attrlist=["cn", "rbtaSubsystemMetainfo"],
        )
        logger.info("Данные о КД получены")

        updated_dcs = list()
        for dc in dcs:
            dc_metainfo = json.loads(dc["rbtaSubsystemMetainfo"][0].decode())
            location = dc_metainfo["location"]
            cn = dc["cn"][0].decode()
            dc["ipalocation"] = location
            dc["cn"] = cn
            del dc["rbtaSubsystemMetainfo"]
            updated_dcs.append(dc)

        return updated_dcs

    def _get_locations(self) -> List[str]:
        """Получение сайтов

        Returns:
            locations (list): список сайтов
        """
        locations = self.transport._get_data(
            target_dn="cn=locations,cn=etc,",
            filterstr="(objectClass=ipaLocationObject)",
            attrlist=["idnsname"],
        )
        logger.info("Данные о сайтах получены")
        locations = [location["idnsname"][0].decode() for location in locations]
        return locations

    def _get_locations_in_dns(self) -> List[str]:
        """Получение сайтов в DNS записях

        Returns:
            locations_in_dns (list): список сайтов
        """
        locations_in_dns = self.transport._get_data(
            target_dn=f"{NtpDnsrecord.__superior__},",
            filterstr="(idnsname=_ntp._udp._roots.*._locations)",
            attrlist=["idnsname"],
        )
        logger.info("Данные о сайтах в DNS записях получены")
        locations_in_dns = [
            location_in_dns["idnsname"][0].decode().split(".")[-2] for location_in_dns in locations_in_dns
        ]
        return locations_in_dns

    def _get_ntp_servers_info(self, dcs: Iterable[Dict[str, str]]) -> Dict[str, List[str]]:
        """Получение информации о внутренних NTP серверах

        Args:
            dcs: (Iterable[Dict[str, str]]): генератор с КД

        Returns:
            ntps_info (Dict[str, List[str]]): словарь с валидными/невалидными NTP серверами
        """
        ntp_entries = self.transport._get_data(
            target_dn="cn=ntp,cn=services,cn=aldpro,cn=etc,",
            filterstr="(&(objectClass=rbta-ald-ntp)(isexternal=FALSE))",
            attrlist=["ntpserver"],
        )
        logger.info("Данные об NTP-серверах получены")

        ntps_info = {"invalid": [], "valid": []}
        for ntp_entry in ntp_entries:
            ntp_is_valid = False
            ntpserver = ntp_entry["ntpserver"][0].decode()
            for dc in dcs:
                if ntpserver == dc["cn"]:
                    ntp_is_valid = True
            if not ntp_is_valid:
                ntps_info["invalid"].append(ntpserver)
                logger.info(f"Найден несуществующий КД в службах NTP: {ntpserver}")
            else:
                ntps_info["valid"].append(ntpserver)
                logger.info(f"Найден действующий КД в службах NTP: {ntpserver}")

        return ntps_info

    def _get_location_servers_info(
        self, locations: List[str], dcs: Iterable[Dict[str, str]], root_ntp: List[str]
    ) -> Tuple[Dict[str, Dict[str, List[str]]], Set[str], Set[str]]:
        """Получение корневых NTP серверов
        Args:
            locations (list): список сайтов
            dcs (iterable): список КД
            root_ntp (list): список корневых NTP серверов

        Returns:
            tuple: кортеж метаинформации о сайтах КД и дефолтных серверах NTP
        """
        location_dcs = {location: {"default": [], "root": []} for location in locations}
        default_ntps = set()
        default_roots = set()

        for dc in dcs:
            cn = dc["cn"]

            dc_short = cn.split(".", maxsplit=1)[0]
            dc_arecords = self.transport._get_data(
                target_dn=f"idnsname={dc_short},{NtpDnsrecord.__superior__},", attrlist=["arecord"]
            )
            dc_arecords = next(dc_arecords).get("arecord", [b""])
            dc_arecords = [arecrod.decode() for arecrod in dc_arecords]

            default_ntps.update(dc_arecords)

            location = dc["ipalocation"]
            if "idnsname=" in location:
                location = location.split("=")[-1]

            dc_location = location_dcs.get(location)
            if not dc_location:
                self._process_error(f"Некорректное имя сайта '{location}' у {cn}")

            if cn in root_ntp:
                dc_location["root"].extend(dc_arecords)
                default_roots.update(dc_arecords)
            else:
                dc_location["default"].extend(dc_arecords)

        return location_dcs, default_ntps, default_roots

    def _generate_ldif(
        self,
        invalid_ntps: List[str],
        default_ntps: Set[str],
        default_roots: Set[str],
        location_dcs: Dict[str, Dict[str, List[str]]],
        deleted_locations: Set[str],
    ) -> None:
        """Генерация LDIF-файла
        Args:
            default_ntps (set): список дефолтных NTP-серверов
            default_roots (set): список дефолтных корневых NTP-серверов
            location_dcs (dict): список корневых NTP серверов
            deleted_locations (set): удаленные сайты

        Returns:
            tuple: кортеж метаинформации о сайтах КД и дефолтных серверах NTP
        """
        self._write_arecords(self._full_dn(NtpDnsrecord.default_nonroot_dn()), default_ntps)
        self._write_arecords(self._full_dn(NtpDnsrecord.default_root_dn()), default_roots)
        for location, ntps in location_dcs.items():
            self._write_arecords(self._full_dn(NtpDnsrecord.nonroot_dn(location)), ntps["default"])
            self._write_arecords(self._full_dn(NtpDnsrecord.root_dn(location)), ntps["root"])

        for location in deleted_locations:
            self.ldif_data["delete"].append(
                (self._full_dn(NtpDnsrecord.nonroot_dn(location)), self._full_dn(NtpDnsrecord.root_dn(location)))
            )
            logger.info(f"Добавление idnsname записей на удаление, связанных с сайтом {location}")

        for invalid_ntp in invalid_ntps:
            self.ldif_data["delete"].append((self._full_dn(NtpServer.x_dn(invalid_ntp)),))
            logger.info(f"Добавление NTP-сервера {invalid_ntp} на удаление из служб")

    def _write_arecords(self, target_dn: str, arecords: List[str] = []) -> bool:
        """Добавление aRecords записей к нужному формату модификации (add, del, modify)

        Args:
            target_dn (str): DN базовой записи
            arecords (list): Список aRecords записей

        Returns:
            yield Dict[str, str]: Метаинформация о каждом КД
        """
        raw_record = self.transport._get_record(target_dn)
        if not raw_record:
            self.ldif_data["add"].append(target_dn)
        raw_record = {k.lower(): v for k, v in raw_record.items()}

        old_arecords = set([arecord.decode() for arecord in raw_record.get("arecord", set())])
        new_arecords = set(arecords)

        if old_arecords != new_arecords:
            data_for_modify_add = new_arecords - old_arecords
            data_for_modify_delete = old_arecords - new_arecords
            changes = {"add": data_for_modify_add, "delete": data_for_modify_delete}
            for mod, data in changes.items():
                if data:
                    self.ldif_data["modify"].update({target_dn: {mod: {"aRecord": list(data)}}})
                    logger.info(f"Записана модификация {mod} для {list(data)} у {target_dn}")

        return True

    def _write_ldif(self) -> bool:
        """Генерация LDIF файла

        Returns:
            bool: Статус выполнения создания LDIF-файла
        """
        ldif_body = str()
        result = False
        for mod, data in self.ldif_data.items():
            if mod == "add":
                for dn in data:
                    ldif_body += f"dn: {dn}\n"
                    ldif_body += f"changetype: {mod}\n"
                    ldif_body += "objectclass: top\n"
                    ldif_body += "objectclass: idnsrecord\n"
                    ldif_body += f'idnsname: {dn.split(",", 1)[0].split("=")[1]}\n'
                    ldif_body += "\n"
            elif mod == "modify":
                for dn, dn_data in data.items():
                    ldif_body += f"dn: {dn}\n"
                    ldif_body += f"changetype: {mod}\n"
                    for operation_type, attribute_info in dn_data.items():
                        attributes = list(attribute_info)
                        for attribute in attributes:
                            ldif_body += f"{operation_type}: {attribute}\n"
                            for attribute_value in attribute_info[attribute]:
                                ldif_body += f"{attribute}: {attribute_value}\n"
                    ldif_body += "\n"
            else:
                for dns_tuple in data:
                    for dn in dns_tuple:
                        ldif_body += f"dn: {dn}\n"
                        ldif_body += f"changetype: {mod}\n"
                        ldif_body += "\n"
        if ldif_body:
            logger.info(f"Добавление собранных данных в {self.path}")
            with open(self.path, "w+") as f:
                f.write(ldif_body)
            result = True
        else:
            logger.info(f"Добавление собранных данных в {self.path} не требуется")

        return result

    def _execute_ldif(self, cmd: Tuple[str]) -> Tuple[str, str]:
        """Выполняет команду, переданную в cmd.
        Определяет признак успеха по содержимому stderr.
        Возвращает tuple.
            без ошибки: (True, декодированный stdout)
            с ошибкой: (False, декодированный stderr)

        Args:
            cmd (tuple): системная команда для исполнения

        Returns:
            tuple:
                без ошибки: (True, декодированный stdout)
                с ошибкой: (False, декодированный stderr)
        """
        command: subprocess.Popen = subprocess.Popen(
            cmd,
            stdout=subprocess.PIPE,
            stdin=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )
        output, error = command.communicate()
        command.terminate()

        error = error.decode("utf-8")
        error = [x for x in error.split("\n") if not (x.startswith("SASL") or x == "")]

        if len(error):
            message = "\n".join(error)
            full_error_message = f"Ошибка применения {self.path} к схеме LDAP: {message}"
            return False, full_error_message
        else:
            logger.info(f"{self.path} успешно применен к схеме LDAP")
            os.remove(self.path)
            return True, output.decode("utf-8")

    def _process_error(self, message: str) -> None:
        """Окончания работы программы с выводом ошибки

        Args:
            message (str): текст ошибки
        """
        raise NtpException(message)

    def check_ntp_exceptions(method):
        """Декоратор, обрабатывающий ошибки в миграции NTP"""

        def wrapper(*args, **kwargs):
            logger.info("Старт миграции NTP...")
            try:
                method(*args, **kwargs)
            except Exception as e:
                logger.error(f"Работа миграции NTP завершена с ошибкой: {e}")
                logger.error(f"\n{format_exc()}")
            else:
                logger.info("Миграция NTP успешно завершена")

        return wrapper

    @check_ntp_exceptions
    def run(self) -> None:
        """Запуск работы программы"""
        self.ldif_data = {"add": list(), "modify": dict(), "delete": list()}

        dcs = self._get_dcs()
        locations = self._get_locations()
        locations_in_dns = self._get_locations_in_dns()
        ntps_info = self._get_ntp_servers_info(dcs)
        invalid_ntps, root_ntps = ntps_info["invalid"], ntps_info["valid"]

        deleted_locations = set(locations_in_dns) - set(locations)

        location_dcs, default_ntps, default_roots = self._get_location_servers_info(locations, dcs, root_ntps)

        self._generate_ldif(invalid_ntps, default_ntps, default_roots, location_dcs, deleted_locations)

        res = self._write_ldif()
        if res:
            result = self._execute_ldif(
                cmd=("/usr/bin/ldapmodify", "-D", LDAP_USER, "-w", LDAP_PASSWORD, "-f", self.path)
            )
            if not result[0]:
                self._process_error(result[1])


def parse_args() -> Any:
    """Окончания работы программы с выводом ошибки

    Returns:
        args (Any): передеанные аргументы командной строки
    """
    desc = """
    Демон обработки NTP-записей.
    Данный юнит выполняет следующие действия:
    \t- Периодическая синхронизация структуры NTP с доступными КД
    """
    parser = ArgumentParser(description=desc, formatter_class=RawDescriptionHelpFormatter)
    parser.add_argument(
        "-lp",
        "--ldif_path",
        nargs="?",
        action="store",
        help="Путь до файла .ldif",
        default="/tmp/99-migrate-ntp-data.ldif",
    )
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()
    transport = DaemonLDAPTransport(LDAP_SERVER)
    ntp_management = NtpManagement(args.ldif_path, transport)

    while True:
        with ntp_management.transport:
            if ntp_management.transport.is_connected():
                ntp_management.run()

        time.sleep(SLEEP_PERIOD)
