#!/usr/bin/env python
# Copyright (c) 2011-2012, Universite de Versailles St-Quentin-en-Yvelines
#
# This file is part of ASK.  ASK is free software: you can redistribute
# it and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation, version 2.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

import numpy as np
from common.util import fatal
from common.configuration import Configuration


def should_stop(ts, window, threshold):
    # the change rate should only be evaluated on the window
    ts = ts[-window:]
    
    best = np.min(ts[:,1]) 
    worst = np.max(ts[:,1])

    rate = (worst-best)/(worst)
    print("Rate of change in the last {0} iterations = {1}"
          .format(window, rate))

    if rate < threshold:
        fatal("Threshold reached: stopping", 254)


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(
        description="Stops the driver when the percentage change over an"
        " iteration window is under a given threshold")
    parser.add_argument('configuration')
    parser.add_argument('iteration', type=int)
    parser.add_argument('labelled_file')
    parser.add_argument('model_file')
    args = parser.parse_args()
    conf = Configuration(args.configuration)
    
    window = conf("modules.control.params.window", int, 5)
    threshold = conf("modules.control.params.threshold", float, 0.1)

    # Check if the user specifies a timeseries file to use
    if "timeseries" in conf("modules.control.params"):
        timeseries = conf("modules.control.params.timeseries")
    else:
        # If not fallback to the reporter one ?
        timeseries = conf("modules.reporter.params.timeseries")

    # Wait until we have at least a full window of samples
    if args.iteration >= max(window,2): 
        ts = np.genfromtxt(timeseries, names=True) 
        should_stop(ts, window, threshold)
