#!/usr/bin/env python

#
# Find the timed regions of the HPCC benchmarks and surround them with
# calls to the comm diagnostics routines.
#
# ASSUMPTIONS:
# - The name of the binary is <name>_commDiags, where name is the
#   corresponding HPCC benchmark
# - There is only 1 timed region in the program text
# - Region starts with the first call to getCurrentTime() in the program text
# - Region end with the last call to getCurrentTime() in the program text
# - Occurance of getCurrentTime() is the last line of the statement
# - SPECIAL CASE for stream-ep: To get consistent results, include the outer
#   coforall loop in the comm counts.
#

import sys, shutil

timerstr = 'getCurrentTime()'
testname = sys.argv[1][0:sys.argv[1].find('_commDiags')]

#
# Can't have - in module name
#
if testname == 'ra-atomics':
    module_testname = 'raAtomics'
elif testname == 'stream-ep':
    module_testname = 'streamEP'
else:
    module_testname = testname

#
# get the source
#
f=open('../../../release/examples/benchmarks/hpcc/'+testname+'.chpl')
lines = f.readlines()
f.close()

#
# Insert calls
#
f=open(module_testname+'.chpl', 'w')
f.write('//\n')
f.write('// Auto-generated: '+testname+' with communication diagnostics\n')
f.write('//\n')
f.write('\nuse CommDiagnostics;\n')

found = False
for l in lines:
    if testname == 'stream-ep':
        if l.strip().find('printResults') == 0:
            f.write('stopCommDiagnostics();\n')
            f.write('writeln(getCommDiagnostics());\n')
        if l.strip().find('coforall loc in Locales') == 0:
            f.write('resetCommDiagnostics();\n')
            f.write('startCommDiagnostics();\n')

    if l.strip().find(timerstr) >= 0:
        if not found:
            found = True
            f.write(l)
            if testname != 'stream-ep':
                f.write('resetCommDiagnostics();\n')
                f.write('startCommDiagnostics();\n')
        else:
            if testname != 'stream-ep':
                f.write('stopCommDiagnostics();\nwriteln(getCommDiagnostics());\n')
            f.write(l)
            found = False
    else:
        f.write(l)

f.close()
