#!/usr/bin/env python
#
# Copyright (C) 2012  Chad Hanna
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import sys
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot
import numpy
from glue.ligolw import ligolw
from glue.ligolw import array
from glue.ligolw import param
array.use_in(ligolw.LIGOLWContentHandler)
param.use_in(ligolw.LIGOLWContentHandler)
from glue.ligolw import utils
from pylal import series as lalseries
from gstlal import reference_psd

## @file
# A program to plot reference psds; see gstlal_plot_psd for more info

## @package gstlal_plot_psd
#
# A program to plot a psd such as one generated by gstlal_reference_psd
#
# ### Usage:
#
#		gstlal_plot_psd OUTPUT-NAME PSD-FILE-1 PSD-FILE-2
#
# e.g.,
#
#		gstlal_plot_psd test.png psd.xml.gz
#

if len(sys.argv) < 3:
	print "USAGE gstlal_plot_psd OUTPUT-NAME PSD-FILE-1 PSD-FILE-2 ..."
	sys.exit()

outname = sys.argv[1]

for fname in sys.argv[2:]:
	psds = lalseries.read_psd_xmldoc(utils.load_filename(fname, verbose = True, contenthandler = ligolw.LIGOLWContentHandler))
	for k,v in psds.items():
		# compute horizon up to 90% of nyquist to avoid roll off
		h = reference_psd.horizon_distance(v, 1.4, 1.4, 8, 10, len(v.data) * v.deltaF * 0.90)
		if len(sys.argv[2:]) > 1:
			flabel = fname
		else:
			flabel = ""
		pyplot.loglog(v.deltaF * numpy.arange(len(v.data)), v.data, label = '%s %s: %0.f Mpc' % (flabel, k, h), alpha=0.75)
		pyplot.xlabel('Frequency (Hz)')
		pyplot.ylabel('Strain^2/Hz')
pyplot.legend()
pyplot.grid()
pyplot.savefig(outname)
