forked from lijiext/lammps
Add extract_compute, extract_fix, and extract_variable to lammps.numpy
This commit is contained in:
parent
a216d3f5f5
commit
0b8136a38b
|
@ -435,6 +435,54 @@ class lammps(object):
|
|||
|
||||
return self.darray(raw_ptr, nelem, dim)
|
||||
|
||||
def extract_compute(self, cid, style, datatype):
|
||||
value = self.lmp.extract_compute(cid, style, datatype)
|
||||
|
||||
if style in (LMP_STYLE_GLOBAL, LMP_STYLE_LOCAL):
|
||||
if datatype == LMP_TYPE_VECTOR:
|
||||
nrows = self.lmp.extract_compute(cid, style, LMP_SIZE_VECTOR)
|
||||
print("NROWS", nrows)
|
||||
return self.darray(value, nrows)
|
||||
elif datatype == LMP_TYPE_ARRAY:
|
||||
nrows = self.lmp.extract_compute(cid, style, LMP_SIZE_ROWS)
|
||||
ncols = self.lmp.extract_compute(cid, style, LMP_SIZE_COLS)
|
||||
return self.darray(value, nrows, ncols)
|
||||
elif style == LMP_STYLE_ATOM:
|
||||
if datatype == LMP_TYPE_VECTOR:
|
||||
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
|
||||
return self.darray(value, nlocal)
|
||||
elif datatype == LMP_TYPE_ARRAY:
|
||||
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
|
||||
ncols = self.lmp.extract_compute(cid, style, LMP_SIZE_COLS)
|
||||
return self.darray(value, nlocal, ncols)
|
||||
return value
|
||||
|
||||
def extract_fix(self, fid, style, datatype, nrow=0, ncol=0):
|
||||
value = self.lmp.extract_fix(fid, style, datatype, nrow, ncol)
|
||||
if style == LMP_STYLE_ATOM:
|
||||
if datatype == LMP_TYPE_VECTOR:
|
||||
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
|
||||
return self.darray(value, nlocal)
|
||||
elif datatype == LMP_TYPE_ARRAY:
|
||||
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
|
||||
ncols = self.lmp.extract_fix(fid, style, LMP_SIZE_COLS, 0, 0)
|
||||
return self.darray(value, nlocal, ncols)
|
||||
elif style == LMP_STYLE_LOCAL:
|
||||
if datatype == LMP_TYPE_VECTOR:
|
||||
nrows = self.lmp.extract_fix(fid, style, LMP_SIZE_ROWS, 0, 0)
|
||||
return self.darray(value, nrows)
|
||||
elif datatype == LMP_TYPE_ARRAY:
|
||||
nrows = self.lmp.extract_fix(fid, style, LMP_SIZE_ROWS, 0, 0)
|
||||
ncols = self.lmp.extract_fix(fid, style, LMP_SIZE_COLS, 0, 0)
|
||||
return self.darray(value, nrows, ncols)
|
||||
return value
|
||||
|
||||
def extract_variable(self, name, group=None, datatype=LMP_VAR_EQUAL):
|
||||
value = self.lmp.extract_variable(name, group, datatype)
|
||||
if datatype == LMP_VAR_ATOM:
|
||||
return np.ctypeslib.as_array(value)
|
||||
return value
|
||||
|
||||
def iarray(self, c_int_type, raw_ptr, nelem, dim=1):
|
||||
np_int_type = self._ctype_to_numpy_int(c_int_type)
|
||||
|
||||
|
|
Loading…
Reference in New Issue