import logging

from twisted.internet.defer import Deferred, inlineCallbacks, returnValue
from zookeeper import NoNodeException
import yaml

from juju.state.base import StateBase
from juju.state.relation import RelationStateManager, ServiceRelationState
from juju.state.service import ServiceStateManager, parse_service_name
from juju.state.errors import (
    RelationStateNotFound, ServiceStateNotFound, ServiceUnitStateNotFound, InvalidRelationIdentity,
    StopWatcher, NoMatchingEndpoints, AmbiguousRelation)
from juju.unit.workflow import WorkflowStateClient


class AnySetting(object):

    def __str__(self):
        return "<Any setting>"

    __repr__ = __str__

ANY_SETTING = AnySetting()


class Watcher(StateBase):

    def __init__(self, client):
        super(Watcher, self).__init__(client)
        self.log = logging.getLogger("aiki.watcher")
        self.relation_state_manager = RelationStateManager(self._client)
        self.service_state_manager = ServiceStateManager(self._client)

    @inlineCallbacks
    def wait_for_relation(self, *descriptors):
        """Wait for a relation described by `descriptors` to exist in the topology"""
        done = Deferred()

        @inlineCallbacks
        def expected(*ignored):
            try:
                relation_id = yield self.get_relation_id_from_descriptors(*descriptors)
                relation_ident = yield self.get_relation_ident(descriptors[0], relation_id)
                done.callback(relation_ident)
                raise StopWatcher()
            except (ServiceStateNotFound, RelationStateNotFound):
                self.log.debug("Topology changed but continuing to wait for relation: %s" % (descriptors,))

        yield self._watch_topology(expected)
        returnValue((yield done))

    @inlineCallbacks
    def wait_for_service(self, service_name):
        """Wait for `service_name` to exist in the topology"""
        done = Deferred()

        @inlineCallbacks
        def expected(*ignored):
            try:
                yield self.service_state_manager.get_service_state(service_name)
                done.callback(service_name)
                raise StopWatcher()
            except ServiceStateNotFound:
                self.log.debug("Topology changed, still waiting for existence of service: %s" % service_name)

        yield self._watch_topology(expected)
        returnValue((yield done))

    @inlineCallbacks
    def wait_for_unit(self, unit_name):
        """Wait for `unit_name` to exist in the topology"""
        done = Deferred()

        @inlineCallbacks
        def expected(*ignored):
            try:
                service_name = parse_service_name(unit_name)
                service = yield self.service_state_manager.get_service_state(service_name)
                yield service.get_unit_state(unit_name)
                done.callback(service_name)
                raise StopWatcher()
            except (ServiceUnitStateNotFound, ServiceStateNotFound):
                self.log.debug("Topology changed, still waiting for existence of unit: %s" % unit_name)

        yield self._watch_topology(expected)
        returnValue((yield done))

    @inlineCallbacks
    def watch_new_service_units(self, service_name, cb):
        """As new service units are added to service_name, calls the `cb` function for each unit."""

        service = yield self.service_state_manager.get_service_state(service_name)
        watched_units = set()

        @inlineCallbacks
        def check(*ignored):
            try:
                units = yield service.get_all_unit_states()
                for unit in units:
                    # NOTE ignore units that have gone missing for the
                    # time being; from actual testing, this seems
                    # unlikely, but we might need to do more elaborate
                    # canceling (StopWatcher). Regardless, that's what
                    # would do with juju, but this is jitsu.
                    if unit.unit_name not in watched_units:
                        watched_units.add(unit.unit_name)

                        from twisted.internet import reactor
                        reactor.callLater(0, cb, unit.unit_name)
            except ServiceStateNotFound:
                self.log.warn("Service is no longer in topology: %s" % service_name)
                raise StopWatcher()

        yield self._watch_topology(check)

    def setup_zk_watch(self, path, callback):
        """Returns a Deferred"""

        @inlineCallbacks
        def manage_callback(*ignored):
            # Need to guard on the client being connected in the case
            # 1) a watch is waiting to run (in the reactor);
            # 2) and the connection is closed.
            #
            # It remains the reponsibility of `callback` to raise
            # `StopWatcher`, per above.
            if not self._client.connected:
                returnValue(None)
            exists_d, watch_d = self._client.exists_and_watch(path)
            stat = yield exists_d
            exists = bool(stat)
            if exists:
                try:
                    yield callback(exists)
                except StopWatcher:
                    returnValue(None)
                except NoNodeException, e:
                    # This may occur if the callback is trying to process
                    # data related to this path, so just ignore
                    self.log.debug("Ignoring no node exception when watching %s: %s", path, e)

            watch_d.addCallback(manage_callback)

        return manage_callback()

    def get_relation_id(self, relation_ident):
        """Return the (internal) relation id for `relation_ident`."""
        # NOTE may want to refactor by adding to juju.state.relation
        parts = relation_ident.split(":")
        if len(parts) != 2 or not parts[1].isdigit():
            raise InvalidRelationIdentity(relation_ident)
        relation_name, normalized_id = parts
        relation_id = "%s-%s" % ("relation", normalized_id.zfill(10))
        if not self.topology.has_relation(relation_id):
            raise RelationStateNotFound()
        return relation_id

    @inlineCallbacks
    def get_relations(self, service_name):
        """Get the relations associated to `service_name`."""
        relations = []
        service_state_manager = ServiceStateManager(self._client)
        self.topology = yield self._read_topology()
        service = yield service_state_manager.get_service_state(service_name)
        internal_service_id = service.internal_id
        for info in self.topology.get_relations_for_service(
                internal_service_id):
            service_info = info["service"]
            relations.append(
                ServiceRelationState(
                    self._client,
                    internal_service_id,
                    info["relation_id"],
                    info["scope"],
                    **service_info))
        returnValue(relations)

    @inlineCallbacks
    def get_relation_ident(self, descriptor, relation_id):
        parts = descriptor.split(":", 1)
        service_name = parts[0]
        relations = yield self.get_relations(service_name)
        relation_ident = [r.relation_ident for r in relations if r.internal_relation_id == relation_id][0]
        returnValue(relation_ident)

    @inlineCallbacks
    def get_relation_id_from_descriptors(self, *descriptors):
        endpoint_pairs = yield self.service_state_manager.join_descriptors(*descriptors)
        if len(endpoint_pairs) == 0:
            raise NoMatchingEndpoints()
        elif len(endpoint_pairs) > 1:
            for pair in endpoint_pairs[1:]:
                if not (pair[0].relation_name.startswith("juju-") or
                        pair[1].relation_name.startswith("juju-")):
                    raise AmbiguousRelation(descriptors, endpoint_pairs)
        endpoints = endpoint_pairs[0]
        if endpoints[0] == endpoints[1]:
            endpoints = endpoints[0:1]

        # so this needs to be put in the context of a continual watch until the topology is in effect
        relation_state = yield self.relation_state_manager.get_relation_state(*endpoints)
        returnValue(relation_state.internal_id)

    @inlineCallbacks
    def watch_unit_settings(self, unit_name, relation_ident, expected_settings):
        """Finally returns when unit_name has the expected settings, if ever"""
        # NOTE may want to refactor by adding to juju.state.relation
        # the generic settings path logic
        self.topology = topology = yield self._read_topology()  # TODO refactor one of these assignments out
        relation_id = self.get_relation_id(relation_ident)
        unit_id = topology.get_service_unit_id_from_name(unit_name)
        container = topology.get_service_unit_container(unit_id)
        container_info = "%s/" % container[-1] if container else ""
        settings_path = "/relations/%s/%ssettings/%s" % (relation_id, container_info, unit_id)
        done = Deferred()

        @inlineCallbacks
        def expected(*ignored):
            content_yaml, stat = yield self._client.get(settings_path)
            content = yaml.safe_load(content_yaml)
            for K, V in expected_settings.iteritems():
                if V is None:
                    # key should *not* have a setting
                    if K in content:
                        break
                elif V is ANY_SETTING:
                    if K not in content:
                        break
                elif content.get(K) != V:
                    break
            else:
                done.callback(content)
                raise StopWatcher()

        yield self.setup_zk_watch(settings_path, expected)
        settings = yield done
        returnValue(settings)

    @inlineCallbacks
    def get_unit_agent_state(self, unit):
        # NOTE extracted from juju.control.status
        unit_workflow_client = WorkflowStateClient(self._client, unit)
        workflow_state = yield unit_workflow_client.get_state()
        if not workflow_state:
            agent_state = "pending"
        else:
            unit_connected = yield unit.has_agent()
            agent_state = workflow_state.replace("_", "-") if unit_connected else "down"
        returnValue(agent_state)

    @inlineCallbacks
    def wait_for_unit_state(self, unit_name, expected_states, excluded_states):
        service_name = parse_service_name(unit_name)
        service = yield self.service_state_manager.get_service_state(service_name)
        unit = yield service.get_unit_state(unit_name)
        done = Deferred()

        @inlineCallbacks
        def expected(*ignored):
            if done.called:
                raise StopWatcher()
            agent_state = yield self.get_unit_agent_state(unit)
            self.log.debug("%s has state: %s", unit_name, agent_state)
            if agent_state in expected_states or (excluded_states and agent_state not in excluded_states):
                if not done.called:
                    done.callback(agent_state)
                raise StopWatcher()

        # Watch both the agent ephemeral plus the workflow state - the
        # agent state is a composite of the two
        yield self.setup_zk_watch(unit._get_agent_path(), expected)
        unit_workflow_client = WorkflowStateClient(self._client, unit)
        yield self.setup_zk_watch(unit_workflow_client.zk_state_path, expected)

        agent_state = yield done
        returnValue(agent_state)

    @inlineCallbacks
    def wait_for_unit_ports(self, unit_name, open_ports, closed_ports):
        service_name = parse_service_name(unit_name)
        service = yield self.service_state_manager.get_service_state(service_name)
        unit = yield service.get_unit_state(unit_name)
        done = Deferred()

        @inlineCallbacks
        def expected(*ignored):
            if done.called:
                raise StopWatcher()
            current_ports = set((p["port"], p["proto"]) for p in (yield unit.get_open_ports()))
            self.log.debug("%s has open ports: %s", unit_name, current_ports)
            if open_ports <= current_ports and not (closed_ports & current_ports):
                if not done.called:
                    done.callback(current_ports)
                raise StopWatcher()

        yield unit.watch_ports(expected)
        current_ports = yield done
        returnValue(current_ports)
