#!/usr/bin/python3

# program pdoscar parses DOSCAR files to extract total and partial DOS in
# format suitable for plotting (e.g. save to file or pipe into xmgrace)
# Syntax: pdoscar DOSCAR -flags
# Examples:
# 1. "pdoscar DOSCAR" to extract total DOS or spin up and down components.
# 2. "pdoscar DOSCAR -atom1 N1" to extract all local DOS components of atom N1.
# 3. "pdoscar DOSCAR -atom1 N1 -atom2 N2" to sum over atoms N1 to N2>N1.
# 4. "pdoscar DOSCAR -atom1 N1 -p" to sum all the p orbitals of atom N1.
# Atom numbering starts at 1 (1-indexing)
#
# Currently available options:
# -h -efs -atom1 -atom2 -s -p -d -T2g -Eg -f -TotAtom -TotUpDown.
# See below for definitions.
# Note VASP partial DOS ordering (columns)
# E s py pz px dxy dyz dz2 dxz dx2 f-3 f-2 f-1 f0 f1 f2 f3
#
# Spin polarized output will alternate spin reports spin up, then down; Noncolinear spins not handled
#
# -p, etc. will sum over px, py and pz and spin components
# likewise -s, -d and -f
#
# TotAtom sums all partials over range of atoms
# TotUpDown splits sum into spin +/-
##### -Suppress will suppress the complete list of partials DEPRECATED
##### keeping only those requested on the command line
#
# Spin polarized VASP DOSCAR ordering is:
# E DOS+ DOS- i-DOS+ i-DOS-
# followed by loop over atoms:
# E s+ s- px+ px- .....

import sys
import numpy as np

if len(sys.argv) == 1 or '-h' in sys.argv:
    print("Syntax: pdoscar DOSCAR -options")
    print("Options: -h -efs -s -p -d -T2g -Eg -f -TotAtom -TotUpDown -atom1 N1 -atom2 N2")
    print("Columns: E s py pz px dxy dyz dz2 dxz dx2 f-3 f-2 f-1 f0 f1 f2 f3")
    print("See source code for details")
    exit()

def DosPrint(dos):
    for row in dos:
        print('{:8.3f}'.format(row[0]), end=" ")
        for elem in row[1:]:
            print('{:10.6f}'.format(elem), end=" ")
        print(end="\n")

data = []

with open(sys.argv[1]) as f:
    for line in f:
        data += [line.rstrip().split()]

Natom = int(data[0][0])
Nedos = int(data[5][2])
Nspin = int((len(data[6])-1)/2)
EF = float(data[5][3]) if '-efs' in sys.argv else 0.
sStart = 1; sStop = sStart+Nspin
pStart = sStop; pStop = pStart+3*Nspin
dStart = pStop; dStop = dStart+5*Nspin
fStart = dStop; fStop = fStart+7*Nspin

# Default option first block of DOSCAR data
if '-atom1' not in sys.argv:
    DataStart = 6
    DataStop = DataStart+Nedos
    dos = np.asarray(data[DataStart:DataStop],dtype=float)
    dos[...,0] -= EF

    DosPrint(dos[...,0:Nspin+1])
    exit()

# Local DOS requires at least one atom selection. Syntax: -atom1 N1
if '-atom1' in sys.argv:
    i = sys.argv.index('-atom1')
    Atom1 = int(sys.argv[i+1])
    DataStart = 6+Atom1*(Nedos+1)
    DataStop = DataStart+Nedos
    if DataStop > len(data):
        print("Insufficient data! Check if LORBIT = 12")
        exit()
    Nproj = len(data[DataStart])
    dos = np.asarray(data[DataStart:DataStop],dtype=float)
    dos[...,0] -= EF

# Option to define range of atoms. Syntax -atom1 N1 -atom2 N2
# Local DOS will be summed over atoms
if '-atom2' in sys.argv:
    i = sys.argv.index('-atom2')
    Atom2 = int(sys.argv[i+1])

    for Atom in range(Atom1,Atom2):
        DataStart += Nedos+1 # skip headder line
        DataStop = DataStart+Nedos
        dos2 = np.asarray(data[DataStart:DataStop],dtype=float)
        dos[...,1:] += dos2[...,1:]

# TotAtom will sum partial DOS over all orbitals. Other options will sum
# over subsets of orbitals. Results are appended after the partials.
if '-TotAtom' in sys.argv:
    AtomTot = np.sum(dos[...,1:],axis=-1).reshape((Nedos,1))
    dos = np.insert(dos,[1],AtomTot,axis=1)
    #dos = np.append(dos,AtomTot,axis=1)

if '-TotUpDown' in sys.argv:
    UpTot = np.sum(dos[...,1::2],axis=-1).reshape((Nedos,1))
    DownTot = np.sum(dos[...,2::2],axis=-1).reshape((Nedos,1))
    dos = np.append(dos,UpTot,axis=1)
    dos = np.append(dos,DownTot,axis=1)

if '-s' in sys.argv:
    sTot = np.sum(dos[...,sStart:sStop],axis=-1).reshape((Nedos,1))
    dos = np.append(dos,sTot,axis=1)

if '-p' in sys.argv:
    pTot = np.sum(dos[...,pStart:pStop],axis=-1).reshape((Nedos,1))
    dos = np.append(dos,pTot,axis=1)

# Note sequence of partial d-orbitals
if '-T2g' in sys.argv:
    T2g = np.sum(dos[...,dStart:dStart+2*Nspin],axis=-1).reshape((Nedos,1))
    T2g += np.sum(dos[...,dStart+3*Nspin:dStart+4*Nspin],axis=-1).reshape((Nedos,1))
    dos = np.append(dos,T2g,axis=1)

# Note sequence of partial d-orbitals
if '-Eg' in sys.argv:
    Eg = np.sum(dos[...,dStart+2*Nspin:dStart+3*Nspin],axis=-1).reshape((Nedos,1))
    Eg += np.sum(dos[...,dStart+4*Nspin:dStop],axis=-1).reshape((Nedos,1))
    dos = np.append(dos,Eg,axis=1)

if '-d' in sys.argv:
    dTot = np.sum(dos[...,dStart:dStop],axis=-1).reshape((Nedos,1))
    dos = np.append(dos,dTot,axis=1)

if '-f' in sys.argv:
    fTot = np.sum(dos[...,fStart:fStop],axis=-1).reshape((Nedos,1))
    dos = np.append(dos,AtomTot,axis=1)

## Option to suppress the partials and keep only the sums.
#if '-Suppress' in sys.argv:
#    if '-TotAtom' in sys.argv:
#        sStart += 1
#    dos = np.delete(dos,range(sStart,Nproj),axis=1)

DosPrint(dos)        
    
exit()
