from const import *
import numpy as np, matplotlib, matplotlib.pyplot as plt
import sys
import emcee
import nonGRCurves as crv
import pickle

##################################################################
numDays    = 2
noiseSigma = 1
nwalkers   = 50
burnLength = 50
mcmcLength = 1000

obsTime     = numDays * Td
inputTimes  = np.arange(0, obsTime, sampR)
data_length = len(inputTimes)
romer = np.squeeze(crv.deltaTheta(1, obsTime).romerDelay(inputTimes))

ndim = 5

if len(sys.argv) == ndim + 1:
    true_values = [float(i) for i in sys.argv[1:]]
else:
    h0   = 0.7
    phi0 = 1.5
    cosI = 0.3
    psi  = 2.0
    beta = 1.5
    true_values = [h0, phi0, cosI, psi, beta]

##################################################################
minH0  = 0
maxH0  = 1

minPhi = 0
maxPhi = np.pi

minCosI= -1
maxCosI= 1

minPsi = 0
maxPsi = np.pi

minBeta= 0.5
maxBeta= 2.0
##################################################################

def base(length):
	obj = crv.Curve(1, 0, length)
	gra = obj.makeAPs()
	return gra

def changeBase(APs, inputTime, h0, phi0, cosI, psi2, beta):
    (plus, crss) = APs
    deltaPsi = psi2 - psi
    (plus, crss) = (plus * np.cos(2 * deltaPsi) + crss * np.sin(2 * deltaPsi),\
                    crss * np.cos(2 * deltaPsi) - plus * np.sin(2 * deltaPsi))
    data  = (1 + cosI ** 2) / 4 * plus - 1j/2 * cosI * crss
    data *= float(h0)
    data *= np.exp(1j * phi0)

    dlTh  = crv.deltaTheta(beta, inputTime)
    data *= dlTh.expAngle(romer)

    return data

count = 0
def ln_probability(parameters, data):
    """ Returns the natural logarithm of the probability of a 
    given point in parameter space assuming the observed data 
    contains a sinusoidal signal plus Gaussian noise of zero mean
    and unit standard deviation, with flat priors for all 
    parameters.

    Arguments
    ---------
    parameters: list
      sampled values of amplitude and phase parameters in the form:
      [float(amplitude), float(phase)];
    data: np.array
      array of obseved data.

    Returns
    -------
    lnprob: float
      logarithm of posterior probability: 
      p(parameters|data) = p(data|parameters)p(parameters).
    """
    lowers = (minH0, minPhi, minCosI, minPsi, minBeta)
    uppers = (maxH0, maxPhi, maxCosI, maxPsi, maxBeta)

    global count
    count +=1 
    if count % 1000 == 0:
       sys.stdout.write("Check: %d\n" % (count)) 
    if all(u <= j <= v for u, j, v in zip(lowers, parameters, uppers)):
        template = changeBase(base, inputTimes, *parameters)
        lnprob = - 0.5 * np.sum(np.abs(data - template)**2)
        # here should rescale by the prior, but not necessary
        # because they are flat
        return lnprob / (2 * noiseSigma ** 2)
    else:
        return -np.inf

##################################################################
# FAKE REAL-VALUED DATA
##################################################################
noise = np.random.normal(0., noiseSigma, data_length) \
 + 1j * np.random.normal(0., noiseSigma, data_length)

base  = base(obsTime)

injection = changeBase(base, inputTimes, *true_values)
data = noise + injection

##################################################################
# Just check that this is the same
##################################################################
plt.plot(data.real, 'r')
plt.plot(data.imag, 'b')
plt.plot(injection.real, 'o', color='r')
plt.plot(injection.imag, 'o', color='b')
plt.title("Simulated data and template")
plt.show()


##################################################################
# RUN MCMC
##################################################################

p0 = [[np.random.uniform(minH0, maxH0), np.random.uniform(minPhi, maxPhi),\
       np.random.uniform(minCosI, maxCosI), np.random.uniform(minPsi, maxPsi),\
       np.random.uniform(minBeta, maxBeta)]
      for i in range(nwalkers)]

sampler = emcee.EnsembleSampler(nwalkers, ndim, ln_probability,\
		  args=[data], threads=3)
# start the burnin
pos, prob, state = sampler.run_mcmc(p0, burnLength)
# start the sampler
sampler.reset()
print "Burnt in"
# resetting and starting the sampler again like this treats the
# first sampling as burnin
sampler.run_mcmc(p0, mcmcLength)

pickle.dump(sampler.flatchain, open("posterior_samples.p", "wb" ))


##################################################################
# PLOT POSTERIOR PDFs
##################################################################
param_list = ['$h_0$', '$\phi_0$', '$\cos\, \iota$', '$\psi$', '$\\beta$']

for i in range(ndim):
    plt.figure()
    plt.hist(sampler.flatchain[:,i][::sampler.acor[i]], nwalkers, color="k", histtype="step")
    plt.axvline(true_values[i])
    plt.title("%s"%(param_list[i]))
    plt.savefig("../Images/%s.png"%param_list[i])
    plt.show()
