From 89e7778497daf590a2643bb5799e4a05b95d9914 Mon Sep 17 00:00:00 2001 From: huangxinjing Date: Thu, 19 Nov 2020 17:32:40 +0800 Subject: [PATCH] Add UnsortedSegmentMax Operation --- .../ccsrc/frontend/parallel/dynamic_creator.h | 1 + .../frontend/parallel/ops_info/ops_utils.h | 1 + .../ops_info/unsorted_segment_op_info.cc | 36 ++++ .../ops_info/unsorted_segment_op_info.h | 14 ++ .../frontend/parallel/step_auto_parallel.cc | 3 +- .../parallel/test_unsortedsegmentmax.py | 162 ++++++++++++++++++ .../parallel/test_unsortedsegmentmin.py | 3 +- .../parallel/test_unsortedsegmentsum.py | 3 +- 8 files changed, 220 insertions(+), 3 deletions(-) create mode 100644 tests/ut/python/parallel/test_unsortedsegmentmax.py diff --git a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h index 1f57837fbcb..acf8ecdbd8f 100644 --- a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h @@ -179,6 +179,7 @@ REGISTER(SquareInfo); REGISTER(UniformCandidateSamplerInfo); REGISTER(UnsortedSegmentSumInfo); REGISTER(UnsortedSegmentMinInfo); +REGISTER(UnsortedSegmentMaxInfo); REGISTER(GatherV2PInfo); REGISTER(EmbeddingLookupInfo); REGISTER(TileInfo); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 3bd8d50fe69..00497b54acd 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -305,6 +305,7 @@ constexpr char UNSORTEF_SEGMENT_MIND[] = "UnsortedSegmentMinD"; constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD"; constexpr char UNSORTED_SEGMENT_SUM[] = "UnsortedSegmentSum"; constexpr char UNSORTED_SEGMENT_MIN[] = "UnsortedSegmentMin"; +constexpr char UNSORTED_SEGMENT_MAX[] = "UnsortedSegmentMax"; constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative"; constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D"; constexpr char ADD[] = "Add"; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.cc index f8316307f9e..0451cf86b36 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.cc @@ -332,5 +332,41 @@ Status UnsortedSegmentMinInfo::ComputeReplaceGraph(const CNodePtr &cnode) { return SUCCESS; } +// The UnsortedSegmentMaxInfo is almost same with UnsortedSegmentMinInfo +// Except the reduceMin op in the ComputeReplaceGraph is replaced with reduceMax op +ReplaceGraphPtr UnsortedSegmentMaxInfo::replace_graph(const CNodePtr &cnode) { + auto input_id_strategy = strategy_->GetInputDim().at(1); + // 1. the two input shapes are same, and the strategy is not all ones + if (std::any_of(input_id_strategy.begin(), input_id_strategy.end(), [](const int64_t &shard) { return shard > 1; })) { + if (ComputeReplaceGraph(cnode) != SUCCESS) { + MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed."; + } + } + return replace_graph_; +} + +Status UnsortedSegmentMaxInfo::ComputeReplaceGraph(const CNodePtr &cnode) { + GenerateGraph gen_g = GenerateGraph(); + if (gen_g.Init(cnode) != SUCCESS) { + MS_LOG(ERROR) << "GenerateGraph Init failed"; + return FAILED; + } + // Get the attributes of the UnsortedSegmentMin + auto num_segments = GetValue(input_value_.at(2)); + // Step1: Output branch + auto segment_max = gen_g.PushBack({gen_g.NewOpInst(UNSORTED_SEGMENT_MAX), gen_g.virtual_input_node(), + gen_g.virtual_input_node(), CreatInt64Imm(num_segments)}); + auto expandim_output = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), segment_max, CreatInt64Imm(0)}); + auto all_gather_output = gen_g.PushBack({gen_g.NewOpInst(ALL_GATHER), expandim_output}); + auto final_output = gen_g.PushBack({gen_g.NewOpInst(REDUCE_MAX), all_gather_output, CreatInt64Imm(0)}); + + std::vector> input_nodes = {std::make_pair(segment_max, 1), + std::make_pair(segment_max, 2)}; + replace_graph_ = std::make_shared>, AnfNodePtr>>( + std::make_pair(input_nodes, final_output)); + + return SUCCESS; +} + } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.h index 2ede42b6a03..f82f3bb0f9c 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.h @@ -79,6 +79,20 @@ class UnsortedSegmentMinInfo : public UnsortedSegmentOpInfo { Status ComputeReplaceGraph(const CNodePtr &cnode); }; +class UnsortedSegmentMaxInfo : public UnsortedSegmentOpInfo { + public: + UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + ~UnsortedSegmentMaxInfo() override = default; + + ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; + Status InferForwardCommunication() override { return SUCCESS; } + + protected: + Status ComputeReplaceGraph(const CNodePtr &cnode); +}; + } // namespace parallel } // namespace mindspore #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNSORTEDSEGMENTOP_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 08f247bc25a..abc59cdeed2 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -317,7 +317,8 @@ bool IsSplittableOperator(const std::string &op_name) { EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE, BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2, SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM, - UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE}; + UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, + UNSORTED_SEGMENT_MAX}; // clang-format on auto iter = splittable_op.find(op_name); diff --git a/tests/ut/python/parallel/test_unsortedsegmentmax.py b/tests/ut/python/parallel/test_unsortedsegmentmax.py new file mode 100644 index 00000000000..b13b83cdd06 --- /dev/null +++ b/tests/ut/python/parallel/test_unsortedsegmentmax.py @@ -0,0 +1,162 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.common.api import _executor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.ops.operations.comm_ops import _VirtualDataset +from tests.ut.python.ops.test_math_ops import VirtualLoss + +context.set_context(mode=context.GRAPH_MODE) + + +grad_all = C.GradOperation(get_all=True) + + +class Net(nn.Cell): + def __init__(self, strategy1, strategy2, num_segments): + super(Net, self).__init__() + self.virtual_dataset = _VirtualDataset() + self.merge_op = P.UnsortedSegmentMax().shard((strategy1, strategy2)) + self.num_segments = num_segments + + def construct(self, vectors, segment_ids): + predict = self.merge_op(vectors, segment_ids, self.num_segments) + return predict + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x, y): + return grad_all(self.network)(x, y) + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x, y): + predict = self.network(x, y) + return self.loss(predict) + + +def compile_graph(x, y, segments, strategy1, strategy2, auto=False): + net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments))) + net.set_auto_parallel() + net.set_train() + if auto: + context.set_auto_parallel_context(parallel_mode="auto_parallel") + else: + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + _executor.compile(net, x, y) + + +def test_UnsortedSegmentMax_model_parallel_slice_1d(): + context.set_auto_parallel_context(device_num=8, global_rank=0) + x = Tensor(np.ones(8), ms.float32) + y = Tensor(np.ones(8), ms.int32) + num_segments = 16 + strategy1 = (8,) + strategy2 = (8,) + compile_graph(x, y, num_segments, strategy1, strategy2) + + +def test_UnsortedSegmentMax_model_parallel_no_slice_1d(): + context.set_auto_parallel_context(device_num=8, global_rank=0) + x = Tensor(np.ones(8), ms.float32) + y = Tensor(np.ones(8), ms.int32) + num_segments = 16 + strategy1 = (1,) + strategy2 = (1,) + compile_graph(x, y, num_segments, strategy1, strategy2) + + +def test_UnsortedSegmentMax_model_parallel_index_slice_2d(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 8)), ms.float32) + y = Tensor(np.arange(4), ms.int32) + num_segments = 4 + strategy1 = (4, 1) + strategy2 = (4,) + compile_graph(x, y, num_segments, strategy1, strategy2) + + +def test_UnsortedSegmentMax_model_parallel_vector_slice_2d(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 8)), ms.float32) + y = Tensor(np.ones(4), ms.int32) + num_segments = 4 + strategy1 = (1, 4) + strategy2 = (1,) + compile_graph(x, y, num_segments, strategy1, strategy2) + + +def test_UnsortedSegmentMax_model_parallel_vector_slice_3d(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 8, 8)), ms.float32) + y = Tensor(np.ones(4), ms.int32) + num_segments = 4 + strategy1 = (1, 2, 2) + strategy2 = (1,) + compile_graph(x, y, num_segments, strategy1, strategy2) + + +def test_UnsortedSegmentMax_model_parallel_index_vector_slice_2d(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 8)), ms.float32) + y = Tensor(np.ones(4), ms.int32) + num_segments = 4 + strategy1 = (2, 2) + strategy2 = (2,) + compile_graph(x, y, num_segments, strategy1, strategy2) + + +def test_UnsortedSegmentMax_model_parallel_index_vector_slice_3d(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 4, 8)), ms.float32) + y = Tensor(np.ones((4)), ms.int32) + num_segments = 16 + strategy1 = (2, 1, 2) + strategy2 = (2,) + compile_graph(x, y, num_segments, strategy1, strategy2) + +def test_UnsortedSegmentMax_model_parallel_float16(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 4, 8)), ms.float16) + y = Tensor(np.ones((4)), ms.int32) + num_segments = 16 + strategy1 = (2, 1, 2) + strategy2 = (2,) + compile_graph(x, y, num_segments, strategy1, strategy2) + +def test_UnsortedSegmentMax_model_parallel_int32(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 4, 8)), ms.int32) + y = Tensor(np.ones((4)), ms.int32) + num_segments = 16 + strategy1 = (2, 1, 2) + strategy2 = (2,) + compile_graph(x, y, num_segments, strategy1, strategy2) diff --git a/tests/ut/python/parallel/test_unsortedsegmentmin.py b/tests/ut/python/parallel/test_unsortedsegmentmin.py index 2b55dff5da5..e0fbf943a14 100644 --- a/tests/ut/python/parallel/test_unsortedsegmentmin.py +++ b/tests/ut/python/parallel/test_unsortedsegmentmin.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# ============================================================================ import numpy as np diff --git a/tests/ut/python/parallel/test_unsortedsegmentsum.py b/tests/ut/python/parallel/test_unsortedsegmentsum.py index 6ea84a1467d..2f7b4648705 100644 --- a/tests/ut/python/parallel/test_unsortedsegmentsum.py +++ b/tests/ut/python/parallel/test_unsortedsegmentsum.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# ============================================================================ import numpy as np