# SPDX-FileCopyrightText: Huawei Inc.
#
# SPDX-License-Identifier: Apache-2.0
#
# Create a diff of two JSON CVE statuses

import sys
import getopt
import re
from deepdiff import DeepDiff
from pprint import pprint

verbose_level = 0


def show_syntax_and_exit(code):
    """
    Show the program syntax and exit with an errror
    Arguments:
        code: the error code to return
    """
    print("Syntax: %s [-h] [-v] file1 file2" % __name__)
    sys.exit(code)


def exit_error(code, message):
    """
    Show the error message and exit with an errror
    Arguments:
        code: the error code to return
        message: the message to show
    """
    print("Error: %s" % message)
    sys.exit(code)


def vprint(*args, **kwargs):
    """
    Printing with verbosity levels
    Arguments:
        args: as for print
        kwargs: as for print
    """
    global verbose_level

    if verbose_level:
        print(*args, **kwargs)


def parse_args(argv):
    """
    Parse the program arguments, put options in global variables
    Arguments:
        argv: program arguments
    Returns:
        Two mandatory file arguments
    """
    global verbose_level
    try:
        opts, args = getopt.getopt(argv, "hv", ["help", "verbose"])
    except getopt.GetoptError:
        show_syntax_and_exit(1)
    for opt, arg in opts:
        if opt in ("-h", "--help"):
            show_syntax_and_exit(0)
        elif opt in ("-v", "--verbose"):
            verbose_level = 1
        else:
            show_syntax_and_exit(1)

    if len(args) < 2:
        exit_error(1, "Need at least two files to compare")
    return (args[0], args[1])


def validate_cve_json(data):
    """
    Check correctness of the loaded JSON data
    Arguments:
        data: loaded data
    Returns:
        Bool: True if file is correct, False otherwise
        Error message: if the file has errors
    """
    if not "version" in data or data["version"] != "1":
        return False, "Unrecognized format version number"
    if not "package" in data:
        return False, "Mandatory 'package' key not found"
    for package in data["package"]:
        keys_in_package = {"name", "layer", "version", "issue"}
        if keys_in_package - package.keys():
            return False, "Missing a mandatory key in package: %s" % (
                keys_in_package - package.keys()
            )

        for issue in package["issue"]:
            keys_in_issue = {"id", "scorev2", "scorev3", "vector", "status"}
            if keys_in_issue - issue.keys():
                return (
                    False,
                    "Missing mandatory keys %s in 'issue' for the package '%s'"
                    % (keys_in_issue - issue.keys(), package_name),
                )
    return True, ""


def get_name(value):
    """
    Function used for sorting, return the sorting key to use
    Argument:
        value: raw value
    Return:
        Package name to use for sorting
    """
    return value["name"]


def load_cve_json(filename):
    """
    Load the JSON file, return the resulting dictionary
    Arguments:
        filename: the file to open
    Returns:
        Parsed file as a dictionary
    """
    import json

    out = {}
    try:
        with open(filename, "r") as f:
            out = json.load(f)
    except FileNotFoundError:
        exit_error(1, "Input file (%s) not found" % (filename))
    except json.decoder.JSONDecodeError as error:
        exit_error(1, "Malformed JSON file: %s" % str(error))

    # Validate file
    validated, error = validate_cve_json(out)
    if not validated:
        exit_error(1, error)

    out_sorted = sorted(out["package"], key=get_name)
    return out_sorted


def calculate_diff(data1, data2):
    """
    Calculate a deep diff between two JSON data sets
    Arguments:
        data1: source data to compare (current)
        data2: destination data to compare (upstream)
    Returns:
        Diff in the DeepDiff tree format
    """
    ddiff = DeepDiff(data1, data2, ignore_order=True, view="tree")
    return ddiff


def calculate_package_diff(diff):
    """
    Parse a difference in packages from the deep diff
    Arguments:
        diff: the deep diff, tree format
    Returns:
        removed set: added packages list
        added_set: removed packages list
    """
    removed_set = []
    added_set = []
    vprint("Package status:")
    if "iterable_item_added" in diff.keys():
        for p in diff["iterable_item_added"]:
            # Only p.t2 exists
            # Assure only "root[XXXX]" items
            if not re.search("^root\[[0-9]+\]$", p.path()):
                # If we have a new issue
                # if re.search("^root\[[0-9]+\]\['issue'\]\[[0-9]+\]$", p):
                #    print("New issue")
                continue
            vprint("Added package: %s %s" % (p.t2["name"], p.t2["version"]))
            added_set.append(p.t2)
    if "iterable_item_removed" in diff.keys():
        for p in diff["iterable_item_removed"]:
            # Only p.t1 exists
            # Assure only "root[XXXX]" items
            if not re.search("^root\[[0-9]+\]$", p.path()):
                # print("removed: Not found in %s" % (p.t1))
                continue
            vprint("Removed package: %s %s" % (p.t1["name"], p.t1["version"]))
            removed_set.append(p.t1)
    if "values_changed" in diff.keys():
        # print("Some values changed")
        for p in diff["values_changed"]:
            # Filter out name changes, this is addition/removal - removal of t1
            if re.search("^root\[[0-9]+\]\['name'\]$", p.path()):
                vprint("Removed package: %s %s" % (p.up.t1["name"], p.up.t1["version"]))
                vprint("Added package: %s %s" % (p.up.t2["name"], p.up.t2["version"]))
                removed_set.append(p.up.t1)
                added_set.append(p.up.t2)
                # TODO: Handle the special case of linux
            else:
                # If a high level property root[XXX]['somename'] of the item that was changed
                if re.search("^root\[[0-9]+\]\['[a-z,A-Z,0-9]+'\]$", p.path()):
                    # If the parent has been removed, this item is to be ignored
                    # We take it into account only if paren't hasn't been removed
                    if not p.up.t1 in removed_set:
                        vprint("Found changed items for: %s" % (p.up.t1["name"]))
                # product table varies when the high level has changed
                # Format: [root][XX]['products'][XX]['product']
                elif re.search(
                    "^root\[[0-9]+\]\['products'\]\[[0-9]+\]\['product'\]$", p.path()
                ):
                    if p.up.up.up.t1 not in removed_set:
                        vprint(
                            "Change of product names for package: %s %s"
                            % (p.up.up.up.t1["name"], p.up.up.up.t2["name"])
                        )
                # TODO: product table varies when the high level has changed
                # TODO: issue table version when the high level has changed
                # else:
                #    print(p)
    vprint("")
    return (removed_set, added_set)


def calculate_cve_diff(diff):
    """
    Parse a difference in CVEs from the deep diff
    Arguments:
        diff: the deep diff, tree format
    Returns:
        removed_cves: removed CVEs
        added_cves: added CVEs
    """
    added_cves = []
    removed_cves = []
    vprint("CVE status:")
    # Look for added/removed CVEs
    if "values_changed" in diff.keys():
        # Differences like <root[42]['issue'][2562]['status'] t1:'Patched', t2:'Unpatched'>
        for p in diff["values_changed"]:
            if re.search("^root\[[0-9]+\]\['issue'\]\[[0-9]+\]\['status'\]$", p.path()):
                if (p.t1 == "Patched" or p.t1 == "Ignored") and p.t2 == "Unpatched":
                    # New CVE
                    vprint(
                        "Unpatched CVE: %s (%s %s)"
                        % (
                            p.up.t2["id"],
                            p.up.up.up.t2["name"],
                            p.up.up.up.t2["version"],
                        )
                    )
                    added_cves.append(p.up.t2["id"])
                elif p.t1 == "Unpatched" and (p.t2 == "Patched" or p.t2 == "Ignored"):
                    # Fixed CVE
                    vprint(
                        "Fixed CVE: %s (%s %s)"
                        % (
                            p.up.t2["id"],
                            p.up.up.up.t2["name"],
                            p.up.up.up.t2["version"],
                        )
                    )
                    removed_cves.append(p.up.t2["id"])
                else:
                    vprint(
                        "Unknown status of CVE: %s (%s %s)"
                        % (
                            p.up.t2["id"],
                            p.up.up.up.t2["name"],
                            p.up.up.up.t2["version"],
                        )
                    )
            # Other changes in the issue table
            elif re.search("^root\[[0-9]+\]\['issue'\].*$", p.path()):
                # The whole issue element has changed
                if re.search("^root\[[0-9]+\]\['issue'\]\[[0-9]+\]$", p.path()):
                    if p.up.up.t1["name"] == p.up.up.t2["name"]:
                        vprint(
                            "Changed CVE for package %s %s"
                            % (p.up.up.t2["name"], p.up.up.t2["version"])
                        )
                else:
                    # Some fields of the issue changed
                    if re.search(
                        "^root\[[0-9]+\]\['issue'\]\[[0-9]+\]\['[a-z,A-Z]+'\]$",
                        p.path(),
                    ):
                        # If package names are different, this is a spurious diff
                        if p.up.up.up.t1["name"] == p.up.up.up.t2["name"]:
                            vprint(
                                "Changed CVE for package %s %s"
                                % (p.up.up.up.t2["name"], p.up.up.up.t2["version"])
                            )
    vprint("")
    return (removed_cves, added_cves)


def main(argv):
    file1, file2 = parse_args(argv)

    data1 = load_cve_json(file1)
    data2 = load_cve_json(file2)
    diff = calculate_diff(data1, data2)

    removed_set, added_set = calculate_package_diff(diff)

    removed_cves, added_cves = calculate_cve_diff(diff)

    print("Summary:")
    print(
        "Package report: the new version adds %s packages and removes %d packages"
        % (len(added_set), len(removed_set))
    )
    print(
        "CVE report: the new version removes %s CVEs and adds %s CVEs"
        % (len(removed_cves), len(added_cves))
    )


if __name__ == "__main__":
    main(sys.argv[1:])
