#!/usr/bin/python3

import argparse
import configparser
from enum import Enum
from string import Template
import subprocess
from pathlib import Path
import os
import pwd
import signal
import sys
from typing import (
    Any,
    Collection,
    Optional,
    Sequence,
    TypedDict,
)
from concurrent.futures import ThreadPoolExecutor

__VERSION__ = "0.20251015.0"


def make_table(data: Collection[Collection[Any]], headers: Sequence[str]) -> str:
    """
    Alternative for tabulate.tabulate, to make git-keeper free of deps.
    """
    s = ""
    # get longest item in every col
    colnums = len(headers)
    collens = [0] * colnums
    for idx, head in enumerate(headers):
        collens[idx] = max(collens[idx], len(head))
    for row in data:
        assert len(row) == colnums
        for idx, col in enumerate(row):
            collens[idx] = max(collens[idx], len(str(col)))

    for idx, head in enumerate(headers):
        s += head + " " * (collens[idx] - len(head) + 1)
    s += "\n"
    for idx in range(len(headers)):
        s += "-" * (collens[idx]) + " "
    s += "\n"
    for ridx, row in enumerate(data):
        for idx, col in enumerate(row):
            s += str(col) + " " * (collens[idx] - len(str(col)) + 1)
        if ridx != len(data) - 1:
            s += "\n"
    return s


# user email & name config
class InnerUserConfigType(TypedDict):
    path: Path
    user: Optional[str]


UserConfigType = dict[str, InnerUserConfigType]
if os.environ.get("SUDO_USER"):
    user_home = Path(pwd.getpwuid(int(os.environ["SUDO_UID"])).pw_dir)
else:
    user_home = Path.home()
USER_CONFIG_FILE = user_home / ".git-keeper.conf"
USER_GITCONFIG_FILE = user_home / ".gitconfig"

current_username = os.environ.get("SUDO_USER") or pwd.getpwuid(os.getuid()).pw_name


class GitRuntimeError(RuntimeError):
    returncode: int
    message: str

    def __init__(self, message: str, returncode: int):
        self.message = message
        self.returncode = returncode

    def __str__(self) -> str:
        return f"{self.message} (exit code {self.returncode})"


def _get_passthrough_envs() -> list[str]:
    envs: list[str] = []
    if os.environ.get("SSH_AUTH_SOCK"):
        envs.append(f"SSH_AUTH_SOCK={os.environ['SSH_AUTH_SOCK']}")
    if USER_GITCONFIG_FILE.exists():
        envs.append(f"GIT_CONFIG_GLOBAL={str(USER_GITCONFIG_FILE.absolute())}")
    return envs


def _run_git_command(
    repo: Path,
    command: list[str],
    sudo_user: Optional[str] = None,
    raw: bool = False,
    timeout: Optional[float] = None,
) -> Optional[str]:
    if sudo_user:
        envs = _get_passthrough_envs()
        command = [
            "sudo",
            "-u",
            sudo_user,
            "env",
            *envs,
            "git",
        ] + command
    else:
        command = ["git"] + command
    try:
        if not raw:
            p = subprocess.run(
                command, cwd=repo, stdout=subprocess.PIPE, timeout=timeout, check=False
            )
        else:
            p = subprocess.run(command, cwd=repo, timeout=timeout, check=False)
    except subprocess.TimeoutExpired as exc:
        raise GitRuntimeError(
            f"Timeout running {command} on {repo}", -signal.SIGKILL
        ) from exc
    if p.returncode != 0 and p.returncode != -signal.SIGPIPE:
        raise GitRuntimeError(f"Error running {command} on {repo}", p.returncode)
    if not raw:
        stdout = p.stdout.decode("utf-8")
    else:
        stdout = None
    return stdout


def run_git_command(
    repo: Path,
    command: list[str],
    sudo_user: Optional[str] = None,
    timeout: Optional[float] = None,
) -> str:
    return _run_git_command(repo, command, sudo_user, timeout=timeout)  # type: ignore


def run_git_command_raw(
    repo: Path,
    command: list[str],
    sudo_user: Optional[str] = None,
    timeout: Optional[float] = None,
) -> None:
    _run_git_command(repo, command, sudo_user, raw=True, timeout=timeout)


def get_sudo_user(repo: InnerUserConfigType, current_uid: int) -> Optional[str]:
    username = repo.get("user")
    if username:
        uid = pwd.getpwnam(username).pw_uid
        if uid != current_uid:
            return username
        return None
    uid = Path(repo["path"]).stat().st_uid
    if uid != current_uid:
        return pwd.getpwuid(uid).pw_name
    return None


def compare_hashes(repo: Path, sudo_user: Optional[str]) -> tuple[str, str, str]:
    local_hash = run_git_command(repo, ["rev-parse", "HEAD"], sudo_user=sudo_user)
    remote_hash = run_git_command(repo, ["rev-parse", "@{u}"], sudo_user=sudo_user)
    base_hash = run_git_command(
        repo, ["merge-base", "HEAD", "@{u}"], sudo_user=sudo_user
    )
    return local_hash, remote_hash, base_hash


class LocalRemoteStatus(Enum):
    UP_TO_DATE = 1
    NEEDS_PULL = 2
    NEEDS_PUSH = 3
    DIVERGED = 4


class RemoteType(Enum):
    SSH = 1
    HTTP = 2
    NONE = 3
    UNKNOWN = 4


def fetch_and_compare(
    repo: Path, sudo_user: Optional[str], fetch: bool = True
) -> LocalRemoteStatus:
    if fetch:
        _ = run_git_command(repo, ["fetch"], sudo_user=sudo_user, timeout=30)
    local_hash, remote_hash, base_hash = compare_hashes(repo, sudo_user)
    if local_hash == remote_hash:
        return LocalRemoteStatus.UP_TO_DATE
    elif local_hash == base_hash:
        return LocalRemoteStatus.NEEDS_PULL
    elif remote_hash == base_hash:
        return LocalRemoteStatus.NEEDS_PUSH
    else:
        return LocalRemoteStatus.DIVERGED


def has_remote(repo: Path, sudo_user: Optional[str], push: bool = False) -> RemoteType:
    remote = run_git_command(repo, ["remote"], sudo_user=sudo_user)
    remote = remote.strip()
    if len(remote) == 0:
        return RemoteType.NONE
    if not push:
        url = run_git_command(repo, ["remote", "get-url", remote], sudo_user=sudo_user)
    else:
        url = run_git_command(
            repo, ["remote", "get-url", "--push", remote], sudo_user=sudo_user
        )
    # ignore http(s) remotes
    if url.startswith("http"):
        return RemoteType.HTTP
    if url.startswith("git@"):
        return RemoteType.SSH
    return RemoteType.UNKNOWN


def get_status(repo: Path, sudo_user: Optional[str]) -> tuple[int, int]:
    status_str = run_git_command(
        repo, ["status", "--porcelain=v1"], sudo_user=sudo_user
    )
    untracked = 0
    staged = 0
    for line in status_str.split("\n"):
        if not line:
            continue
        status_str = line[:2]
        if status_str == "??":
            untracked += 1
        else:
            staged += 1
    return (untracked, staged)


def get_user_args(user_name: str, user_email: str) -> list[str]:
    return ["-c", f"user.name={user_name}", "-c", f"user.email={user_email}"]


def find_git_toplevel() -> Optional[Path]:
    try:
        p = Path.cwd()
        root_dev = p.stat().st_dev

        while p.stat().st_dev == root_dev:
            if (p / ".git").exists():
                return p
            if p == p.parent:
                break
            p = p.parent
    except Exception as e:
        # don't crash if we can't find git toplevel, just notice the user
        print("Error finding git toplevel:", e)

    return None


def repo_name_check(repo_name: str, config: UserConfigType) -> None:
    if not config.get(repo_name):
        print(f"Repo {repo_name} not found in config")
        sys.exit(1)


def repo_names_handle(repo_names: list[str], config: UserConfigType) -> None:
    for repo_name in repo_names:
        repo_name_check(repo_name, config)
    if len(repo_names) == 0 and config.get("."):
        # Remove . if user wants to show/handle all repos
        del config["."]


def status(
    config: UserConfigType,
    repo_names: list[str],
    current_uid: int,
    no_fetch: bool,
    threads: int,
    # (dirty_template_path, clean_template_path)
    motd_paths: Optional[tuple[Path, Path]] = None,
) -> None:
    if not repo_names:
        repo_names = list(config.keys())

    def get_single_status(repo_name: str) -> tuple[Path, str, str, str]:
        repo = config[repo_name]
        repo_path = Path(repo["path"])
        status_str = "?"
        remote_str = "?"

        try:
            sudo_user = get_sudo_user(repo, current_uid)
            untracked, staged = get_status(repo_path, sudo_user)
            if untracked or staged:
                status_str = f"❗ {staged} staged, {untracked} untracked"
            else:
                status_str = "✅ clean"

            # get remote info
            remote = has_remote(repo_path, sudo_user)
            if remote == RemoteType.NONE:
                remote_str = "N/A"
            elif remote == RemoteType.UNKNOWN:
                remote_str = "Unknown"
            else:
                lrstatus = fetch_and_compare(repo_path, sudo_user, not no_fetch)
                msg = ""
                if lrstatus == LocalRemoteStatus.UP_TO_DATE:
                    msg = "✅ up to date"
                elif lrstatus == LocalRemoteStatus.NEEDS_PULL:
                    msg = "⬇️ needs pull"
                elif lrstatus == LocalRemoteStatus.NEEDS_PUSH:
                    msg = "⬆️ needs push"
                else:
                    msg = "🔀 diverged"
                remote_str = msg
        except GitRuntimeError as e:
            print(e)

        return (repo_path, repo_name, status_str, remote_str)

    with ThreadPoolExecutor(max_workers=threads) as executor:
        table = list(executor.map(get_single_status, repo_names))

    headers = ["Repo", "Name", "Status", "Remote"]
    if not motd_paths:
        print(make_table(table, headers=headers))
    else:
        with open("/run/git-keeper-motd", "w", encoding="utf-8") as f:
            # Filter clean and up-to-date repos
            for i in range(len(table) - 1, -1, -1):
                _, _, status_str, remote_str = table[i]
                if "✅" in status_str and ("✅" in remote_str or "N/A" == remote_str):
                    del table[i]
            if len(table) > 0:
                with open(motd_paths[0], "r", encoding="utf-8") as tmplf:
                    tmpl = Template(tmplf.read())
                f.write(tmpl.substitute(table=make_table(table, headers=headers)))
            else:
                with open(motd_paths[1], "r", encoding="utf-8") as tmplf:
                    tmpl = Template(tmplf.read())
                f.write(tmpl.substitute())


def commit(
    repo: InnerUserConfigType,
    current_uid: int,
    user_name: str,
    user_email: str,
    args: list[str],
) -> None:
    try:
        sudo_user = get_sudo_user(repo, current_uid)
        repo_path = Path(repo["path"])
        # commit
        run_git_command(repo_path, ["add", "."], sudo_user=sudo_user)
        run_git_command_raw(
            repo_path,
            get_user_args(user_name, user_email) + ["commit"] + args,
            sudo_user=sudo_user,
        )
    except GitRuntimeError as e:
        print(e)
        sys.exit(1)


def update(config: UserConfigType, current_uid: int, threads: int) -> None:
    def update_single(repo_name: str) -> None:
        try:
            repo = config[repo_name]
            sudo_user = get_sudo_user(repo, current_uid)
            repo_path = Path(repo["path"])
            push_remote = has_remote(repo_path, sudo_user, push=True)
            lrstatus = fetch_and_compare(repo_path, sudo_user)
            if lrstatus == LocalRemoteStatus.UP_TO_DATE:
                return
            elif lrstatus == LocalRemoteStatus.NEEDS_PULL:
                run_git_command(repo_path, ["pull"], sudo_user=sudo_user)
                # update submodules, if any
                run_git_command(
                    repo_path,
                    ["submodule", "update", "--init", "--recursive"],
                    sudo_user=sudo_user,
                )
                print(f"Pulled {repo_path}")
            elif lrstatus == LocalRemoteStatus.NEEDS_PUSH:
                if push_remote == RemoteType.HTTP:
                    print(f"Repo {repo_path} has HTTP remote, skipping push")
                    return
                run_git_command(repo_path, ["push"], sudo_user=sudo_user)
                print(f"Pushed {repo_path}")
            elif lrstatus == LocalRemoteStatus.DIVERGED:
                print(f"Repo {repo_path} has diverged (requires manual intervention)")
        except GitRuntimeError as e:
            print(e)

    with ThreadPoolExecutor(max_workers=threads) as executor:
        executor.map(update_single, config.keys())


def vcs(
    repo: InnerUserConfigType,
    current_uid: int,
    user_name: str,
    user_email: str,
    args: list[str],
) -> None:
    try:
        sudo_user = get_sudo_user(repo, current_uid)
        repo_path = Path(repo["path"])
        run_git_command_raw(
            repo_path,
            get_user_args(user_name, user_email) + args,
            sudo_user=sudo_user,
        )
    except GitRuntimeError as e:
        print(e)
        sys.exit(1)


def diff(repo: InnerUserConfigType, current_uid: int) -> None:
    try:
        sudo_user = get_sudo_user(repo, current_uid)
        repo_path = Path(repo["path"])
        # 1. Show git status output
        run_git_command_raw(repo_path, ["status"], sudo_user=sudo_user)
        input("Press enter to show `git diff HEAD`")
        # 2. Show git diff HEAD output
        run_git_command_raw(repo_path, ["diff", "HEAD"], sudo_user=sudo_user)
    except GitRuntimeError as e:
        print(e)
        sys.exit(1)
    except KeyboardInterrupt:
        pass


def handle_user_config() -> tuple[str, str]:
    def check_config(config_path: Path) -> Optional[configparser.ConfigParser]:
        parser = configparser.ConfigParser()
        if not parser.read(config_path):
            return None
        # The config should have ["user"]["email"] available
        if "user" in parser.sections() and parser["user"].get("email"):
            return parser
        return None

    # ~/.git-keeper.conf first, and then ~/.gitconfig
    parser = check_config(USER_CONFIG_FILE)
    if not parser:
        parser = check_config(USER_GITCONFIG_FILE)
    if parser:
        user_email = parser["user"]["email"]
    else:
        # ask for user email, save it to USER_CONFIG_FILE
        user_email = input("Enter your email: ")
        parser = configparser.ConfigParser()
        parser["user"] = {"email": user_email}
        with open(USER_CONFIG_FILE, "w", encoding="utf-8") as f:
            parser.write(f)
        print(
            f"You can change your email (and name, if necessary) in {USER_CONFIG_FILE}"
        )
    user_name = parser["user"].get("name", current_username)
    return user_email, user_name


def main(args: argparse.Namespace) -> None:
    current_uid = os.getuid()
    if args.command != "motd":
        user_email, user_name = handle_user_config()
    else:
        # Make type checker happy
        user_email, user_name = "", ""

    if not args.config.exists():
        print(f"Global config file {args.config} not found")
        sys.exit(1)

    parser = configparser.ConfigParser()
    parser.read(args.config)
    cwd_toplevel = find_git_toplevel()
    _config = {k: dict(parser.items(k)) for k in parser.sections() if k != ".config"}
    # check if config is valid
    config: UserConfigType = {}
    for repo_name in _config:
        if repo_name == ".":
            print("A repo name of . is unsupported.")
            continue
        config[repo_name] = {}  # type: ignore
        config[repo_name]["path"] = Path(_config[repo_name]["path"])
        if _config[repo_name].get("user"):
            config[repo_name]["user"] = _config[repo_name]["user"]
        # sometimes the path in config, or cwd toplevel get, could be a symlink...
        # we need to resolve to make sure this case is covered.
        if (
            cwd_toplevel
            and config[repo_name]["path"].resolve() == cwd_toplevel.resolve()
        ):
            config["."] = {
                "path": cwd_toplevel,
                "user": config[repo_name].get("user"),
            }

    threads = args.parallel

    repo_names: list[str]
    if args.command == "status":
        repo_names = args.repos
        repo_names_handle(repo_names, config)

        no_fetch = args.no_fetch
        status(
            config,
            repo_names,
            current_uid,
            no_fetch,
            threads,
        )
    elif args.command == "motd":
        tmpl_dirty_path = Path(
            parser[".config"].get(
                "motd-template-dirty", "/etc/git-keeper.d/motd-dirty.tmpl"
            )
        )
        tmpl_clean_path = Path(
            parser[".config"].get(
                "motd-template-clean", "/etc/git-keeper.d/motd-clean.tmpl"
            )
        )
        if not tmpl_dirty_path.exists() or not tmpl_clean_path.exists():
            print(
                f"MOTD template files ({tmpl_dirty_path} or {tmpl_clean_path}) not found, please check your config"
            )
            sys.exit(1)
        all_repos: list[str] = []
        repo_names_handle(all_repos, config)
        status(
            config,
            all_repos,
            current_uid,
            no_fetch=True,
            threads=threads,
            motd_paths=(tmpl_dirty_path, tmpl_clean_path),
        )
    elif args.command == "commit":
        repo_name = args.repo
        repo_name_check(repo_name, config)

        repo = config[repo_name]
        commit(repo, current_uid, user_name, user_email, args.args)
    elif args.command == "update":
        repo_names = args.repos
        repo_names_handle(repo_names, config)

        if repo_names:
            new_config: UserConfigType = {}
            for repo_name in repo_names:
                new_config[repo_name] = config[repo_name]
            config = new_config
        update(config, current_uid, threads)
    elif args.command == "vcs":
        repo_name = args.repo
        repo_name_check(repo_name, config)

        repo = config[repo_name]
        # different from other commands, vcs requires changing path when user is requesting "."
        # so that some git commands could work correctly like "git-keeper vcs . blame ./somefile"
        # if you're inside somedir/ and don't change cwd here, git would not work as expected.
        if repo_name == ".":
            repo["path"] = Path.cwd()
        vcs(repo, current_uid, user_name, user_email, args.args)
    elif args.command == "ls":
        print("\n".join(config.keys()))
    elif args.command == "diff":
        repo_name = args.repo
        repo_name_check(repo_name, config)

        repo = config[repo_name]
        diff(repo, current_uid)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        "git-keeper", description="Tracks specified git repos for sysadmins."
    )
    parser.add_argument(
        "--config",
        "-c",
        help="Path to config file",
        default="/etc/git-keeper.conf",
        type=Path,
    )
    parser.add_argument(
        "--parallel",
        "-p",
        help="Threads to use for status and update commands",
        default=8,
        type=int,
    )
    parser.add_argument(
        "--version", "-v", action="version", version=f"%(prog)s {__VERSION__}"
    )
    subparsers = parser.add_subparsers(dest="command")

    parser_status = subparsers.add_parser("status", help="Show status of repo(s)")
    parser_status.add_argument("repos", nargs="*", help="Repo(s) to show status of")
    parser_status.add_argument(
        "--no-fetch", action="store_true", help="Don't fetch remote"
    )

    parser_commit = subparsers.add_parser(
        "commit", help="Add all and commit changes in repo(s) on behalf of current user"
    )
    parser_commit.add_argument("repo", help="The repo to commit changes in")
    parser_commit.add_argument(
        "args", nargs=argparse.REMAINDER, help="Arguments appended to git commit"
    )

    parser_update = subparsers.add_parser(
        "update", help="Push/pull repo(s) with remote"
    )
    parser_update.add_argument("repos", nargs="*", help="Repo(s) to push/pull")

    parser_vcs = subparsers.add_parser("vcs", help="Run a git command on repo(s)")
    parser_vcs.add_argument("repo", help="The repo to run git command on")
    parser_vcs.add_argument("args", nargs=argparse.REMAINDER, help="Arguments to git")

    parser_ls = subparsers.add_parser("ls", help="Just list all repos")
    parser_diff = subparsers.add_parser("diff", help="Show changes in given repo")
    parser_diff.add_argument("repo", help="The repo to show diff of")

    parser_motd = subparsers.add_parser("motd", help="Generate MOTD output")

    parser_help = subparsers.add_parser("help", help="Show help")

    args = parser.parse_args()
    if sys.argv[0] == "gitkp":
        # A special alias for git-keeper, to save typing
        args.command = "vcs"
        args.repo = "."

    if args.command is None:
        args.command = "status"
        args.repos = []
        args.no_fetch = False
    elif args.command == "help":
        parser.print_help()
        sys.exit(0)

    main(args)
