#!/usr/bin/python
import sys, os, glob, random 
import numpy as np
import argparse

parser = argparse.ArgumentParser()
# Arguments (required)
parser.add_argument('L', type=int, help='length input for N=2(l^3) atom configuration')
# Options
parser.add_argument('--energies', metavar='EFILE', default='MoNbTaW_tetEs', help='file containing lookup table of energies for all\
        256 decorations of tetrahedra; default: `%(default)s`')
args = parser.parse_args()

L=2*args.L; N=2*(args.L**3)
kB = 8.617333262e-5 # Boltzmann constant (eV / K) 
NNdirs = [(-1,1,1),(-1,-1,1),(-1,-1,-1),(-1,1,-1),(1,1,1),(1,-1,1),(1,-1,-1),(1,1,-1)]
NNNdirs = [(2,0,0),(0,2,0),(0,0,2),(-2,0,0),(0,-2,0),(0,0,-2)]

allCoords=[]
for l in range(L): 
    for r in range(L): 
        for c in range(L):
            if l%2==r%2==c%2: allCoords.append((l,r,c))
J = {}
with open(args.energies, 'r') as ClusterEnergies:
    for line in ClusterEnergies:
        cE = line.split()
        a,b,c,d,E=int(cE[0]),int(cE[1]),int(cE[2]),int(cE[3]),float(cE[4])
        J[(a,b,c,d)]=E/24.0

trigs = [[(2,0,0),(1,1,1),(1,1,-1)],[(2,0,0),(1,1,1),(1,-1,1)],[(0,2,0),(1,1,1),(1,1,-1)],
         [(0,2,0),(1,1,1),(-1,1,1)],[(0,0,2),(1,1,1),(-1,1,1)],[(0,0,2),(1,1,1),(1,-1,1)]]
def getTetrahedra():
    allTet=set()
    for (l0,r0,c0) in allCoords:
        SL0=(l0+r0+c0)%4
        for i in range(6):
            sites=[0,0,0,0]
            sites[SL0]=(l0,r0,c0)
            sites[(SL0+2)%4]=((l0+trigs[i][0][0])%L,(r0+trigs[i][0][1])%L,(c0+trigs[i][0][2])%L)
            sites[(SL0+3)%4]=((l0+trigs[i][1][0])%L,(r0+trigs[i][1][1])%L,(c0+trigs[i][1][2])%L)
            sites[(SL0+1)%4]=((l0+trigs[i][2][0])%L,(r0+trigs[i][2][1])%L,(c0+trigs[i][2][2])%L)
            allTet.add(tuple(sites))
    return allTet

allTetrahedra=getTetrahedra()

def Ham(sig):
    H=0.
    for tet in allTetrahedra:
        H+=J[sig[tet[0]],sig[tet[2]],sig[tet[1]],sig[tet[3]]]
    return H

def dE(file1,file2):
    sig1 = np.loadtxt(file1, dtype=int).reshape((L,L,L))
    sig2 = np.loadtxt(file2, dtype=int).reshape((L,L,L))
    E1 = Ham(sig1); E2 = Ham(sig2)
    return (E1 - E2)

#### Code below was taken from `/home/MD/vasp/hybrid-tmd/CuZn/3x3x3` and slightly adapted

def fix_dir(dir,Tnew):
    with open(dir+"/Trun",'w') as Tfile: Tfile.write(str(int(Tnew))); Tfile.close()

runs = glob.glob('run*'); nruns = len(runs); Tdict = {}; Tlist = []
if os.path.isfile("./Nrun"):
    with open("./Nrun") as fh: Nrun = int(fh.readline())
else:
    Nrun = 0

parity = 0 if nruns == 2 else Nrun%2

for run in runs:
    with open(run+"/Trun") as fh:
        Trun = float(fh.readline())
        Tdict[Trun] = run
        Tlist.append(Trun)
        fh.close()                                                                                                                                                          
Tlist.sort()

for a in range(parity, nruns-1, 2):
    Ta = Tlist[a]; Tb = Tlist[a+1]
    adir = Tdict[Ta]; bdir = Tdict[Tb]
    dE_ab = dE(adir+"/MCin", bdir+"/MCin")
    dbeta=(1./kB)*((1./Ta)-(1./Tb))
    prob = 1. if (dbeta*dE_ab) > 0 else np.exp(dbeta*dE_ab)
    ran = random.random()
    if ran < prob:
        status="accept"
        fix_dir(adir,Tb); fix_dir(bdir,Ta)
    else:
        status="reject" 
    print("Cycle %d swap of Ta=%d with Tb=%d dbeta*dE=%.4f prob=%.4f ran=%.4f %6s" % (Nrun,Ta,Tb,dbeta*dE_ab,prob,ran,status))
