!24379 add side-effect mark to sponge operators

Merge pull request !24379 from huangbingjian/sponge_v
This commit is contained in:
i-robot 2021-09-29 09:03:53 +00:00 committed by Gitee
commit 9cf07ea45d
2 changed files with 6 additions and 0 deletions

View File

@ -2791,6 +2791,7 @@ class MDIterationLeapFrogLiujian(PrimitiveWithInfer):
self.init_prim_io_names(
inputs=['inverse_mass', 'sqrt_mass_inverse', 'vel', 'crd', 'frc', 'acc', 'rand_state', 'rand_frc'],
outputs=['output'])
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, inverse_mass, sqrt_mass_inverse, vel, crd, frc, acc, rand_state, rand_frc):
n = self.atom_numbers

View File

@ -312,6 +312,7 @@ class RefreshCrdVel(PrimitiveWithInfer):
self.init_prim_io_names(
inputs=['crd', 'vel', 'test_frc', 'mass_inverse'],
outputs=['res'])
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, crd_shape, vel_shape, test_frc_shape, mass_inverse_shape):
cls_name = self.name
@ -625,6 +626,7 @@ class MDIterationLeapFrogLiujianWithMaxVel(PrimitiveWithInfer):
self.init_prim_io_names(
inputs=['inverse_mass', 'sqrt_mass_inverse', 'vel', 'crd', 'frc', 'acc', 'rand_state', 'rand_frc'],
outputs=['output'])
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, inverse_mass, sqrt_mass_inverse, vel, crd, frc, acc, rand_state, rand_frc):
n = self.atom_numbers
@ -1086,6 +1088,7 @@ class MDIterationLeapFrog(PrimitiveWithInfer):
self.init_prim_io_names(
inputs=['sqrt_mass_inverse', 'vel', 'crd', 'frc', 'acc', 'inverse_mass'],
outputs=['res'])
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, vel, crd, frc, acc, inverse_mass):
n = self.atom_numbers
@ -1163,6 +1166,7 @@ class MDIterationLeapFrogWithMaxVel(PrimitiveWithInfer):
self.init_prim_io_names(
inputs=['vel', 'crd', 'frc', 'acc', 'inverse_mass'],
outputs=['res'])
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, vel, crd, frc, acc, inverse_mass):
n = self.atom_numbers
@ -1228,6 +1232,7 @@ class MDIterationGradientDescent(PrimitiveWithInfer):
self.init_prim_io_names(
inputs=['crd', 'frc'],
outputs=['res'])
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, crd, frc):
n = self.atom_numbers