#!/opt/local/Library/Frameworks/Python.framework/Versions/2.7/Resources/Python.app/Contents/MacOS/Python
import optparse
import os
import platform
import sys
import codecs
import csv
from decimal import *
import numpy
import pprint
import numpy
import operator
from Crypto.Random import random
import sys
from numpy.random import RandomState
import numpy as np
from scipy.stats import norm
from scipy.stats import lognorm
import math
import logging

class EstablishmentOfCohort:
	"""Class to define cohort of firms for simulation"""
	
	# Set default values; should not be used, only set to prevent crash if not defined
	
	averageFirmSize = 18
	numberOfFirms = 3
	tranche = []
	lstDev = 1
	firms = []
	twister = False
	rState = False
	distNorm = False
	roundval = False
	
	# Class Startup, default values defined to prevent crash if undefined
	
	def __init__(self, rState, averageFirmSize=18, lstDev=1, numberOfFirms=3, tranche=[], twister=False, roundval=False, distNorm=False):
		self.averageFirmSize = averageFirmSize
		self.numberOfFirms = numberOfFirms
		self.lstDev = lstDev
		self.tranche = tranche
		self.twister = twister
		self.roundval = roundval
		self.rState = rState
		self.firms = []
		self.distNorm = False
		
		for x in range(self.numberOfFirms):
			firmDict = {}
			firmDict['tranche'] = self.Findtranche()
			firmDict['size'] = self.GetFirmSize()
			self.firms.append(firmDict)	
	
	def GetFirms(self):
		return self.firms
	
	def Findtranche(self):
		randomValue = self.FindURandom()
		for x in range(len(self.tranche)):
			if randomValue < self.tranche[x]:
				return x
		return x
	
	def GetFirmSize(self):
		if float(self.lstDev) > 0:
			if self.distNorm==True:
				firmsize = float(norm.ppf(float(self.FindURandom()),scale=float(self.lstDev),loc=float(self.averageFirmSize)))
			else:
				firmsize = float(lognorm.ppf(float(self.FindURandom()),float(self.lstDev)/float(self.averageFirmSize),scale=float(self.averageFirmSize)))
			if math.isinf(firmsize) == True:
				firmsize = 0
				logging.info("Infinity encountered")
		else:
			firmsize = self.averageFirmSize
		
		if self.roundval=='floor':
			firmsize = np.floor(firmsize)
		elif self.roundval=='ceil':
			firmsize = np.ceil(firmsize)
		elif self.roundval=='tenths':
			firmsize = np.round(firmsize,1)
		elif self.roundval==True:
			firmsize = np.round(firmsize)
		
		if firmsize > 0:
			return firmsize
		else:
			return 0

	def FindURandom(self):
		if self.twister == True:
			value = self.rState.getValueTwister()
		else:	
			value = self.rState.getValue()
		return value



class CalculateGamma:
	"""Class to calculate gamma value based on a simulated cohort"""

	# Set default values; should not be used, only set to prevent crash if not defined

	cohort = []
	tranche = []
	industrySize = 0
	herfindahlValue = float(0)
	xSquaredSum = float(0)
	siList = []
	gValue = float(0)
	gamma = float(0)
	
	
	# Class Startup, default values defined to prevent crash if undefined
	
	def __init__(self, rState, averageFirmSize=18, lstDev=1, numberOfFirms=3, tranche=[], twister=False, roundval=False, distNorm=False, cMS = False):
		self.industrySize = 0
		self.herfindahlValue = float(0)
		self.xSquaredSum = float(0)
		self.gValue = float(0)
		self.gamma = float(0)
		self.tranche = tranche
		self.siList = []
		self.cohort = []
		eCv = EstablishmentOfCohort(rState, averageFirmSize, lstDev, numberOfFirms, self.tranche, twister, roundval, distNorm)
		self.cohort = eCv.GetFirms()
		del eCv
		
		self.SquaredSumOvertranche()
		self.Herfindahl()
		
		if self.industrySize>0:
			self.DefineSiList()
		else:
			self.siList = [float(0)] * len(self.tranche)
		
		if cMS == True:
			self.CalculateGMS()
		else:
			self.CalculateG()
		self.Calculate()
		
	def GetHerfindahl(self):
		return float(self.herfindahlValue)
		
	def GetGValue(self):
		return float(self.gValue)
		
	def GetGamma(self):
		return float(self.gamma)

	def Calculate(self):
		numerator = self.gValue - float(1)*self.herfindahlValue + self.xSquaredSum*self.herfindahlValue
		oneminus = float(1) - self.herfindahlValue
		denominator = float(1)*oneminus - self.xSquaredSum*oneminus
		if float(denominator) != float(0):
			self.gamma = numerator/denominator
		else:
			self.gamma = float(0)

	def CalculateG(self):
		g = float(0)
		for x in range(len(self.tranche)):
			if x == 0:
				prob = float(self.tranche[0])
			else:
				prob = float(self.tranche[x])-float(self.tranche[x-1])
			gi = float(self.siList[x])-prob
			g = g + gi**2
		self.gValue = g	
	
	def CalculateGMS(self):
		self.gValue = float(-self.xSquaredSum + self.CalculateSiSQ())

	def CalculateSiSQ(self):		
		s = float(0)
		for x in range(len(self.siList)):
			si = float(self.siList[x])
			s = s + si**2
		return float(s)
	
	def DefineSiList(self):		
		for x in range(len(self.tranche)):
			sumOfi = 0
			for i in range(len(self.cohort)):
				cohortdict = self.cohort[i]
				if cohortdict['tranche'] == x:
					sumOfi = sumOfi + cohortdict['size']
			self.siList.append(float(sumOfi)/float(self.industrySize))
	
	def SquaredSumOvertranche(self):
		sumvalue = float(0)
		for x in range(len(self.tranche)):
			if x == 0:
				prob = float(self.tranche[0])
			else:
				prob = float(self.tranche[x])-float(self.tranche[x-1])
			sumvalue = sumvalue + prob**2
		self.xSquaredSum = sumvalue
	
	def Herfindahl(self):
		size = 0
		for x in range(len(self.cohort)):
			cohortdict = self.cohort[x]
			size = size + float(cohortdict['size'])
		if size == 0:
			return 0
		else:
			self.industrySize = size
		del size
		HValue = float(0)
		if self.industrySize>0:
			for x in range(len(self.cohort)):
				cohortdict = self.cohort[x]
				proportion = float(cohortdict['size'])/float(self.industrySize)
				HValue = HValue + proportion**2
		self.herfindahlValue = HValue


class gammaSimulation:
	"""Class to handle/store simulation values"""
	
	# Set default values; should not be used, only set to prevent crash if not defined
	
	tranche = []
	numberOfFirms = 3
	lstDev = 1
	averageFirmSize = 18
	sGamma = {}
	sHerfindahl = {}
	sGValue = {} 
	critcalValues = []
	twister = False
	roundval = False
	rState = False
	distNorm = False
	pValues = []
	cMS = False
	herfCen = False
	herfLow=-1
	herfHigh=-1
	
	
	# Class Startup, default values defined to prevent crash if undefined
	
	def __init__(self, rState, averageFirmSize=18, lstDev=1, numberOfFirms=3, tranche=[], critcalValues=[], tLoops=1, twister=False, roundval=False, distNorm=False, pValues=[], cMS=False, herfCen = False, herfLow=-1, herfHigh=-1):
		self.tranche = tranche
		self.numberOfFirms = numberOfFirms
		self.averageFirmSize = averageFirmSize
		self.lstDev = lstDev
		self.critcalValues = critcalValues
		self.twister = twister
		self.roundval = roundval
		self.rState = rState
		self.distNorm = distNorm
		self.pValues = pValues
		self.cMS = cMS
		self.herfCen = herfCen
		self.herfLow = herfLow
		self.herfHigh = herfHigh
		self.Run(tLoops)

	def getGamma(self):
		return self.sGamma
	
	def getHerfindahl(self):
		return self.sHerfindahl
	
	def getGValue(self):
		return self.sGValue
	
	def Run(self,tLoops):
		gammaList = []
		herfindahlList = []
		gValueList = []
		combindedList = []
		whileLoopCount = 0
		totalLoopCount = 0
		while (whileLoopCount<tLoops):
			eCg = CalculateGamma(self.rState,self.averageFirmSize,self.lstDev,self.numberOfFirms,self.tranche,self.twister,self.roundval,self.distNorm, self.cMS)
			gamma = float(eCg.GetGamma())
			herfindahl = float(eCg.GetHerfindahl())
			gValue = float(eCg.GetGValue())
			del eCg
			if self.herfCen==False or (herfindahl>=self.herfLow and self.herfHigh>=herfindahl):
				gammaList.append(gamma)
				herfindahlList.append(herfindahl)
				gValueList.append(gValue)
				combined = {}
				combined['gamma'] = gamma
				combined['herfindahl'] = herfindahl
				combindedList.append(combined)
				whileLoopCount = whileLoopCount + 1
			totalLoopCount = totalLoopCount + 1
		gammaList.sort()
		herfindahlList.sort()
		combindedList.sort(key=operator.itemgetter('gamma'))
		
		self.sGamma = {}
		self.sGamma['mean'] = numpy.mean(gammaList)
		self.sGamma['std'] = numpy.std(gammaList)
		self.sGamma['min'] = numpy.nanmin(gammaList)
		self.sGamma['max'] = numpy.nanmax(gammaList)
		sumProb = float(0)
		finalPValue = float(0)
		gcritcalValues = []
		gpvalueList = []
		hvlist = []
		hcritcalValues = []
		hatcrit = []
		hatcritl = []
		hatcrits = []
		hatcritls = []
		herflist = []
		
		oneUnit = float(1)/float(len(gammaList))
		
		for i in range(len(self.pValues)):
			gpvalueList.append('')
		
		for i in range(len(self.pValues)):
			hvlist.append('')

		for i in range(len(self.critcalValues)):
			hcritcalValues.append('')
		
		for i in range(len(self.critcalValues)):
			gcritcalValues.append('')
			hatcrit.append('')
			hatcritl.append('')
			hatcrits.append('')
			hatcritls.append('')
		
		
		for x in range(len(gammaList)):
						
			sumProb = sumProb + oneUnit	
			
			
			if gammaList[x] < 0:
				finalPValue = sumProb
				
			for i in range(len(self.pValues)):
				if gammaList[x] < self.pValues[i]:
					gpvalueList[i] = sumProb
				if herfindahlList[x] < self.pValues[i]:
					hvlist[i] = sumProb
				
			for i in range(len(self.critcalValues)):
				if sumProb < self.critcalValues[i]:
					gcritcalValues[i] = gammaList[x]
				elif sumProb >= self.critcalValues[i] and hatcrit[i] == '' and hatcritl[i]== '' and len(herflist)>0:
					hatcrit[i] = numpy.percentile(numpy.array(herflist),97.5) 
					hatcritl[i] = numpy.percentile(numpy.array(herflist),2.5)
					hatcrits[i] = numpy.percentile(numpy.array(herflist),95) 
					hatcritls[i] = numpy.percentile(numpy.array(herflist),5)

				if sumProb < self.critcalValues[i]:
					hcritcalValues[i] = herfindahlList[x]
					
			herflist.append(combindedList[x]['herfindahl'])
					
		self.sGamma['pvalue'] = finalPValue
		self.sGamma['criticalValues'] = gcritcalValues
		self.sGamma['hCritical'] = hatcrit
		self.sGamma['hCriticalLow'] = hatcritl
		self.sGamma['hCriticals'] = hatcrits
		self.sGamma['hCriticalLows'] = hatcritls
		self.sGamma['pValues'] = gpvalueList
		self.sHerfindahl = {}
		self.sHerfindahl['criticalValues'] = hcritcalValues
		self.sHerfindahl['pValues'] = hvlist
		self.sHerfindahl['mean'] = numpy.mean(herfindahlList)
		self.sHerfindahl['std'] = numpy.std(herfindahlList)
		self.sHerfindahl['min'] = numpy.nanmin(herfindahlList)
		self.sHerfindahl['max'] = numpy.nanmax(herfindahlList)
		self.sGValue = {}
		self.sGValue['mean'] = numpy.mean(gValueList)
		self.sGValue['std'] = numpy.std(gValueList)
		self.sGValue['min'] = numpy.nanmin(gValueList)
		self.sGValue['max'] = numpy.nanmax(gValueList)
		self.sGValue['loops'] = whileLoopCount
		self.sGValue['totalloops'] = totalLoopCount
		
		
		
class RandomIntVal:
	
	seed = 1012810
	nState = RandomState(seed)
	cState = random.StrongRandom()
	
	def __init__(self, seed=1012810):
		self.nState = RandomState(seed)
		self.cState = random.StrongRandom()

		# Sampler warmup
		print "Starting Sampler Warm-up"
		junk = self.nState.random_sample(10000)
		print "Warm-up Complete"
	
	def getValue(self):
		maxsize = sys.maxint-1
		rn = float(self.cState.randint(0,maxsize))/maxsize
		return rn
	def getValueTwister(self):
		return self.nState.random_sample()






def isNumeric(value):
	return str(value).replace('.','').strip().isdigit()

def fileExists(value):
    if os.path.isfile(os.path.expanduser(value.strip())):
        return os.path.abspath(os.path.expanduser(value.strip()))
    else:
        print "I can't find the file " + value
        sys.exit()

def isReturnFile(myfile):
	if os.path.abspath(os.path.expanduser(myfile.strip())) != False:
		return os.path.abspath(os.path.expanduser(myfile.strip()))
	else:
		print 'You can\'t save to that location'
		sys.exit()

def WriteFile(filename,criticalvalues, pvaluesList,data,herfCen=False):
	fieldList = ['NumberOfFirms','FirmSize','StDev','GammaMean','GammaMin','GammaMax','GammaStd','HerfindahlMean','HerfindahlMin','HerfindahlMax','HerfindahlStd','GValueMean','GValueMin','GValueMax','GValueStd','PValue']
	
	for x in range(len(criticalvalues)):
		fieldList.append("C" + str(criticalvalues[x]).replace('.','').strip())
	
	for x in range(len(criticalvalues)):
		fieldList.append("GHC" + str(criticalvalues[x]).replace('.','').strip())
	
	for x in range(len(criticalvalues)):
		fieldList.append("GHCL" + str(criticalvalues[x]).replace('.','').strip())
	
	for x in range(len(criticalvalues)):
		fieldList.append("GHC95" + str(criticalvalues[x]).replace('.','').strip())
	
	for x in range(len(criticalvalues)):
		fieldList.append("GHCL5" + str(criticalvalues[x]).replace('.','').strip())
		
	for x in range(len(pvaluesList)):
		fieldList.append("P" + str(pvaluesList[x]).replace('.','').strip())
		
	for x in range(len(criticalvalues)):
		fieldList.append("HC" + str(criticalvalues[x]).replace('.','').strip())

	for x in range(len(pvaluesList)):
		fieldList.append("HP" + str(pvaluesList[x]).replace('.','').strip())
	
	if herfCen == True:
		fieldList.append("SavedIterations")
		fieldList.append("TotalIterations")
	
	if os.path.isfile(filename) == False:
		mf = open(filename, 'wb')
		myfile = csv.writer(mf)
		myfile.writerow(fieldList)
		mf.close()
	
	myfile = open(filename,'ab+')
	WriteFile = csv.DictWriter(myfile,fieldList)
	WriteFile.writerow(data)
	myfile.close()
	print "Saving # of Firms: " + str(data['NumberOfFirms']) + ", Firm Size: " + str(data['FirmSize']) + ", StDev: " + str(data['StDev'])


def RunSimulation(rState, numberoffirmsList,firmsizeList,sdevList,trancheList,criticalvaluesList,loopsc,destination, twister, roundval, normaldist, pvaluesList, cMS, herfCen = False, herfLow=-1, herfHigh=-1):
	for x in range(len(numberoffirmsList)):
		for y in range(len(firmsizeList)):
			for z in range(len(sdevList)):
				resultDic = {}
				if float(sdevList[z]) == float(0) or float(firmsizeList[y]) == float(0):
					stsend = 0
				else:
					stsend = float(firmsizeList[y]*float(sdevList[z]))
				cGS = gammaSimulation(rState, firmsizeList[y], stsend, int(numberoffirmsList[x]), trancheList, criticalvaluesList, loopsc, twister, roundval, normaldist, pvaluesList, cMS, herfCen, herfLow, herfHigh)
				gamma = cGS.getGamma()
				herfindahl = cGS.getHerfindahl()
				gValue = cGS.getGValue()
				del cGS
				resultDic['NumberOfFirms'] = numberoffirmsList[x]
				resultDic['FirmSize'] = firmsizeList[y]
				resultDic['StDev'] = sdevList[z]
				
				resultDic['GammaMean'] = gamma['mean']
				resultDic['GammaMin'] = gamma['min']
				resultDic['GammaMax'] = gamma['max']
				resultDic['GammaStd'] = gamma['std']
				
				
				resultDic['HerfindahlMean'] = herfindahl['mean']
				resultDic['HerfindahlMin'] = herfindahl['min']
				resultDic['HerfindahlMax'] = herfindahl['max']
				resultDic['HerfindahlStd'] = herfindahl['std']
				
				resultDic['GValueMean'] = gValue['mean']
				resultDic['GValueMin'] = gValue['min']
				resultDic['GValueMax'] = gValue['max']
				resultDic['GValueStd'] = gValue['std']
				
				
				resultDic['PValue'] = gamma['pvalue']
				
				
				
				lcv = gamma['criticalValues']
				
				for cv in range(len(criticalvaluesList)):
					key = "C" + str(criticalvaluesList[cv]).replace('.','').strip()
					resultDic[key] = lcv[cv]
				
				lcv = gamma['hCritical']
				
					
				for cv in range(len(criticalvaluesList)):
					key = "GHC" + str(criticalvaluesList[cv]).replace('.','').strip()
					resultDic[key] = lcv[cv]
					
				lcv = gamma['hCriticalLow']
					
				for cv in range(len(criticalvaluesList)):
					key = "GHCL" + str(criticalvaluesList[cv]).replace('.','').strip()
					resultDic[key] = lcv[cv]
				
				lcv = gamma['hCriticals']
				
					
				for cv in range(len(criticalvaluesList)):
					key = "GHC95" + str(criticalvaluesList[cv]).replace('.','').strip()
					resultDic[key] = lcv[cv]
					
				lcv = gamma['hCriticalLows']
					
				for cv in range(len(criticalvaluesList)):
					key = "GHCL5" + str(criticalvaluesList[cv]).replace('.','').strip()
					resultDic[key] = lcv[cv]
				

				lcv = herfindahl['criticalValues']
				
				for cv in range(len(criticalvaluesList)):
					key = "HC" + str(criticalvaluesList[cv]).replace('.','').strip()
					resultDic[key] = lcv[cv]
				
				lcv = gamma['pValues']

				for cv in range(len(pvaluesList)):
					key = "P" + str(pvaluesList[cv]).replace('.','').strip()
					resultDic[key] = lcv[cv]
				
				lcv = herfindahl['pValues']
				for cv in range(len(pvaluesList)):
					key = "HP" + str(pvaluesList[cv]).replace('.','').strip()
					resultDic[key] = lcv[cv]	
				
				if herfCen == True:
					resultDic['SavedIterations'] = gValue['loops']
					resultDic['TotalIterations'] = gValue['totalloops']
				
				WriteFile(destination,criticalvaluesList, pvaluesList,resultDic,herfCen)
				
 
def loadFile(value):
	empty_data = []
	with open(value.strip(), 'rU') as f:
		read_data = f.readlines()
	
	for x in range(len(read_data)):
		if isNumeric(read_data[x].strip()):
			empty_data.append(float(read_data[x].strip()))
			
	return empty_data

def main():
	desc = 'Tool to simulate EG statistic and Herfindahl values'
	p = optparse.OptionParser(description=desc)
	p.add_option('--tranchefile', '-t', dest="tranche", help="File containing geographic tranches", default='', metavar='"<File Path>"')
	p.add_option('--criticalvalues', '-c', dest="criticalvalues", help="File containing critical values to test", default='', metavar='"<File Path>"')
	p.add_option('--pvalues',dest="pvalues", help="File containing p values to test", default='', metavar='"<File Path>"')
	p.add_option('--firmsize', '-f', dest="firmsize", help="File containing firm size (head count)", default='', metavar='"<File Path>"')
	p.add_option('--sdev', '-s', dest="sdev", help="File containing the standard deviations to test", default='', metavar='"<File Path>"')
	p.add_option('--numberoffirms', '-n', dest="numberoffirms", help="File containing the number of firms (in an industry) to test", default='', metavar='"<File Path>"')
	p.add_option('--iterations', '-i', type="int", dest="iterations", help="Number of iterations to run for each simulation", default=1000)
	p.add_option('--destination', '-d', dest="destination", help="Main csv file to save simulation(s) output", default='', metavar='"<File Path>"')
	p.add_option("--twister", action="store_true", dest="twister", default=False, help="Use mersenne twister for random number generation instead of fortuna")
	p.add_option("--roundfirmsize", action="store_true", dest="roundval", default=False, help="Round firm size to closest integer")
	p.add_option("--roundfirmsizedown", action="store_true", dest="roundvaldown", default=False, help="Round firm size down to closest integer")
	p.add_option("--roundfirmsizeup", action="store_true", dest="roundvalup", default=False, help="Round firm size up to closest integer")
	p.add_option("--roundfirmsizetenths", action="store_true", dest="roundvaltenths", default=False, help="Round firm size to nearest tenth")
	p.add_option("--seed", type="int", dest="seed", default=1012810, help="Seed the random generator with a specified value")
	p.add_option("--normal", action="store_true", dest="normaldist", default=False, help="Normal distributed firm sizes instead of log normal")
	p.add_option("--maurel", action="store_true", dest="cMS", default=False, help="Use Maurel and Sedillot (1999)'s value of G instead of EG")
	p.add_option("--HerfCensuredLow", type="float", dest="HerfCensuredLow", default=-1.0, help="Toss any simulated result where the Herfindahl is below this value")
	p.add_option("--HerfCensuredHigh", type="float", dest="HerfCensuredHigh", default=-1.0, help="Toss any simulated result where the Herfindahl is above this value")
	
	(options, arguments) = p.parse_args();
	herfRange = False
	if options.cMS == True:
		cMS = True
	else:
		cMS = False
	
	
	if options.HerfCensuredLow>=0 and options.HerfCensuredHigh>options.HerfCensuredLow:
		herfRange = True
	if (options.HerfCensuredLow>-1 or options.HerfCensuredLow>-1) and herfRange==False:
		print 'Sorry, you must specify a censored range such that the low Herfindahl value is less than the high Herfindahl value.  We highly suggest you run a uncensored simulation to find out where the Herfindahl values are likely to be.'
		sys.exit()
		
	if len(options.destination)>0:
		destination = isReturnFile(options.destination.strip())
	else:
		print 'You must specify a destination file'
		sys.exit()
	
	if int(options.iterations)<=0 or int(options.seed)<0:
		print 'You must specify a positive value for both iterations and seeding the random number generator'
		sys.exit()
	
	pvalues = []
	if len(options.pvalues)>0:
		pvaluesfile = fileExists(options.pvalues)
		pvalues = loadFile(pvaluesfile)
		
	
	if len(options.tranche)>0 and len(options.criticalvalues)>0 and len(options.firmsize)>0 and len(options.numberoffirms)>0 and len(options.sdev)>0:
		tranchefile = fileExists(options.tranche)
		criticalvaluesfile = fileExists(options.criticalvalues)
		firmsizefile = fileExists(options.firmsize)
		numberoffirmsfile = fileExists(options.numberoffirms)
		sdevfile = fileExists(options.sdev)
		
		trancheList = loadFile(tranchefile)
		criticalvaluesList = loadFile(criticalvaluesfile)
		firmsizeList = loadFile(firmsizefile)
		numberoffirmsList = loadFile(numberoffirmsfile)
		sdevList = loadFile(sdevfile)
		roundval = False
		
		if options.roundvaldown == True:
			roundval = 'floor'
		elif options.roundvalup == True:
			roundval = 'ceil'
		elif options.roundvaltenths == True:
			roundval = 'tenths'
		elif options.roundval == True:
			roundval = True
		
		rState = RandomIntVal(int(options.seed))
		RunSimulation(rState,numberoffirmsList,firmsizeList,sdevList,trancheList,criticalvaluesList,int(options.iterations),destination,options.twister,roundval,options.normaldist,pvalues, cMS, herfRange, options.HerfCensuredLow, options.HerfCensuredHigh)
		del rState
	else:
		print 'You must specify files for tranche, critical values, firm size, number of firms, standard deviation'
		sys.exit()	

if __name__ == '__main__':
    main()