vectorized in parts and made changes as suggested by evoyiatzis

This commit is contained in:
tanmoy.7989 2019-09-08 10:43:22 -07:00
parent bbb0f5740e
commit f41a1f8303
1 changed files with 24 additions and 23 deletions

View File

@ -38,11 +38,11 @@ StringIO (or io if in Python 3.x)
import os, sys, numpy as np, argparse, time, pickle
import os, numpy as np, argparse, time, pickle
from scipy.special import logsumexp
from mpi4py import MPI
from tqdm import tqdm, trange
from tqdm import tqdm
import gzip, bz2
try:
# python-2
@ -78,12 +78,10 @@ def _get_nearest_temp(temps, query_temp):
"""
if isinstance(temps, list): temps = np.array(temps)
idx = np.argmin(abs(temps - query_temp))
out_temp = temps[idx]
return out_temp
return temps[np.argmin(np.abs(temps-query_temp))]
def readwrite(trajfn, mode = "rb"):
def readwrite(trajfn, mode):
"""
Helper function for input/output LAMMPS traj files.
Trajectories may be plain text, .gz or .bz2 compressed.
@ -96,11 +94,14 @@ def readwrite(trajfn, mode = "rb"):
"""
if trajfn.endswith(".gz"):
return gzip.GzipFile(trajfn, mode)
of = gzip.open(trajfn, mode)
#return gzip.GzipFile(trajfn, mode)
elif trajfn.endswith(".bz2"):
return bz2.BZ2File(trajfn, mode)
of = bz2.open(trajfn, mode)
#return bz2.BZ2File(trajfn, mode)
else:
return file(trajfn, mode)
of = open(trajfn, mode)
return of
def get_replica_frames(logfn, temps, nswap, writefreq):
@ -163,7 +164,7 @@ def get_byte_index(rep_inds, byteindfns, intrajfns):
if os.path.isfile(byteindfns[n]): continue
# extract bytes
fobj = readwrite(intrajfns[n])
fobj = readwrite(intrajfns[n], "rb")
byteinds = [ [0,0] ]
# place file pointer at first line
@ -243,7 +244,7 @@ def write_reordered_traj(temp_inds, byte_inds, outtemps, temps,
for n in temp_inds:
# open string-buffer and file
buf = IOBuffer()
of = readwrite(outtrajfns[n], mode = "wb")
of = readwrite(outtrajfns[n], "wb")
# get frames
abs_temp_ind = np.argmin( abs(temps - outtemps[n]) )
@ -281,7 +282,7 @@ def write_reordered_traj(temp_inds, byte_inds, outtemps, temps,
def get_canonical_logw(enefn, frametuple_dict, temps, nprod, writefreq,
kB = 0.001987):
kB):
"""
Gets configurational log-weights (logw) for each frame and at each temp.
from the REMD simulation. ONLY WRITTEN FOR THE CANONICAL (NVT) ensemble.
@ -348,25 +349,25 @@ def get_canonical_logw(enefn, frametuple_dict, temps, nprod, writefreq,
#3) get reduced energies (*ONLY FOR THE CANONICAL ENSEMBLE*)
u_kln = np.zeros([ntemps, ntemps, nframes], float)
for k in range(ntemps):
for l in range(ntemps):
u_kln[ k, l, 0:nframes_k[k] ] = beta_k[l] * u_kn[k, 0:nframes_k[k]]
u_kln[k] = np.outer(beta_k, u_kn[k])
# run pymbar and extract the free energies
print("\nRunning pymbar...")
mbar = pymbar.mbar.MBAR(u_kln, nframes_k, verbose = True)
f_k = mbar.f_k
f_k = mbar.f_k # (1 x k array)
# calculate the log-weights
print("\nExtracting log-weights...")
log_nframes = np.log(nframes)
logw = dict( (k, np.zeros([ntemps, nframes], float)) for k in range(ntemps) )
for l in range(ntemps):
# get log-weights to reweight to this temp.
for k in range(ntemps):
for n in range(nframes):
num = -beta_k[k] * u_kn[k,n]
denom = f_k - beta_k[k] * u_kn[k,n]
# get log-weights to reweight to this temp.
for k in range(ntemps):
for n in range(nframes):
num = -beta_k[k] * u_kn[k,n]
denom = f_k - beta_k[k] * u_kn[k,n]
for l in range(ntemps):
logw[l][k,n] = num - logsumexp(denom) - log_nframes
return logw
@ -515,7 +516,7 @@ if __name__ == "__main__":
comm.barrier()
# open all replica files for reading
infobjs = [readwrite(i) for i in intrajfns]
infobjs = [readwrite(i, "rb") for i in intrajfns]
# open all byteindex files
byte_inds = dict( (i, np.loadtxt(fn)) for i, fn in enumerate(byteindfns) )