forked from lijiext/lammps
vectorized in parts and made changes as suggested by evoyiatzis
This commit is contained in:
parent
bbb0f5740e
commit
f41a1f8303
|
@ -38,11 +38,11 @@ StringIO (or io if in Python 3.x)
|
|||
|
||||
|
||||
|
||||
import os, sys, numpy as np, argparse, time, pickle
|
||||
import os, numpy as np, argparse, time, pickle
|
||||
from scipy.special import logsumexp
|
||||
from mpi4py import MPI
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
from tqdm import tqdm
|
||||
import gzip, bz2
|
||||
try:
|
||||
# python-2
|
||||
|
@ -78,12 +78,10 @@ def _get_nearest_temp(temps, query_temp):
|
|||
"""
|
||||
|
||||
if isinstance(temps, list): temps = np.array(temps)
|
||||
idx = np.argmin(abs(temps - query_temp))
|
||||
out_temp = temps[idx]
|
||||
return out_temp
|
||||
return temps[np.argmin(np.abs(temps-query_temp))]
|
||||
|
||||
|
||||
def readwrite(trajfn, mode = "rb"):
|
||||
def readwrite(trajfn, mode):
|
||||
"""
|
||||
Helper function for input/output LAMMPS traj files.
|
||||
Trajectories may be plain text, .gz or .bz2 compressed.
|
||||
|
@ -96,11 +94,14 @@ def readwrite(trajfn, mode = "rb"):
|
|||
"""
|
||||
|
||||
if trajfn.endswith(".gz"):
|
||||
return gzip.GzipFile(trajfn, mode)
|
||||
of = gzip.open(trajfn, mode)
|
||||
#return gzip.GzipFile(trajfn, mode)
|
||||
elif trajfn.endswith(".bz2"):
|
||||
return bz2.BZ2File(trajfn, mode)
|
||||
of = bz2.open(trajfn, mode)
|
||||
#return bz2.BZ2File(trajfn, mode)
|
||||
else:
|
||||
return file(trajfn, mode)
|
||||
of = open(trajfn, mode)
|
||||
return of
|
||||
|
||||
|
||||
def get_replica_frames(logfn, temps, nswap, writefreq):
|
||||
|
@ -163,7 +164,7 @@ def get_byte_index(rep_inds, byteindfns, intrajfns):
|
|||
if os.path.isfile(byteindfns[n]): continue
|
||||
|
||||
# extract bytes
|
||||
fobj = readwrite(intrajfns[n])
|
||||
fobj = readwrite(intrajfns[n], "rb")
|
||||
byteinds = [ [0,0] ]
|
||||
|
||||
# place file pointer at first line
|
||||
|
@ -243,7 +244,7 @@ def write_reordered_traj(temp_inds, byte_inds, outtemps, temps,
|
|||
for n in temp_inds:
|
||||
# open string-buffer and file
|
||||
buf = IOBuffer()
|
||||
of = readwrite(outtrajfns[n], mode = "wb")
|
||||
of = readwrite(outtrajfns[n], "wb")
|
||||
|
||||
# get frames
|
||||
abs_temp_ind = np.argmin( abs(temps - outtemps[n]) )
|
||||
|
@ -281,7 +282,7 @@ def write_reordered_traj(temp_inds, byte_inds, outtemps, temps,
|
|||
|
||||
|
||||
def get_canonical_logw(enefn, frametuple_dict, temps, nprod, writefreq,
|
||||
kB = 0.001987):
|
||||
kB):
|
||||
"""
|
||||
Gets configurational log-weights (logw) for each frame and at each temp.
|
||||
from the REMD simulation. ONLY WRITTEN FOR THE CANONICAL (NVT) ensemble.
|
||||
|
@ -348,25 +349,25 @@ def get_canonical_logw(enefn, frametuple_dict, temps, nprod, writefreq,
|
|||
#3) get reduced energies (*ONLY FOR THE CANONICAL ENSEMBLE*)
|
||||
u_kln = np.zeros([ntemps, ntemps, nframes], float)
|
||||
for k in range(ntemps):
|
||||
for l in range(ntemps):
|
||||
u_kln[ k, l, 0:nframes_k[k] ] = beta_k[l] * u_kn[k, 0:nframes_k[k]]
|
||||
|
||||
u_kln[k] = np.outer(beta_k, u_kn[k])
|
||||
|
||||
# run pymbar and extract the free energies
|
||||
print("\nRunning pymbar...")
|
||||
mbar = pymbar.mbar.MBAR(u_kln, nframes_k, verbose = True)
|
||||
f_k = mbar.f_k
|
||||
f_k = mbar.f_k # (1 x k array)
|
||||
|
||||
# calculate the log-weights
|
||||
print("\nExtracting log-weights...")
|
||||
log_nframes = np.log(nframes)
|
||||
logw = dict( (k, np.zeros([ntemps, nframes], float)) for k in range(ntemps) )
|
||||
for l in range(ntemps):
|
||||
# get log-weights to reweight to this temp.
|
||||
for k in range(ntemps):
|
||||
for n in range(nframes):
|
||||
num = -beta_k[k] * u_kn[k,n]
|
||||
denom = f_k - beta_k[k] * u_kn[k,n]
|
||||
# get log-weights to reweight to this temp.
|
||||
for k in range(ntemps):
|
||||
for n in range(nframes):
|
||||
num = -beta_k[k] * u_kn[k,n]
|
||||
denom = f_k - beta_k[k] * u_kn[k,n]
|
||||
for l in range(ntemps):
|
||||
logw[l][k,n] = num - logsumexp(denom) - log_nframes
|
||||
|
||||
return logw
|
||||
|
||||
|
||||
|
@ -515,7 +516,7 @@ if __name__ == "__main__":
|
|||
comm.barrier()
|
||||
|
||||
# open all replica files for reading
|
||||
infobjs = [readwrite(i) for i in intrajfns]
|
||||
infobjs = [readwrite(i, "rb") for i in intrajfns]
|
||||
|
||||
# open all byteindex files
|
||||
byte_inds = dict( (i, np.loadtxt(fn)) for i, fn in enumerate(byteindfns) )
|
||||
|
|
Loading…
Reference in New Issue