whitespace

This commit is contained in:
Axel Kohlmeyer 2022-09-24 15:34:41 -04:00
parent a02ab6eaa1
commit 59ca352e48
No known key found for this signature in database
GPG Key ID: D9B44E93BF0C375A
4 changed files with 12 additions and 12 deletions

View File

@ -1622,8 +1622,8 @@ class lammps(object):
"""Return a string with detailed information about any devices that are
usable by the GPU package.
This is a wrapper around the :cpp:func:`lammps_get_gpu_device_info`
function of the C-library interface.
This is a wrapper around the :cpp:func:`lammps_get_gpu_device_info`
function of the C-library interface.
:return: GPU device info string
:rtype: string

View File

@ -29,7 +29,7 @@ from ctypes import pythonapi, c_int, c_void_p, py_object
class DynamicLoader(importlib.abc.Loader):
def __init__(self,module_name,library,api_version=1013):
self.api_version = api_version
attr = "PyInit_"+module_name
initfunc = getattr(library,attr)
# c_void_p is standin for PyModuleDef *
@ -44,7 +44,7 @@ class DynamicLoader(importlib.abc.Loader):
createfunc.restype = py_object
module = createfunc(self.module_def, spec, self.api_version)
return module
def exec_module(self, module):
execfunc = pythonapi.PyModule_ExecDef
# c_void_p is standin for PyModuleDef *
@ -59,12 +59,12 @@ def activate_mliappy(lmp):
library = lmp.lib
module_names = ["mliap_model_python_couple", "mliap_unified_couple"]
api_version = library.lammps_python_api_version()
for module_name in module_names:
# Make Machinery
loader = DynamicLoader(module_name,library,api_version)
spec = importlib.util.spec_from_loader(module_name,loader)
# Do the import
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module

View File

@ -19,16 +19,16 @@ class MLIAPUnifiedLJ(MLIAPUnified):
def compute_gradients(self, data):
"""Test compute_gradients."""
def compute_descriptors(self, data):
"""Test compute_descriptors."""
def compute_forces(self, data):
"""Test compute_forces."""
eij, fij = self.compute_pair_ef(data)
data.update_pair_energy(eij)
data.update_pair_forces(fij)
def compute_pair_ef(self, data):
rij = data.rij

View File

@ -80,10 +80,10 @@ class TorchWrapper(torch.nn.Module):
n_params : torch.nn.Module (None)
Number of NN model parameters
device : torch.nn.Module (None)
Accelerator device
dtype : torch.dtype (torch.float64)
Dtype to use on device
"""
@ -325,6 +325,6 @@ class ElemwiseModels(torch.nn.Module):
per_atom_attributes = torch.zeros(elems.size(dim=0), dtype=self.dtype)
given_elems, elem_indices = torch.unique(elems, return_inverse=True)
for i, elem in enumerate(given_elems):
self.subnets[elem].to(self.dtype)
self.subnets[elem].to(self.dtype)
per_atom_attributes[elem_indices == i] = self.subnets[elem](descriptors[elem_indices == i]).flatten()
return per_atom_attributes