#!/usr/bin/env python3

import os
import re
import sys
import json
import shutil
import hashlib
import tempfile
import argparse
import subprocess

HELP = '''
Prepare a packaged environment, caching/reusing previously build packages.
The path to the environment package will be written to stdout.
If the --approx option is specified, the returned image may contain additional (unrequested) packages.
This mode reuses existing environments that contain a superset of the spec.
If the --merge option is given, spec may be merged into an existing environment.
The argument to --merge gives the maximum Jaccard distance at which to merge.
'''

TAR_EXT = '.tar.gz'
CACHEDIR = os.path.join(tempfile.gettempdir(), 'python_package-{}'.format(os.geteuid()))
SPECS = {}
PKGSETS = {}
PKGSIZE = {}
PKGMTIME = {}

def jaccard(a, b):
    return 1 - float(len(a & b))/len(a | b)

def deps_to_set(conda_deps, pip_deps):
    return frozenset({'C' + x for x in conda_deps} | {'P' + x for x in pip_deps})

def set_to_deps(s):
    conda = []
    pip = []
    for x in s:
        if x[0] == 'C':
            conda.append(x[1:])
        if x[0] == 'P':
            pip.append(x[1:])
    return conda, pip

def extract_deps(spec):
    conda_deps = spec['dependencies'][:]
    pip_deps = []
    for i in range(len(conda_deps)):
        if isinstance(conda_deps[i], dict):
            p = conda_deps.pop(i)
            assert(len(p) == 1)
            pip_deps = p['pip'][:]
    return conda_deps, pip_deps

def insert_deps(spec, conda_deps, pip_deps):
    conda_deps.sort()
    pip_deps.sort()
    spec['dependencies'] = conda_deps[:]
    if len(pip_deps) > 0:
        spec['dependencies'].append({'pip': pip_deps[:]})

def hash_spec(spec):
    insert_deps(spec, *extract_deps(spec)) #sorts dep lists
    return hashlib.sha256(json.dumps(spec, sort_keys=True).encode()).hexdigest()

def compatible(a, b):
    #TODO compare channel configs?
    names_a = {x[1:x.index('=')]: x for x in a}
    names_b = {x[1:x.index('=')]: x for x in b}
    for x in set(names_a.keys()) & set(names_b.keys()):
        if names_a[x] != names_b[x]:
            return False
    return True

def load_cache():
    SPECS.clear()
    PKGSETS.clear()
    PKGSIZE.clear()
    PKGMTIME.clear()
    os.makedirs(CACHEDIR, exist_ok=True)
    for a in os.scandir(CACHEDIR):
        if a.is_dir():
            pass
        elif a.is_file():
            name, ext = os.path.splitext(a.name)
            if ext != '.json':
                continue
            with open(os.path.join(CACHEDIR, a.name)) as f:
                s = json.load(f)
            SPECS[name] = s
            c, p = extract_deps(s)
            PKGSETS[deps_to_set(c, p)] = name
            PKGSIZE[name] = os.stat(os.path.join(CACHEDIR, name) + TAR_EXT).st_size
            PKGMTIME[name] = a.stat().st_mtime
    sys.stderr.write(
            'package cache contains {} env(s) totalling {} byte(s)\n'.format(
                    len(SPECS), sum(PKGSIZE.values())))

def find_exact(spec):
    h = hash_spec(spec)
    path = os.path.join(CACHEDIR, h + TAR_EXT)
    if os.access(path, os.R_OK):
        sys.stderr.write('found exact match\n')
        os.utime(os.path.join(CACHEDIR, h + '.json'))
        return path

def find_matching(spec):
    s = deps_to_set(*extract_deps(spec))
    hits = [x for x in PKGSETS.keys() if s.issubset(x)]
    hits.sort(key=lambda x: jaccard(s, x))
    sys.stderr.write('found {} env(s) that satisfy requirements\n'.format(len(hits)))
    if len(hits) > 0:
        os.utime(os.path.join(CACHEDIR, PKGSETS[hits[0]] + '.json'))
        return os.path.join(CACHEDIR, PKGSETS[hits[0]] + TAR_EXT)

def insert(spec):
    out = os.path.join(CACHEDIR, hash_spec(spec))

    meta_fd, meta_path = tempfile.mkstemp(dir=CACHEDIR, suffix='.yml')
    meta = os.fdopen(meta_fd, mode='w')
    json.dump(spec, meta, indent=2)
    meta.close()

    tarball_fd, tarball_path = tempfile.mkstemp(dir=CACHEDIR, suffix=TAR_EXT)
    sys.stderr.write('creating new env at {}\n'.format(tarball_path))
    os.close(tarball_fd)
    subprocess.run(['python_package_create', meta_path, tarball_path],
            check=True, stdout=sys.stderr)

    os.replace(tarball_path, out + TAR_EXT)
    os.replace(meta_path, out + '.json')

    return out + TAR_EXT

def merge(spec, alpha):
    s = deps_to_set(*extract_deps(spec))
    candidates = [x for x in PKGSETS.keys() if compatible(x, s) and jaccard(x, s) < alpha]
    candidates.sort(key=lambda x: jaccard(s, x))
    if len(candidates) == 0:
        sys.stderr.write('no similar envs found (max distance={})\n'.format(alpha))
        return
    sys.stderr.write('merging with similar env (distance={})'.format(
            jaccard(s, candidates[0])))

    existing = SPECS[PKGSETS[candidates[0]]]
    path = os.path.join(CACHEDIR, hash_spec(existing))
    os.unlink(path + '.json')
    os.unlink(path + TAR_EXT)
    insert_deps(existing, *set_to_deps(s | candidates[0]))
    return insert(existing)

def delete(env):
    sys.stderr.write('deleting {}\n'.format(env))
    prefix = os.path.join(CACHEDIR, env)
    os.unlink(prefix + '.json')
    os.unlink(prefix + TAR_EXT)
    PKGMTIME.pop(env)
    return PKGSIZE[env]

def shrink_cache(capacity):
    sys.stdout.write('cleaning cache\n')
    s = sum(PKGSIZE.values())
    deleted = 0
    if s > capacity:
        sys.stderr.write('clearing {} byte(s) from the cache\n'.format(s - capacity))
    while s > capacity:
        s -= delete(min(PKGMTIME, key=PKGMTIME.get))
        deleted += 1
    if deleted > 0:
        sys.stderr.write('deleted {} env(s)\n'.format(deleted))

def get_env(spec, approx, alpha):
    if approx or alpha > 0:
        hit = find_matching(spec)
        if hit:
            return hit
        merged = merge(spec, alpha)
        if merged:
            return merged
    else:
        hit = find_exact(spec)
        if hit:
            return hit

    return insert(spec)

def check_context(ctx):
    if not re.match(r'^[-a-zA-Z0-9_]+$', ctx):
        raise ValueError('contexts must consist of alphanumerics, underscores, and dashes')

def keep(path, ctx):
    check_context(ctx)
    prefix = os.path.join(CACHEDIR, ctx)
    saved = os.path.join(prefix, os.path.basename(path))
    os.makedirs(prefix, exist_ok=True)
    try:
        os.link(path, saved)
    except FileExistsError:
        pass
    return saved

def free_ctx(ctx):
    check_context(ctx)
    try:
        shutil.rmtree(os.path.join(CACHEDIR, ctx))
    except FileNotFoundError:
        pass

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=HELP)
    parser.add_argument('spec', help='Spec file, or - to read from stdin.', nargs='?')
    parser.add_argument('--approx', action='store_true', help='Reuse environments containing a superset of spec')
    parser.add_argument('--merge', type=float, default=0, help='Allow merging of similar environments. Implies --approx')
    parser.add_argument('--capacity', type=int, help='Limit on cache size in bytes')
    parser.add_argument('--keep', metavar='CTX', help='Protect the returned environment until the context is freed')
    parser.add_argument('--free', metavar='CTX', help='Allow environments in the given context to be cleaned up')
    args = parser.parse_args()

    if args.free:
        if args.spec:
            parser.error('--free must be used by itself')
        free_ctx(args.free)
        sys.exit(0)

    if not args.spec:
        parser.error('spec is required')

    spec_file = sys.stdin if args.spec == '-' else open(args.spec)
    spec = json.load(spec_file)

    load_cache()
    if args.capacity:
        shrink_cache(args.capacity)
        load_cache()

    out = get_env(spec, args.approx, args.merge)
    if args.keep:
        print(keep(out, args.keep))
    else:
        print(out)
