#!/usr/bin/env python3

# Copyright (C) 2020 Matěj Týč <matyc@redhat.com>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the
# Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
# Boston, MA 02110-1301 USA

import argparse
import collections
import datetime
import re
import sys
import pathlib
import xml.etree.ElementTree as ET
import xml.dom.minidom
import json


NS = "http://checklists.nist.gov/xccdf/1.2"
DEFAULT_PROFILE_SUFFIX = "_customized"
DEFAULT_REVERSE_DNS = "org.ssgproject.content"
ROLES = ["full", "unscored", "unchecked"]
SEVERITIES = ["unknown", "info", "low", "medium", "high"]
ATTRIBUTES = ["role", "severity"]


def quote(string):
    return "\"" + string + "\""


def assignment_to_tuple(assignment):
    varname, value = assignment.split("=", 1)
    return varname, value


def is_valid_xccdf_id(string):
    return re.match(
        r"^xccdf_[a-zA-Z0-9.-]+_(benchmark|profile|rule|group|value|"
        r"testresult|tailoring)_.+",
        string) is not None


class Tailoring:
    def __init__(self):
        self.id = "xccdf_auto_tailoring_default"
        self.reverse_dns = DEFAULT_REVERSE_DNS
        self.version = 1
        self._profile_id = None
        self.extends = ""
        self.original_ds_filename = ""
        self.profile_title = ""

        self.value_changes = []
        self.rules_to_select = []
        self.rules_to_unselect = []
        self.groups_to_select = []
        self.groups_to_unselect = []
        self._rule_refinements = collections.defaultdict(dict)
        self._value_refinements = collections.defaultdict(dict)

    @property
    def profile_id(self):
        if self._profile_id is not None:
            return self._profile_id
        else:
            return self.extends + "_customized"

    @profile_id.setter
    def profile_id(self, new_profile_id):
        self._profile_id = new_profile_id

    def rule_refinements(self, rule_id, attribute):
        return self._rule_refinements[rule_id][attribute]

    @staticmethod
    def _find_rule_enumeration(attribute):
        if attribute == "role":
            enumeration = ROLES
        elif attribute == "severity":
            enumeration = SEVERITIES
        else:
            msg = f"Unsupported refine-rule attribute {attribute}"
            raise ValueError(msg)
        return enumeration

    @staticmethod
    def _validate_rule_refinement_params(rule_id, attribute, value):
        if not is_valid_xccdf_id(rule_id):
            msg = f"Rule id '{rule_id}' is invalid!"
            raise ValueError(msg)
        enumeration = Tailoring._find_rule_enumeration(attribute)
        if value in enumeration:
            return
        allowed = ", ".join(map(quote, enumeration))
        msg = (f"Can't refine {attribute} of rule '{rule_id}' to '{value}'. "
               f"Allowed {attribute} values are: {allowed}.")
        raise ValueError(msg)

    @staticmethod
    def _validate_value_refinement_params(value_id, attribute, value):
        if not is_valid_xccdf_id(value_id):
            msg = f"Value id '{value_id}' is invalid!"
            raise ValueError(msg)
        if attribute == 'selector':
            return
        msg = (f"Can't refine {attribute} of value '{value_id}' to '{value}'. "
               f"Unsupported refine-rule attribute {attribute}.")
        raise ValueError(msg)

    def _prevent_duplicate_rule_refinement(self, attribute, rule_id, value):
        refinements = self._rule_refinements[rule_id]
        if attribute not in refinements:
            return
        current = refinements[attribute]
        msg = (f"Can't refine {attribute} of rule '{rule_id}' to '{value}'. "
               f"This rule {attribute} is already refined to '{current}'.")
        raise ValueError(msg)

    def _prevent_duplicate_value_refinement(self, attribute, value_id, value):
        refinements = self._value_refinements[value_id]
        if attribute not in refinements:
            return
        current = refinements[attribute]
        msg = (f"Can't refine {attribute} of value '{value_id}' to '{value}'. "
               f"This value {attribute} is already refined to '{current}'.")
        raise ValueError(msg)

    def refine_rule(self, rule_id, attribute, value):
        Tailoring._validate_rule_refinement_params(rule_id, attribute, value)
        self._prevent_duplicate_rule_refinement(attribute, rule_id, value)
        self._rule_refinements[rule_id][attribute] = value

    def refine_value(self, value_id, attribute, value):
        Tailoring._validate_value_refinement_params(value_id, attribute, value)
        self._prevent_duplicate_value_refinement(attribute, value_id, value)
        self._value_refinements[value_id][attribute] = value

    def change_rule_attribute(self, rule_id, attribute, value):
        full_rule_id = self._full_rule_id(rule_id)
        self.refine_rule(full_rule_id, attribute, value)

    def change_value_attribute(self, var_id, attribute, value):
        full_value_id = self._full_var_id(var_id)
        self.refine_value(full_value_id, attribute, value)

    def change_rules_attributes(self, assignments, attribute):
        for change in assignments:
            rule_id, value = assignment_to_tuple(change)
            self.change_rule_attribute(rule_id, attribute, value)

    def change_roles(self, assignments):
        self.change_rules_attributes(assignments, "role")

    def change_severities(self, assignments):
        self.change_rules_attributes(assignments, "severity")

    def change_values(self, assignments):
        for change in assignments:
            varname, value = assignment_to_tuple(change)
            self.add_value_change(varname, value)

    def change_selectors(self, assignments):
        for change in assignments:
            varname, selector = assignment_to_tuple(change)
            self.change_value_attribute(varname, "selector", selector)

    def _full_id(self, string, el_type):
        if is_valid_xccdf_id(string):
            return string
        return f"xccdf_{self.reverse_dns}_{el_type}_{string}"

    def _full_profile_id(self, string):
        return self._full_id(string, "profile")

    def _full_var_id(self, string):
        return self._full_id(string, "value")

    def _full_rule_id(self, string):
        return self._full_id(string, "rule")

    def _full_group_id(self, string):
        return self._full_id(string, "group")

    def add_value_change(self, varname, value):
        self.value_changes.append((varname, value))

    def _add_group_select_operations(self, container_element):
        for group_id in self.groups_to_select:
            change = ET.SubElement(container_element, "{%s}select" % NS)
            change.set("idref", self._full_group_id(group_id))
            change.set("selected", "true")

        for group_id in self.groups_to_unselect:
            change = ET.SubElement(container_element, "{%s}select" % NS)
            change.set("idref", self._full_group_id(group_id))
            change.set("selected", "false")

    def _add_rule_select_operations(self, container_element):
        for rule_id in self.rules_to_select:
            change = ET.SubElement(container_element, "{%s}select" % NS)
            change.set("idref", self._full_rule_id(rule_id))
            change.set("selected", "true")

        for rule_id in self.rules_to_unselect:
            change = ET.SubElement(container_element, "{%s}select" % NS)
            change.set("idref", self._full_rule_id(rule_id))
            change.set("selected", "false")

    def _add_value_overrides(self, container_element):
        for varname, value in self.value_changes:
            change = ET.SubElement(container_element, "{%s}set-value" % NS)
            change.set("idref", self._full_var_id(varname))
            change.text = str(value)

    def rule_refinements_to_xml(self, profile_el):
        for rule_id, refinements in self._rule_refinements.items():
            ref_rule_el = ET.SubElement(profile_el, "{%s}refine-rule" % NS)
            ref_rule_el.set("idref", rule_id)
            for attr, val in refinements.items():
                ref_rule_el.set(attr, val)

    def value_refinements_to_xml(self, profile_el):
        for value_id, refinements in self._value_refinements.items():
            ref_value_el = ET.SubElement(profile_el, "{%s}refine-value" % NS)
            ref_value_el.set("idref", value_id)
            for attr, val in refinements.items():
                ref_value_el.set(attr, val)

    def to_xml(self, location=None):
        root = ET.Element("{%s}Tailoring" % NS)
        root.set("id", self.id)

        benchmark = ET.SubElement(root, "{%s}benchmark" % NS)
        datastream_uri = pathlib.Path(
            self.original_ds_filename).absolute().as_uri()
        benchmark.set("href", datastream_uri)

        version = ET.SubElement(root, "{%s}version" % NS)
        version.set("time", datetime.datetime.now().isoformat())
        version.text = str(self.version)

        profile = ET.SubElement(root, "{%s}Profile" % NS)
        profile.set("id", self._full_profile_id(self.profile_id))
        profile.set("extends", self._full_profile_id(self.extends))

        # Title has to be there due to the schema definition.
        title = ET.SubElement(profile, "{%s}title" % NS)
        if self.profile_title:
            title.set("override", "true")
        else:
            title.set("override", "false")
        title.text = self.profile_title

        self._add_group_select_operations(profile)
        self._add_rule_select_operations(profile)
        self._add_value_overrides(profile)
        self.rule_refinements_to_xml(profile)
        self.value_refinements_to_xml(profile)

        root_str = ET.tostring(root)
        pretty_xml = xml.dom.minidom.parseString(root_str).toprettyxml()
        with open(location, "w") if location != "-" else sys.stdout as f:
            f.write(pretty_xml)

    def _import_groups_from_tailoring(self, tailoring):
        if "groups" in tailoring:
            for group_id, props in tailoring["groups"].items():
                if "evaluate" in props:
                    if props["evaluate"]:
                        self.groups_to_select.append(group_id)
                    else:
                        self.groups_to_unselect.append(group_id)

    def _import_variables_from_tailoring(self, tailoring):
        if "variables" in tailoring:
            for variable_id, props in tailoring["variables"].items():
                if "value" in props:
                    self.add_value_change(variable_id, props["value"])
                if "option_id" in props:
                    self.change_value_attribute(variable_id, "selector", props["option_id"])

    def _import_rules_from_tailoring(self, tailoring):
        if "rules" in tailoring:
            for rule_id, props in tailoring["rules"].items():
                if "evaluate" in props:
                    if props["evaluate"]:
                        self.rules_to_select.append(rule_id)
                    else:
                        self.rules_to_unselect.append(rule_id)
                for attr in ATTRIBUTES:
                    if attr in props:
                        self.change_rule_attribute(rule_id, attr, props[attr])

    def import_json_tailoring(self, json_tailoring):
        with open(json_tailoring, "r") as jf:
            all_tailorings = json.load(jf)

        if 'profiles' in all_tailorings and all_tailorings['profiles']:
            if len(all_tailorings['profiles']) > 1:
                raise ValueError("The autotailor tool currently does not support multi-profile JSON tailoring.")
            tailoring = all_tailorings['profiles'][0]
        else:
            raise ValueError("JSON Tailoring does not define any profiles.")

        self.extends = tailoring["base_profile_id"]

        self.profile_id = tailoring.get("id", self.profile_id)
        self.profile_title = tailoring.get("title", self.profile_title)

        self._import_groups_from_tailoring(tailoring)
        self._import_rules_from_tailoring(tailoring)
        self._import_variables_from_tailoring(tailoring)


def get_parser():
    parser = argparse.ArgumentParser(
        description="This script produces XCCDF 1.2 tailoring files "
        "to be used by SCAP scanners and SCAP data streams.")
    parser.add_argument(
        "datastream", metavar="DS_FILENAME",
        help="The tailored data stream filename.")
    parser.add_argument(
        "profile", metavar="BASE_PROFILE_ID", nargs='?', default="",
        help="Specify ID of the base profile. ID of the profile can be "
        "either its full ID, or the suffix, in which case the "
        "'xccdf_<id-namespace>_profile' prefix will be prepended internally.")
    parser.add_argument(
        "-j", "--json-tailoring", metavar="JSON_TAILORING_FILENAME", default="",
        help="JSON Tailoring (https://github.com/ComplianceAsCode/schemas/blob/main/tailoring/schema.json) "
        "filename.")
    parser.add_argument(
        "--title", default="",
        help="Title of the new profile.")
    parser.add_argument(
        "--id-namespace", default=DEFAULT_REVERSE_DNS,
        help="The reverse-DNS style string that is part of entities IDs in "
        "the corresponding data stream. If left out, the default value "
        "'{reverse_dns}' is used.".format(reverse_dns=DEFAULT_REVERSE_DNS))
    parser.add_argument(
        "-v", "--var-value", metavar="VAR=VALUE", action="append", default=[],
        help="Specify modification of the XCCDF value in form "
        "<varname>=<value>. Name of the variable can be either its full name, "
        "or the suffix, in which case the 'xccdf_<id-namespace>_value' prefix "
        "will be prepended internally. Specify the argument multiple times "
        "if needed.")
    parser.add_argument(
        "-V", "--var-select", metavar="VAR=SELECTOR", action="append", default=[],
        help="Specify refinement of the XCCDF value in form "
        "<varname>=<selector>. Name of the variable can be either its full name, "
        "or the suffix, in which case the 'xccdf_<id-namespace>_value' prefix "
        "will be prepended internally. Specify the argument multiple times "
        "if needed.")
    parser.add_argument(
        "-r", "--rule-role", metavar="RULE=ROLE", action="append", default=[],
        help="Specify refinement of the XCCDF rule role in form "
        "<rule_id>=<role>. Name of the rule can be either its full name, or "
        "the suffix, in which case the 'xccdf_<id-namespace>_rule_' prefix "
        "will be prepended internally. The value of <role> can be one of "
        "{options}. Specify the argument multiple times if "
        "needed.".format(options=", ".join(ROLES)))
    parser.add_argument(
        "-e", "--rule-severity", metavar="RULE=SEVERITY", action="append",
        default=[],
        help="Specify refinement of the XCCDF rule severity in form "
        "<rule_id>=<severity>. Name of the rule can be either its full name, "
        "or the suffix, in which case the 'xccdf_<id-namespace>_rule_' "
        "prefix will be prepended internally. The value of <severity> can be "
        "one of {options}. Specify the argument multiple times if "
        "needed.".format(options=", ".join(SEVERITIES)))
    parser.add_argument(
        "-s", "--select", metavar="RULE_ID", action="append", default=[],
        help="Specify what rules to select. "
        "The rule ID can be either full, or just the suffix, in which case "
        "the 'xccdf_<id-namespace>_rule' prefix will be prepended internally. "
        "Specify the argument multiple times if needed.")
    parser.add_argument(
        "-u", "--unselect", metavar="RULE_ID", action="append", default=[],
        help="Specify what rules to unselect. "
        "The argument works the same way as the --select argument.")
    parser.add_argument(
        "-p", "--new-profile-id",
        help="Specify the ID of the tailored profile. "
        "The ID of the new profile can be either its full ID, or the suffix, "
        "in which case the 'xccdf_<id-namespace>_profile_' prefix will be "
        "prepended internally. If left out, the new ID will be obtained "
        "by appending '{suffix}' to the tailored profile ID."
        .format(suffix=DEFAULT_PROFILE_SUFFIX))
    parser.add_argument(
        "-o", "--output", default="-",
        help="Where to save the tailoring file. If not supplied, write to "
        "standard output.")
    return parser


if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()

    if not args.profile and not args.json_tailoring:
        parser.error("one of the following arguments has to be provided: "
                     "BASE_PROFILE_ID or --json-tailoring JSON_TAILORING_FILENAME")

    t = Tailoring()
    t.original_ds_filename = args.datastream
    t.reverse_dns = args.id_namespace

    if args.json_tailoring:
        t.import_json_tailoring(args.json_tailoring)

    if args.profile:
        t.extends = args.profile
    if args.new_profile_id:
        t.profile_id = args.new_profile_id
    if args.title:
        t.profile_title = args.title

    t.rules_to_select.extend(args.select)
    t.rules_to_unselect.extend(args.unselect)

    t.change_values(args.var_value)
    t.change_selectors(args.var_select)
    t.change_roles(args.rule_role)
    t.change_severities(args.rule_severity)

    t.to_xml(args.output)
