#!/usr/bin/python3
"""
Copyright (C) 2023  Michael Ablassmeier <abi@grinser.de>

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""
import os
import sys
import pwd
import libvirt
import signal
import logging
import argparse
import shutil
import glob
import re
from typing import List
from datetime import datetime
from functools import partial
from concurrent.futures import ThreadPoolExecutor, as_completed

try:
    import boto3
    from botocore.exceptions import ClientError
    BOTO3_AVAILABLE = True
except ImportError:
    BOTO3_AVAILABLE = False

from nbd import __version__ as __nbdversion__
from libvirtnbdbackup import sighandle
from libvirtnbdbackup import argopt
from libvirtnbdbackup import __version__
from libvirtnbdbackup import virt
from libvirtnbdbackup.objects import DomainDisk
from libvirtnbdbackup.virt import checkpoint
from libvirtnbdbackup import output
from libvirtnbdbackup.output import stream
from libvirtnbdbackup import common as lib
from libvirtnbdbackup.logcount import logCount
from libvirtnbdbackup import exceptions
from libvirtnbdbackup import rotation
from libvirtnbdbackup.backup import partialfile
from libvirtnbdbackup.backup import job
from libvirtnbdbackup.backup import disk
from libvirtnbdbackup.backup import metadata
from libvirtnbdbackup.backup import check
from libvirtnbdbackup.ssh.exceptions import sshError
from libvirtnbdbackup.virt.exceptions import (
    domainNotFound,
    connectionFailed,
)
from libvirtnbdbackup.output.exceptions import OutputException


def _read_key_from_file(path: str) -> str:
    """Reads a key from a file and strips whitespace."""
    try:
        with open(path, 'r') as f:
            return f.read().strip()
    except FileNotFoundError:
        sys.stderr.write(f"ERROR: S3 key file not found: {path}\n")
        sys.exit(1)
    except Exception as e:
        sys.stderr.write(f"ERROR: Error reading S3 key file {path}: {e}\n")
        sys.exit(1)


def _get_s3_client(args):
    """Creates and returns a boto3 S3 client."""
    if not BOTO3_AVAILABLE:
        logging.error("boto3 library is not installed. Please run 'pip install boto3'.")
        sys.exit(1)

    access_key = _read_key_from_file(args.s3_access_key_file)
    secret_key = _read_key_from_file(args.s3_secret_key_file)

    logging.info("Connecting to S3 endpoint: %s", args.s3_endpoint_url)
    try:
        s3_client = boto3.client(
            's3',
            endpoint_url=args.s3_endpoint_url,
            aws_access_key_id=access_key,
            aws_secret_access_key=secret_key
        )
        return s3_client
    except Exception as e:
        logging.error("Failed to create S3 client: %s", e)
        sys.exit(1)


def _upload_to_s3(s3_client, local_path: str, bucket_name: str, vm_name: str):
    """
    Uploads the content of a local directory to an S3 bucket intelligently.
    """
    bucket_name = re.sub(r'[^a-z0-9.-]', '', bucket_name.lower())
    cpt_file_name = f"{vm_name}.cpt"

    logging.info("Preparing to intelligently upload backup to S3 bucket: '%s'", bucket_name)
    try:
        try:
            s3_client.head_bucket(Bucket=bucket_name)
            logging.info("Using existing S3 bucket: '%s'", bucket_name)
        except ClientError as e:
            if e.response['Error']['Code'] in ['404', 'NoSuchBucket']:
                logging.info("S3 bucket '%s' not found. Creating it.", bucket_name)
                s3_client.create_bucket(Bucket=bucket_name)
                logging.info("S3 bucket '%s' created successfully.", bucket_name)
            else:
                raise

        logging.info("Listing existing objects in bucket '%s' for comparison...", bucket_name)
        try:
            paginator = s3_client.get_paginator('list_objects_v2')
            pages = paginator.paginate(Bucket=bucket_name)
            existing_s3_keys = {obj['Key'] for page in pages for obj in page.get('Contents', [])}
            logging.info("Found %d existing objects in S3.", len(existing_s3_keys))
        except ClientError as e:
            logging.error("Failed to list objects from S3, cannot perform intelligent upload: %s", e)
            return False

        for root, _, files in os.walk(local_path):
            for filename in files:
                local_file_path = os.path.join(root, filename)
                s3_object_name = os.path.relpath(local_file_path, local_path)

                is_cpt_file = (filename == cpt_file_name)

                if s3_object_name in existing_s3_keys and not is_cpt_file:
                    logging.info("Skipping upload for '%s' as it already exists in S3.", s3_object_name)
                    continue

                if is_cpt_file:
                    logging.info("Uploading and overwriting manifest file '%s' to '%s/%s'", local_file_path, bucket_name, s3_object_name)
                else:
                    logging.info("Uploading new file '%s' to '%s/%s'", local_file_path, bucket_name, s3_object_name)

                s3_client.upload_file(local_file_path, bucket_name, s3_object_name)

        logging.info("S3 upload process completed successfully.")
        return True
    except ClientError as e:
        logging.error("An S3 error occurred: %s", e)
        return False
    except Exception as e:
        logging.error("An unexpected error occurred during S3 upload: %s", e)
        return False


def auth_callback(credentials, user_data):
    """
    Callback for passing credentials to libvirt.openAuth
    """
    for credential in credentials:
        if credential[0] == libvirt.VIR_CRED_AUTHNAME:
            credential[4] = user_data['user']
        elif credential[0] == libvirt.VIR_CRED_PASSPHRASE:
            credential[4] = user_data['password']
    return 0


def main() -> None:
    """Handle backup operation and settings."""
    parser = argparse.ArgumentParser(
        description="Backup libvirt/qemu virtual machines. Includes rotation management.",
        epilog=(
            "Examples:\n"
            "   # full backup of domain 'webvm' with all attached disks:\n"
            "\t%(prog)s -d webvm -l full -o /backup/\n"
            "   # incremental backup:\n"
            "\t%(prog)s -d webvm -l inc -o /backup/\n"
            "   # backup with rotation management:\n"
            "\t%(prog)s -rotation -d webvm -o /backup --maxdepth 7 --depthtokeep 3\n"
        ),
        formatter_class=argparse.RawTextHelpFormatter,
    )

    opt = parser.add_argument_group("General options")
    opt.add_argument("-d", "--domain", required=False, type=str, help="Domain to backup")
    opt.add_argument(
        "-l",
        "--level",
        default="copy",
        choices=["copy", "full", "inc", "diff", "auto"],
        type=str,
        help="Backup level. (default: %(default)s)",
    )
    opt.add_argument(
        "-t",
        "--type",
        default="stream",
        type=str,
        choices=["stream", "raw"],
        help="Output type: stream or raw. (default: %(default)s)",
    )
    opt.add_argument(
        "-r",
        "--raw",
        default=False,
        action="store_true",
        help="Include full provisioned disk images in backup. (default: %(default)s)",
    )
    opt.add_argument(
        "-o", "--output", required=False, type=str, help="Output target directory, or 's3:/local/path' for S3 backup"
    )
    opt.add_argument(
        "-C",
        "--checkpointdir",
        required=False,
        default=None,
        type=str,
        help="Persistent libvirt checkpoint storage directory",
    )
    opt.add_argument(
        "--scratchdir",
        default="/var/tmp",
        required=False,
        type=str,
        help="Target dir for temporary scratch file. (default: %(default)s)",
    )
    opt.add_argument(
        "-S",
        "--start-domain",
        default=False,
        required=False,
        action="store_true",
        help="Start virtual machine if it is offline. (default: %(default)s)",
    )
    opt.add_argument(
        "--pause",
        default=False,
        required=False,
        action="store_true",
        help="Pause virtual machine while starting backup job. (default: %(default)s)",
    )
    opt.add_argument(
        "-i",
        "--include",
        default=None,
        type=str,
        help="Backup only disk with target dev name (-i vda)",
    )
    opt.add_argument(
        "-x",
        "--exclude",
        default=None,
        type=str,
        help="Exclude disk(s) with target dev name (-x vda,vdb)",
    )
    opt.add_argument(
        "-f",
        "--socketfile",
        default=f"/var/tmp/virtnbdbackup.{os.getpid()}",
        type=str,
        help="Use specified file for NBD Server socket (default: %(default)s)",
    )
    opt.add_argument(
        "-n",
        "--noprogress",
        default=False,
        help="Disable progress bar",
        action="store_true",
    )
    opt.add_argument(
        "--checkpoint-name",
        dest="checkpoint_name",
        help="Custom name for the checkpoint. default is timestamp",
    )
    opt.add_argument(
        "-z",
        "--compress",
        default=False,
        type=int,
        const=2,
        nargs="?",
        choices=range(1, 17),
        metavar="[1-16]",
        help="Compress with lz4 compression level. (default: %(default)s)",
        action="store",
    )
    opt.add_argument(
        "-w",
        "--worker",
        type=int,
        default=None,
        help=(
            "Amount of concurrent workers used "
            "to backup multiple disks. (default: amount of disks)"
        ),
    )
    opt.add_argument(
        "-F",
        "--freeze-mountpoint",
        type=str,
        default=None,
        help=(
            "If qemu agent available, freeze only filesystems on specified mountpoints within"
            " virtual machine (default: all)"
        ),
    )
    opt.add_argument(
        "-e",
        "--strict",
        default=False,
        help=(
            "Change exit code if warnings occur during backup operation. "
            "(default: %(default)s)"
        ),
        action="store_true",
    )
    opt.add_argument(
        "--no-sparse-detection",
        default=False,
        help=(
            "Skip detection of sparse ranges during incremental or differential backup. "
            "(default: %(default)s)"
        ),
        action="store_true",
    )
    opt.add_argument(
        "-T",
        "--threshold",
        type=int,
        default=None,
        help=("Execute backup only if threshold is reached."),
    )
    opt.add_argument(
        "--encryption-key-fingerprint",
        type=str,
        default=None,
        help=("Specify the fingerprint of the encryption key.")
    )
    opt.add_argument(
        "--encryption-key-path",
        type=str,
        default=None,
        help=("Provide the file path to the encryption key. "
              "This is the location on your system where the encryption key is stored.")
    ),
    opt.add_argument(
        "--hash",
        type=str,
        default="adler32",
        choices=["adler32", "xxh64"],
        help=("Specify the hashing algorithm to generate checksums")
    ),

    rotopt = parser.add_argument_group("Rotation options")
    rotopt.add_argument(
        "-rotation",
        action="store_true",
        help="Enable backup rotation management mode."
    )
    rotopt.add_argument(
        "--maxdepth",
        type=int,
        default=7,
        help="[Rotation] Maximum number of backups in a single chain before archiving (default: %(default)s)."
    )
    rotopt.add_argument(
        "--depthtokeep",
        type=int,
        default=3,
        help="[Rotation] Number of archive chains to retain (default: %(default)s)."
    )
    rotopt.add_argument(
        "--make-synthetic",
        action="store_true",
        help="[Rotation] Use synthetic backups to merge chains."
    )
    rotopt.add_argument(
        "--inc-cleanup",
        action="store_true",
        help="[Rotation] Clear incremental backups (*.inc.data) from the archive after creating synthetics."
    )

    s3opt = parser.add_argument_group("S3 Storage options")
    s3opt.add_argument("--s3-endpoint-url", type=str, default=None, help="S3 endpoint URL, e.g., http://192.168.120.50:9000")
    s3opt.add_argument("--s3-access-key-file", type=str, default=None, help="Path to file containing S3 access key")
    s3opt.add_argument("--s3-secret-key-file", type=str, default=None, help="Path to file containing S3 secret key")
    s3opt.add_argument(
        "--s3-remove-local-on-success", action="store_true", default=True,
        help="Remove local backup copy after successful upload to S3.",
    )

    remopt = parser.add_argument_group("Remote Backup options")
    argopt.addRemoteArgs(remopt)
    logopt = parser.add_argument_group("Logging options")
    logopt.add_argument(
        "-L",
        "--syslog",
        default=False,
        action="store_true",
        help="Additionally send log messages to syslog (default: %(default)s)",
    )
    logopt.add_argument(
        "--quiet",
        default=False,
        action="store_true",
        help="Disable logging to stderr (default: %(default)s)",
    )
    argopt.addLogColorArgs(logopt)
    debopt = parser.add_argument_group("Debug options")
    debopt.add_argument(
        "-q",
        "--qemu",
        default=False,
        action="store_true",
        help="Use Qemu tools to query extents.",
    )
    debopt.add_argument(
        "-s",
        "--startonly",
        default=False,
        help="Only initialize backup job via libvirt, do not backup any data",
        action="store_true",
    )
    debopt.add_argument(
        "-k",
        "--killonly",
        default=False,
        help="Kill any running block job",
        action="store_true",
    )
    debopt.add_argument(
        "-p",
        "--printonly",
        default=False,
        help="Quit after printing estimated checkpoint size.",
        action="store_true",
    )
    argopt.addDebugArgs(debopt)

    repository = output.target()
    args = lib.argparse(parser)

    try:

        if args.rotation:
            rotation.handle_rotation(args)
            sys.exit(0)

        if not args.domain or not args.output:
            parser.error("аргументы -d/--domain и -o/--output обязательны для стандартного режима бэкапа.")

        s3_mode = False
        if args.output.startswith("s3:"):
            s3_mode = True
            local_staging_path = args.output[3:]

            if not local_staging_path:
                sys.stderr.write("ERROR: For S3 backup, a local path must be specified after 's3:'. Example: 's3:/path/to/backups'\n")
                sys.exit(1)

            args.output = local_staging_path

            if not all([args.s3_endpoint_url, args.s3_access_key_file, args.s3_secret_key_file]):
                sys.stderr.write("ERROR: For S3 backup, --s3-endpoint-url, --s3-access-key-file, and --s3-secret-key-file are required.\n")
                sys.exit(1)

        lib.setThreadName()
        args.stdout = args.output == "-"
        args.sshClient = None
        args.diskInfo = []
        args.offline = False
        vm_uuid = None

        if args.quiet is True:
            args.noprogress = True

        fileStream = stream.get(args, repository)

        try:
            if not args.stdout:
                fileStream.create(args.output)
        except OutputException as e:
            sys.stderr.write(f"ERROR: Can't open output file: [{e}]\n")
            sys.exit(1)

        if args.worker is not None and args.worker < 1:
            args.worker = 1

        now = datetime.now().strftime("%m%d%Y%H%M%S")
        logFile = f"{args.output}/backup.{args.level}.{now}.log"
        fileLog = lib.getLogFile(logFile) or sys.exit(1)

        counter = logCount()  # pylint: disable=unreachable
        lib.configLogger(args, fileLog, counter)
        lib.printVersion(__version__)

        if s3_mode:
            logging.info("S3 backup mode activated. Staging backup locally in: %s", args.output)

        if args.compress is False and (args.encryption_key_fingerprint or args.encryption_key_path):
            args.compress = 1
        logging.info("Backup level: [%s]", args.level)
        if args.compress is not False:
            logging.info("Compression enabled, level [%s]", args.compress)

        try:
            check.arguments(args)
        except exceptions.BackupException as e:
            logging.error(e)
            sys.exit(1)

        if args.user:
            if not args.password:
                logging.error("Argument --user specified ('%s'), but --password is missing.", args.user)
                logging.error("Password authentication is required when a specific user is requested.")
                sys.exit(1)
            try:
                pwd.getpwnam(args.user)
            except KeyError:
                logging.error("User specified with --user '%s' does not exist on this system.", args.user)
                sys.exit(1)

            auth_creds = [
                [libvirt.VIR_CRED_AUTHNAME, libvirt.VIR_CRED_PASSPHRASE],
                auth_callback,
                {'user': args.user, 'password': args.password}
            ]
            logging.info("Authenticating against libvirt as user '%s'...", args.user)
            try:
                # Attempt to open a connection to verify credentials.
                # args.uri is usually 'qemu:///system' (for root) or 'qemu:///session'.
                # Using openAuth will force libvirt to verify the transferred data.
                conn_test = libvirt.openAuth(args.uri, auth_creds, 0)
                if conn_test:
                    conn_test.close()
            except libvirt.libvirtError as e:
                logging.error("Authentication failed for user '%s': %s", args.user, e)
                sys.exit(1)
            logging.info("Authentication successful.")

        if partialfile.exists(args):
            sys.exit(1)

        if not args.checkpointdir:
            args.checkpointdir = f"{args.output}/checkpoints"
        else:
            logging.info("Store checkpoints in: [%s]", args.checkpointdir)

        fileStream.create(args.checkpointdir)

        def connectionError(_, reason, args):
            """Callback if the libvirt connection drops mid
            data transfer, used to potentially cleanup the
            leftover backup job.
            """
            virConnectCloseReason = (
                "Misc I/O error",
                "End-of-file from server",
                "Keepalive timer triggered",
                "Client side connection close",
                "Unknown",
            )
            logging.error(
                "Libvirt connection error [%s], trying to reconnect",
                virConnectCloseReason[reason],
            )
            try:
                virtClient = virt.client(args)
            except connectionFailed as e:
                logging.error("Unrecoverable connection error: %s", e)
                sys.exit(1)
            domObj = virtClient.getDomain(args.domain)
            if not args.offline:
                logging.error("Attempting to stop backup task")
                virtClient.stopBackup(domObj)
            sys.exit(1)

        try:
            virtClient = virt.client(args)
            domObj = virtClient.getDomain(args.domain)
            vm_uuid = domObj.UUIDString()
        except domainNotFound as e:
            logging.error("%s", e)
            sys.exit(1)
        except connectionFailed as e:
            logging.error("Can't connect libvirt daemon: [%s]", e)
            sys.exit(1)

        virtClient._conn.registerCloseCallback(  # pylint: disable=W0212
            connectionError, args
        )

        logging.info("Libvirt library version: [%s]", virtClient.libvirtVersion)
        logging.info("NBD library version: [%s]", __nbdversion__)

        try:
            check.vmfeature(virtClient, domObj)
            checkpoint.checkForeign(args, domObj)
            check.vmstate(args, virtClient, domObj)
            check.targetDir(args)
        except exceptions.BackupException as e:
            logging.error(e)
            sys.exit(1)
        except exceptions.CheckpointException:
            sys.exit(1)

        if args.raw is True and args.level in ("inc", "diff"):
            logging.warning(
                "Raw disks can't be included during incremental or differential backup."
            )
            logging.warning("Excluding raw disks.")
            args.raw = False

        signal.signal(
            signal.SIGINT,
            partial(sighandle.Backup.catch, args, domObj, virtClient, logging),
        )

        if args.level not in ("inc", "diff") and args.no_sparse_detection is True:
            args.no_sparse_detection = False

        vmConfig = virtClient.getDomainConfig(domObj)
        disks: List[DomainDisk] = virtClient.getDomainDisks(args, vmConfig)
        args.info = virtClient.getDomainInfo(vmConfig)
        if virtClient.getTPMDevice(vmConfig):
            logging.warning("Emulated TPM device attached: User action required.")
            logging.warning(
                "Please manually backup contents of: [/var/lib/libvirt/swtpm/%s/]",
                domObj.UUIDString(),
            )

        try:
            check.diskformat(args, disks)
        except exceptions.BackupException as e:
            logging.info(e)

        if not disks:
            logging.error("Unable to detect disks suitable for backup.")
            metadata.saveFiles(args, vmConfig, disks, fileStream, logFile)
            sys.exit(1)

        try:
            check.blockjobs(args, virtClient, domObj, disks)
        except exceptions.BackupException as e:
            logging.error(e)
            sys.exit(1)

        logging.info(
            "Backup will save [%s] attached disks.",
            len(disks),
        )
        if args.worker is None or args.worker > int(len(disks)):
            args.worker = int(len(disks))
        logging.info("Concurrent backup processes: [%s]", args.worker)

        if args.killonly is True:
            logging.info("Stopping backup job")
            if not virtClient.stopBackup(domObj):
                sys.exit(1)
            sys.exit(0)

        try:
            checkpoint.create(args, domObj)
        except exceptions.CheckpointException as errmsg:
            logging.error(errmsg)
            sys.exit(1)

        if args.printonly and args.cpt.parent and not args.offline:
            size = checkpoint.getSize(domObj, args.cpt.parent)
            logging.info("Estimated checkpoint backup size: [%s] Bytes", size)
            sys.exit(0)

        if args.threshold and args.threshold < 1:
            logging.info(
                "The threshold value must be greater than 0, current value: [%s]",
                args.threshold,
            )
            sys.exit(0)
        if args.threshold is not None and args.threshold < 1:
            logging.error(
                "Error: --threshold value cannot be less than 1. Aborting operation."
            )
            if not args.offline:
                virtClient.stopBackup(domObj)
            sys.exit(1)
        if args.threshold and args.cpt.parent and not args.offline:
            size = checkpoint.getSize(domObj, args.cpt.parent)
            if size < args.threshold:
                logging.info(
                    "Backup size [%s] does not meet required threshold [%s], skipping backup.",
                    size,
                    args.threshold,
                )
                sys.exit(0)

        if virtClient.remoteHost != "":
            args.sshClient = lib.sshSession(args, virtClient.remoteHost)
            if not args.sshClient:
                logging.error("Remote backup detected but ssh session setup failed")
                sys.exit(1)
            logging.info(
                "Remote NBD Endpoint host: [%s]",
                virtClient.remoteHost,
            )
            if args.offline is True:
                logging.info(
                    "Remote ports used for backup: [%s-%s]",
                    args.nbd_port,
                    args.nbd_port + args.worker,
                )
        else:
            logging.info("Local NBD Endpoint sockets:")
            for sdisk in disks:
                logging.info(
                    "%s: [nbd+unix:///%s?socket=%s]",
                    sdisk.target,
                    sdisk.target,
                    args.socketfile,
                )

        if args.offline is not True:
            logging.info("Temporary scratch file target directory: [%s]", args.scratchdir)
            fileStream.create(args.scratchdir)
            if not job.start(args, virtClient, domObj, disks):
                sys.exit(1)

        if args.level not in ("copy", "diff") and args.offline is False:
            logging.info("Started backup job with checkpoint, saving information.")
            try:
                checkpoint.save(args)
            except exceptions.CheckpointException as e:
                logging.error("Extending checkpoint file failed: [%s]", e)
                sys.exit(1)
            if not checkpoint.backup(args, domObj):
                virtClient.stopBackup(domObj)
                sys.exit(1)

        if args.startonly is True:
            logging.info("Started backup job for debugging, exiting.")
            sys.exit(0)

        backupSize: int = 0
        try:
            with ThreadPoolExecutor(max_workers=args.worker) as executor:
                futures = {
                    executor.submit(
                        disk.backup, args, Disk, count, fileStream, virtClient
                    ): Disk
                    for count, Disk in enumerate(disks)
                }
                for future in as_completed(futures):
                    size, state = future.result()
                    backupSize += size
                    if state is not True:
                        raise exceptions.DiskBackupFailed("Backup of one disk failed")
        except exceptions.BackupException as e:
            logging.error("Disk backup failed: [%s]", e)
        except sshError as e:
            logging.error("Remote Disk backup failed: [%s]", e)
        except Exception as e:  # pylint: disable=broad-except
            logging.critical("Unknown Exception during backup: %s", e)
            logging.exception(e)

        if args.offline is False:
            logging.info("Backup jobs finished, stopping backup task.")
            virtClient.stopBackup(domObj)

        virtClient.close()

        metadata.saveFiles(args, vmConfig, disks, fileStream, logFile)

        if domObj.autostart() == 1:
            metadata.backupAutoStart(args)

        s3_upload_success = True
        if s3_mode:
            if counter.count.errors == 0:
                s3_client = _get_s3_client(args)
                if s3_client:
                    bucket_name = f"{args.domain}-{vm_uuid}"
                    logging.info("Starting upload of local backup from '%s' to S3.", args.output)

                    if not _upload_to_s3(s3_client, args.output, bucket_name, args.domain):
                        logging.error("S3 upload failed. The local backup is preserved for manual inspection.")
                        logging.error("Local backup path: %s", args.output)
                        s3_upload_success = False
                    else:
                        logging.info("Backup successfully uploaded to S3 bucket: %s", bucket_name)
                        if args.s3_remove_local_on_success:
                            try:
                                logging.info("Removing local backup copy at: %s", args.output)
                                shutil.rmtree(args.output)
                            except OSError as e:
                                logging.error("Failed to clean up local backup directory %s: %s", args.output, e)
                        else:
                            logging.info("Local backup copy preserved at: %s", args.output)
            else:
                logging.warning("Errors occurred during local backup. Skipping S3 upload.")
                logging.warning("Local backup data is located at: %s", args.output)
                s3_upload_success = False

        if counter.count.errors > 0 or not s3_upload_success:
            logging.error("Error during backup")
            sys.exit(1)

        if args.sshClient:
            args.sshClient.disconnect()

        if counter.count.warnings > 0 and args.strict is True:
            logging.info(
                "[%s] Warnings detected during backup operation, forcing exit code 2",
                counter.count.warnings,
            )
            sys.exit(2)

        logging.info("Total saved disk data: [%s]", lib.humanize(backupSize))
        logging.info("Finished successfully")

    finally:
        if 'args' in locals() and hasattr(args, 'socketfile'):
            socket_base_patch = args.socketfile
            cleanup_files = glob.glob(f"{socket_base_patch}")

            if cleanup_files:
                logging.info("Performing cleanup of temp files...")
                for f in cleanup_files:
                    try:
                        logging.info(f"Removing temp file: {f}")
                        os.unlink(f)
                    except OSError as e:
                        logging.warning(f"Could not remove temp file {f}: {e}")

if __name__ == "__main__":
    main()