forked from mindspore-Ecosystem/mindspore
!49600 fix vmap clone when prim is primc
Merge pull request !49600 from r1chardf1d0/b1
This commit is contained in:
commit
2a1ad99827
|
@ -26,7 +26,7 @@ from mindspore.ops.operations import math_ops
|
|||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops.operations import nn_ops as nps
|
||||
from mindspore.ops.function import _VmapGeneralPreprocess
|
||||
from mindspore.ops.primitive import Primitive
|
||||
from mindspore.ops.primitive import Primitive, _PrimitiveC
|
||||
from mindspore.ops.operations.random_ops import UniformCandidateSampler, RandomShuffle
|
||||
from mindspore.ops._grad.grad_base import BpropRegistry as VmapRuleRegistry
|
||||
|
||||
|
@ -418,6 +418,8 @@ def _vmap_clone_prim(prim):
|
|||
"""
|
||||
Cloning a new primitive object same as `prim`.
|
||||
"""
|
||||
if isinstance(prim, _PrimitiveC):
|
||||
return _PrimitiveC(prim.name, prim.attrs)
|
||||
new_ops = _ops_vmap_clone_prim_dict.get(prim.name, None)
|
||||
if new_ops is None:
|
||||
raise ValueError("Failed to get the primitive object of {} from `_ops_vmap_clone_prim_dict`. Please register "
|
||||
|
|
Loading…
Reference in New Issue