#!/usr/bin/python3
# pylint: disable=invalid-name

"""
A JSON-based interface for depsolving using DNF.

Reads a request through stdin and prints the result to stdout.
In case of error, a structured error is printed to stdout as well.
"""
import json
import os
import sys
import tempfile
from datetime import datetime

import dnf
import hawkey


class Solver():

    # pylint: disable=too-many-arguments
    def __init__(self, repos, module_platform_id, persistdir, cachedir, arch):
        self.base = dnf.Base()

        # Enable fastestmirror to ensure we choose the fastest mirrors for
        # downloading metadata (when depsolving) and downloading packages.
        self.base.conf.fastestmirror = True

        # We use the same cachedir for multiple architectures. Unfortunately,
        # this is something that doesn't work well in certain situations
        # with zchunk:
        # Imagine that we already have cache for arch1. Then, we use dnf-json
        # to depsolve for arch2. If ZChunk is enabled and available (that's
        # the case for Fedora), dnf will try to download only differences
        # between arch1 and arch2 metadata. But, as these are completely
        # different, dnf must basically redownload everything.
        # For downloding deltas, zchunk uses HTTP range requests. Unfortunately,
        # if the mirror doesn't support multi range requests, then zchunk will
        # download one small segment per a request. Because we need to update
        # the whole metadata (10s of MB), this can be extremely slow in some cases.
        # I think that we can come up with a better fix but let's just disable
        # zchunk for now. As we are already downloading a lot of data when
        # building images, I don't care if we download even more.
        self.base.conf.zchunk = False

        # Set the rest of the dnf configuration.
        self.base.conf.module_platform_id = module_platform_id
        self.base.conf.config_file_path = "/dev/null"
        self.base.conf.persistdir = persistdir
        self.base.conf.cachedir = cachedir
        self.base.conf.substitutions['arch'] = arch
        self.base.conf.substitutions['basearch'] = dnf.rpm.basearch(arch)

        for repo in repos:
            self.base.repos.add(self._dnfrepo(repo, self.base.conf))
        self.base.fill_sack(load_system_repo=False)

    # pylint: disable=too-many-branches
    @staticmethod
    def _dnfrepo(desc, parent_conf=None):
        """Makes a dnf.repo.Repo out of a JSON repository description"""

        repo = dnf.repo.Repo(desc["id"], parent_conf)

        if "name" in desc:
            repo.name = desc["name"]
        if "baseurl" in desc:
            repo.baseurl = desc["baseurl"]
        elif "metalink" in desc:
            repo.metalink = desc["metalink"]
        elif "mirrorlist" in desc:
            repo.mirrorlist = desc["mirrorlist"]
        else:
            assert False

        if desc.get("ignoressl", False):
            repo.sslverify = False
        if "sslcacert" in desc:
            repo.sslcacert = desc["sslcacert"]
        if "sslclientkey" in desc:
            repo.sslclientkey = desc["sslclientkey"]
        if "sslclientcert" in desc:
            repo.sslclientcert = desc["sslclientcert"]

        if "check_gpg" in desc:
            repo.gpgcheck = desc["check_gpg"]
        if "check_repogpg" in desc:
            repo.repo_gpgcheck = desc["check_repogpg"]
        if "gpgkey" in desc:
            repo.gpgkey = [desc["gpgkey"]]
        if "gpgkeys" in desc:
            # gpgkeys can contain a full key, or it can be a URL
            # dnf expects urls, so write the key to a temporary location and add the file://
            # path to repo.gpgkey
            keydir = os.path.join(parent_conf.persistdir, "gpgkeys")
            if not os.path.exists(keydir):
                os.makedirs(keydir, mode=0o700, exist_ok=True)

            for key in desc["gpgkeys"]:
                if key.startswith("-----BEGIN PGP PUBLIC KEY BLOCK-----"):
                    # Not using with because it needs to be a valid file for the duration. It
                    # is inside the temporary persistdir so will be cleaned up on exit.
                    # pylint: disable=consider-using-with
                    keyfile = tempfile.NamedTemporaryFile(dir=keydir, delete=False)
                    keyfile.write(key.encode("utf-8"))
                    repo.gpgkey.append(f"file://{keyfile.name}")
                    keyfile.close()
                else:
                    repo.gpgkey.append(key)

        # In dnf, the default metadata expiration time is 48 hours. However,
        # some repositories never expire the metadata, and others expire it much
        # sooner than that. We therefore allow this to be configured. If nothing
        # is provided we error on the side of checking if we should invalidate
        # the cache. If cache invalidation is not necessary, the overhead of
        # checking is in the hundreds of milliseconds. In order to avoid this
        # overhead accumulating for API calls that consist of several dnf calls,
        # we set the expiration to a short time period, rather than 0.
        repo.metadata_expire = desc.get("metadata_expire", "20s")

        # This option if True disables modularization filtering. Effectively
        # disabling modularity for given repository.
        if "module_hotfixes" in desc:
            repo.module_hotfixes = desc["module_hotfixes"]

        return repo

    @staticmethod
    def _timestamp_to_rfc3339(timestamp):
        return datetime.utcfromtimestamp(timestamp).strftime('%Y-%m-%dT%H:%M:%SZ')

    def dump(self):
        packages = []
        for package in self.base.sack.query().available():
            packages.append({
                "name": package.name,
                "summary": package.summary,
                "description": package.description,
                "url": package.url,
                "repo_id": package.repoid,
                "epoch": package.epoch,
                "version": package.version,
                "release": package.release,
                "arch": package.arch,
                "buildtime": self._timestamp_to_rfc3339(package.buildtime),
                "license": package.license
            })
        return packages

    def search(self, args):
        """ Perform a search on the available packages

        args contains a "search" dict with parameters to use for searching.
        "packages" list of package name globs to search for
        "latest" is a boolean that will return only the latest NEVRA instead
        of all matching builds in the metadata.

        eg.

            "search": {
                "latest": false,
                "packages": ["tmux", "vim*", "*ssh*"]
            },
        """
        pkg_globs = args.get("packages", [])

        packages = []

        # NOTE: Build query one piece at a time, don't pass all to filterm at the same
        # time.
        available = self.base.sack.query().available()
        for name in pkg_globs:
            # If the package name glob has * in it, use glob.
            # If it has *name* use substr
            # If it has neither use exact match
            if "*" in name:
                if name[0] != "*" or name[-1] != "*":
                    q = available.filter(name__glob=name)
                else:
                    q = available.filter(name__substr=name.replace("*", ""))
            else:
                q = available.filter(name__eq=name)

            if args.get("latest", False):
                q = q.latest()

            for package in q:
                packages.append({
                    "name": package.name,
                    "summary": package.summary,
                    "description": package.description,
                    "url": package.url,
                    "repo_id": package.repoid,
                    "epoch": package.epoch,
                    "version": package.version,
                    "release": package.release,
                    "arch": package.arch,
                    "buildtime": self._timestamp_to_rfc3339(package.buildtime),
                    "license": package.license
                })
        return packages

    def depsolve(self, transactions):
        last_transaction = []

        for idx, transaction in enumerate(transactions):
            self.base.reset(goal=True)
            self.base.sack.reset_excludes()

            # don't install weak-deps for transactions after the 1st transaction
            if idx > 0:
                self.base.conf.install_weak_deps = False

            # set the packages from the last transaction as installed
            for installed_pkg in last_transaction:
                self.base.package_install(installed_pkg, strict=True)

            # depsolve the current transaction
            self.base.install_specs(
                transaction.get("package-specs"),
                transaction.get("exclude-specs"),
                reponame=transaction.get("repo-ids"),
            )
            self.base.resolve()

            # store the current transaction result
            last_transaction.clear()
            for tsi in self.base.transaction:
                # Avoid using the install_set() helper, as it does not guarantee
                # a stable order
                if tsi.action not in dnf.transaction.FORWARD_ACTIONS:
                    continue
                last_transaction.append(tsi.pkg)

        dependencies = []
        for package in last_transaction:
            dependencies.append({
                "name": package.name,
                "epoch": package.epoch,
                "version": package.version,
                "release": package.release,
                "arch": package.arch,
                "repo_id": package.repoid,
                "path": package.relativepath,
                "remote_location": package.remote_location(),
                "checksum": (
                    f"{hawkey.chksum_name(package.chksum[0])}:"
                    f"{package.chksum[1].hex()}"
                )
            })

        return dependencies


def setup_cachedir(request):
    arch = request["arch"]
    # If dnf-json is run as a service, we don't want users to be able to set the cache
    cache_dir = os.environ.get("OVERWRITE_CACHE_DIR", "")
    if cache_dir:
        cache_dir = os.path.join(cache_dir, arch)
    else:
        cache_dir = request.get("cachedir", "")

    if not cache_dir:
        return "", {"kind": "Error", "reason": "No cache dir set"}

    return cache_dir, None


def solve(request, cache_dir):
    command = request["command"]
    arch = request["arch"]
    module_platform_id = request["module_platform_id"]
    arguments = request["arguments"]

    transactions = arguments.get("transactions")
    with tempfile.TemporaryDirectory() as persistdir:
        try:
            solver = Solver(
                arguments["repos"],
                module_platform_id,
                persistdir,
                cache_dir,
                arch
            )
            if command == "dump":
                result = solver.dump()
            elif command == "depsolve":
                result = solver.depsolve(transactions)
            elif command == "search":
                result = solver.search(arguments.get("search", {}))

        except dnf.exceptions.MarkingErrors as e:
            printe("error install_specs")
            return None, {
                "kind": "MarkingErrors",
                "reason": f"Error occurred when marking packages for installation: {e}"
            }
        except dnf.exceptions.DepsolveError as e:
            printe("error depsolve")
            # collect list of packages for error
            pkgs = []
            for t in transactions:
                pkgs.extend(t["package-specs"])
            return None, {
                "kind": "DepsolveError",
                "reason": f"There was a problem depsolving {', '.join(pkgs)}: {e}"
            }
        except dnf.exceptions.RepoError as e:
            return None, {
                "kind": "RepoError",
                "reason": f"There was a problem reading a repository: {e}"
            }
        except dnf.exceptions.Error as e:
            printe("error repository setup")
            return None, {
                "kind": type(e).__name__,
                "reason": str(e)
            }
    return result, None


def printe(*msg):
    print(*msg, file=sys.stderr)


def fail(err):
    printe(f"{err['kind']}: {err['reason']}")
    print(json.dumps(err))
    sys.exit(1)


def respond(result):
    print(json.dumps(result))


def validate_request(request):
    command = request.get("command")
    valid_cmds = ("depsolve", "dump", "search")
    if command not in valid_cmds:
        return {
            "kind": "InvalidRequest",
            "reason": f"invalid command '{command}': must be one of {', '.join(valid_cmds)}"
        }

    if not request.get("arch"):
        return {
            "kind": "InvalidRequest",
            "reason": "no 'arch' specified"
        }

    if not request.get("module_platform_id"):
        return {
            "kind": "InvalidRequest",
            "reason": "no 'module_platform_id' specified"
        }
    arguments = request.get("arguments")
    if not arguments:
        return {
            "kind": "InvalidRequest",
            "reason": "empty 'arguments'"
        }

    if not arguments.get("repos"):
        return {
            "kind": "InvalidRequest",
            "reason": "no 'repos' specified"
        }

    return None


def main():
    request = json.load(sys.stdin)
    err = validate_request(request)
    if err:
        fail(err)

    cachedir, err = setup_cachedir(request)
    if err:
        fail(err)
    result, err = solve(request, cachedir)
    if err:
        fail(err)
    else:
        respond(result)


if __name__ == "__main__":
    main()
