#!/usr/bin/python
import sys, math, copy, time, os, glob, shutil
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import argparse
from pathlib2 import Path

parser = argparse.ArgumentParser()
# Arguments (required)
parser.add_argument('L', type=int, help='length input for N=2(L^3) atom configuration')
#parser.add_argument('dir',type=str,help='directory of histograms')
parser.add_argument('-T_i',type=float,help='temperature for initial guess for F')
parser.add_argument('N_it',type=int, help='number of iterations before automatically finishing')
parser.add_argument('tol',type=int,help='tolerance exponent; iterations stop once sum(dF^2) < 1/10**(tol)')
parser.add_argument('mix',type=float,help='float between 0 and 1 of how much of updated F contributes to next F')
parser.add_argument('--Tlist',type=str,default='../Tlist',help='path to file with list of all run temperatures')
parser.add_argument('-F',type=str,default=None,help='path to file with initial guess for F')
parser.add_argument('-Ts',type=str,default=None,help='path to file with T values (used to generate "continous" data; uses `../Tlist` by default)')
args = parser.parse_args()

kB = 8.617333262e-5 # Boltzmann constant (eV / K) 
leng=args.L ; L=2*leng; N=2*(leng**3)

with open(args.Tlist, 'r') as TL:
    Tlist = np.loadtxt(TL,dtype=float)
    TL.close()

N_it=args.N_it
tol=10**(-args.tol)
mix=args.mix

# Load all histograms and write to their respective temperature in H_T
H_all = {}; tmpH = []
for T in Tlist:
    fmt_T = "%04d" % T
    #with open(args.dir+'/HistE_T_'+fmt_T+'.dat','r') as HistFile:
    with open('E-Histograms/HistE_T_'+fmt_T+'.dat','r') as HistFile:
        tmp=np.loadtxt(HistFile,dtype=float)
        if args.F == None and float(fmt_T) == args.T_i:
            H_i = np.copy(tmp)
        tmpH.append(tmp)
        for h in tmp:
            if h[0] in H_all:
                H_all[h[0]] += h[1]
            else:
                H_all[h[0]] = h[1]
        HistFile.close()
H = np.array(tmpH)
E=np.array(H_all.keys())
Eshift=np.float128(np.mean(E))
E = np.subtract(E,Eshift)
E = E.astype(np.float128)
N_E=len(E)

# If initial guess for F is provided, load it
if args.F != None:
    with open(args.F, 'r') as Ffile:
        Tlist = np.loadtxt(Ffile, usecols=0, dtype=np.float128); Ffile.close()
    with open(args.F, 'r') as Ffile:
        F = np.loadtxt(Ffile, usecols=1, dtype=np.float128); Ffile.close()
    N_T = len(Tlist)
# Otherwise, use single histogram at T_i to generate an initial guess for F
else:
    N_T = len(Tlist)
    T_i=args.T_i
    # Get an initial free energy using histogram for Ti
    F=np.zeros(N_T, dtype=np.float128); W_i=np.zeros(len(H_i),dtype=np.float128)
    E_i=np.zeros(len(H_i),dtype=np.float128)
    # Loop over each bin in the histogram of T_i
    for e in range(len(H_i)):
        # Add each energy bin from H_i to E_i
        E_i[e]=H_i[e][0]
        # Create a reweighted histogram at T_i
        W_i[e]=H_i[e][1]*np.exp((H_i[e][0]-Eshift)/(kB*T_i)) # Density of states
    E_i = E_i - Eshift
    for it in range(N_T):
        T1 = Tlist[it]
        zt = sum(W_i * np.exp(-E_i/(kB*T1)))
        if zt > 0:
            F[it] = -kB * T1 * np.log(zt)

# Initialize arrays
Z=np.zeros(N_T, dtype=np.float128); C=np.zeros(N_T, dtype=np.float128)
Eav=np.zeros(N_T, dtype=np.float128); E2av=np.zeros(N_T, dtype=np.float128)
Fnew=np.zeros(N_T, dtype=np.float128)
W=np.zeros(N_E)
num = np.zeros(N_E,dtype=np.float128)
for e in range(N_E):
    num[e]=H_all[np.add(E[e],Eshift)]
num = num.astype(np.float128)
# MAIN LOOP
cur = 0
for l in range(N_it):
    W=np.zeros(N_E)
    dF2 = 0.
    den = np.zeros(N_E,dtype=np.float128)
    for ie in range(N_E):
        den[ie] = sum(np.exp((F - E[ie]) / (kB*Tlist)))
    W = (num / den) / sum(num / den)
    Z=np.zeros(N_T, dtype=np.float128); C=np.zeros(N_T, dtype=np.float128)
    Eav=np.zeros(N_T, dtype=np.float128); E2av=np.zeros(N_T, dtype=np.float128)
    Fnew=np.zeros(N_T, dtype=np.float128) 
    for it in range(N_T):
        T = Tlist[it]
        Z[it] = sum(W * np.exp(-E/(kB*T)))
        Eav[it] = sum(W * E * np.exp((-E)/(kB*T))) / Z[it]
        E2av[it] = sum(W * (E**2) * np.exp((-E)/(kB*T))) / Z[it]
        C[it] = (E2av[it] - Eav[it]**2) / (kB * (T**2))
        if Z[it] > 0:
            Fnew[it] = -kB*T*np.log(Z[it])
        dF2 += (Fnew[it]-F[it])**2 
    F = mix*Fnew + (1. - mix)*F
    #print('After iteration %d, dF2 = %12.12f' % (l,dF2))
    cur += 1
    if dF2 < tol:
        break

if args.Ts != None:
    with open(args.Ts, 'r') as TsFile:
        Ts = np.loadtxt(TsFile, dtype=np.float128)
else:    
    Ts = np.copy(Tlist)

N_Ts=len(Ts)
Z_f=np.zeros(N_Ts, dtype=np.float128); C_f=np.zeros(N_Ts, dtype=np.float128)
Eav_f=np.zeros(N_Ts, dtype=np.float128); E2av_f=np.zeros(N_Ts, dtype=np.float128)
F_f=np.zeros(N_Ts, dtype=np.float128) 
for it in range(N_Ts):
    T=Ts[it]
    Z_f[it] = sum(W * np.exp(-E/(kB*T)))
    Eav_f[it] = sum(W * E * np.exp((-E)/(kB*T))) / Z_f[it]
    E2av_f[it] = sum(W * (E**2) * np.exp((-E)/(kB*T))) / Z_f[it]
    C_f[it] = (E2av_f[it] - Eav_f[it]**2) / (kB * (T**2))
    if Z_f[it] > 0:
        F_f[it] = -kB*T*np.log(Z_f[it])

S_f = (-1.*F_f + Eav_f)/Ts 

itfile=Path("info.multihist")
# info file layout: [run number] [last iteration] [tolerance]
if itfile.is_file():
    with open('info.multihist', 'r') as infoFile:
        INFO = np.loadtxt(infoFile, dtype=int)
        if np.shape(INFO)==(3,):
            run = INFO[0]+1
            it = INFO[1] + cur
        else:
            run = INFO[-1,0] + 1
            it = INFO[-1,1] + cur
        infoFile.close()
    with open('info.multihist', 'a') as fileOut:
        fileOut.write("%d %d %d\n" % (run, it, args.tol))
        fileOut.close()
else:
    with open('info.multihist', 'w') as fileOut:
        run = 1
        fileOut.write("%d %d %d\n" % (1, cur, args.tol))
        fileOut.close()

np.savetxt('F_vs_T.multihist.'+str(run)+'.dat',zip(Ts,F_f),fmt='%10.1f  %8e')    
np.savetxt('C_vs_T.multihist.'+str(run)+'.dat',zip(Ts,C_f),fmt='%10.1f  %8e')
np.savetxt('U_vs_T.multihist.'+str(run)+'.dat',zip(Ts,Eav_f),fmt='%10.1f  %8e')
np.savetxt('S_vs_T.multihist.'+str(run)+'.dat',zip(Ts,S_f),fmt='%10.1f  %8e')
