Merge pull request #1641 from rbberger/fix_external_via_python

Extend lib interface to allow setting fix external callbacks
This commit is contained in:
Axel Kohlmeyer 2019-08-21 10:54:59 -04:00 committed by GitHub
commit bf85bff783
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 147 additions and 3 deletions

View File

@ -0,0 +1,36 @@
# this example requires the LAMMPS Python package (lammps.py) to be installed
# and LAMMPS to be loadable as shared library in LD_LIBRARY_PATH
import lammps
def callback(caller, ntimestep, nlocal, tag, x, fext):
"""
This callback receives a caller object that was setup when registering the callback
In addition to timestep and number of local atoms, the tag and x arrays are passed as
NumPy arrays. The fext array is a force array allocated for fix external, which
can be used to apply forces to all atoms. Simply update the value in the array,
it will be directly written into the LAMMPS C arrays
"""
print("Data passed by caller (optional)", caller)
print("Timestep:", ntimestep)
print("Number of Atoms:", nlocal)
print("Atom Tags:", tag)
print("Atom Positions:", x)
print("Force Additions:", fext)
fext.fill(1.0)
print("Force additions after update:", fext)
print("="*40)
L = lammps.lammps()
L.file("in.fix_external")
# you can pass an arbitrary Python object to the callback every time it is called
# this can be useful if you need more state information such as the LAMMPS ptr to
# make additional library calls
custom_object = ["Some data", L]
L.set_fix_external_callback("2", callback, custom_object)
L.command("run 100")

View File

@ -0,0 +1,23 @@
# LAMMPS input for coupling LAMMPS with Python via fix external
units metal
dimension 3
atom_style atomic
atom_modify sort 0 0.0
lattice diamond 5.43
region box block 0 1 0 1 0 1
create_box 1 box
create_atoms 1 box
mass 1 28.08
velocity all create 300.0 87293 loop geom
fix 1 all nve
fix 2 all external pf/callback 1 1
#dump 2 all image 25 image.*.jpg type type &
# axes yes 0.8 0.02 view 60 -30
#dump_modify 2 pad 3
thermo 1

View File

@ -219,6 +219,12 @@ class lammps(object):
self.c_imageint = get_ctypes_int(self.extract_setting("imageint"))
self._installed_packages = None
# add way to insert Python callback for fix external
self.callback = {}
self.FIX_EXTERNAL_CALLBACK_FUNC = CFUNCTYPE(None, c_void_p, self.c_bigint, c_int, POINTER(self.c_tagint), POINTER(POINTER(c_double)), POINTER(POINTER(c_double)))
self.lib.lammps_set_fix_external_callback.argtypes = [c_void_p, c_char_p, self.FIX_EXTERNAL_CALLBACK_FUNC, c_void_p]
self.lib.lammps_set_fix_external_callback.restype = None
# shut-down LAMMPS instance
def __del__(self):
@ -602,6 +608,42 @@ class lammps(object):
self._installed_packages.append(sb.value.decode())
return self._installed_packages
def set_fix_external_callback(self, fix_name, callback, caller=None):
import numpy as np
def _ctype_to_numpy_int(ctype_int):
if ctype_int == c_int32:
return np.int32
elif ctype_int == c_int64:
return np.int64
return np.intc
def callback_wrapper(caller_ptr, ntimestep, nlocal, tag_ptr, x_ptr, fext_ptr):
if cast(caller_ptr,POINTER(py_object)).contents:
pyCallerObj = cast(caller_ptr,POINTER(py_object)).contents.value
else:
pyCallerObj = None
tptr = cast(tag_ptr, POINTER(self.c_tagint * nlocal))
tag = np.frombuffer(tptr.contents, dtype=_ctype_to_numpy_int(self.c_tagint))
tag.shape = (nlocal)
xptr = cast(x_ptr[0], POINTER(c_double * nlocal * 3))
x = np.frombuffer(xptr.contents)
x.shape = (nlocal, 3)
fptr = cast(fext_ptr[0], POINTER(c_double * nlocal * 3))
f = np.frombuffer(fptr.contents)
f.shape = (nlocal, 3)
callback(pyCallerObj, ntimestep, nlocal, tag, x, f)
cFunc = self.FIX_EXTERNAL_CALLBACK_FUNC(callback_wrapper)
cCaller = cast(pointer(py_object(caller)), c_void_p)
self.callback[fix_name] = { 'function': cFunc, 'caller': caller }
self.lib.lammps_set_fix_external_callback(self.lmp, fix_name.encode(), cFunc, cCaller)
# -------------------------------------------------------------------------
# -------------------------------------------------------------------------
# -------------------------------------------------------------------------
@ -872,8 +914,8 @@ class PyLammps(object):
output = self.__getattr__('run')(*args, **kwargs)
if(lammps.has_mpi4py):
output = self.lmp.comm.bcast(output, root=0)
output = self.lmp.comm.bcast(output, root=0)
self.runs += get_thermo_data(output)
return output

View File

@ -37,6 +37,7 @@
#include "error.h"
#include "force.h"
#include "info.h"
#include "fix_external.h"
#if defined(LAMMPS_EXCEPTIONS)
#include "exceptions.h"
@ -1595,7 +1596,7 @@ void lammps_create_atoms(void *ptr, int n, tagint *id, int *type,
if (lmp->atom->natoms != natoms_prev + n) {
char str[128];
sprintf(str,"Library warning in lammps_create_atoms, "
snprintf(str, 128, "Library warning in lammps_create_atoms, "
"invalid total atoms " BIGINT_FORMAT " " BIGINT_FORMAT,
lmp->atom->natoms,natoms_prev+n);
if (lmp->comm->me == 0)
@ -1605,6 +1606,40 @@ void lammps_create_atoms(void *ptr, int n, tagint *id, int *type,
END_CAPTURE
}
/* ----------------------------------------------------------------------
find fix external with given ID and set the callback function
and caller pointer
------------------------------------------------------------------------- */
void lammps_set_fix_external_callback(void *ptr, char *id, FixExternalFnPtr callback_ptr, void * caller)
{
LAMMPS *lmp = (LAMMPS *) ptr;
FixExternal::FnPtr callback = (FixExternal::FnPtr) callback_ptr;
BEGIN_CAPTURE
{
int ifix = lmp->modify->find_fix(id);
if (ifix < 0) {
char str[128];
snprintf(str, 128, "Can not find fix with ID '%s'!", id);
lmp->error->all(FLERR,str);
}
Fix *fix = lmp->modify->fix[ifix];
if (strcmp("external",fix->style) != 0){
char str[128];
snprintf(str, 128, "Fix '%s' is not of style external!", id);
lmp->error->all(FLERR,str);
}
FixExternal * fext = (FixExternal*) fix;
fext->set_callback(callback, caller);
}
END_CAPTURE
}
// ----------------------------------------------------------------------
// library API functions for accessing LAMMPS configuration
// ----------------------------------------------------------------------

View File

@ -58,6 +58,14 @@ void lammps_gather_atoms_subset(void *, char *, int, int, int, int *, void *);
void lammps_scatter_atoms(void *, char *, int, int, void *);
void lammps_scatter_atoms_subset(void *, char *, int, int, int, int *, void *);
#ifdef LAMMPS_BIGBIG
typedef void (*FixExternalFnPtr)(void *, int64_t, int, int64_t *, double **, double **);
void lammps_set_fix_external_callback(void *, char *, FixExternalFnPtr, void*);
#else
typedef void (*FixExternalFnPtr)(void *, int, int, int *, double **, double **);
void lammps_set_fix_external_callback(void *, char *, FixExternalFnPtr, void*);
#endif
int lammps_config_has_package(char * package_name);
int lammps_config_package_count();
int lammps_config_package_name(int index, char * buffer, int max_size);