add side-effect mark to sponge operators

This commit is contained in:
huangbingjian 2021-09-27 19:54:37 +08:00
parent fc4d5b0da5
commit fac79a10a6
2 changed files with 6 additions and 0 deletions

View File

@ -2791,6 +2791,7 @@ class MDIterationLeapFrogLiujian(PrimitiveWithInfer):
self.init_prim_io_names(
inputs=['inverse_mass', 'sqrt_mass_inverse', 'vel', 'crd', 'frc', 'acc', 'rand_state', 'rand_frc'],
outputs=['output'])
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, inverse_mass, sqrt_mass_inverse, vel, crd, frc, acc, rand_state, rand_frc):
n = self.atom_numbers

View File

@ -312,6 +312,7 @@ class RefreshCrdVel(PrimitiveWithInfer):
self.init_prim_io_names(
inputs=['crd', 'vel', 'test_frc', 'mass_inverse'],
outputs=['res'])
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, crd_shape, vel_shape, test_frc_shape, mass_inverse_shape):
cls_name = self.name
@ -625,6 +626,7 @@ class MDIterationLeapFrogLiujianWithMaxVel(PrimitiveWithInfer):
self.init_prim_io_names(
inputs=['inverse_mass', 'sqrt_mass_inverse', 'vel', 'crd', 'frc', 'acc', 'rand_state', 'rand_frc'],
outputs=['output'])
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, inverse_mass, sqrt_mass_inverse, vel, crd, frc, acc, rand_state, rand_frc):
n = self.atom_numbers
@ -1085,6 +1087,7 @@ class MDIterationLeapFrog(PrimitiveWithInfer):
self.init_prim_io_names(
inputs=['sqrt_mass_inverse', 'vel', 'crd', 'frc', 'acc', 'inverse_mass'],
outputs=['res'])
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, vel, crd, frc, acc, inverse_mass):
n = self.atom_numbers
@ -1162,6 +1165,7 @@ class MDIterationLeapFrogWithMaxVel(PrimitiveWithInfer):
self.init_prim_io_names(
inputs=['vel', 'crd', 'frc', 'acc', 'inverse_mass'],
outputs=['res'])
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, vel, crd, frc, acc, inverse_mass):
n = self.atom_numbers
@ -1227,6 +1231,7 @@ class MDIterationGradientDescent(PrimitiveWithInfer):
self.init_prim_io_names(
inputs=['crd', 'frc'],
outputs=['res'])
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, crd, frc):
n = self.atom_numbers