#!/usr/bin/python
import sys
import numpy as np
from scipy.special import xlogy

# Entropy is in units of kB

with open('Natoms','r') as Nfile:
    Natoms = int(Nfile.readline(4))
    Nfile.close()

take=int(sys.argv[1])

with open('Tlist','r') as TlistFile:
    Tlist = np.loadtxt(TlistFile, dtype=int)
    TlistFile.close()

S_TET = []; S_TRI = []; S_NNN = []; S_NN = []; S_POINT = []

for T in Tlist:
    fmtT="%04d" % T
    with open('correlations_data/ZABGD_T'+fmtT,'r') as ZABGDcorr:
        Z0=np.loadtxt(ZABGDcorr)
        l=len(Z0)
        ZABGD0=Z0[l-take:,:]
        Nrun=len(ZABGD0)

    ZABGD=np.sum(ZABGD0,axis=0)/Nrun
    ZABGD=ZABGD/(6*Natoms)
    ZABGD=ZABGD.reshape((4,4,4,4))
    UAGD=np.sum(ZABGD,1)
    UBGD=np.sum(ZABGD,0)
    UGAB=np.swapaxes(np.swapaxes(np.sum(ZABGD,3),0,2),1,2)
    UDAB=np.swapaxes(np.swapaxes(np.sum(ZABGD,2),0,2),1,2)
    YAG=np.swapaxes(np.sum(UGAB,2),0,1)
    YAD=np.swapaxes(np.sum(UDAB,2),0,1)
    YBG=np.swapaxes(np.sum(UGAB,1),0,1)
    YBD=np.swapaxes(np.sum(UDAB,1),0,1)
    VAB=np.sum(UDAB,0)
    VGD=np.sum(UAGD,0)
    XA=np.sum(YAG,1); XB=np.sum(YBG,1)
    XG=np.sum(YAG,0); XD=np.sum(YAD,0)

    S_X = np.sum(xlogy(XA,XA)) + np.sum(xlogy(XB,XB)) + np.sum(xlogy(XG,XG)) + np.sum(xlogy(XD,XD))
    S_Y = np.sum(xlogy(YAG,YAG)) + np.sum(xlogy(YAD,YAD)) + np.sum(xlogy(YBG,YBG)) + np.sum(xlogy(YBD,YBD))
    S_V = np.sum(xlogy(VAB,VAB)) + np.sum(xlogy(VGD,VGD))
    S_U = np.sum(xlogy(UAGD,UAGD)) + np.sum(xlogy(UBGD,UBGD)) + np.sum(xlogy(UGAB,UGAB)) + np.sum(xlogy(UDAB,UDAB))
    S_Z = np.sum(xlogy(ZABGD,ZABGD))

    s_TET = (1./4.)*S_X + (-1.)*S_Y + (-3./2)*S_V + 3.*S_U + (-6.)*S_Z
    s_TRI = (-23./4)*S_X + (5.)*S_Y + (9./2)*S_V + (-3.)*S_U
    s_NNN = (13./4)*S_X + (-1.)*S_Y + (-3./2)*S_V
    s_NN = (7./4)*S_X + (-1.)*S_Y
    s_POINT = (-1./4.)*S_X

    S_TET.append([T, s_TET])
    S_TRI.append([T, s_TRI])
    S_NNN.append([T, s_NNN])
    S_NN.append([T, s_NN])
    S_POINT.append([T, s_POINT])

S_TET = np.array(S_TET)
S_TRI = np.array(S_TRI)
S_NNN = np.array(S_NNN)
S_NN  = np.array(S_NN)
S_POINT = np.array(S_POINT)

with open('S-TET_vs_T.dat','w') as out_TET:
    np.savetxt(out_TET, S_TET, fmt='%4d %6.12f')
with open('S-TRI_vs_T.dat','w') as out_TRI:
    np.savetxt(out_TRI, S_TRI, fmt='%4d %6.12f')
with open('S-NNN_vs_T.dat','w') as out_NNN:
    np.savetxt(out_NNN, S_NNN, fmt='%4d %6.12f')
with open('S-NN_vs_T.dat','w') as out_NN:
    np.savetxt(out_NN, S_NN, fmt='%4d %6.12f')
with open('S-POINT_vs_T.dat','w') as out_POINT:
    np.savetxt(out_POINT, S_POINT, fmt='%4d %6.12f')
