#!/usr/bin/env python3
# SPDX-License-Identifier: BSD-3-Clause
#
# Config file generator for the DMXGUS lump / config file.
#
# This script generates a GUS config file with instrument mappings for
# 256,512,768 and 1024KB cards. This is done automatically from the config
# in config.py. The instruments to include are selected using statistics
# processed from a collection of well-known Doom WAD files.
#

import re
import sys

import config
import stats

HEADER_TEXT = """
# Freedoom GUS config.
# Autogenerated by the gen-ultramid script.
# Please do not manually edit this file!
# The ordering of entries in this file is significant, to work around a
# bug in Doom's DMX sound library.
"""


def normalize_stats(stats):
    """Normalize the gathered instrument statistics.

    Percussion instruments tend to be more widely used, which can give
    them a disproportionate priority. Therefore, generate a "normalized"
    set of statistics that adjust the percussion instruments to be
    roughly equal to the main instruments.
    """
    main_stats = stats[0:128]
    perc_stats = stats[128:]
    main_av = sum(main_stats) / len(main_stats)
    perc_av = sum(perc_stats) / (len(perc_stats) - 35)

    def adjusted_priority(stat):
        return (float(stat) * main_av) / perc_av

    adjusted_perc_stats = map(adjusted_priority, perc_stats)

    return main_stats + list(adjusted_perc_stats)


def ranked_patches():
    """Get a list of GUS patches ranked by priority.

    This uses the gathered statistics about use of different instruments
    in Doom WADs and orders them by benefit/cost ratio (where cost is
    the size that the instrument file will occupy in RAM).
    """
    adjusted_stats = normalize_stats(stats.INSTRUMENT_STATS)
    result = []
    for instr_id, name in config.GUS_INSTR_PATCHES.items():
        priority = (
            float(adjusted_stats[instr_id]) / config.PATCH_FILE_SIZES[name]
        )
        result.append((priority, name))

    return list(map(lambda x: x[1], reversed(sorted(result))))


def midi_instr_for_name(name):
    """Given a GUS patch name, find the associated MIDI instrument."""
    for instr_id, instr_name in config.GUS_INSTR_PATCHES.items():
        if name == instr_name:
            return instr_id
    raise KeyError("Unknown instrument: %s" % name)


def patchset_size(mapping):
    """Given an instrument-instrument mapping patch set, calculate its
       size.
    """
    result = 0
    for i1, i2 in mapping.items():
        if i1 == i2:
            name = config.GUS_INSTR_PATCHES[i1]
            result += config.PATCH_FILE_SIZES[name]
    return result


def patchset_to_string(mapping):
    result = []
    for i1, i2 in mapping.items():
        if i1 == i2:
            name = config.GUS_INSTR_PATCHES[i1]
            result.append(
                "%s (%i bytes)" % (name, config.PATCH_FILE_SIZES[name])
            )
    return "\n".join(result)


def mapping_for_size(size):
    """Select a set of patches for the given RAM size.

    Args:
      size: Size in bytes of the GUS RAM.
    Returns:
      Dictionary mapping from instrument number to instrument number.
      An instrument that maps to itself is included in the output.
    """
    # Leave some extra space. The ultramid.ini distributed with the
    # GUS drivers says this:
    #   The libraries are built in such a way as to leave 8K+32bytes
    #   after the patches are loaded for digital audio.
    size -= 32 * 1024 + 8

    # Get a list of patches sorted by decreasing priority.
    patches = ranked_patches()

    # Start by processing the similarity groups and pointing all
    # instruments in a group to their group leader.
    result = {}

    for group in config.SIMILAR_GROUPS:
        leader = group[0]
        leader_index = midi_instr_for_name(leader)
        for patch in group:
            patch_index = midi_instr_for_name(patch)
            result[patch_index] = leader_index

    # We now have a mapping that should cover every instrument with
    # a fallback. Go through the patches in order of priority and add
    # patches that will fit.
    curr_size = patchset_size(result)
    assert curr_size < size, (
        "Minimal config for %s will not fit in RAM! (%i):\n%s"
        % (size, curr_size, patchset_to_string(result))
    )

    for patch in patches:
        patch_index = midi_instr_for_name(patch)
        patch_size = config.PATCH_FILE_SIZES[patch]

        if (
            result[patch_index] != patch_index
            and curr_size + patch_size < size
        ):
            result[patch_index] = patch_index
            curr_size += patch_size

    return result


def instrument_patches(mappings):
    """Returns list of MIDI instruments in an appropriate order for output.

    The ordering in the output file is important, because of a bug in the
    DMX sound library; when patches are shared between instruments it is
    only possible to refer to instruments listed earlier in the file.

    Args:
      mappings: List of mappings from instrument ID to leader instrument.
    Yields:
      A tuple containing each MIDI instrument number and patch file name
      to load.
    """
    done_instr_ids = set()
    # Make multiple passes until we've done all the instruments.
    while len(done_instr_ids) < len(config.GUS_INSTR_PATCHES):
        made_progress = False
        for instr_id, name in sorted(config.GUS_INSTR_PATCHES.items()):
            for mapping in mappings:
                mapped_instr_id = mapping[instr_id]
                if (
                    instr_id != mapped_instr_id
                    and mapped_instr_id not in done_instr_ids
                ):
                    break
            else:
                if instr_id not in done_instr_ids:
                    yield instr_id, name
                    done_instr_ids.add(instr_id)
                    made_progress = True

        assert made_progress, "infinite loop while producing patches list"


if len(sys.argv) != 2:
    print("Usage: %s <filename>" % sys.argv[0])
    sys.exit(1)

mappings = (
    mapping_for_size(256 * 1024),
    mapping_for_size(512 * 1024),
    mapping_for_size(768 * 1024),
    mapping_for_size(1024 * 1024),
)

with open(sys.argv[1], "w") as output:
    output.write(HEADER_TEXT.lstrip())

    for instr_id, name in instrument_patches(mappings):
        line = "%i, %i, %i, %i, %i, %s" % (
            instr_id,
            mappings[0][instr_id],
            mappings[1][instr_id],
            mappings[2][instr_id],
            mappings[3][instr_id],
            name,
        )

        output.write(line + "\n")
