#!/usr/bin/python2

# Copyright John Dennis, jdennis@redhat.com

import getopt
import zipfile
import re
import fnmatch
import os
import sys
import readline
try:
    import rpm
    have_rpm = True
except:
    have_rpm = False
    

#-------------------------------------------------------------------------------

prog_name = os.path.basename(sys.argv[0])

prompt = "Enter class: "
std_java_dir = "/usr/share/java"
std_jni_dir = "/usr/lib/java"
path_to_class_re = re.compile('/')
class_to_path_re = re.compile('\.')
rpm_file_cache = {}
exclude_jar_patterns = []
exclude_jar_regexps = []

config = {
    'jar-dirs'     : [std_java_dir, std_jni_dir],
    'action'       : 'query',
    'interactive'  : False,
    'ignore-case'  : False,
    'pattern-glob' : True,
    'show-rpms'    : False,
    'summary'      : False,
}

#-------------------------------------------------------------------------------


def get_rpm_name_by_file_path(path):
    if path is None:
        return None

    name = None
    try:
        ts = rpm.ts()
        mi = ts.dbMatch(rpm.RPMTAG_BASENAMES, path)
        for h in mi:
            name = h['name']
            break
    except:
        print sys.stderr >> "ERROR: failed to retrieve rpm info for %s" % path
    return name

def get_rpm_name(path):
    if rpm_file_cache.has_key(path):
        return rpm_file_cache[path]

    rpm_name = get_rpm_name_by_file_path(path)
    if rpm_name is None:
        return None
    rpm_file_cache[path] = rpm_name
    return rpm_name

#-------------------------------------------------------------------------------

def get_jar_classes(jar_path):
    class_list = []

    try:
        f = zipfile.ZipFile(jar_path)
    except IOError, e:
        print "Error: %s (%s)" % (e.filename, e.strerror)
        return class_list
    except Exception, e:
        print "Error: %s" % e
        return class_list

    # For every entry in the zip file determine if it is a .class file,
    # if so translate / to . and add it to the class list
    for name in f.namelist():
        path, ext = os.path.splitext(name)
        if ext == ".class":
            path = path_to_class_re.sub('.', path)
            class_list.append(path)

    return class_list

def get_files_in_dir(dir_path, recursive=False):
    paths = []

    if recursive:
        # Walk the directory
        for dir_path, dir_names, file_names in os.walk(dir_path, followlinks=True):
            # for every non-directory determine if it's a plain file or a symlink
            for name in file_names:
                path = os.path.join(dir_path, name)
                paths.append(path)
    else:
        names = os.listdir(dir_path)
        for name in names:
            path = os.path.join(dir_path, name)
            if os.path.isdir(path):
                continue
            paths.append(path)
    return paths

def get_jars(paths, jar_paths):
    '''
    Iterate over all the paths.  If the path is a link then traverse
    each link until it resolves to a file (e.g. the jar_path), record
    the traversal path used to reach the jar_path in a list.

    If the resolved path is a zip file record it the jar_paths dict,
    this assures only a single path for the jar is recorded because
    there may be many ways to reach it via symbolic links.

    The value in the jar_paths dict for the jar_path is a list of
    traversal lists. Thus if the jar_path was never reached via
    symbolic list traversal it will have an empty list.  If the
    jar_path was reached by one traversal the list will contain one
    list.  If the jar_path was reached by two different traversals the
    list will contain two lists, etc.
    '''
    # for every non-directory determine if it's a plain file or a symlink
    for path in paths:
        links = None
        if os.path.islink(path):
            # The file is a symlink, chase down the links until it resolves to a real file
            links = []
            tmp_path = path
            while os.path.islink(tmp_path):
                links.append(tmp_path)
                target = os.readlink(tmp_path)
                #print "link %s -> %s" % (tmp_path, target)
                if target.startswith("/"):
                    tmp_path = target
                else:
                    tmp_path = os.path.normpath(os.path.join(os.path.dirname(tmp_path), target))
            links.append(tmp_path)
            # jar_path is a real file worthy of consideration
            jar_path = tmp_path
            #print jar_path
        elif os.path.isfile(path):
            # It's a real file, consider it
            #print path
            jar_path = path

        # Is the path a Zip file, if so consider it a jar and add it to the jar_paths
        if zipfile.is_zipfile(jar_path):
            jar_links = jar_paths.setdefault(jar_path, [])
            if links:
                jar_links.append(links)

def is_jar_excluded(jar_path):
    for regexp in exclude_jar_regexps:
        if regexp.search(jar_path):
            return True
    return False
    

#-------------------------------------------------------------------------------

class JarCollection:
    def __init__(self, jar_dirs=None):
        self.jar_dirs = jar_dirs
        self.jar_pathnames = None
        self.class_jar_paths = {}
        self.classes = None
        self.all_jars = {}
        self.all_rpms = {}

        # Get a list of all UNIQUE jar path names
        self.jar_paths = {}
        for jar_dir in self.jar_dirs:
            paths = get_files_in_dir(jar_dir)
            get_jars(paths, self.jar_paths)
        
        # self.jar_paths is a dict, each key is one jar file.  The
        # value of the key is a list of traversal lists.  If the
        # jar_path was not reached by symbolic links the list will be
        # empty. If the jar path could be reached by one or more
        # traversal paths then there will be a list of each traveral
        # path. Each traversal path list is a sequence starting with
        # the original path (which by definition is a symbolic link)
        # followed by every intermediate symbolic link ending with the
        # jar_path.
        #
        # Thus the keys in self.jar_paths represent the UNIQUE jar
        # files and the value of the key provides all the ways that
        # jar file can be reached via symbolic links.

        self.jar_pathnames = self.jar_paths.keys()
        self.jar_pathnames.sort()

        # For each unique jar get it's classes and add it to the table of class jar_paths
        for jar_path in self.jar_pathnames:
            if is_jar_excluded(jar_path):
                continue
            for klass in get_jar_classes(jar_path):
                class_jar_path_list = self.class_jar_paths.setdefault(klass, [])
                class_jar_path_list.append(jar_path)

        self.classes = self.class_jar_paths.keys()
        self.classes.sort()

    def show_class_jar_paths(self, classes=None):
        if classes is None: classes = self.classes
        for klass in classes:
            print "%s" % klass
            jar_paths = self.class_jar_paths[klass]
            for jar_path in jar_paths:
                self.all_jars[jar_path] = True
                if config['show-rpms']:
                    rpm_name = get_rpm_name(jar_path)
                    self.all_rpms[rpm_name] = True
                    print "    %s [rpm: %s]" % (jar_path, rpm_name)
                else:
                    print "    %s" % jar_path

        if config['summary']:
            print
            print "Summary:"
            print "%d Unique Jar's" % (len(self.all_jars))
            jar_paths = self.all_jars.keys()
            jar_paths.sort()
            for jar_path in jar_paths:
                print "     %s" % jar_path

            if config['show-rpms']:
                print "%d Unique RPM's" % (len(self.all_rpms))
                rpm_names = self.all_rpms.keys()
                rpm_names.sort()
                for rpm_name in rpm_names:
                    print "     %s" % rpm_name


    def show_jar_links(self):
        for jar_path in self.jar_pathnames:
            jar_links = self.jar_paths[jar_path]
            if len(jar_links) == 0: continue
            print "%s" % jar_path
            for link_traversal in jar_links:
                print "    %s" % ' -> '.join(link_traversal)

    def lookup_class(self, klass):
        return self.class_jar_paths.get(klass)

    def search_class(self, pattern):
        classes = []
        regexp_flags = 0

        if config['pattern-glob']:
            pattern = '^' + fnmatch.translate(pattern)

        if config['ignore-case']: regexp_flags |= re.IGNORECASE

        try:
            regexp = re.compile(pattern, regexp_flags)
        except Exception, e:
            print >>sys.stderr, "ERROR, cannot compile search pattern '%s' (%s)" % (pattern, e)
            return None
        
        for klass in self.classes:
            if regexp.search(klass):
                classes.append(klass)

        return classes

    def find_multiple_definitions(self):
        classes = []

        for klass in self.classes:
            if len(self.class_jar_paths[klass]) > 1:
                classes.append(klass)

        return classes

    def compare_multiple_definitions(self, classes):
        n_defs = 0
        n_equal = 0
        n_not_equal = 0

        for klass in classes:
            print "comparing %s" % klass
            jar_paths = self.class_jar_paths[klass]
            n_jar_paths = len(jar_paths)
            name = class_to_path_re.sub('/', klass)
            name += '.class'
            
            class_defs = []
            for jar_path in jar_paths:
                n_defs += 1
                f = zipfile.ZipFile(jar_path)
                class_def = f.read(name)
                class_defs.append(class_def)
                f.close()
                
            i = 0
            while i < n_jar_paths-1:
                class_def1 = class_defs[i]
                j = i + 1
                while j < n_jar_paths:
                    class_def2 = class_defs[j]
                
                    result = cmp(class_def1, class_def2)
                    if result == 0:
                        result_str = "    equal"
                        n_equal += 1
                    else:
                        result_str = "not equal"
                        n_not_equal += 1

                    print "    %s %s %s" % (result_str, jar_paths[i], jar_paths[j])
                    j += 1
                i += 1

        print "%d classes, %d multiple classes, %d class defintions, %d equal, %d not_equal" % \
            (len(self.classes), len(classes), n_defs, n_equal, n_not_equal)

#-------------------------------------------------------------------------------
class Completer:
    def __init__(self, classes):
        self.completion_index = 0
        self.completions = {}

        for klass in classes:
            table = self.completions
            components = klass.split(".")
            for component in components:
                table = table.setdefault(component, {})

        #print self.completions

    def do_complete(self, text, state):
        try:
            # State will be 0 when a new completion begins,
            # otherwise it will be the index of the next possible completion.
            # We signal the end of possible completions by returning None.
            if state == 0:
                i = 0
            else:
                i = self.completion_index + 1

            #print "text=%s state=%s" % (text, state)

            # Find out which word we're completing and get the table for that word
            leading_text = readline.get_line_buffer()[0 : readline.get_begidx()]
            #print "leading_text=%s" % leading_text
            words = leading_text.split(".")
            table = self.completions
            for word in words:
                if len(word):
                    table = table.get(word)
                    if table is None: return None
                
            completions = table.keys()
            completions.sort()

            #print "completions=%s" % completions
            while i < len(completions):
                if completions[i].startswith(text):
                    self.completion_index = i
                    return completions[i]
                i += 1
            return None
        except Exception, e:
            print "Completer Exception: %s" % (e)


#-------------------------------------------------------------------------------

class Usage(Exception):
    def __init__(self, msg):
        self.msg = msg

def usage():
    'Command help.'

    return '''\

%(prog_name)s [pattern ...]

-h --help               print help
-d --dir                add dir to jar search path
-D --clear-dirs         clear the jar search path
-l --list               list every class and the jar's it's located in
-i --ignore-case        when searching with regular expressions ignore case
-R --regexp             when searching use full regular expressions (default is globbing)
-I --interactive        keep prompting for a class (with tab completion)
-m --multiple           list classes which have multiple definitions
-M --multiple-compare   for each multiply defined class show if they are equal
-r --show-rpms          when listing jars show which rpms provides it
-L --links              list each jar symbolic link traversal
-x --exclude            exclude any jar file matching this
                        regular expression. May be specified multiple times.
-s --summary            dump summary information

Standard Jar Search Path = %(jar_path)s

In interactive mode tab will complete up to the next location in the
class hierarchy. Then enter a dot (.) to move to the next location in
the class hierarchy, tab will complete again, repeat until the class
name is fully specified. (Note, after a fully completing a name a
space is inserted which you'll have to backspace over, sorry this is
limitation of the Python completion implementation).

Examples:

# Interactive class lookup with tab completion
%(prog_name)s -I

# Find all classes matching regular expression pattern
%(prog_name)s pattern

# Add a directory to jar search path, interactively lookup classes
%(prog_name)s -d /usr/share/foo/java/lib -I

# Search only this directory and search for pattern
%(prog_name)s -D -d /usr/share/foo/java/lib pattern
''' % {'prog_name' : prog_name,
       'jar_path'  : config['jar-dirs'],
      }


#-------------------------------------------------------------------------------

def main(argv=None):
    if argv is None:
        argv = sys.argv

    try:
        try:
            opts, args = getopt.getopt(argv[1:], 'hd:DlIiRmMrLx:s',
                                       ['help', 'dir=', 'clear-dirs', 'list',
                                        'interactive', 'ignore-case', 'regexp', 'multiple',
                                        'multiple-compare', 'show-rpms', 'links',
                                        'exclude=', 'summary'])
        except getopt.GetoptError, err:
            print >>sys.stderr, str(err) # will print something like 'option -a not recognized'
            usage()
            return 2

        for o, a in opts:
            if o in ('-h', '--help'):
                print >>sys.stdout, usage()
                return 0
            elif o in ('-l', '--list'):
                config['action'] = 'list'
            elif o in ('-m', '--multiple'):
                config['action'] = 'list-multiple'
            elif o in ('-M', '--multiple-compare'):
                config['action'] = 'compare-multiple'
            elif o in ('-d', '--dir'):
                config['jar-dirs'].append(a)
            elif o in ('-D', '--clear-dirs'):
                config['jar-dirs'] = []
            elif o in ('-i', '--ignore-case'):
                config['ignore-case'] = True
            elif o in ('-R', '--regexp'):
                config['pattern-glob'] = False
            elif o in ('-I', '--interactive'):
                config['interactive'] = True
            elif o in ('-r', '--show-rpms'):
                if have_rpm:
                    config['show-rpms'] = True
                else:
                    print >>sys.stderr, "ERROR: python rpm module not available, cannot enable RPM reporting"
                    return 1
            elif o in ('-L', '--links'):
                config['action'] = 'list-links'
            elif o in ('-x', '--exclude'):
                exclude_jar_patterns.append(a)
            elif o in ('-s', '--summary'):
                config['summary'] = True
            else:
                raise Usage("command argument '%s' not handled, internal error" % o)
    except Usage, e:
        print >>sys.stderr, e.msg
        print >>sys.stderr, "for help use --help"
        return 2


    for pat in exclude_jar_patterns:
        try:
            regexp = re.compile(pat)
            exclude_jar_regexps.append(regexp)
        except Exception, e:
            print >>sys.stderr, "ERROR, cannot compile exclude pattern '%s' (%s)" % (pat, e)
            return 1

    jc = JarCollection(config['jar-dirs'])

    # Database is now built, determine how we want to query it
    if config['action'] == 'list':
        jc.show_class_jar_paths()
        return 0
    elif config['action'] == 'list-multiple':
        classes = jc.find_multiple_definitions()
        jc.show_class_jar_paths(classes)
        return 0
    elif config['action'] == 'compare-multiple':
        classes = jc.find_multiple_definitions()
        jc.show_class_jar_paths(classes)
        jc.compare_multiple_definitions(classes)
        return 0
    elif config['action'] == 'list-links':
        jc.show_jar_links()
        return 0

    unique_classes = {}
    for pattern in args:
        for klass in jc.search_class(pattern):
            unique_classes[klass] = True
    if len(unique_classes):
        classes = unique_classes.keys()
        classes.sort()
        jc.show_class_jar_paths(classes)

    if config['interactive']:
        completer = Completer(jc.classes)
        readline.set_completer_delims(".")
        readline.set_completer(completer.do_complete)
        readline.parse_and_bind("tab: complete")

        while True:
            try:
                input = raw_input(prompt).strip()
            except EOFError:
                break
            print jc.lookup_class(input)
        return 0

#-------------------------------------------------------------------------------

if __name__ == '__main__':
    sys.exit(main())
