forked from lijiext/lammps
Detect correct integer type in lammps python interface
This commit is contained in:
parent
b9fd1156b2
commit
93be2d264e
|
@ -4,9 +4,14 @@ import ctypes
|
|||
import traceback
|
||||
import numpy as np
|
||||
|
||||
class LAMMPSIntegrator(object):
|
||||
def __init__(self, ptr):
|
||||
class LAMMPSFix(object):
|
||||
def __init__(self, ptr, group_name="all"):
|
||||
self.lmp = lammps.lammps(ptr=ptr)
|
||||
self.group_name = group_name
|
||||
|
||||
class LAMMPSIntegrator(LAMMPSFix):
|
||||
def __init__(self, ptr, group_name="all"):
|
||||
super(LAMMPSIntegrator, self).__init__(ptr, group_name)
|
||||
|
||||
def init(self):
|
||||
pass
|
||||
|
@ -29,8 +34,9 @@ class LAMMPSIntegrator(object):
|
|||
|
||||
class NVE(LAMMPSIntegrator):
|
||||
""" Python implementation of fix/nve """
|
||||
def __init__(self, ptr):
|
||||
def __init__(self, ptr, group_name="all"):
|
||||
super(NVE, self).__init__(ptr)
|
||||
assert(self.group_name == "all")
|
||||
|
||||
def init(self):
|
||||
dt = self.lmp.extract_global("dt", 1)
|
||||
|
@ -66,8 +72,9 @@ class NVE(LAMMPSIntegrator):
|
|||
|
||||
class NVE_Opt(LAMMPSIntegrator):
|
||||
""" Tuned Python implementation of fix/nve """
|
||||
def __init__(self, ptr):
|
||||
def __init__(self, ptr, group_name="all"):
|
||||
super(NVE_Opt, self).__init__(ptr)
|
||||
assert(self.group_name == "all")
|
||||
|
||||
def init(self):
|
||||
dt = self.lmp.extract_global("dt", 1)
|
||||
|
|
|
@ -37,7 +37,7 @@ def get_ctypes_int(size):
|
|||
return c_int32
|
||||
elif size == 8:
|
||||
return c_int64
|
||||
return c_int
|
||||
return c_int
|
||||
|
||||
class MPIAbortException(Exception):
|
||||
def __init__(self, message):
|
||||
|
@ -266,25 +266,41 @@ class lammps(object):
|
|||
def __init__(self, lmp):
|
||||
self.lmp = lmp
|
||||
|
||||
def extract_atom_iarray(self, name, nelem, dim=1):
|
||||
if dim == 1:
|
||||
tmp = self.lmp.extract_atom(name, 0)
|
||||
ptr = cast(tmp, POINTER(c_int * nelem))
|
||||
else:
|
||||
tmp = self.lmp.extract_atom(name, 1)
|
||||
ptr = cast(tmp[0], POINTER(c_int * nelem * dim))
|
||||
def _ctype_to_numpy_int(self, ctype_int):
|
||||
if ctype_int == c_int32:
|
||||
return np.int32
|
||||
elif ctype_int == c_int64:
|
||||
return np.int64
|
||||
return np.intc
|
||||
|
||||
a = np.frombuffer(ptr.contents, dtype=np.intc)
|
||||
def extract_atom_iarray(self, name, nelem, dim=1):
|
||||
if name in ['id', 'molecule']:
|
||||
c_int_type = self.lmp.c_tagint
|
||||
elif name in ['image']:
|
||||
c_int_type = self.lmp.c_imageint
|
||||
else:
|
||||
c_int_type = c_int
|
||||
|
||||
np_int_type = self._ctype_to_numpy_int(c_int_type)
|
||||
|
||||
if dim == 1:
|
||||
tmp = self.lmp.extract_atom(name, 0)
|
||||
ptr = cast(tmp, POINTER(c_int_type * nelem))
|
||||
else:
|
||||
tmp = self.lmp.extract_atom(name, 1)
|
||||
ptr = cast(tmp[0], POINTER(c_int_type * nelem * dim))
|
||||
|
||||
a = np.frombuffer(ptr.contents, dtype=np_int_type)
|
||||
a.shape = (nelem, dim)
|
||||
return a
|
||||
|
||||
def extract_atom_darray(self, name, nelem, dim=1):
|
||||
if dim == 1:
|
||||
tmp = self.lmp.extract_atom(name, 2)
|
||||
ptr = cast(tmp, POINTER(c_double * nelem))
|
||||
tmp = self.lmp.extract_atom(name, 2)
|
||||
ptr = cast(tmp, POINTER(c_double * nelem))
|
||||
else:
|
||||
tmp = self.lmp.extract_atom(name, 3)
|
||||
ptr = cast(tmp[0], POINTER(c_double * nelem * dim))
|
||||
tmp = self.lmp.extract_atom(name, 3)
|
||||
ptr = cast(tmp[0], POINTER(c_double * nelem * dim))
|
||||
|
||||
a = np.frombuffer(ptr.contents)
|
||||
a.shape = (nelem, dim)
|
||||
|
|
Loading…
Reference in New Issue