#!/usr/bin/python3

"""
This script generates Graphviz-compatible dot scripts from the class
definitions of the containing project. Specify the root class to generate with
the -i (multiple roots can be specified). Specify parts of the hierarchy to
exclude with -x. Default configurations can be specified in the containing
project's setup.cfg under [{SETUP_SECTION}]
"""

from __future__ import annotations

import re
import sys
assert sys.version_info >= (3, 6), 'Script requires Python 3.6+'
import typing as t
from pathlib import Path
from configparser import ConfigParser
from argparse import ArgumentParser, Namespace, FileType


PROJECT_ROOT = (Path(__file__).parent / '..').resolve()
SETUP_SECTION = str(Path(__file__).name) + ':settings'


def main(args: t.List[str] = None):
    if args is None:
        args = sys.argv[1:]
    config = get_config(args)

    m = make_class_map(config.source, config.omit)
    if config.include or config.exclude:
        m = filter_map(m, include_roots=config.include,
                       exclude_roots=config.exclude)
    config.output.write(render_map(m, config.abstract))


def get_config(args: t.List[str]) -> Namespace:
    config = ConfigParser(
        defaults={
            'source': '',
            'include': '',
            'exclude': '',
            'abstract': '',
            'omit': '',
            'output': '-',
        },
        delimiters=('=',), default_section=SETUP_SECTION,
        empty_lines_in_values=False, interpolation=None,
        converters={'list': lambda s: s.strip().splitlines()})
    config.read(PROJECT_ROOT / 'setup.cfg')
    sect = config[SETUP_SECTION]
    # Resolve source and output defaults relative to setup.cfg
    if sect['source']:
        sect['source'] = '\n'.join(
            str(PROJECT_ROOT / source)
            for source in sect.getlist('source')
        )
    if sect['output'] and sect['output'] != '-':
        sect['output'] = str(PROJECT_ROOT / sect['output'])

    parser = ArgumentParser(description=__doc__.format(**globals()))
    parser.add_argument(
        '-s', '--source', action='append', metavar='PATH',
        default=sect.getlist('source'),
        help="the pattern(s) of files to search for classes; can be specified "
        "multiple times. Default: %(default)r")
    parser.add_argument(
        '-i', '--include', action='append', metavar='CLASS',
        default=sect.getlist('exclude'),
        help="only include classes which have BASE somewhere in their "
        "ancestry; can be specified multiple times. Default: %(default)r")
    parser.add_argument(
        '-x', '--exclude', action='append', metavar='CLASS',
        default=sect.getlist('exclude'),
        help="exclude any classes which have BASE somewhere in their "
        "ancestry; can be specified multiple times. Default: %(default)r")
    parser.add_argument(
        '-o', '--omit', action='append', metavar='CLASS',
        default=sect.getlist('omit'),
        help="omit the specified class, but not its descendents from the "
        "chart; can be specified multiple times. Default: %(default)r")
    parser.add_argument(
        '-a', '--abstract', action='append', metavar='CLASS',
        default=sect.getlist('abstract'),
        help="mark the specified class as abstract, rendering it in a "
        "different color; can be specified multiple times. Default: "
        "%(default)r")
    parser.add_argument(
        'output', nargs='?', type=FileType('w'),
        default=sect['output'],
        help="the file to write the output to; defaults to stdout")
    ns = parser.parse_args(args)
    ns.abstract = set(ns.abstract)
    ns.include = set(ns.include)
    ns.exclude = set(ns.exclude)
    ns.omit = set(ns.omit)
    if not ns.source:
        ns.source = [str(PROJECT_ROOT)]
    ns.source = set(ns.source)
    return ns


def make_class_map(search_paths: t.List[str], omit: t.Set[str])\
        -> t.Dict[str, t.Set[str]]:
    """
    Find all Python source files under *search_paths*, extract (via a crude
    regex) all class definitions and return a mapping of class-name to the list
    of base classes.

    All classes listed in *omit* will be excluded from the result, but not
    their descendents (useful for excluding "object" etc.)
    """
    def find_classes() -> t.Iterator[t.Tuple[str, t.Set[str]]]:
        class_re = re.compile(
            r'^class\s+(?P<name>\w+)\s*(?:\((?P<bases>.*)\))?:', re.MULTILINE)
        for path in search_paths:
            p = Path(path)
            for py_file in p.parent.glob(p.name):
                with py_file.open() as f:
                    for match in class_re.finditer(f.read()):
                        if match.group('name') not in omit:
                            yield match.group('name'), {
                                base.strip()
                                for base in (
                                    match.group('bases') or 'object'
                                ).split(',')
                                if base.strip() not in omit
                            }
    return {
        name: bases
        for name, bases in find_classes()
    }


def filter_map(class_map: t.Dict[str, t.Set[str]], include_roots: t.Set[str],
               exclude_roots: t.Set[str]) -> t.Dict[str, t.Set[str]]:
    """
    Returns *class_map* (which is a mapping such as that returned by
    :func:`make_class_map`), with only those classes which have at least one
    of the *include_roots* in their ancestry, and none of the *exclude_roots*.
    """
    def has_parent(cls: str, parent: str) -> bool:
        return cls == parent or any(
            has_parent(base, parent) for base in class_map.get(cls, ()))

    filtered = {
        name: bases
        for name, bases in class_map.items()
        if (not include_roots or
            any(has_parent(name, root) for root in include_roots))
        and not any(has_parent(name, root) for root in exclude_roots)
    }
    pure_bases = {
        base for name, bases in filtered.items() for base in bases
    } - set(filtered)
    # Make a second pass to fill in missing links between classes that are
    # only included as bases of other classes
    for base in pure_bases:
        filtered[base] = pure_bases & class_map[base]
    return filtered


def render_map(class_map: t.Dict[str, t.Set[str]], abstract: t.Set[str]) -> str:
    """
    Renders *class_map* (which is a mapping such as that returned by
    :func:`make_class_map`) to graphviz's dot language.

    The *abstract* sequence determines which classes will be rendered lighter
    to indicate their abstract nature. All classes with names ending "Mixin"
    will be implicitly rendered in a different style.
    """
    def all_names(class_map: t.Dict[str, t.Set[str]]) -> t.Iterator[str]:
        for name, bases in class_map.items():
            yield name
            for base in bases:
                yield base

    template = """\
digraph classes {{
    graph [rankdir=RL];
    node [shape=rect, style=filled, fontname=Sans, fontsize=10];
    edge [];

    /* Mixin classes */
    node [color="#c69ee0", fontcolor="#000000"]
    {mixin_nodes}

    /* Abstract classes */
    node [color="#9ec6e0", fontcolor="#000000"]
    {abstract_nodes}

    /* Concrete classes */
    node [color="#2980b9", fontcolor="#ffffff"];
    {concrete_nodes}

    /* Edges */
    {edges}
}}
"""

    return template.format(
        mixin_nodes='\n    '.join(
            '{name};'.format(name=name)
            for name in sorted(set(all_names(class_map)))
            if name.endswith('Mixin')
        ),
        abstract_nodes='\n    '.join(
            '{name};'.format(name=name)
            for name in sorted(abstract & set(all_names(class_map)))
        ),
        concrete_nodes='\n    '.join(
            '{name};'.format(name=name)
            for name in sorted(set(all_names(class_map)))
            if not name.endswith('Mixin')
            and not name in abstract
        ),
        edges='\n    '.join(
            '{name}->{base};'.format(name=name, base=base)
            for name, bases in sorted(class_map.items())
            for base in sorted(bases)
        ),
    )


if __name__ == '__main__':
    sys.exit(main())
