From ac8699674611040e5397c4dba2e89f9086c6a1b5 Mon Sep 17 00:00:00 2001 From: buxue Date: Wed, 6 May 2020 14:54:08 +0800 Subject: [PATCH] develop op ScatterMax and dock ge process --- mindspore/ccsrc/transform/convert.cc | 2 + mindspore/ccsrc/transform/op_declare.cc | 5 + mindspore/ccsrc/transform/op_declare.h | 2 + mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/array_ops.py | 49 +++++++++- tests/ut/python/ops/test_ops.py | 121 ++++++++++++++---------- 6 files changed, 128 insertions(+), 54 deletions(-) diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index 0171752dfb0..4a352bf9d29 100644 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -102,6 +102,7 @@ const char kNameReLU6Grad[] = "ReLU6Grad"; const char kNameElu[] = "Elu"; const char kNameEluGrad[] = "EluGrad"; const char kNameScatterNdUpdate[] = "ScatterNdUpdate"; +const char kNameScatterMax[] = "ScatterMax"; const char kNameNMSWithMask[] = "NMSWithMask"; const char kNameCheckValid[] = "CheckValid"; const char kNameSmoothL1Loss[] = "SmoothL1Loss"; @@ -253,6 +254,7 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameZerosLike), ADPT_DESC(ZerosLike)}, {string(kNameOnesLike), ADPT_DESC(OnesLike)}, {string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)}, + {string(kNameScatterMax), ADPT_DESC(ScatterMax)}, {string(kNameNMSWithMask), ADPT_DESC(NMSWithMask)}, {string(kNameCheckValid), ADPT_DESC(CheckValid)}, {string(kNameSmoothL1Loss), ADPT_DESC(SmoothL1Loss)}, diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 27c1d306aaa..8b4046a361d 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -530,6 +530,11 @@ INPUT_MAP(ScatterNdUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3 ATTR_MAP(ScatterNdUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; OUTPUT_MAP(ScatterNdUpdate) = {{0, OUTPUT_DESC(var)}}; +// ScatterMax +INPUT_MAP(ScatterMax) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; +ATTR_MAP(ScatterMax) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ScatterMax) = {{0, OUTPUT_DESC(var)}}; + // CheckValid INPUT_MAP(CheckValid) = {{1, INPUT_DESC(bbox_tensor)}, {2, INPUT_DESC(img_metas)}}; ATTR_MAP(CheckValid) = EMPTY_ATTR_MAP; diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h index d15a664256e..96b828ac539 100755 --- a/mindspore/ccsrc/transform/op_declare.h +++ b/mindspore/ccsrc/transform/op_declare.h @@ -136,6 +136,8 @@ DECLARE_OP_ADAPTER(OnesLike) DECLARE_OP_USE_OUTPUT(OnesLike) DECLARE_OP_ADAPTER(ScatterNdUpdate) DECLARE_OP_USE_OUTPUT(ScatterNdUpdate) +DECLARE_OP_ADAPTER(ScatterMax) +DECLARE_OP_USE_OUTPUT(ScatterMax) DECLARE_OP_ADAPTER(NMSWithMask) DECLARE_OP_USE_OUTPUT(NMSWithMask) DECLARE_OP_ADAPTER(Unpack) diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index d83f5accd09..acc0de1d549 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -24,7 +24,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Fill, GatherNd, GatherV2, InvertPermutation, IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, - SameTypeShape, + SameTypeShape, ScatterMax, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, Shape, Size, Slice, Split, Squeeze, StridedSlice, Tile, @@ -184,6 +184,7 @@ __all__ = [ 'BoundingBoxDecode', 'L2Normalize', 'ScatterNd', + 'ScatterMax', 'ResizeNearestNeighbor', 'Pad', 'MirrorPad', diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index aca87cab669..3d5071c8ba4 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1953,7 +1953,7 @@ class ScatterNdUpdate(PrimitiveWithInfer): Using given values to update tensor value, along with the input indices. Args: - use_locking (bool): Whether protect the assignment by a lock. Defaule: True. + use_locking (bool): Whether protect the assignment by a lock. Default: True. Inputs: - **input_x** (Tensor) - The target tensor. @@ -1995,6 +1995,53 @@ class ScatterNdUpdate(PrimitiveWithInfer): return x_dtype +class ScatterMax(PrimitiveWithInfer): + """ + Update the value of the input tensor through the max operation. + + Using given values to update tensor value through the max operation, along with the input indices,. + + Args: + use_locking (bool): Whether protect the assignment by a lock. Default: True. + + Inputs: + - **input_x** (Tensor) - The target tensor. + - **indices** (Tensor) - The index to do max operation whose data type should be int. + - **updates** (Tensor) - The tensor doing the maximum operation with 'input_x', + the data type is same as 'input_x', the shape is 'indices_shape + x_shape[1:]'. + + Outputs: + Tensor, has the same shape and data type as `input_x`. + + Examples: + >>> input_x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32) + >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) + >>> update = Tensor(np.ones([2, 2, 3]) * 88, mindspore.float32) + >>> scatter_max = P.ScatterMax() + >>> output = scatter_max(input_x, indices, update) + [[88.0, 88.0, 88.0], [88.0, 88.0, 88.0]] + """ + + @prim_attr_register + def __init__(self, use_locking=True): + """Init ScatterMax""" + self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) + validator.check_value_type('use_locking', use_locking, (bool,), self.name) + + def infer_shape(self, x_shape, indices_shape, updates_shape): + if updates_shape and updates_shape != indices_shape + x_shape[1:]: + raise ValueError(f"For '{self.name}', the shape of update should be [] or " + f"update_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " + f"indices_shape: {indices_shape}, update_shape: {updates_shape}.") + return x_shape + + def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): + validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name) + args = {"x": x_dtype, "updates": updates_dtype} + validator.check_tensor_type_same(args, mstype.number_type, self.name) + return x_dtype + + class SpaceToDepth(PrimitiveWithInfer): r""" Rearrange blocks of spatial data into depth. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 7a3d7d967f8..c28786359a3 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -15,7 +15,7 @@ """ test ops """ import functools import numpy as np -from mindspore import ops +from mindspore import ops, Parameter, context from mindspore.ops import functional as F from mindspore.ops import operations as P from mindspore.ops.operations import _grad_ops as G @@ -26,10 +26,10 @@ from mindspore.common import dtype as mstype from ..ut_filter import non_graph_engine from ....mindspore_test_framework.mindspore_test import mindspore_test -from ....mindspore_test_framework.pipeline.forward.compile_forward\ +from ....mindspore_test_framework.pipeline.forward.compile_forward \ import (pipeline_for_compile_forward_ge_graph_for_case_by_case_config, pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception) -from ....mindspore_test_framework.pipeline.gradient.compile_gradient\ +from ....mindspore_test_framework.pipeline.gradient.compile_gradient \ import pipeline_for_compile_grad_ge_graph_for_case_by_case_config @@ -150,7 +150,7 @@ class CumSumNet(nn.Cell): class SummaryNet(nn.Cell): - def __init__(self,): + def __init__(self): super(SummaryNet, self).__init__() self.s = P.ScalarSummary() self.add = P.TensorAdd() @@ -161,7 +161,7 @@ class SummaryNet(nn.Cell): class HistogramSummaryNet(nn.Cell): - def __init__(self,): + def __init__(self): super(HistogramSummaryNet, self).__init__() self.summary = P.HistogramSummary() self.add = P.TensorAdd() @@ -173,6 +173,19 @@ class HistogramSummaryNet(nn.Cell): return out +class ScatterMax(nn.Cell): + """ScatterMax net definition""" + + def __init__(self): + super(ScatterMax, self).__init__() + self.scatter_max = P.ScatterMax() + self.ref = Parameter(Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], np.float32)), name="ref") + + def construct(self, indices, updates): + out = self.scatter_max(self.ref, indices, updates) + return out + + test_case_math_ops = [ ('Neg', { 'block': P.Neg(), @@ -298,28 +311,28 @@ test_case_math_ops = [ ('StridedSlice', { 'block': P.StridedSlice(), 'desc_const': [(0, 1, 2, 1), - (2, 3, 3, 4), - (1, 1, 1, 1)], + (2, 3, 3, 4), + (1, 1, 1, 1)], 'desc_inputs': [[2, 3, 3, 5]], 'desc_bprop': [[2, 2, 1, 3]]}), ('Slice_1', { 'block': P.Slice(), 'desc_const': [(0, 1, 2, 1), - (1, 1, 1, 2)], + (1, 1, 1, 2)], 'desc_inputs': [[2, 3, 3, 5]], 'desc_bprop': [[1, 1, 1, 2]]}), ('StridedSliceGrad', { 'block': G.StridedSliceGrad(), 'desc_const': [(64, 1, 1024), - (0, 1, 0), - (64, 2, 1024), - (1, 1, 1)], + (0, 1, 0), + (64, 2, 1024), + (1, 1, 1)], 'desc_inputs': [[64, 128, 1024]], 'skip': ['backward']}), ('RandomChoiceWithMask', { 'block': P.RandomChoiceWithMask(256), 'desc_inputs': [Tensor(np.random.rand(24000, 4).astype(np.bool_))], - 'desc_bprop': [[256,4], [256,4]], + 'desc_bprop': [[256, 4], [256, 4]], 'skip': ['backward']}), ('LessEqual', { 'block': P.LessEqual(), @@ -419,7 +432,7 @@ test_case_math_ops = [ 'desc_bprop': [Tensor(np.ones((2, 3, 4, 5), np.bool_))]}), ('NotEqual_0', { 'block': P.NotEqual(), - 'desc_inputs': [ 1, [2, 3, 4, 5]], + 'desc_inputs': [1, [2, 3, 4, 5]], 'desc_bprop': [Tensor(np.ones((2, 3, 4, 5), np.bool_))], 'skip': ['backward']}), ('Greater', { @@ -435,13 +448,13 @@ test_case_math_ops = [ 'desc_inputs': [Tensor(np.zeros((3, 4, 5), np.bool_))], 'desc_bprop': [Tensor(np.ones((3, 4, 5), np.bool_))]}), ('LogicalAnd', { - 'block': P.LogicalAnd(), - 'desc_inputs': [Tensor(np.zeros((2, 3, 4), np.bool_)), Tensor(np.ones((1), np.bool_))], - 'desc_bprop': [Tensor(np.zeros((2, 3, 4), np.bool_))]}), + 'block': P.LogicalAnd(), + 'desc_inputs': [Tensor(np.zeros((2, 3, 4), np.bool_)), Tensor(np.ones((1), np.bool_))], + 'desc_bprop': [Tensor(np.zeros((2, 3, 4), np.bool_))]}), ('LogicalOr', { - 'block': P.LogicalOr(), - 'desc_inputs': [Tensor(np.zeros((3, 4, 5), np.bool_)), Tensor(np.ones((3, 1, 1), np.bool_))], - 'desc_bprop': [Tensor(np.zeros((3, 4, 5), np.bool_))]}), + 'block': P.LogicalOr(), + 'desc_inputs': [Tensor(np.zeros((3, 4, 5), np.bool_)), Tensor(np.ones((3, 1, 1), np.bool_))], + 'desc_bprop': [Tensor(np.zeros((3, 4, 5), np.bool_))]}), ('NpuAllocFloatStatus', { 'block': P.NPUAllocFloatStatus(), 'desc_inputs': [], @@ -476,8 +489,8 @@ test_case_math_ops = [ ('CumSum', { 'block': P.CumSum(), 'desc_const': [0], - 'desc_inputs': [Tensor(np.array([[3, 4],[1, 6]]).astype(np.float16))], - 'desc_bprop': [Tensor(np.array([[3, 4],[4, 10]]).astype(np.float16))]}), + 'desc_inputs': [Tensor(np.array([[3, 4], [1, 6]]).astype(np.float16))], + 'desc_bprop': [Tensor(np.array([[3, 4], [4, 10]]).astype(np.float16))]}), ('ReduceSum_3', { 'block': P.ReduceSum(), 'desc_const': [0], @@ -717,8 +730,8 @@ test_case_nn_ops = [ ('UnsortedSegmentSum', { 'block': P.UnsortedSegmentSum(), 'desc_const': [1280], - 'desc_inputs': [[1280,1024], Tensor(np.ones(1280).astype(np.int32))], - 'desc_bprop': [[8192,1024]], + 'desc_inputs': [[1280, 1024], Tensor(np.ones(1280).astype(np.int32))], + 'desc_bprop': [[8192, 1024]], 'skip': ['backward']}), ('UnsortedSegmentSum_1', { 'block': P.UnsortedSegmentSum(), @@ -821,19 +834,20 @@ test_case_nn_ops = [ 'skip': ['backward']}), ('ArgmaxNet', { 'block': ArgmaxNet(), - 'desc_inputs': [Tensor(np.array([[128, 32, 32, 64],[128, 32, 32, 64]]).astype(np.float16))], - 'desc_bprop': [Tensor(np.array([[128, 32, 32, 64],[128, 32, 32, 64]]).astype(np.float16))], + 'desc_inputs': [Tensor(np.array([[128, 32, 32, 64], [128, 32, 32, 64]]).astype(np.float16))], + 'desc_bprop': [Tensor(np.array([[128, 32, 32, 64], [128, 32, 32, 64]]).astype(np.float16))], 'skip': ['backward']}), ('ArgminNet', { 'block': ArgminNet(), - 'desc_inputs': [Tensor(np.array([[128, 32, 32, 64],[128, 32, 32, 64]]).astype(np.float16))], - 'desc_bprop': [Tensor(np.array([[128, 32, 32, 64],[128, 32, 32, 64]]).astype(np.float16))], + 'desc_inputs': [Tensor(np.array([[128, 32, 32, 64], [128, 32, 32, 64]]).astype(np.float16))], + 'desc_bprop': [Tensor(np.array([[128, 32, 32, 64], [128, 32, 32, 64]]).astype(np.float16))], 'skip': ['backward']}), ('CumSumNet', { 'block': CumSumNet(), 'desc_const': [0], - 'desc_inputs': [Tensor(np.array([[3, 4, 6, 10],[1, 6, 7, 9],[4, 3, 8, 7],[1, 3, 7, 9]]).astype(np.float16))], - 'desc_bprop': [Tensor(np.array([[3, 4, 6, 10],[1, 6, 7, 9],[4, 3, 8, 7],[1, 3, 7, 9]]).astype(np.float16))]}), + 'desc_inputs': [Tensor(np.array([[3, 4, 6, 10], [1, 6, 7, 9], [4, 3, 8, 7], [1, 3, 7, 9]]).astype(np.float16))], + 'desc_bprop': [ + Tensor(np.array([[3, 4, 6, 10], [1, 6, 7, 9], [4, 3, 8, 7], [1, 3, 7, 9]]).astype(np.float16))]}), ('OneHot', { 'block': P.OneHot(), 'desc_const': [3, Tensor(1.0, mstype.float32), Tensor(0.0, mstype.float32)], @@ -1021,31 +1035,31 @@ test_case_array_ops = [ 'desc_inputs': [(Tensor(np.array([1], np.float32)), Tensor(np.array([1], np.float32)), Tensor(np.array([1], np.float32)))], - 'desc_bprop': [[3,]]}), + 'desc_bprop': [[3, ]]}), ('Pack_0', { 'block': NetForPackInput(P.Pack()), - 'desc_inputs':[[2, 2], [2, 2], [2, 2]], - 'desc_bprop':[[3, 2, 2]], + 'desc_inputs': [[2, 2], [2, 2], [2, 2]], + 'desc_bprop': [[3, 2, 2]], }), ('Pack_1', { 'block': NetForPackInput(P.Pack(axis=-2)), - 'desc_inputs':[[3, 2, 3], [3, 2, 3], [3, 2, 3]], - 'desc_bprop':[[3, 2, 3, 3]], + 'desc_inputs': [[3, 2, 3], [3, 2, 3], [3, 2, 3]], + 'desc_bprop': [[3, 2, 3, 3]], }), ('Pack_2', { 'block': NetForPackInput(P.Pack()), - 'desc_inputs':[[128, 128], [128, 128]], - 'desc_bprop':[[2, 128, 128]], + 'desc_inputs': [[128, 128], [128, 128]], + 'desc_bprop': [[2, 128, 128]], }), ('Unpack_0', { 'block': NetForUnpackInput(P.Unpack(axis=0)), - 'desc_inputs':[[2, 4]], - 'desc_bprop':[[4], [4]], + 'desc_inputs': [[2, 4]], + 'desc_bprop': [[4], [4]], }), ('Unpack_1', { 'block': NetForUnpackInput(P.Unpack(axis=-1)), - 'desc_inputs':[Tensor(np.array([[1, 1, 1]], np.float32))], - 'desc_bprop':[[1], [1], [1]], + 'desc_inputs': [Tensor(np.array([[1, 1, 1]], np.float32))], + 'desc_bprop': [[1], [1], [1]], }), ('Diag_1', { 'block': P.Diag(), @@ -1117,6 +1131,11 @@ test_case_other_ops = [ 'desc_inputs': (Tensor(np.ones((2, 2), np.int32)), Tensor(np.ones((2,), np.int32))), 'desc_bprop': [([3, 3], {'dtype': np.int32})]}), + ('ScatterMax', { + 'block': ScatterMax(), + 'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)), + Tensor(np.ones([2, 2, 3], np.float32) * 99)), + 'skip': ['backward']}), ('SmoothL1Loss', { 'block': P.SmoothL1Loss(), 'desc_inputs': [[256, 4], [256, 4]], @@ -1131,17 +1150,17 @@ test_case_other_ops = [ Tensor(np.array([1.2]).astype(np.float32))], 'skip': ['backward']}), ('ConfusionMulGrad_1', { - 'block': P.ConfusionMulGrad(axis = [0], keep_dims = False), + 'block': P.ConfusionMulGrad(axis=[0], keep_dims=False), 'desc_inputs': [[3, 2], [3, 2], [3, 2]], 'desc_bprop': [[3, 2], [2]], 'skip': ['backward']}), ('ConfusionMulGrad_2', { - 'block': P.ConfusionMulGrad(axis = [0], keep_dims = True), + 'block': P.ConfusionMulGrad(axis=[0], keep_dims=True), 'desc_inputs': [[3, 2], [3, 2], [3, 2]], 'desc_bprop': [[3, 2], [1, 2]], 'skip': ['backward']}), ('ConfusionMulGrad_3', { - 'block': P.ConfusionMulGrad(axis = (), keep_dims = True), + 'block': P.ConfusionMulGrad(axis=(), keep_dims=True), 'desc_inputs': [[2, 3, 4], [2, 3, 4], [2, 3, 4]], 'desc_bprop': [[2, 3, 4], [1, 1, 1]], 'skip': ['backward']}), @@ -1150,7 +1169,7 @@ test_case_other_ops = [ 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), Tensor(np.array([1.2]).astype(np.float32))], 'skip': ['backward']}), - + ] test_case_lists = [test_case_nn_ops, test_case_math_ops, test_case_array_ops, test_case_other_ops] @@ -1162,15 +1181,13 @@ test_case = functools.reduce(lambda x, y: x + y, test_case_lists) test_exec_case = test_case test_backward_exec_case = filter(lambda x: 'skip' not in x[1] or - 'backward' not in x[1]['skip'], test_case) + 'backward' not in x[1]['skip'], test_case) -import mindspore.context as context - @non_graph_engine @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) def test_exec(): - context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + context.set_context(mode=context.GRAPH_MODE) return test_exec_case @@ -1207,12 +1224,12 @@ raise_set = [ 'desc_bprop': [[2, 3]]}), ('Pack', { 'block': (NetForPackInput(P.Pack()), {'exception': ValueError}), - 'desc_inputs':[[2, 2]], - 'desc_bprop':[[1, 2, 2]]}), + 'desc_inputs': [[2, 2]], + 'desc_bprop': [[1, 2, 2]]}), ('PReLU', { 'block': (P.PReLU(), {'exception': ValueError}), - 'desc_inputs':[[2], [1]], - 'desc_bprop':[[1]]}), + 'desc_inputs': [[2], [1]], + 'desc_bprop': [[1]]}), ]