#!/usr/bin/env python

"""
Augment the alignment_summary.gff file with consensus and variants information.
"""

from collections import namedtuple, defaultdict
import argparse
import logging
import bisect
import json
import gzip
import sys

import numpy as np

from pbcommand.utils import setup_log
from pbcommand.cli import pbparser_runner, get_default_argparser
from pbcommand.models import FileTypes, get_pbparser
from pbcommand.common_options import add_resolved_tool_contract_option
from pbcore.io import GffReader, GffWriter, Gff3Record

from GenomicConsensus.utils import error_probability_to_qv
from GenomicConsensus import __VERSION__

#
# Note: GFF-style coordinates
#
Region = namedtuple("Region", ("seqid", "start", "end"))

log = logging.getLogger(__name__)

class Constants(object):
    TOOL_ID = "genomic_consensus.tasks.summarize_consensus"
    DRIVER_EXE = "summarizeConsensus --resolved-tool-contract "


def get_contract_parser():
    p = get_pbparser(
        Constants.TOOL_ID,
        __VERSION__,
        "Summarize Consensus",
        __doc__,
        Constants.DRIVER_EXE,
        default_level="ERROR")
    p.add_input_file_type(FileTypes.GFF, "alignment_summary",
        "Alignment summary GFF", "Alignment summary GFF file")
    p.tool_contract_parser.add_input_file_type(FileTypes.GFF, "variants",
        "Variants GFF", "Variants GFF file")
    p.arg_parser.parser.add_argument("--variantsGff",
        type=str,
        help="Input variants.gff or variants.gff.gz filename",
        required=True)
    p.tool_contract_parser.add_output_file_type(FileTypes.GFF, "output",
        name="Output GFF file",
        description="New alignment summary GFF file",
        default_name="alignment_summary_variants")
    p.arg_parser.parser.add_argument("-o", "--output",
        type=str,
        help="Output alignment_summary.gff filename")
    return p

def get_args_from_resolved_tool_contract(resolved_tool_contract):
    rtc = resolved_tool_contract
    p = get_contract_parser().arg_parser.parser
    args = [
        rtc.task.input_files[0],
        "--variantsGff", rtc.task.input_files[1],
        "--output", rtc.task.output_files[0],
    ]
    return p.parse_args(args)

def run(options):
    headers = [
        ("source", "GenomicConsensus %s" % __VERSION__),
        ("pacbio-alignment-summary-version", "0.6"),
        ("source-commandline", " ".join(sys.argv)),
        ]

    inputVariantsGff = GffReader(options.variantsGff)
    inputAlignmentSummaryGff = GffReader(options.alignment_summary)

    summaries = {}
    for gffRecord in inputAlignmentSummaryGff:
        region = Region(gffRecord.seqid, gffRecord.start, gffRecord.end)
        summaries[region] = { "ins" : 0,
                              "del" : 0,
                              "sub" : 0,
                              # TODO: base consensusQV on effective coverage
                              "cQv" : (20, 20, 20)
                             }
    inputAlignmentSummaryGff.close()

    counterNames = { "insertion"    : "ins",
                     "deletion"     : "del",
                     "substitution" : "sub" }
    regions_by_contig = defaultdict(list)
    for region in summaries:
        regions_by_contig[region.seqid].append(region)
    for seqid in regions_by_contig.keys():
        r = regions_by_contig[seqid]
        regions_by_contig[seqid] = sorted(r, lambda a,b: cmp(a.start, b.start))
    logging.info("Processing variant records")
    i = 0
    have_contigs = set(regions_by_contig.keys())
    for variantGffRecord in inputVariantsGff:
        if not variantGffRecord.seqid in have_contigs:
            raise KeyError(
                "Can't find alignment summary for contig '{s}".format(
                s=variantGffRecord.seqid))
        positions = [r.start for r in regions_by_contig[variantGffRecord.seqid]]
        idx = bisect.bisect_right(positions, variantGffRecord.start) - 1
        # XXX we have to be a little careful here - an insertion at the start
        # of a contig will have start=0 versus start=1 for the first region
        if idx < 0:
            idx = 0
        region = regions_by_contig[variantGffRecord.seqid][idx]
        assert ((region.start <= variantGffRecord.start <= region.end) or
                (region.start == 1 and variantGffRecord.start == 0 and
                 variantGffRecord.type == "insertion")), \
            (variantGffRecord.seqid, region.start, variantGffRecord.start,
             region.end, variantGffRecord.type, idx)
        summary = summaries[region]
        counterName = counterNames[variantGffRecord.type]
        variantLength = max(len(variantGffRecord.reference),
                            len(variantGffRecord.variantSeq))
        summary[counterName] += variantLength
        i += 1
        if i % 1000 == 0:
            logging.info("{i} records...".format(i=i))


    inputAlignmentSummaryGff = open(options.alignment_summary)
    outputAlignmentSummaryGff = open(options.output, "w")

    inHeader = True

    for line in inputAlignmentSummaryGff:
        line = line.rstrip()

        # Pass any metadata line straight through
        if line[0] == "#":
            print >>outputAlignmentSummaryGff, line.strip()
            continue

        if inHeader:
            # We are at the end of the header -- write the tool-specific headers
            for k, v in headers:
                print >>outputAlignmentSummaryGff, ("##%s %s" % (k, v))
            inHeader = False

        # Parse the line
        rec = Gff3Record.fromString(line)

        if rec.type == "region":
            summary = summaries[(rec.seqid, rec.start, rec.end)]
            if "cQv" in summary:
                cQvTuple = summary["cQv"]
                line += ";%s=%s" % ("cQv", ",".join(str(int(f)) for f in cQvTuple))
            for counterName in counterNames.values():
                if counterName in summary:
                    line += ";%s=%d" % (counterName, summary[counterName])
            print >>outputAlignmentSummaryGff, line
    return 0


def args_runner(args):
    return run(options=args)


def resolved_tool_contract_runner(resolved_tool_contract):
    args = get_args_from_resolved_tool_contract(resolved_tool_contract)
    return run(options=args)


def main(argv=sys.argv):
    return pbparser_runner(argv[1:],
                           get_contract_parser(),
                           args_runner,
                           resolved_tool_contract_runner,
                           log,
                           setup_log)

if __name__ == "__main__":
    sys.exit(main())
