Add extract_compute, extract_fix, and extract_variable to lammps.numpy

This commit is contained in:
Richard Berger 2020-08-27 16:15:59 -04:00
parent a216d3f5f5
commit 0b8136a38b
No known key found for this signature in database
GPG Key ID: A9E83994E0BA0CAB
1 changed files with 48 additions and 0 deletions

View File

@ -435,6 +435,54 @@ class lammps(object):
return self.darray(raw_ptr, nelem, dim)
def extract_compute(self, cid, style, datatype):
value = self.lmp.extract_compute(cid, style, datatype)
if style in (LMP_STYLE_GLOBAL, LMP_STYLE_LOCAL):
if datatype == LMP_TYPE_VECTOR:
nrows = self.lmp.extract_compute(cid, style, LMP_SIZE_VECTOR)
print("NROWS", nrows)
return self.darray(value, nrows)
elif datatype == LMP_TYPE_ARRAY:
nrows = self.lmp.extract_compute(cid, style, LMP_SIZE_ROWS)
ncols = self.lmp.extract_compute(cid, style, LMP_SIZE_COLS)
return self.darray(value, nrows, ncols)
elif style == LMP_STYLE_ATOM:
if datatype == LMP_TYPE_VECTOR:
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
return self.darray(value, nlocal)
elif datatype == LMP_TYPE_ARRAY:
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
ncols = self.lmp.extract_compute(cid, style, LMP_SIZE_COLS)
return self.darray(value, nlocal, ncols)
return value
def extract_fix(self, fid, style, datatype, nrow=0, ncol=0):
value = self.lmp.extract_fix(fid, style, datatype, nrow, ncol)
if style == LMP_STYLE_ATOM:
if datatype == LMP_TYPE_VECTOR:
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
return self.darray(value, nlocal)
elif datatype == LMP_TYPE_ARRAY:
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
ncols = self.lmp.extract_fix(fid, style, LMP_SIZE_COLS, 0, 0)
return self.darray(value, nlocal, ncols)
elif style == LMP_STYLE_LOCAL:
if datatype == LMP_TYPE_VECTOR:
nrows = self.lmp.extract_fix(fid, style, LMP_SIZE_ROWS, 0, 0)
return self.darray(value, nrows)
elif datatype == LMP_TYPE_ARRAY:
nrows = self.lmp.extract_fix(fid, style, LMP_SIZE_ROWS, 0, 0)
ncols = self.lmp.extract_fix(fid, style, LMP_SIZE_COLS, 0, 0)
return self.darray(value, nrows, ncols)
return value
def extract_variable(self, name, group=None, datatype=LMP_VAR_EQUAL):
value = self.lmp.extract_variable(name, group, datatype)
if datatype == LMP_VAR_ATOM:
return np.ctypeslib.as_array(value)
return value
def iarray(self, c_int_type, raw_ptr, nelem, dim=1):
np_int_type = self._ctype_to_numpy_int(c_int_type)