forked from mindspore-Ecosystem/mindspore
!24379 add side-effect mark to sponge operators
Merge pull request !24379 from huangbingjian/sponge_v
This commit is contained in:
commit
9cf07ea45d
|
@ -2791,6 +2791,7 @@ class MDIterationLeapFrogLiujian(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(
|
||||
inputs=['inverse_mass', 'sqrt_mass_inverse', 'vel', 'crd', 'frc', 'acc', 'rand_state', 'rand_frc'],
|
||||
outputs=['output'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
def infer_shape(self, inverse_mass, sqrt_mass_inverse, vel, crd, frc, acc, rand_state, rand_frc):
|
||||
n = self.atom_numbers
|
||||
|
|
|
@ -312,6 +312,7 @@ class RefreshCrdVel(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(
|
||||
inputs=['crd', 'vel', 'test_frc', 'mass_inverse'],
|
||||
outputs=['res'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
def infer_shape(self, crd_shape, vel_shape, test_frc_shape, mass_inverse_shape):
|
||||
cls_name = self.name
|
||||
|
@ -625,6 +626,7 @@ class MDIterationLeapFrogLiujianWithMaxVel(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(
|
||||
inputs=['inverse_mass', 'sqrt_mass_inverse', 'vel', 'crd', 'frc', 'acc', 'rand_state', 'rand_frc'],
|
||||
outputs=['output'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
def infer_shape(self, inverse_mass, sqrt_mass_inverse, vel, crd, frc, acc, rand_state, rand_frc):
|
||||
n = self.atom_numbers
|
||||
|
@ -1086,6 +1088,7 @@ class MDIterationLeapFrog(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(
|
||||
inputs=['sqrt_mass_inverse', 'vel', 'crd', 'frc', 'acc', 'inverse_mass'],
|
||||
outputs=['res'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
def infer_shape(self, vel, crd, frc, acc, inverse_mass):
|
||||
n = self.atom_numbers
|
||||
|
@ -1163,6 +1166,7 @@ class MDIterationLeapFrogWithMaxVel(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(
|
||||
inputs=['vel', 'crd', 'frc', 'acc', 'inverse_mass'],
|
||||
outputs=['res'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
def infer_shape(self, vel, crd, frc, acc, inverse_mass):
|
||||
n = self.atom_numbers
|
||||
|
@ -1228,6 +1232,7 @@ class MDIterationGradientDescent(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(
|
||||
inputs=['crd', 'frc'],
|
||||
outputs=['res'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
def infer_shape(self, crd, frc):
|
||||
n = self.atom_numbers
|
||||
|
|
Loading…
Reference in New Issue