#!/usr/bin/env python

import argparse
import logging
import textwrap

from twisted.internet.defer import inlineCallbacks, returnValue, succeed

from aiki.cli import (
    make_arg_parser, setup_logging, run_command, parse_relation_settings,
    format_states, format_descriptors, format_port_pairs)
from aiki.twistutils import CountDownLatch, wait_for_results
from aiki.watcher import Watcher, ANY_SETTING
from juju.hooks.cli import parse_port_protocol


log = logging.getLogger("watch")
ERROR_STATES = set(["error", "install-error", "start-error", "stop-error"])


def main():
    parser = make_parser()
    options = parse_options(parser)
    setup_logging(options)
    run_command(WatchManager(), options)


def make_parser(root_parser=None):
    main_parser = make_arg_parser(
        root_parser, "watch",
        help="waits on Juju environment for specified conditions",
        description="Wait on Juju environment for specified conditions to become true",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        usage=textwrap.dedent("""\
        watch [-h] [-e ENVIRONMENT]
                   [--loglevel CRITICAL|ERROR|WARNING|INFO|DEBUG] [--verbose]
                   [--any] [--number]
                   CONDITION [CONDITION ...]

        Watch Juju environment for condition(s) to become true, waiting as necessary.

        Each condition is specified as follows:

          INSTANCE              Either service, relation, or service unit. For a 
                                relation, either its relation id (example: db:0) or
                                descriptors can be used (examples: "mysql wordpress"
                                or "mysql:db wordpress:db")

          -r RELATION, --relation RELATION
                                Specify relation for unit or service
          -n NUM, --num-units NUM
                                For a service, this is the number of units for which
                                the following *unit* conditions must be true. Defaults
                                to 1 if any unit conditions are specified
          --open-port PORT      Port is open for unit
          --closed-port PORT    Port is closed for unit
          --setting SETTING     Relation setting exists for unit
          --state STATE         Unit is in this state
          --x-state STATE       Cannot be in this state for unit
          [NAME=SETTING [NAME=SETTING ... ]]
                                Settings are in effect for unit
        """),
        epilog=textwrap.dedent("""\
        Examples of watches with single conditions:

        $ relation_id=`jitsu watch "wordpress mysql"`        # waits until relation available and prints id; must be quoted
        $ relation_id=`jitsu watch "wordpress:db mysql:app"` # like above, but with fully specified descriptors
        $ jitsu watch mysql                                  # service is deployed
        $ jitsu watch mysql/0 --state=started                # unit is started
        $ timeout 60s jitsu watch mysql/0 --state=started    # watch up to 60s for mysql/0 to be running, using core timeout command
        $ jitsu watch -r db:0 mysql/0 foo=bar                # watch for foo to be set to bar on mysql/0 on relation id of db:0
        $ jitsu watch mysql/0 -r "wordpress mysql" foo=bar   # watch for wordpress<->mysql, then watch for this setting
        $ jitsu watch mysql/0 -r db:0 foo=                   # watch for foo to be unset
        $ jitsu watch mysql/0 -r db:0 --setting=foo          # watch for foo to be set to *some* value

        Multiple conditions can be combined:

        $ jitsu watch mysql/0 -r "wordpress mysql" foo= mysql/1 -r "wordpress mysql" foo=bar # all conditions must apply
        $ jitsu --any watch ...                                                              # any of the conditions may apply
    """))
    main_parser.add_argument("--any", default=False, action="store_true", help="Any of the conditions may be true")
    main_parser.add_argument("--failfast", default=False, action="store_true", help="Immediately fail if any conditions are in error state")
    main_parser.add_argument("--number", default=False, action="store_true", help="Number output by the corresponding condition")
    return main_parser


def get_condition_parser():
    # Note: generally when setting up a parser, the following would be
    # added as a subparser. However, instead this is being used to
    # parse a sequence of args, so it stands alone, due to the
    # complexity of parsing conditions. Note that there is no help
    # defined here, since this complexity requires that it be put in
    # the main parser instead.
    condition_parser = argparse.ArgumentParser()
    condition_parser.add_argument("-n", "--num-units", default=None, type=int)
    condition_parser.add_argument("--open-port", default=[], dest="open_ports", action="append")
    condition_parser.add_argument("--closed-port", default=[], dest="closed_ports", action="append")
    condition_parser.add_argument("-r", "--relation", default=None)
    condition_parser.add_argument("--setting", default=[], action="append")
    condition_parser.add_argument("--state", default=[], dest="states", action="append")
    condition_parser.add_argument("--x-state", default=[], dest="excluded_states", action="append")
    condition_parser.add_argument("rest", nargs=argparse.REMAINDER)
    return condition_parser


def parse_options(main_parser):
    condition_parser = get_condition_parser()
    options, condition_args = main_parser.parse_known_args()

    # Partially parse (using remainder args) multiple times working
    # left to right to get conditions until all original args are
    # consumed
    options.conditions = []
    count = 0
    while condition_args:
        instance = condition_args.pop(0)
        condition = condition_parser.parse_args(condition_args)
        condition.instance = instance
        condition.count = count
        options.conditions.append(condition)  # Just collect together
        condition.settings, condition_args = parse_relation_settings(condition.rest)
        for setting in condition.setting:
            condition.settings[setting] = ANY_SETTING
        condition.states = set(condition.states)
        condition.excluded_states = set(condition.excluded_states)
        count += 1
    return options


class WatchManager(object):
    """Manages the watching of all conditions"""

    def __init__(self):
        self.overall = None
        self.exception = None

    def failfast(self, e):
        log.error("Failed fast: %s", e)
        self.exception = e
        self.overall.callback(None)

    @inlineCallbacks
    def __call__(self, result, client, options):
        self.options = options
        self.client = client
        self.watcher = Watcher(self.client)
        wait_for_conditions = []
        for condition in self.options.conditions:
            wait_for_conditions.append(self.parse_and_wait_on_condition(condition))
        self.overall = wait_for_results(wait_for_conditions, self.options.any)
        yield self.overall
        if self.exception:
            raise self.exception

    def print_result(self, condition, result):
        """Given conditions complete asynchronously, the user can specify numbering to disambiguate"""
        if self.options.number:
            print condition.count,
        print result

    @inlineCallbacks
    def parse_and_wait_on_condition(self, condition):
        """Given a condition description from argparse, creates a watcher for it"""

        # Disambiguate the following cases: relation, unit, or service:
        # 1. Relation
        if len(condition.instance.split()) >= 2:
            # Only handle non-peer relations, since peer relations are automatically established anyway;
            # also handle error case of a relation between more than 2 services
            relation_ident = yield self.wait_for_relation(condition, condition.instance)
            self.print_result(condition, relation_ident)
            return

        # 2. Service unit
        if "/" in condition.instance:
            yield self.wait_for_unit(condition, condition.instance)
            self.print_result(condition, condition.instance)
            return

        # 3. Service
        yield self.wait_for_service(condition)
        self.print_result(condition, condition.instance)

    @inlineCallbacks
    def wait_for_unit(self, condition, unit_name, relation_ident=None):
        log.info("Waiting for unit %s...", unit_name)
        yield self.watcher.wait_for_unit(unit_name)

        # Support --state/--x-state
        if self.options.failfast:
            condition.states.update(ERROR_STATES)

        if condition.states or condition.excluded_states:
            log.info(
                "Waiting for unit %s to be in %s and not in %s", 
                unit_name, format_states(condition.states), format_states(condition.excluded_states))
            agent_state = yield self.watcher.wait_for_unit_state(unit_name, condition.states, condition.excluded_states)
            log.info("Completed waiting for unit %s state: %s", unit_name, agent_state)
            if self.options.failfast and agent_state in ERROR_STATES:
                self.failfast(Exception("Unit %s in error state: %s" % (unit_name, agent_state)))

        # Support --open-port/--closed-port
        if condition.open_ports or condition.closed_ports:
            open_ports = set(parse_port_protocol(p) for p in condition.open_ports)
            closed_ports = set(parse_port_protocol(p) for p in condition.closed_ports)
            log.info(
                "Waiting for unit %s to have open ports in %s and not in %s", 
                unit_name, format_port_pairs(open_ports), format_port_pairs(closed_ports))
            open_ports = yield self.watcher.wait_for_unit_ports(unit_name, open_ports, closed_ports)
            log.info("Completed waiting for unit %s open ports: %s", unit_name, format_port_pairs(open_ports))

        # Support relation settings
        if condition.relation:
            if relation_ident is None:
                relation_ident = yield self.wait_for_relation(condition)
            if condition.settings:
                log.info("Waiting for %s: settings %s", unit_name, condition.settings)
                settings = yield self.watcher.watch_unit_settings(unit_name, relation_ident, condition.settings)
                log.info("Completed waiting for %s: expected %s, actual %s",
                         unit_name, condition.settings, settings)
        log.info("Completed waiting for unit %s", unit_name) 
        returnValue(unit_name)

    @inlineCallbacks
    def wait_for_relation(self, condition, relation=None):
        """Return relation ident corresponding to the condition or relation (if specified), waiting as necessary"""

        if relation is None:
            relation = condition.relation

        if ":" in relation:
            parts = relation.split(":")
            if len(parts) == 2 and parts[1].isdigit():
                relation_ident = relation
                returnValue(relation_ident)

        # Otherwise wait for it:
        descriptors = relation.split()
        if len(descriptors) == 1 or len(descriptors) == 2:
            log.info("Waiting for %s...", format_descriptors(*descriptors))
            relation_ident = yield self.watcher.wait_for_relation(*descriptors)
            log.info("Completed waiting for %s", format_descriptors(*descriptors))
            returnValue(relation_ident)

        raise ValueError("Bad relation: %s" % relation)

    @inlineCallbacks
    def wait_for_service(self, condition):
        """Return service name when sucessfully waited"""
        log.info("Waiting for %s%s service...", 
                 "%d unit(s) of " % condition.num_units if condition.num_units else "",
                 condition.instance)
        yield self.watcher.wait_for_service(condition.instance)
        log.info("Completed waiting for %s service", condition.instance)
        if condition.relation:
            relation_ident = yield self.wait_for_relation(condition)
        else:
            relation_ident = None

        # If options used in the condition imply service unit watching,
        # then default num_units to 1
        if condition.num_units is None and (
                condition.states or condition.excluded_states or condition.settings or
                condition.open_ports or condition.closed_ports):
            condition.num_units = 1

        if condition.num_units:
            latch = CountDownLatch(condition.num_units)

            def new_unit_cb(unit_name):
                # NOTE Could also presumably cancel any removed units, consider that for a future refactoring
                return succeed(latch.add(self.wait_for_unit(condition, unit_name, relation_ident)))

            yield self.watcher.watch_new_service_units(condition.instance, new_unit_cb)
            yield latch.completed
        returnValue(condition.instance)


if __name__ == '__main__':
    main()
