#!/usr/bin/python
import sys, math, copy, time, os, glob, shutil
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import argparse
from pathlib2 import Path

parser = argparse.ArgumentParser()
# Arguments (required)
parser.add_argument('L', type=int, help='length input for N=2(l^3) atom configuration')
parser.add_argument('-T_i',type=float,help='temperature for initial guess for F')
parser.add_argument('N_it',type=int, help='number of iterations before automatically finishing')
parser.add_argument('tol',type=int,help='tolerance exponent; iterations stop once sum(dF^2) < 1/10**(tol)')
parser.add_argument('mix',type=float,help='float between 0 and 1 of how much of updated F contributes to next F')
parser.add_argument('--Tlist',type=str,default='../Tlist',help='path to file with list of all run temperatures')
parser.add_argument('-F',type=str,default=None,help='path to file with initial guess for F')
parser.add_argument('-Ts',type=str,default=None,help='path to file with T values (used to generate "continous" data; uses `../Tlist` by default')
parser.add_argument('-bins',type=int,default=100,help='number of bins; default: %(default)s')
args = parser.parse_args()

kB = 8.617333262e-5 # Boltzmann constant (eV / K) 
leng=args.L ; L=2*leng; N=2*(leng**3)

with open(args.Tlist, 'r') as TL:
    Tlist = np.loadtxt(TL,dtype=int)
    TL.close()

N_it=args.N_it
tol=10**(-args.tol)
mix=args.mix
Nbins=args.bins

H_E_all = {}; H_Ep_all = {}
for T in Tlist:
    fmt_T = "%04d" % T
    with open('DoubleHistograms/HistData_T_'+fmt_T+'.dat', 'r') as HistFile:
        Es,ps = np.loadtxt(HistFile,unpack=True,dtype=np.float128)
        H=np.histogram2d(Es,ps,bins=Nbins)
        w_E=(H[1][1]-H[1][0])/2. ; w_p=(H[2][1]-H[2][0])/2.
        counts=H[0]; Ebins=np.delete(H[1],Nbins) + w_E; pbins=np.delete(H[2],Nbins) + w_p
        tmpH_E=np.squeeze(np.dstack((Ebins,np.sum(counts,axis=1))))
        if args.F == None and float(fmt_T) == args.T_i:
            H_i = np.copy(tmpH_E)
        for h in tmpH_E:
            #print(h[0])
            if h[0] in H_E_all: H_E_all[h[0]] += h[1]
            else: H_E_all[h[0]] = h[1]
        for ie in range(Nbins):
            for ip in range(Nbins):
                if (Ebins[ie],pbins[ip]) in H_Ep_all: H_Ep_all[(Ebins[ie],pbins[ip])] += counts[ie,ip]
                else: H_Ep_all[(Ebins[ie],pbins[ip])] = counts[ie,ip]

E=np.array(H_E_all.keys())
Ep=np.array(H_Ep_all.keys())
Eshift=np.float128(np.mean(E))
E = np.subtract(E,Eshift)
E = E.astype(np.float128)
Ep[:,0] -= Eshift
Ep = Ep.astype(np.float128)
N_E=len(E)
N_Ep=len(Ep)
# If initial guess for F is provided, load it
if args.F != None:
    with open(args.F, 'r') as Ffile:
        Tlist,F = np.loadtxt(Ffile, unpack=True, dtype=np.float128); Ffile.close()
    N_T = len(Tlist)
# Otherwise, use single histogram at T_i to generate an initial guess for F
else:
    N_T = len(Tlist)
    T_i=args.T_i
    # Get an initial free energy using histogram for Ti
    F=np.zeros(N_T, dtype=np.float128); W_i=np.zeros(len(H_i),dtype=np.float128)
    E_i=np.zeros(len(H_i),dtype=np.float128)
    # Loop over each bin in the histogram of T_i
    for e in range(len(H_i)):
        # Add each energy bin from H_i to E_i
        E_i[e]=H_i[e][0]
        # Create a reweighted histogram at T_i
        W_i[e]=H_i[e][1]*np.exp((H_i[e][0]-Eshift)/(kB*T_i)) # Density of states
    E_i = E_i - Eshift
    for it in range(N_T):
        T1 = Tlist[it]
        zt = sum(W_i * np.exp(-E_i/(kB*T1)))
        if zt > 0:
            F[it] = -kB * T1 * np.log(zt)

# Initialize arrays
Z=np.zeros(N_T, dtype=np.float128); C=np.zeros(N_T, dtype=np.float128)
Eav=np.zeros(N_T, dtype=np.float128); E2av=np.zeros(N_T, dtype=np.float128)
Fnew=np.zeros(N_T, dtype=np.float128)
W=np.zeros(N_E)
num = np.zeros(N_E,dtype=np.float128)
for ie in range(N_E):
    num[ie]=H_E_all[np.add(E[ie],Eshift)]
num = num.astype(np.float128)
# MAIN LOOP
for l in range(N_it):
    W=np.zeros(N_E)
    dF2 = 0.
    den = np.zeros(N_E,dtype=np.float128)
    for ie in range(N_E):
        den[ie] = sum(np.exp((F - E[ie]) / (kB*Tlist)))
    W = (num / den) / sum(num / den)
    Z=np.zeros(N_T, dtype=np.float128); C=np.zeros(N_T, dtype=np.float128)
    Eav=np.zeros(N_T, dtype=np.float128); E2av=np.zeros(N_T, dtype=np.float128)
    Fnew=np.zeros(N_T, dtype=np.float128) 
    for it in range(N_T):
        T = Tlist[it]
        Z[it] = sum(W * np.exp(-E/(kB*T)))
        Eav[it] = sum(W * E * np.exp((-E)/(kB*T))) / Z[it]
        E2av[it] = sum(W * (E**2) * np.exp((-E)/(kB*T))) / Z[it]
        C[it] = (E2av[it] - Eav[it]**2) / (kB * (T**2))
        if Z[it] > 0:
            Fnew[it] = -kB*T*np.log(Z[it])
        dF2 += (Fnew[it]-F[it])**2 
    F = mix*Fnew + (1. - mix)*F
    #print('After iteration %d, dF2 = %12.12f' % (l,dF2))
    if dF2 < tol: break

if args.Ts != None:
    with open(args.Ts, 'r') as TsFile:
        Ts = np.loadtxt(TsFile, dtype=np.float128)
else:    
    Ts = np.copy(Tlist)

num2d = np.zeros(N_Ep,dtype=np.float128); den2d = np.zeros(N_Ep,dtype=np.float128)
for iep in range(N_Ep):
    num2d[iep] = H_Ep_all[(np.add(Ep[iep,0],Eshift),Ep[iep,1])]
    den2d[iep] = sum(np.exp((F - Ep[iep,0]) / (kB*Tlist)))
W_Ep = (num2d / den2d) / sum(num2d / den2d)
N_Ts=len(Ts)
Z_f=np.zeros(N_Ts, dtype=np.float128); C_f=np.zeros(N_Ts, dtype=np.float128)
Eav_f=np.zeros(N_Ts, dtype=np.float128); E2av_f=np.zeros(N_Ts, dtype=np.float128)
F_f=np.zeros(N_Ts, dtype=np.float128)
pav=np.zeros(N_Ts, dtype=np.float128); p2av=np.zeros(N_Ts,dtype=np.float128)
chi=np.zeros(N_Ts, dtype=np.float128)
for it in range(N_Ts):
    T=Ts[it]
    Z_f[it] = sum(W * np.exp(-E/(kB*T)))
    Eav_f[it] = sum(W * E * np.exp((-E)/(kB*T))) / Z_f[it]
    E2av_f[it] = sum(W * (E**2) * np.exp((-E)/(kB*T))) / Z_f[it]
    C_f[it] = (E2av_f[it] - Eav_f[it]**2) / (kB * (T**2))
    pav[it] = sum(W_Ep * Ep[:,1] * np.exp((-Ep[:,0])/(kB*T))) / Z_f[it]
    p2av[it] = sum(W_Ep * (Ep[:,1]**2) * np.exp((-Ep[:,0])/(kB*T))) / Z_f[it]
    chi[it] = N*(p2av[it] - (pav[it]**2)) / (kB * T)
    if Z_f[it] > 0:
        F_f[it] = -kB*T*np.log(Z_f[it])

S_f = (-1.*F_f + Eav_f)/Ts 

np.savetxt('F_vs_T.multihist.done.dat',zip(Ts,F_f),fmt='%10.1f  %8e')    
np.savetxt('C_vs_T.multihist.done.dat',zip(Ts,C_f),fmt='%10.1f  %8e')
np.savetxt('U_vs_T.multihist.done.dat',zip(Ts,Eav_f),fmt='%10.1f  %8e')
np.savetxt('S_vs_T.multihist.done.dat',zip(Ts,S_f),fmt='%10.1f  %8e')
np.savetxt('p_vs_T.multihist.done.dat',zip(Ts,pav),fmt='%10.1f  %12.12f')
np.savetxt('chi_vs_T.multihist.done.dat',zip(Ts,chi),fmt='%10.1f  %12.12f')
