#!/usr/bin/python
# Calculates a series of configurational entropy approximations based on CVM entropy formalism.
# To run, requires a directory "./correlations_data" with files of the form "ZABGD_T_0800".
# Entropy is in units of kB/atom.

import sys, argparse
import numpy as np
from scipy.special import xlogy

parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(description='Calculates a series of configurational entropy approximations based on \
        CVM entropy formalism. To run, requires a directory "./correlations_data" with files of the form "ZABGD_T_0800".\
        Output entropies are in units of kB/site.')
# Arguments (required)
parser.add_argument('Nspecies', type=int, help='number of atomic species')
parser.add_argument('L', type=int, help='specifies length input for N=2(L^3) atom configuration')
parser.add_argument('n', type=int, help='number of outputs to use (starting from most recently recorded)')
#parser.add_argument('-a', action='store_true', help="toggle to average occupation statistics between A/B sites and between C/D sites (i.e. reduce the 4 sublattices to just even and odd sublattices)")
parser.add_argument('-T',metavar="T_list",type=str,default='./Tlist',help='path to file with list of temperatures; default "%(default)s"')
parser.add_argument('--corrdir', metavar='CORRDIR', help='directory for tetrahedra correlation data files; default: `%(default)s`',\
        default='./correlations_data')
args = parser.parse_args()

L = args.L
Nsites=2*(L**3)
nspecies=args.Nspecies
n=args.n

with open(args.T,'r') as TlistFile:
    Tlist = np.loadtxt(TlistFile, dtype=int)
    TlistFile.close()

S_TET = []; S_TRI = []; S_NNN = []; S_NN = []; S_POINT = []

# calculate sequence of CVM entropies at each temperature
for T in Tlist:
    fmtT="%04d" % T
    # load correlation data generated by run at temperature T,
    # taking only the last n entries
    with open(args.corrdir+'/ZABGD_T'+fmtT,'r') as ZABGDcorr:
        Z0=np.loadtxt(ZABGDcorr)
        l=len(Z0)
        if (n < l): ZABGD0=Z0[l-n:,]
        else: ZABGD0=Z0
        #print(Z0)
        Nrun=len(ZABGD0)

    # TET correlations
    ZABGD=np.sum(ZABGD0,axis=0)/Nrun
    ZABGD=ZABGD/(6*Nsites)
    ZABGD=ZABGD.reshape((nspecies,nspecies,nspecies,nspecies))

    # TRI correlations
    UAGD=np.sum(ZABGD,1)
    UBGD=np.sum(ZABGD,0)
    UGAB=np.swapaxes(np.swapaxes(np.sum(ZABGD,3),0,2),1,2)
    UDAB=np.swapaxes(np.swapaxes(np.sum(ZABGD,2),0,2),1,2)

    # NN pair correlations
    YAG=np.swapaxes(np.sum(UGAB,2),0,1)
    YAD=np.swapaxes(np.sum(UDAB,2),0,1)
    YBG=np.swapaxes(np.sum(UGAB,1),0,1)
    YBD=np.swapaxes(np.sum(UDAB,1),0,1)
    
    # NNN pair correlations
    VAB=np.sum(UDAB,0)
    VGD=np.sum(UAGD,0)

    # point correlations
    XA=np.sum(YAG,1); XB=np.sum(YBG,1)
    XG=np.sum(YAG,0); XD=np.sum(YAD,0)

    S_X = np.sum(xlogy(XA,XA)) + np.sum(xlogy(XB,XB)) + np.sum(xlogy(XG,XG)) + np.sum(xlogy(XD,XD))
    S_Y = np.sum(xlogy(YAG,YAG)) + np.sum(xlogy(YAD,YAD)) + np.sum(xlogy(YBG,YBG)) + np.sum(xlogy(YBD,YBD))
    S_V = np.sum(xlogy(VAB,VAB)) + np.sum(xlogy(VGD,VGD))
    S_U = np.sum(xlogy(UAGD,UAGD)) + np.sum(xlogy(UBGD,UBGD)) + np.sum(xlogy(UGAB,UGAB)) + np.sum(xlogy(UDAB,UDAB))
    S_Z = np.sum(xlogy(ZABGD,ZABGD))

    s_TET = (1./4.)*S_X + (-1.)*S_Y + (-3./2)*S_V + 3.*S_U + (-6.)*S_Z
    s_TRI = (-23./4)*S_X + (5.)*S_Y + (9./2)*S_V + (-3.)*S_U
    s_NNN = (13./4)*S_X + (-1.)*S_Y + (-3./2)*S_V
    s_NN = (7./4)*S_X + (-1.)*S_Y
    s_POINT = (-1./4.)*S_X

    S_TET.append([T, s_TET])
    S_TRI.append([T, s_TRI])
    S_NNN.append([T, s_NNN])
    S_NN.append([T, s_NN])
    S_POINT.append([T, s_POINT])

S_TET = np.array(S_TET)
S_TRI = np.array(S_TRI)
S_NNN = np.array(S_NNN)
S_NN  = np.array(S_NN)
S_POINT = np.array(S_POINT)

with open('S-TET_vs_T.dat','w') as out_TET:
    np.savetxt(out_TET, S_TET, fmt='%4d %6.12f')
with open('S-TRI_vs_T.dat','w') as out_TRI:
    np.savetxt(out_TRI, S_TRI, fmt='%4d %6.12f')
with open('S-NNN_vs_T.dat','w') as out_NNN:
    np.savetxt(out_NNN, S_NNN, fmt='%4d %6.12f')
with open('S-NN_vs_T.dat','w') as out_NN:
    np.savetxt(out_NN, S_NN, fmt='%4d %6.12f')
with open('S-POINT_vs_T.dat','w') as out_POINT:
    np.savetxt(out_POINT, S_POINT, fmt='%4d %6.12f')
