#! /usr/bin/env python
# Author: Martin C. Frith 2020
# SPDX-License-Identifier: GPL-3.0-or-later

from __future__ import print_function

import bisect
import gzip
import optparse
import sys

def openFile(fileName):
    if fileName == "-":
        return sys.stdin
    if fileName.endswith(".gz"):
        return gzip.open(fileName, "rt")  # xxx dubious for Python2
    return open(fileName)

def genesFromLines(opts, lines):
    extension = opts.distance + 1
    for line in lines:
        fields = line.split()
        if len(fields) < 6 or fields[5] in "+-":  # BED
            chrom = fields[0]
            beg = int(fields[1])
            end = int(fields[2])
            if len(fields) > 3:
                geneName = fields[3]
            else:
                geneName = "{0}:{1}-{2}".format(*fields)
        else:  # genePred
            chrom = fields[2]
            beg = int(fields[4])
            end = int(fields[5])
            geneName = fields[12 if len(fields) > 12 else 0]
        yield chrom, beg - extension, end + extension, geneName

def seqRangeFromText(text):
    sequenceName, r = text.split(":")
    parts = r.split("<" if "<" in r else ">")
    beg, end = map(int, parts)
    return sequenceName, beg, end

def breakpointsFromSeqRanges(seqRanges):
    for i, x in enumerate(seqRanges):
        sequenceName, beg, end = seqRangeFromText(x)
        if i > 0:
            yield sequenceName, beg
        if i + 1 < len(seqRanges):
            yield sequenceName, end

def groupsFromLines(lines, wantedGroupNames):
    groupName = None
    isWanted = False
    for line in lines:
        fields = line.split()
        if len(fields) > 1 and fields[0] == "#":
            if fields[1] == "PART":
                if groupName:
                    yield groupName, rangeLists, groupLines
                    groupName = None
                isWanted = fields[2] in wantedGroupNames
            elif ":" in fields[1]:
                rangeLists[-1] += fields[1:]
                groupLines.append(line)
            elif len(fields) == 2:
                if groupName:
                    yield groupName, rangeLists, groupLines
                groupName = fields[1]
                rangeLists = [[]]
                groupLines = [line]
            elif groupName:
                rangeLists.append(fields[2:])
                groupLines.append(line)
        if isWanted:
            print(line, end="")
    if groupName:
        yield groupName, rangeLists, groupLines

def genesAtBreakpoints(breakpoints, genes, maxGeneLength):
    for sequenceName, pos in set(breakpoints):
        i = bisect.bisect(genes, (sequenceName, pos, pos))
        while i > 0:
            i -= 1
            chrom, beg, end, geneName = genes[i]
            if chrom < sequenceName or beg + maxGeneLength <= pos:
                break
            if end > pos:
                yield geneName

def main(opts, genesFile, groupsFile):
    genes = sorted(genesFromLines(opts, openFile(genesFile)))
    maxGeneLength = max(end - beg for chrom, beg, end, geneName in genes)
    wantedGroupNames = set()
    for groupData in groupsFromLines(openFile(groupsFile), wantedGroupNames):
        groupName, rangeLists, groupLines = groupData
        breakpointLists = map(breakpointsFromSeqRanges, rangeLists)
        breakpoints = [j for i in breakpointLists for j in i]
        geneNames = genesAtBreakpoints(breakpoints, genes, maxGeneLength)
        geneNames = sorted(set(geneNames))
        if geneNames:
            if opts.output == 0:
                print(groupName, ",".join(geneNames), sep="\t")
            else:
                print(*groupLines, sep="")
                wantedGroupNames.add(groupName)

if __name__ == "__main__":
    usage = "%prog [options] refGene.txt groups.maf > groups.txt"
    descr = "Find genes at or near rearrangement breakpoints."
    op = optparse.OptionParser(usage=usage, description=descr)
    op.add_option("-d", "--distance", metavar="D", type="int", default=100,
                  help="maximum distance (default=%default)")
    op.add_option("-o", "--output", metavar="N", type="int", default=0,
                  help="0=simple table, 1=parts of input near genes "
                  "(default=%default)")
    opts, args = op.parse_args()
    if len(args) == 2:
        main(opts, *args)
    else:
        op.print_help()
