handle the case where the variable type is invalid and thus a null pointer is returned

This commit is contained in:
Axel Kohlmeyer 2020-10-12 06:16:28 -04:00
parent 6cda1e16ae
commit 8c5da70823
No known key found for this signature in database
GPG Key ID: D9B44E93BF0C375A
1 changed files with 7 additions and 4 deletions

View File

@ -1114,7 +1114,7 @@ class lammps(object):
after the data is copied to a Python variable or list. after the data is copied to a Python variable or list.
The variable must be either an equal-style (or equivalent) The variable must be either an equal-style (or equivalent)
variable or an atom-style variable. The variable type has to variable or an atom-style variable. The variable type has to
provided as ``vartype`` parameter which may be two constants: provided as ``vartype`` parameter which may be one of two constants:
``LMP_VAR_EQUAL`` or ``LMP_VAR_STRING``; it defaults to ``LMP_VAR_EQUAL`` or ``LMP_VAR_STRING``; it defaults to
equal-style variables. equal-style variables.
The group parameter is only used for atom-style variables and The group parameter is only used for atom-style variables and
@ -1135,7 +1135,8 @@ class lammps(object):
if vartype == LMP_VAR_EQUAL: if vartype == LMP_VAR_EQUAL:
self.lib.lammps_extract_variable.restype = POINTER(c_double) self.lib.lammps_extract_variable.restype = POINTER(c_double)
ptr = self.lib.lammps_extract_variable(self.lmp,name,group) ptr = self.lib.lammps_extract_variable(self.lmp,name,group)
result = ptr[0] if ptr: result = ptr[0]
else: return None
self.lib.lammps_free(ptr) self.lib.lammps_free(ptr)
return result return result
elif vartype == LMP_VAR_ATOM: elif vartype == LMP_VAR_ATOM:
@ -1143,8 +1144,10 @@ class lammps(object):
result = (c_double*nlocal)() result = (c_double*nlocal)()
self.lib.lammps_extract_variable.restype = POINTER(c_double) self.lib.lammps_extract_variable.restype = POINTER(c_double)
ptr = self.lib.lammps_extract_variable(self.lmp,name,group) ptr = self.lib.lammps_extract_variable(self.lmp,name,group)
for i in range(nlocal): result[i] = ptr[i] if ptr:
self.lib.lammps_free(ptr) for i in range(nlocal): result[i] = ptr[i]
self.lib.lammps_free(ptr)
else: return None
return result return result
return None return None