forked from mindspore-Ecosystem/mindspore
add vm support for ApproximateEqual, InplaceUpdateD and InTopKD.
This commit is contained in:
parent
2005ecc284
commit
51fe3501a4
|
@ -116,7 +116,8 @@ static std::map<string, string> tbe_func_adapter_map = {
|
|||
{"a_cos", "acos"},
|
||||
{"a_cos_grad", "acos_grad"},
|
||||
{"histogram_fixed_width", "histogram_fixed_width_d"},
|
||||
{"broadcast_to", "broadcast_to_d"}};
|
||||
{"broadcast_to", "broadcast_to_d"},
|
||||
{"inplace_update", "inplace_update_d"}};
|
||||
|
||||
void TbeAdapter::NormalizeFuncName(std::string *func_name) {
|
||||
if (func_name == nullptr) {
|
||||
|
|
|
@ -682,6 +682,16 @@ def get_bprop_not_equal(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.ApproximateEqual)
|
||||
def get_bprop_approximate_equal(self):
|
||||
"""Grad definition for `ApproximateEqual` operation."""
|
||||
|
||||
def bprop(x, y, out, dout):
|
||||
return zeros_like(x), zeros_like(y)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Greater)
|
||||
def get_bprop_greater(self):
|
||||
"""Grad definition for `Greater` operation."""
|
||||
|
|
|
@ -31,6 +31,7 @@ from .apply_ada_max import _apply_ada_max_tbe
|
|||
from .apply_adadelta import _apply_adadelta_tbe
|
||||
from .apply_adagrad import _apply_adagrad_tbe
|
||||
from .apply_adagrad_v2 import _apply_adagrad_v2_tbe
|
||||
from .approximate_equal import _approximate_equal_tbe
|
||||
from .adam_apply_one import _adam_apply_one_tbe
|
||||
from .assign import _assign_tbe
|
||||
from .assign_add import _assign_add_tbe
|
||||
|
@ -256,3 +257,5 @@ from .sparse_gather_v2 import _sparse_gather_v2_tbe
|
|||
from .data_format_dim_map import _data_format_dim_map_tbe
|
||||
from .histogram_fixed_width import _histogram_fixed_width_tbe
|
||||
from .tensor_scatter_update import _tensor_scatter_update_tbe
|
||||
from .inplace_update import _inplace_update_tbe
|
||||
from .in_top_k import _in_top_k_tbe
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ApproximateEqual op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
approximate_equal_op_info = TBERegOp("ApproximateEqual") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("approximate_equal.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("approximate_equal") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("broadcast") \
|
||||
.attr("tolerance", "optional", "float", "all") \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.BOOL_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.BOOL_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(approximate_equal_op_info)
|
||||
def _approximate_equal_tbe():
|
||||
"""ApproximateEqual TBE register"""
|
||||
return
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""InTopK op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
in_top_k_op_info = TBERegOp("InTopK") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("in_top_k.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("in_top_k") \
|
||||
.partial_flag(True) \
|
||||
.attr("k", "required", "int", "all") \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(in_top_k_op_info)
|
||||
def _in_top_k_tbe():
|
||||
"""InTopK TBE register"""
|
||||
return
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""InplaceUpdate op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
inplace_update_op_info = TBERegOp("InplaceUpdate") \
|
||||
.fusion_type("INPLACE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("inplace_update_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("inplace_update_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("indices", "required", "listInt", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "v", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(inplace_update_op_info)
|
||||
def _inplace_update_tbe():
|
||||
"""InplaceUpdate TBE register"""
|
||||
return
|
|
@ -30,7 +30,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
|
||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
|
||||
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
||||
SpaceToBatchND, BatchToSpaceND, BroadcastTo)
|
||||
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate)
|
||||
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
|
||||
_MirrorOperator, ReduceOp, _VirtualDataset,
|
||||
_VirtualDiv, _GetTensorSlice,
|
||||
|
@ -41,7 +41,7 @@ from .control_ops import ControlDepend, GeSwitch, Merge
|
|||
from .inner_ops import ScalarCast
|
||||
|
||||
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr,
|
||||
BitwiseXor, Inv, Invert,
|
||||
BitwiseXor, Inv, Invert, ApproximateEqual,
|
||||
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd,
|
||||
Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil,
|
||||
Acosh, Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd,
|
||||
|
@ -73,7 +73,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
|
|||
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
|
||||
ApplyProximalAdagrad, SparseApplyProximalAdagrad,
|
||||
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2,
|
||||
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell)
|
||||
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK)
|
||||
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
|
||||
CheckValid, MakeRefKey, Partial, Depend, CheckBprop, ConfusionMatrix)
|
||||
from . import _quant_ops
|
||||
|
@ -306,7 +306,10 @@ __all__ = [
|
|||
"ConfusionMatrix",
|
||||
"BroadcastTo",
|
||||
"Range",
|
||||
"DataFormatDimMap"
|
||||
"DataFormatDimMap",
|
||||
"ApproximateEqual",
|
||||
"InplaceUpdate",
|
||||
"InTopK",
|
||||
]
|
||||
|
||||
__all__.extend(_quant_ops.__all__)
|
||||
|
|
|
@ -2872,3 +2872,70 @@ class BroadcastTo(PrimitiveWithInfer):
|
|||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class InplaceUpdate(PrimitiveWithInfer):
|
||||
r"""
|
||||
Updates specified rows with values in `v`.
|
||||
|
||||
Args:
|
||||
indices (Union[int, tuple]): Indices into the left-most dimension of `x`.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - A tensor which to be inplace updated. It can be of the following data types:
|
||||
float32, float16, int32.
|
||||
- **v** (Tensor) - A tensor of the same type as `x`. Same dimension size as `x` except
|
||||
the first dimension, which must be the same as the size of `indices`.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the input `x`.
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.arange(24).reshape(3, 4, 2), mindspore.float32)
|
||||
>>> v = Tensor(np.arange(-8, 8).reshape(2, 4, 2), mindspore.float32)
|
||||
>>> inplace_update = P.InplaceUpdate((0, 2))
|
||||
>>> result = inplace_update(x, v)
|
||||
[[[-8. -7.]
|
||||
[-6. -5.]
|
||||
[-4. -3.]
|
||||
[-2. -1.]]
|
||||
[[ 8. 9.]
|
||||
[10. 11.]
|
||||
[12. 13.]
|
||||
[14. 15.]]
|
||||
[[ 0. 1.]
|
||||
[ 2. 3.]
|
||||
[ 4. 5.]
|
||||
[ 6. 7.]]]
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, indices):
|
||||
"""Init InplaceUpdate"""
|
||||
self.init_prim_io_names(inputs=['x', 'indices', 'v'], outputs=['y'])
|
||||
validator.check_value_type("indices", indices, [int, tuple], self.name)
|
||||
if isinstance(indices, int):
|
||||
self.add_prim_attr('indices', (indices,))
|
||||
for item in self.indices:
|
||||
validator.check_value_type("item of indices", item, [int], self.name)
|
||||
|
||||
def infer_dtype(self, x_dtype, v_dtype):
|
||||
valid_type = [mstype.int32, mstype.float16, mstype.float32]
|
||||
validator.check_tensor_type_same(
|
||||
{
|
||||
"x": x_dtype,
|
||||
"v": v_dtype
|
||||
}, valid_type, self.name)
|
||||
|
||||
return x_dtype
|
||||
|
||||
def infer_shape(self, x_shape, v_shape):
|
||||
validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
|
||||
|
||||
x_rank = len(x_shape)
|
||||
for idx in range(x_rank)[1:]:
|
||||
validator.check("x dim %d" % idx, x_shape[idx], 'v dim %d' % idx, v_shape[idx], Rel.EQ, self.name)
|
||||
|
||||
validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
|
||||
Rel.EQ, self.name)
|
||||
|
||||
return x_shape
|
||||
|
|
|
@ -1668,6 +1668,44 @@ class Equal(_LogicBinaryOp):
|
|||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.name)
|
||||
|
||||
|
||||
class ApproximateEqual(_LogicBinaryOp):
|
||||
"""
|
||||
Returns the truth value of abs(x1-x2) < tolerance element-wise.
|
||||
|
||||
Args:
|
||||
tolerance (float): The maximum deviation that two elements can be considered equal. Default: 1e-05.
|
||||
|
||||
Inputs:
|
||||
- **x1** (Tensor) - A tensor. Must be one of the following types: float32, float16.
|
||||
- **x2** (Tensor) - A tensor of the same type and shape as 'x1'.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is same as the shape of 'x1', and the data type is bool.
|
||||
|
||||
Examples:
|
||||
>>> x1 = Tensor(np.array([1, 2, 3]), mindspore.float32)
|
||||
>>> x2 = Tensor(np.array([2, 4, 6]), mindspore.float32)
|
||||
>>> approximate_equal = P.ApproximateEqual(2.)
|
||||
>>> result = approximate_equal(x1, x2)
|
||||
[True True False]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, tolerance=1e-05):
|
||||
"""Init ApproximateEqual"""
|
||||
validator.check_value_type("tolerance", tolerance, [float], self.name)
|
||||
|
||||
def infer_shape(self, x_shape, y_shape):
|
||||
validator.check("x_shape", x_shape, "y_shape", y_shape, Rel.EQ, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
args_dtype = {"x": x_dtype, "y": y_dtype}
|
||||
valid_type = [mstype.float32, mstype.float16]
|
||||
validator.check_tensor_type_same(args_dtype, valid_type, prim_name=self.name)
|
||||
return mstype.tensor_type(mstype.bool_)
|
||||
|
||||
|
||||
class EqualCount(PrimitiveWithInfer):
|
||||
"""
|
||||
Computes the number of the same elements of two tensors.
|
||||
|
|
|
@ -4074,3 +4074,44 @@ class BasicLSTMCell(PrimitiveWithInfer):
|
|||
validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_type_name("b", b_dtype, [mstype.float16, mstype.float32], self.name)
|
||||
return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype)
|
||||
|
||||
|
||||
class InTopK(PrimitiveWithInfer):
|
||||
r"""
|
||||
Says whether the targets are in the top `k` predictions.
|
||||
|
||||
Args:
|
||||
k (int): Special the number of top elements to look at for computing precision.
|
||||
|
||||
Inputs:
|
||||
- **x1** (Tensor) - A 2D Tensor define the predictions of a batch of samples with float32 data type.
|
||||
- **x2** (Tensor) - A 1D Tensor define the labels of a batch of samples with int32 data type.
|
||||
|
||||
Outputs:
|
||||
Tensor, which is 1 dimension of type bool and has same shape with `x2`. for label of sample `i` in `x2`,
|
||||
if label in first `k` predictions for sample `i` in `x1`, then the value is True else False.
|
||||
|
||||
Examples:
|
||||
>>> x1 = Tensor(np.array([[1, 8, 5, 2, 7], [4, 9, 1, 3, 5]]), mindspore.float32)
|
||||
>>> x2 = Tensor(np.array([1, 3]), mindspore.int32)
|
||||
>>> in_top_k = P.InTopK(3)
|
||||
>>> result = in_top_k(x1, x2)
|
||||
[True False]
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, k):
|
||||
"""Init InTopK"""
|
||||
self.init_prim_io_names(inputs=['x1', 'x2', 'k'], outputs=['y'])
|
||||
validator.check_value_type("k", k, [int], self.name)
|
||||
|
||||
def infer_dtype(self, x1_dtype, x2_dtype):
|
||||
validator.check_tensor_type_same({"x1": x1_dtype}, (mstype.float32,), self.name)
|
||||
validator.check_tensor_type_same({"x2": x2_dtype}, (mstype.int32,), self.name)
|
||||
|
||||
return mstype.tensor_type(mstype.bool_)
|
||||
|
||||
def infer_shape(self, x1_shape, x2_shape):
|
||||
validator.check("x1", len(x1_shape), "", 2, Rel.EQ, self.name)
|
||||
validator.check("x2", len(x2_shape), "", 1, Rel.EQ, self.name)
|
||||
validator.check("size of x2", x2_shape[0], "x1's first dimension", x1_shape[0], Rel.EQ, self.name)
|
||||
return x2_shape
|
||||
|
|
|
@ -671,6 +671,10 @@ test_case_math_ops = [
|
|||
'desc_inputs': [1, [2, 3, 4, 5]],
|
||||
'desc_bprop': [Tensor(np.ones((2, 3, 4, 5), np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
('ApproximateEqual', {
|
||||
'block': P.ApproximateEqual(),
|
||||
'desc_inputs': [[3, 4, 5], [3, 4, 5]],
|
||||
'desc_bprop': [Tensor(np.zeros((3, 4, 5), np.bool_))]}),
|
||||
('Greater', {
|
||||
'block': P.Greater(),
|
||||
'desc_inputs': [[2, 3, 4, 1], [4, 5]],
|
||||
|
@ -1526,6 +1530,18 @@ test_case_array_ops = [
|
|||
'block': P.BroadcastTo((2,3)),
|
||||
'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.float32))],
|
||||
'desc_bprop': [Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.float32))]}),
|
||||
('InTopK', {
|
||||
'block': P.InTopK(2),
|
||||
'desc_inputs': [Tensor(np.array([[1, 2, 3], [2, 3, 6], [4, 2, 1]]).astype(np.float32)),
|
||||
Tensor(np.array([2, 1, 2]).astype(np.int32))],
|
||||
'skip': ['backward'],
|
||||
}),
|
||||
('InplaceUpdate', {
|
||||
'block': P.InplaceUpdate((0, 2)),
|
||||
'desc_inputs': [Tensor(np.arange(24).reshape(3, 4, 2).astype(np.float32)),
|
||||
Tensor(np.arange(16).reshape(2, 4, 2).astype(np.float32))],
|
||||
'skip': ['backward'],
|
||||
}),
|
||||
]
|
||||
|
||||
test_case_other_ops = [
|
||||
|
|
Loading…
Reference in New Issue