#!/usr/bin/env python

from optparse import OptionParser
import matplotlib.pyplot as plt


p = OptionParser(usage='%prog [OPTION] FILE...',
                 description='plot timings from gpaw parallel timer.  '
                 'The timer dumps a lot of files called "timings.<...>.txt".  '
                 'This programme plots the contents of those files.  '
                 'Typically one would run "%prog timings.*.txt" to plot '
                 'timings on all cores.  (Note: The plotting code is '
                 'rather hacky and ugly at the moment.)')
p.add_option('--threshold', type=float, default=0.01, metavar='FRACTION',
             help='suppress entries of less than FRACTION of total CPU time.  '
             'Such entries are shown as black.  Default: %default')
p.add_option('--ignore', metavar='TIMERS', default='',
             help='comma-separated list of timer names to be ignored.')
p.add_option('--noxcc', action='store_true',
             help='add "XC correction" to --ignore.  XC Correction '
             'is called once for every atom in each SCF step which clutters '
             'the graph a lot.')
p.add_option('--interval', metavar='TIME1:TIME2',
             help='plot only timings within TIME1 and TIME2 '
             'after start of calculation.')
p.add_option('--unit', default='s',
             help='time unit.  s, m or h.  Default: %default.')
p.add_option('--align', default='SCF-cycle', metavar='TIMER',
             help='align timings of all processes to first call of '
             'TIMER[=%default].')
opts, fnames = p.parse_args()

timeunit = {'s': 1.0, 'm': 60.0, 'h': 3600.0}[opts.unit]
ignored_timers = set(opts.ignore.split(','))
if opts.noxcc:
    ignored_timers.add('XC Correction')
fnames.sort()


# Move all plots down by this amount so the plots align *around* the y ticks
# and not *from* the yticks.
plot_vertical_adjustment = 0.25


class Entry:
    def __init__(self, name, t1, parent=None, childnumber=None):
        self.name = name
        self.t1 = t1
        self.parent = parent
        self.children = []
        self.childnumber = childnumber
        if parent is None:
            self.level = -1
        else:
            self.level = parent.level + 1

    def __repr__(self):
        return 'Entry(name=%s, t1=%s, ...)' % (self.name, self.t1)

    def stop(self, t2):
        self.t2 = t2
        self.dt = self.t2 - self.t1

    def subentry(self, name, time):
        subentry = Entry(name, time, self, len(self.children))
        self.children.append(subentry)
        return subentry

    def iterate(self):
        for child1 in self.children:
            yield child1
            for child2 in child1.iterate():
                yield child2

    def normalize(self, start):
        # note: here we "magically" use the timeunit
        offset = start
        scale = 1.0 / timeunit
        self.transform(offset, scale)

    def transform(self, offset, scale):
        self.t1 = scale * (self.t1 - offset)
        self.t2 = scale * (self.t2 - offset)
        self.dt = self.t2 - self.t1
        for child in self.children:
            child.transform(offset, scale)

    def get_first_occurrence(self, name):
        for child in self.iterate():
            if child.name == name:
                return child
        else:
            raise ValueError('Entry not found: %s' % name)


class EntryCollection:
    def __init__(self, entries):
        #tstart = min([entry.t1 for entry in entries])
        #tstop = max([entry.t2 for entry in entries])

        try:
            scf_cycle_starttimes = [entry.get_first_occurrence(opts.align).t1
                                    for entry in entries]
        except ValueError:
            print('Warning: No "%s" entry found.  Times may not '
                  'be synchronized so well' % opts.align)
            scf_cycle_starttimes = [0] * len(entries)

        abs_starttimes = [entry.t1 for entry in entries]
        max_diff = max([scftime - starttime for scftime, starttime
                        in zip(scf_cycle_starttimes, abs_starttimes)])

        for entry in entries:
            assert entry.level == -1
            #start_time = entry.t1
            sync_time = entry.t1
            for child in entry.iterate():
                if child.name == 'SCF-cycle': # hardcoded name!
                    sync_time = child.t1
                    break
            entry.normalize(sync_time - max_diff)

            #entry.normalize(tstart, tstop)
        self.entries = entries
        self.totals = self.get_totals()

    def get_totals(self):
        totals = {}
        for entry in self.entries:
            for child in entry.iterate():
                if child.name not in totals:
                    totals[child.name] = 0.0
                totals[child.name] += child.dt / len(self.entries)
        return totals


def get_timings(fname):
    root = Entry('root', 0.0)
    head = root
    for line in open(fname):
        try:
            line = line.strip()
            part1, part2, action = line.rsplit(' ', 2)
            tokens = part1.split(' ', 3)
            t = float(tokens[2])
            name = tokens[3]

            if action == 'started':
                head = head.subentry(name, t)
            else:
                assert action == 'stopped', action
                assert name == head.name
                head.stop(t)
                head = head.parent
        except StandardError: # guard against interrupted file I/O
            pass
    while head != root: # If file is incomplete, cut remaining timers short
        head.stop(t)
        head = head.parent
    #assert head == root
    root.t1 = root.children[0].t1
    root.stop(root.children[-1].t2)
    return root


alltimings = []
metadata = None
for rank, fname in enumerate(fnames):
    if fname.endswith('metadata.txt'):
        assert metadata is None
        metadata = [line.strip() for line in open(fname)]
    else:
        alltimings.append(get_timings(fname))

if len(alltimings) == 0:
    p.error('no timings found')
if metadata is None:
    metadata = map(str, range(len(alltimings)))

entries = EntryCollection(alltimings)

ordered_names = []
fig = plt.figure(figsize=(12, 8))
fig2 = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111)
fig.subplots_adjust(left=0.08, right=.95, bottom=0.07, top=.99)
#nameax = fig.add_subplot(122)
nameax = fig2.add_subplot(111)
nameax.set_yticks([])
nameax.set_xticks([])
#nameax = [fig2.add_subplot(2, 2, i + 1) for i in range(4)]

styles = {}
bars = {}  # matplotlib objects for legend

thecolors = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow',
             'darkred', 'indigo', 'springgreen', 'purple']
thehatches = ['', '//', 'O', '*', 'o', r'\\', '.', '|']


def getstyle(i):
    return thecolors[i % len(thecolors)], thehatches[i // len(thecolors)]

if opts.interval:
    plotstarttime, plotendtime = map(float, opts.interval.split(':'))
else:
    plotstarttime = 0
    plotendtime = max([timing.t2 for timing in alltimings])

nstyles_used = 0
for rank, rootnode in enumerate(alltimings):
    for child in rootnode.iterate():
        if child.name in ignored_timers:
            continue
        t1 = child.t1
        t2 = child.t2
        if t2 < plotstarttime:
            continue
        if plotendtime < t1:
            continue
        t1 = max(t1, plotstarttime)
        t2 = min(t2, plotendtime)
        dt = t2 - t1
        if child.name not in styles:
            if entries.totals[child.name] < opts.threshold:
                #color, hatch = ('k', '')
                continue
            else:
                color, hatch = getstyle(nstyles_used)
                assert (color, hatch) not in styles.values()
                nstyles_used += 1
                ordered_names.append(child.name)
                assert child.name not in styles

            styles[child.name] = color, hatch

        color, hatch = styles[child.name]
        # hardcoded to max 5.  Graphs will overlap if larger.
        compression = 5.0
        y0 = rank + (child.level / compression) - plot_vertical_adjustment
        height = 1.0 / compression

        bar = ax.bar([t1], [height], bottom=[y0],
                     width=[t2 - t1],
                     color=color,
                     edgecolor='black',
                     hatch=hatch,
                     label='__nolegend__')
        bars[child.name] = bar

ax.set_ylabel('rank')
ax.set_yticks(range(len(metadata)))
ax.set_yticklabels([txt.replace('=', '') for txt in metadata])

ax.set_xlabel('time / %s' % opts.unit)
ax.axis(xmin=plotstarttime, xmax=plotendtime,
        ymin=-plot_vertical_adjustment - 0.1)

# Roughly 36 elements fit in the window.  For each 36, add another column
# Only two columns will fit.  People will have to maximize the thing if they
# have hundreds of different timers somehow.
ncolumns = 1 + len(ordered_names) // 36

for name in ordered_names:
    color, hatch = styles[name]
    nameax.bar([0], [0], [0], [0],
               color=color, hatch=hatch, label=name)

nameax.legend(ordered_names,
              handlelength=2.5,
              labelspacing=0.0,
              fontsize='large',
              ncol=ncolumns,
              mode='expand',
              frameon=True,
              loc='best')

plt.show()
