#!/usr/bin/python3
"""
Copyright (C) 2023 Michael Ablassmeier <abi@grinser.de>

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""
import os
import io
import sys
import logging
import argparse
import tempfile
import shutil
import glob
import re
from typing import List

try:
    import boto3
    from botocore.exceptions import ClientError
    BOTO3_AVAILABLE = True
except ImportError:
    BOTO3_AVAILABLE = False

from libvirtnbdbackup import argopt
from libvirtnbdbackup import __version__
from libvirtnbdbackup import virt
from libvirtnbdbackup.restore import vmconfig
from libvirtnbdbackup.restore import files
from libvirtnbdbackup.restore import sequence
from libvirtnbdbackup.restore import disk
from libvirtnbdbackup.restore import synthesize
from libvirtnbdbackup import output
from libvirtnbdbackup import common as lib
from libvirtnbdbackup.logcount import logCount
from libvirtnbdbackup.sparsestream import streamer
from libvirtnbdbackup.sparsestream import types
from libvirtnbdbackup.ssh import Mode
from libvirtnbdbackup.virt.exceptions import connectionFailed
from libvirtnbdbackup.exceptions import RestoreError
from libvirtnbdbackup import exceptions


def _read_key_from_file(path: str) -> str:
    """Reads a key from a file and strips whitespace."""
    try:
        with open(path, 'r') as f:
            return f.read().strip()
    except FileNotFoundError:
        logging.error(f"S3 key file not found: {path}")
        sys.exit(1)
    except Exception as e:
        logging.error(f"Error reading S3 key file {path}: {e}")
        sys.exit(1)


def _get_s3_client(args):
    """Creates and returns a boto3 S3 client."""
    if not BOTO3_AVAILABLE:
        logging.error("boto3 library is not installed. Please run 'pip install boto3'.")
        sys.exit(1)

    access_key = _read_key_from_file(args.s3_access_key_file)
    secret_key = _read_key_from_file(args.s3_secret_key_file)

    logging.info(f"Connecting to S3 endpoint: {args.s3_endpoint_url}")
    try:
        s3_client = boto3.client(
            's3',
            endpoint_url=args.s3_endpoint_url,
            aws_access_key_id=access_key,
            aws_secret_access_key=secret_key
        )
        return s3_client
    except Exception as e:
        logging.error(f"Failed to create S3 client: {e}")
        sys.exit(1)


def _download_from_s3(s3_client, bucket_name: str, local_path: str):
    """Downloads the content of an S3 bucket to a local directory."""
    logging.info(f"Starting download from S3 bucket '{bucket_name}' to '{local_path}'")
    try:
        paginator = s3_client.get_paginator('list_objects_v2')
        pages = paginator.paginate(Bucket=bucket_name)

        for page in pages:
            if 'Contents' not in page:
                logging.warning(f"S3 bucket '{bucket_name}' is empty or does not exist.")
                return False

            for obj in page['Contents']:
                s3_object_key = obj['Key']
                local_file_path = os.path.join(local_path, s3_object_key)
                local_file_dir = os.path.dirname(local_file_path)

                if not os.path.exists(local_file_dir):
                    os.makedirs(local_file_dir)

                logging.info(f"Downloading 's3://{bucket_name}/{s3_object_key}' to '{local_file_path}'")
                s3_client.download_file(bucket_name, s3_object_key, local_file_path)

        logging.info(f"Successfully downloaded all files from S3 bucket '{bucket_name}'.")
        return True
    except ClientError as e:
        logging.error(f"An S3 error occurred: {e}")
        return False
    except Exception as e:
        logging.error(f"An unexpected error occurred during S3 download: {e}")
        return False

def _restore_uefi_files(args):
    """
    Finds and restores UEFI-specific files (OVMF_CODE, VARS) from the backup.
    It identifies the latest version of each file and copies it to the target.
    """
    logging.info("Checking for UEFI-specific backup files (OVMF/VARS)...")

    # Checkpoints are usually in a subdirectory, but we check the root of the backup as well.
    search_path = os.path.join(args.input, "checkpoints")
    if not os.path.isdir(search_path):
        search_path = args.input

    patterns = [
        os.path.join(search_path, "OVMF_CODE*.qcow2.virtnbdbackup.*"),
        os.path.join(search_path, "*_VARS.qcow2.virtnbdbackup.*")
    ]

    uefi_files_found = []
    for pattern in patterns:
        uefi_files_found.extend(glob.glob(pattern))

    if not uefi_files_found:
        logging.info("No UEFI-specific files found. Skipping.")
        return

    # Group files by their base name to find the latest version of each.
    # e.g., 'OVMF_CODE_4M.qcow2', 'alse18_VARS.qcow2'
    latest_files = {}
    for file_path in uefi_files_found:
        # Extract base name and version number using regex
        match = re.search(r"(.+?\.qcow2)\.virtnbdbackup\.(\d+)$", os.path.basename(file_path))
        if not match:
            continue

        base_name = match.group(1)
        version = int(match.group(2))

        # Store the file if it's the first one we've seen for this base name,
        # or if its version is higher than the one we've already stored.
        if base_name not in latest_files or version > latest_files[base_name]['version']:
            latest_files[base_name] = {'path': file_path, 'version': version}

    logging.info(f"Found {len(latest_files)} unique UEFI files to restore (latest versions).")

    for base_name, file_info in latest_files.items():
        source_path = file_info['path']
        dest_path = os.path.join(args.output, base_name)
        logging.info(f"Restoring UEFI file '{os.path.basename(source_path)}' to '{dest_path}'")
        try:
            if args.sshClient:
                # Handle remote restore via SFTP
                logging.debug(f"(via SSH) Uploading {source_path} to {dest_path}")
                args.sshClient.sftp.put(source_path, dest_path)
            else:
                # Handle local restore
                shutil.copy2(source_path, dest_path)
            logging.info(f"Successfully restored {base_name}.")
        except Exception as e:
            logging.error(f"Failed to restore UEFI file {base_name}: {e}")
            raise RestoreError(f"Failed to restore UEFI file {base_name}") from e


def main() -> None:
    """main function"""
    defaultConfig = "vmconfig.xml"
    parser = argparse.ArgumentParser(
        description="Restore or synthesize virtual machine disks",
        epilog=(
            "Examples:\n"
            "   # Dump backup metadata:\n"
            "\t%(prog)s -i /backup/ -o dump\n"
            "   # Verify checksums for existing data files in backup:\n"
            "\t%(prog)s -i /backup/ -o verify\n"
            "   # Complete restore with all disks:\n"
            "\t%(prog)s -i /backup/ -o /target\n"
            "   # Synthesize a new full backup from a chain:\n"
            "\t%(prog)s -a synthesize -i /backup/ -o /new_backup_dir/ "
            "--sequence vda.full.data,vda.inc.virtnbdbackup.1.data,vda.inc.virtnbdbackup.2.data\n"
            "   # Restore from an S3 bucket:\n"
            "\t%(prog)s -i s3://my-s3-bucket -o /target --s3-endpoint-url http://s3.host:9000 ...\n"
        ),
        formatter_class=argparse.RawTextHelpFormatter,
    )
    opt = parser.add_argument_group("General options")
    opt.add_argument(
        "-a",
        "--action",
        required=False,
        type=str,
        choices=["dump", "restore", "verify", "synthesize"],
        default="restore",
        help="Action to perform: (default: %(default)s)",
    )
    opt.add_argument(
        "-i",
        "--input",
        required=True,
        type=str,
        help="Directory including a backup set or S3 bucket (e.g., s3://bucket-name)",
    )
    opt.add_argument(
        "-o", "--output", required=True, type=str, help="Restore target directory or synthesis output directory"
    )
    opt.add_argument(
        "-u",
        "--until",
        required=False,
        type=str,
        help="Restore only until checkpoint, point in time restore.",
    )
    opt.add_argument(
        "-s",
        "--sequence",
        required=False,
        type=str,
        default=None,
        help="Restore or synthesize image based on specified backup files (comma-separated).",
    )
    opt.add_argument(
        "-d",
        "--disk",
        required=False,
        type=str,
        default=None,
        help="Process only disk matching target dev name. (default: %(default)s)",
    )
    opt.add_argument(
        "-n",
        "--noprogress",
        required=False,
        action="store_true",
        default=False,
        help="Disable progress bar",
    )
    opt.add_argument(
        "-f",
        "--socketfile",
        default=f"/var/tmp/virtnbdbackup.{os.getpid()}",
        type=str,
        help="Use specified file for NBD Server socket (default: %(default)s)",
    )
    opt.add_argument(
        "-r",
        "--raw",
        default=False,
        action="store_true",
        help="Copy raw images as is during restore. (default: %(default)s)",
    )
    opt.add_argument(
        "-c",
        "--adjust-config",
        default=False,
        action="store_true",
        help="Adjust vm configuration during restore. (default: %(default)s)",
    )
    opt.add_argument(
        "-D",
        "--define",
        default=False,
        action="store_true",
        help="Register/define VM after restore. (default: %(default)s)",
    )
    opt.add_argument(
        "-C",
        "--config-file",
        default=defaultConfig,
        type=str,
        help=f"Name of the vm config file used for restore. (default: {defaultConfig})",
    )
    opt.add_argument(
        "-N",
        "--name",
        default=None,
        type=str,
        help="Define restored domain with specified name",
    )
    opt.add_argument(
        "-B",
        "--buffsize",
        default=io.DEFAULT_BUFFER_SIZE,
        type=int,
        help="Buffer size to use during verify (default: %(default)s)",
    )
    opt.add_argument(
        "-A",
        "--preallocate",
        default=False,
        action="store_true",
        help="Preallocate restored qcow images. (default: %(default)s)",
    )

    encopt = parser.add_argument_group("Encryption options")
    encopt.add_argument(
        "--encryption-key-path",
        type=str,
        default=None,
        help=("Provide the file path to the private encryption key for decryption.")
    )

    synthopt = parser.add_argument_group("Synthesize options")
    synthopt.add_argument(
        "--compress",
        default=False,
        type=int,
        const=2,
        nargs="?",
        choices=range(1, 17),
        metavar="[1-16]",
        help="[Synthesize] Compress with lz4 compression level. (default: %(default)s)",
        action="store",
    )
    synthopt.add_argument(
        "--hash",
        type=str,
        default="adler32",
        choices=["adler32", "xxh64"],
        help=("[Synthesize] Specify the hashing algorithm for the new backup.")
    )


    s3opt = parser.add_argument_group("S3 Storage options")
    s3opt.add_argument("--s3-endpoint-url", type=str, default=None, help="S3 endpoint URL, e.g., http://192.168.120.50:9000")
    s3opt.add_argument("--s3-access-key-file", type=str, default=None, help="Path to file containing S3 access key")
    s3opt.add_argument("--s3-secret-key-file", type=str, default=None, help="Path to file containing S3 secret key")

    remopt = parser.add_argument_group("Remote Restore options")
    argopt.addRemoteArgs(remopt)
    logopt = parser.add_argument_group("Logging options")
    argopt.addLogArgs(logopt, parser.prog)
    argopt.addLogColorArgs(logopt)
    debopt = parser.add_argument_group("Debug options")
    argopt.addDebugArgs(debopt)

    args = lib.argparse(parser)
    args.quiet = False
    args.sshClient = None
    # default values for common usage of lib.getDomainDisks
    args.exclude = None
    args.include = args.disk
    lib.setThreadName()
    stream = streamer.SparseStream(types)
    fileLog = lib.getLogFile(args.logfile) or sys.exit(1)
    counter = logCount()  # pylint: disable=unreachable
    lib.configLogger(args, fileLog, counter)
    lib.printVersion(__version__)

    s3_mode = args.input.startswith("s3://")
    local_input_path = args.input
    temp_dir_for_s3 = None

    if s3_mode:
        logging.info("S3 restore mode activated.")
        if not all([args.s3_endpoint_url, args.s3_access_key_file, args.s3_secret_key_file]):
            logging.error("For S3 restore, --s3-endpoint-url, --s3-access-key-file, and --s3-secret-key-file are required.")
            sys.exit(1)

        bucket_name = args.input[5:]
        temp_dir_for_s3 = tempfile.mkdtemp(prefix="virtnbd-s3-restore-")
        local_input_path = temp_dir_for_s3
        logging.info(f"Created temporary directory for S3 download: {local_input_path}")

        s3_client = _get_s3_client(args)
        if not _download_from_s3(s3_client, bucket_name, local_input_path):
            logging.error(f"Failed to download backup from S3 bucket '{bucket_name}'.")
            shutil.rmtree(temp_dir_for_s3)
            sys.exit(1)
        args.input = local_input_path

    if not lib.exists(args, args.input):
        logging.error("Backup source [%s] does not exist.", args.input)
        if temp_dir_for_s3:
            shutil.rmtree(temp_dir_for_s3)
        sys.exit(1)

    dataFiles: List[str] = []
    if args.sequence is not None:
        logging.info("Using manually specified sequence of files.")
        dataFiles = args.sequence.split(",")
        if args.action == "synthesize":
            args.define = False
            args.adjust_config = False
        else:
             logging.info("Disabling redefine and config adjust options.")
             args.define = False
             args.adjust_config = False

        if "full" not in dataFiles[0] and "copy" not in dataFiles[0]:
            logging.error("Sequence must start with full or copy backup.")
            sys.exit(1)
    # Для synthesize --sequence обязателен
    elif args.action == "synthesize" and args.sequence is None:
        logging.error("--sequence is required for the synthesize action.")
        sys.exit(1)
    else:
        dataFiles = lib.getLatest(args.input, "*.data")
        if not dataFiles:
            logging.error("No data files found in directory: [%s]", args.input)
            sys.exit(1)

    if args.action == "dump" or args.output == "dump":
        files.dump(args, stream, dataFiles)
        sys.exit(0)

    if args.action == "verify" or args.output == "verify":
        if not files.verify(args, dataFiles):
            sys.exit(1)
        sys.exit(0)

    if args.action == "synthesize":
        logging.info("Starting synthetic backup creation.")
        try:
            if not os.path.exists(args.output):
                 logging.info("Create target directory for synthesis: [%s]", args.output)
                 os.makedirs(args.output)

            args.stdout = False
            args.type = "stream"
            args.offline = True
            args.qemu = False
            args.no_sparse_detection = False
            args.level_filename = "full" # Итоговый файл будет full

            synthesize.run(args, dataFiles)
            logging.info("Synthetic backup created successfully.")
            sys.exit(0)
        except (RestoreError, exceptions.DiskBackupFailed) as errmsg:
            logging.error("Synthetic backup failed: [%s]", errmsg)
            sys.exit(1)
        finally:
            if temp_dir_for_s3:
                shutil.rmtree(temp_dir_for_s3)

    if args.action == "restore":
        if args.define is True:
            args.adjust_config = True

        try:
            virtClient = virt.client(args)
        except connectionFailed as e:
            logging.error("Unable to connect libvirt: [%s]", e)
            sys.exit(1)

        if virtClient.remoteHost:
            if not args.output.startswith("/"):
                logging.error(
                    "Absolute target path required for restore to remote system"
                )
                sys.exit(1)

            args.sshClient = lib.sshSession(
                args, virtClient.remoteHost, mode=Mode.UPLOAD
            )
            if not args.sshClient:
                logging.error("Remote restore detected but ssh session setup failed")
                sys.exit(1)

            # This logic mimics 'mkdir -p' as paramiko's SFTP client can't create nested directories.
            try:
                args.sshClient.sftp.stat(args.output)
            except FileNotFoundError:
                logging.info("Target directory [%s] does not exist. Creating it recursively.", args.output)
                current_path = ''
                # Handle absolute paths by starting with '/'
                if args.output.startswith('/'):
                    current_path = '/'

                # Split path into components and create them one by one
                for directory in args.output.strip('/').split('/'):
                    if not directory:
                        continue
                    # Safely join path components
                    if current_path and not current_path.endswith('/'):
                        current_path += '/'
                    current_path += directory

                    try:
                        args.sshClient.sftp.stat(current_path)
                    except FileNotFoundError:
                        logging.debug("Creating intermediate directory: [%s]", current_path)
                        args.sshClient.sftp.mkdir(current_path)
        else:
            # For local restores, os.makedirs already handles recursive creation
            if not os.path.exists(args.output):
                 logging.info("Create target directory: [%s]", args.output)
                 os.makedirs(args.output)

        ConfigFiles = lib.getLatest(args.input, "vmconfig*.xml")
        if not ConfigFiles:
            logging.error("No domain config file found")
            sys.exit(1)
        if args.until is not None:
            ConfigFile = ConfigFiles[int(args.until.split(".")[-1])]
        else:
            ConfigFile = ConfigFiles[-1]
        logging.info("Using config file: [%s]", ConfigFile)

        autoStart = False
        if lib.getLatest(args.input, "autostart.*", -1):
            autoStart = True

        restConfig: bytes = b""
        try:
            if args.sequence is not None:
                sequence.restore(args, dataFiles, virtClient)
            else:
                restConfig = disk.restore(args, ConfigFile, virtClient)

            _restore_uefi_files(args)

        except RestoreError as errmsg:
            logging.error("Disk restore failed: [%s]", errmsg)
            if temp_dir_for_s3:
                shutil.rmtree(temp_dir_for_s3)
            sys.exit(1)

        files.restore(args, ConfigFile, virtClient)
        vmconfig.restore(args, ConfigFile, restConfig, args.config_file)
        virtClient.refreshPool(args.output)
        if args.define is True:
            if not virtClient.defineDomain(restConfig, autoStart):
                if temp_dir_for_s3:
                    shutil.rmtree(temp_dir_for_s3)
                sys.exit(1)

    if temp_dir_for_s3:
        logging.info(f"Cleaning up temporary directory: {temp_dir_for_s3}")
        shutil.rmtree(temp_dir_for_s3)


if __name__ == "__main__":
    main()