forked from mindspore-Ecosystem/mindspore
!12010 Convert MakeRefKey to an internal interface
From: @liuyang_655 Reviewed-by: @kingxian,@zhunaipan Signed-off-by: @kingxian
This commit is contained in:
commit
e09aaafdaf
|
@ -41,7 +41,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter,
|
|||
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
||||
TensorSummary, HistogramSummary, Print, Assert)
|
||||
from .control_ops import ControlDepend, GeSwitch, Merge
|
||||
from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign
|
||||
from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey
|
||||
|
||||
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr,
|
||||
BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub,
|
||||
|
@ -86,7 +86,7 @@ from . import _quant_ops
|
|||
from ._quant_ops import *
|
||||
from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingBoxEncode,
|
||||
ConfusionMatrix, PopulationCount,
|
||||
CheckValid, MakeRefKey, Partial, Depend, identity, CheckBprop, Push, Pull)
|
||||
CheckValid, Partial, Depend, identity, CheckBprop, Push, Pull)
|
||||
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
|
||||
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
|
||||
CusMatMulCubeDenseRight,
|
||||
|
@ -290,8 +290,8 @@ __all__ = [
|
|||
'Floor',
|
||||
'NMSWithMask',
|
||||
'IOU',
|
||||
'MakeRefKey',
|
||||
'Partial',
|
||||
'MakeRefKey',
|
||||
'Depend',
|
||||
'identity',
|
||||
'AvgPool',
|
||||
|
|
|
@ -435,7 +435,7 @@ class Receive(PrimitiveWithInfer):
|
|||
will be send by the Send op with the same "sr_tag".
|
||||
src_rank (int): A required integer identifying the source rank.
|
||||
shape (list[int]): A required list identifying the shape of the tensor to be received.
|
||||
dtype (Type): A required Type indentifying the type of the tensor to be received. The supported types:
|
||||
dtype (Type): A required Type identifying the type of the tensor to be received. The supported types:
|
||||
int8, int16, int32, float16, float32.
|
||||
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from ..._checkparam import Validator as validator
|
|||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
from ...common.dtype import tensor, dtype_to_pytype
|
||||
from ..primitive import prim_attr_register, PrimitiveWithInfer
|
||||
from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer
|
||||
|
||||
|
||||
class ScalarCast(PrimitiveWithInfer):
|
||||
|
@ -308,3 +308,52 @@ class LambApplyWeightAssign(PrimitiveWithInfer):
|
|||
args = {"w_norm": w_norm_dtype, "g_norm": g_norm_dtype, "lr": lr_dtype}
|
||||
validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True)
|
||||
return var_dtype
|
||||
|
||||
|
||||
class MakeRefKey(Primitive):
|
||||
"""
|
||||
Makes a RefKey instance by string. RefKey stores the name of Parameter, can be passed through the functions,
|
||||
and used for Assign target.
|
||||
|
||||
Args:
|
||||
tag (str): Parameter name to make the RefKey.
|
||||
|
||||
Inputs:
|
||||
No inputs.
|
||||
|
||||
Outputs:
|
||||
RefKeyType, made from the Parameter name.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Parameter, Tensor
|
||||
>>> from mindspore import dtype as mstype
|
||||
>>> import mindspore.ops as ops
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(Net, self).__init__()
|
||||
... self.y = Parameter(Tensor(np.ones([2, 3]), mstype.int32), name="y")
|
||||
... self.make_ref_key = ops.MakeRefKey("y")
|
||||
...
|
||||
... def construct(self, x):
|
||||
... key = self.make_ref_key()
|
||||
... ref = ops.make_ref(key, x, self.y)
|
||||
... return ref * x
|
||||
...
|
||||
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]), mindspore.int32)
|
||||
>>> net = Net()
|
||||
>>> output = net(x)
|
||||
>>> print(output)
|
||||
[[ 1 4 9]
|
||||
[16 25 36]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, tag):
|
||||
validator.check_value_type('tag', tag, (str,), self.name)
|
||||
|
||||
def __call__(self):
|
||||
pass
|
||||
|
|
|
@ -343,55 +343,6 @@ class IOU(PrimitiveWithInfer):
|
|||
return anchor_boxes
|
||||
|
||||
|
||||
class MakeRefKey(Primitive):
|
||||
"""
|
||||
Makes a RefKey instance by string. RefKey stores the name of Parameter, can be passed through the functions,
|
||||
and used for Assign target.
|
||||
|
||||
Args:
|
||||
tag (str): Parameter name to make the RefKey.
|
||||
|
||||
Inputs:
|
||||
No inputs.
|
||||
|
||||
Outputs:
|
||||
RefKeyType, made from the Parameter name.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Parameter, Tensor
|
||||
>>> from mindspore import dtype as mstype
|
||||
>>> import mindspore.ops as ops
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(Net, self).__init__()
|
||||
... self.y = Parameter(Tensor(np.ones([2, 3]), mstype.int32), name="y")
|
||||
... self.make_ref_key = ops.MakeRefKey("y")
|
||||
...
|
||||
... def construct(self, x):
|
||||
... key = self.make_ref_key()
|
||||
... ref = ops.make_ref(key, x, self.y)
|
||||
... return ref * x
|
||||
...
|
||||
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]), mindspore.int32)
|
||||
>>> net = Net()
|
||||
>>> output = net(x)
|
||||
>>> print(output)
|
||||
[[ 1 4 9]
|
||||
[16 25 36]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, tag):
|
||||
validator.check_value_type('tag', tag, (str,), self.name)
|
||||
|
||||
def __call__(self):
|
||||
pass
|
||||
|
||||
|
||||
class Partial(Primitive):
|
||||
"""
|
||||
Makes a partial function instance, used for pynative mode.
|
||||
|
|
Loading…
Reference in New Issue