!49600 fix vmap clone when prim is primc

Merge pull request !49600 from r1chardf1d0/b1
This commit is contained in:
i-robot 2023-03-02 07:06:22 +00:00 committed by Gitee
commit 2a1ad99827
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 3 additions and 1 deletions

View File

@ -26,7 +26,7 @@ from mindspore.ops.operations import math_ops
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.operations import nn_ops as nps
from mindspore.ops.function import _VmapGeneralPreprocess
from mindspore.ops.primitive import Primitive
from mindspore.ops.primitive import Primitive, _PrimitiveC
from mindspore.ops.operations.random_ops import UniformCandidateSampler, RandomShuffle
from mindspore.ops._grad.grad_base import BpropRegistry as VmapRuleRegistry
@ -418,6 +418,8 @@ def _vmap_clone_prim(prim):
"""
Cloning a new primitive object same as `prim`.
"""
if isinstance(prim, _PrimitiveC):
return _PrimitiveC(prim.name, prim.attrs)
new_ops = _ops_vmap_clone_prim_dict.get(prim.name, None)
if new_ops is None:
raise ValueError("Failed to get the primitive object of {} from `_ops_vmap_clone_prim_dict`. Please register "