include: "Common.snake"

configfile: "config.yaml"

import os.path

onstart:
    shell("mkdir -p tmp")

def final_stage(w):
    if config["reassembly"]["enabled"]:
        return ("reassembly/choose.done", "propagation.done") #Stop after the bin choosing
    if config["propagation"]["enabled"]:
        return "propagation.done" #Stop on the propagation
    return "binning/filtered_bins.tsv" #Stop on the preliminary binning

rule all:
    input:   final_stage
    message: "Dataset of {SAMPLE_COUNT} samples from {IN} has been processed."

# ---- Assembly ----------------------------------------------------------------

# Assemble with MegaHIT
rule megahit:
    input:   left=left_reads, right=right_reads
    output:  "assembly/megahit/{group}/final.contigs.fa"
    params:  left=lambda w: ",".join(left_reads(w)),
             right=lambda w: ",".join(right_reads(w)),
             dir="assembly/megahit/{group}"
    threads: THREADS
    log:     "assembly/megahit.log"
    message: "Assembling {wildcards.group} with MegaHIT"
    shell:   "rm -rf {params.dir} &&"
             " {SOFT}/megahit/megahit -1 {params.left} -2 {params.right}"
             " -t {threads} -o {params.dir} >{log} 2>&1"

rule megahit_convert:
    input:   "assembly/megahit/{group}/final.contigs.fa"
    output:  "assembly/megahit/{group}.fasta"
    message: "Converting megahit contig names from {wildcards.group} into our format"
    shell:   "sed -r 's/>(k[0-9]+)_([0-9]+) flag=. multi=([0-9]+|[0-9]+\.[0-9]+) len=([0-9]+)/>NODE_\\2_length_\\4_cov_\\3/' {input} > {output}"

# Assemble with SPAdes
rule spades:
    input:   left=left_reads, right=right_reads
    output:  "assembly/spades/{group}.fasta"
    params:  left=lambda w: " ".join(expand("-1 {r}", r=left_reads(w))),
             right=lambda w: " ".join(expand("-2 {r}", r=right_reads(w))),
             dir="assembly/spades/{group}", bh=lambda w: "" if is_fastq(w) else "--only-assembler"
    threads: THREADS
    log:     "assembly/{group}.log"
    message: "Assembling {wildcards.group} with metaSPAdes"
    shell:   "{ASSEMBLER_DIR}/spades.py {params.bh} --meta -m 400 -t {threads}"
             " {params.left} {params.right}"
             " --save-gp -o {params.dir} >{log} 2>&1 && "
             "cp {params.dir}/scaffolds.fasta {output}"

rule copy_contigs:
    input:   "assembly/{}/{{group}}.fasta".format(ASSEMBLER)
    output:  "assembly/full/{group,(sample|group)\d+}.fasta"
    message: "Prepare full contigs of {wildcards.group}"
    shell:   "{SCRIPTS}/cut_fasta.py -l {MIN_CONTIG_LENGTH} -m {input} > {output}"

rule split_contigs:
    input:   "assembly/{}/{{group}}.fasta".format(ASSEMBLER)
    output:  "assembly/splits/{group,(sample|group)\d+}.fasta"
    message: "Cutting contigs from {wildcards.group} into {SPLIT_LENGTH} bp splits"
    shell:   "{SCRIPTS}/cut_fasta.py -c {SPLIT_LENGTH} -l {MIN_CONTIG_LENGTH} -m {input} > {output}"

#---- Generating profiles/depths -----------------------------------------------

PROF_VAR = config["profile"].get("var", BINNER == "metabat")
PROF_TYPE = "var" if PROF_VAR else "mpl"
FRAGS = "full" if not SPLIT_LENGTH else "splits"

# Importing the specific profiler
include: "profilers/{}.snake".format(PROFILER)

#---- Binning ------------------------------------------------------------------

# Importing the specific binner
include: "binners/{}.snake".format(BINNER)

rule bin_profiles:
    input:   "binning/profiles.tsv", "binning/binning.tsv"
    output:  "binning/bins.prof"
    message: "Deriving bin profiles"
    shell:   "{SCRIPTS}/bin_profiles.py {input} {PROF_TYPE} >{output}"

rule choose_bins:
    input:   "binning/binning.tsv"
    output:  "binning/filtered_bins.tsv"
    message: "Filter small bins"
    shell:   "{SCRIPTS}/choose_bins.py {BIN_LENGTH} {input} >{output}"

#---- Post-clustering pipeline -------------------------------------------------

rule annotate:
    input:   "binning/binning.tsv"
    output:  expand("binning/annotation/{sample}.ann", sample=GROUPS)
    params:  "binning/annotation/"
    message: "Preparing raw annotations"
    run:
        samples_annotation = dict()
        #Load the whole annotation: {sample: [bins]}
        with open(input[0]) as input_file:
            for line in input_file:
                annotation_str = line.split("\t", 1)
                bin_id = annotation_str[1].strip()
                sample_contig = annotation_str[0].split('-', 1)
                if len(sample_contig) > 1:
                    sample = sample_contig[0]
                    contig = sample_contig[1]
                else: #Backward compatibility with old alternative pipeline runs
                    sample = "group1"
                    contig = sample_contig[0]
                annotation = samples_annotation.setdefault(sample, dict())
                if contig not in annotation:
                    annotation[contig] = [bin_id]
                else:
                    annotation[contig].append(bin_id)

        #Serialize it in the propagator format
        for sample, annotation in samples_annotation.items():
            with open(os.path.join(params[0], sample + ".ann"), "w") as sample_out:
                for contig in samples_annotation[sample]:
                    print(contig, "\t", " ".join(annotation[contig]), sep="", file=sample_out)

# Propagation stage
#Path to saves of necessary assembly stage
SAVES = "K{0}/saves/01_before_repeat_resolution/graph_pack".format(ASSEMBLY_K)

rule prop_binning:
    input:   contigs="assembly/spades/{group}.fasta", splits="assembly/{}/{{group}}.fasta".format(FRAGS),
             ann="binning/annotation/{group}.ann",    left=left_reads, right=right_reads,
             bins="binning/filtered_bins.tsv"
    output:  ann="propagation/annotation/{group}.ann", edges="propagation/edges/{group}.fasta"
    params:  saves=os.path.join("assembly/spades/{group}/", SAVES),
             samples=lambda wildcards: " ".join(GROUPS[wildcards.group])
    log:     "binning/{group}.log"
    message: "Propagating annotation & binning reads for {wildcards.group}"
    shell:   "mkdir -p reads && "
             "{BIN}/prop_binning -k {ASSEMBLY_K} -s {params.saves} -c {input.contigs} -b {input.bins}"
             " -n {params.samples} -l {input.left} -r {input.right} -t {MIN_CONTIG_LENGTH}"
             " -a {input.ann} -f {input.splits} -o reads -p {output.ann} -e {output.edges} >{log} 2>&1"

rule prop_all:
    input:   expand("propagation/annotation/{group}.ann", group=GROUPS)
    output:  touch("propagation.done")
    message: "Finished propagation of all annotations."

rule choose_samples:
    input:   prof="binning/bins.prof", bins="binning/filtered_bins.tsv",
             flag="propagation.done"
    output:  touch("reassembly/choose.done")
    log:     "reassembly/choose_samples.log"
    message: "Choosing bins for reassembly and samples for them"
    shell:   "rm -f reassembly/*.info && rm -rf reassembly/excluded && "
             "{SCRIPTS}/choose_samples.py -f {input.bins} -r reads -o reassembly {input.prof} >{log} 2>&1"
