From 36ffb66782175388dbe200e74cc0cb61a1881db1 Mon Sep 17 00:00:00 2001 From: yangzhenzhang <285824651@qq.com> Date: Thu, 16 Apr 2020 21:52:18 +0800 Subject: [PATCH] add parallel op for square --- mindspore/ccsrc/parallel/dynamic_creator.h | 1 + .../ccsrc/parallel/ops_info/activation_info.h | 8 ++ mindspore/ccsrc/parallel/ops_info/ops_utils.h | 3 +- .../ccsrc/parallel/step_auto_parallel.cc | 7 +- tests/ut/python/parallel/test_square.py | 85 +++++++++++++++++++ 5 files changed, 100 insertions(+), 4 deletions(-) create mode 100644 tests/ut/python/parallel/test_square.py diff --git a/mindspore/ccsrc/parallel/dynamic_creator.h b/mindspore/ccsrc/parallel/dynamic_creator.h index 723c018d7f2..bad947687d4 100644 --- a/mindspore/ccsrc/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/parallel/dynamic_creator.h @@ -128,6 +128,7 @@ REGISTER(BatchMatMulInfo); REGISTER(ExpandDimsInfo); REGISTER(SqueezeInfo); REGISTER(SigmoidCrossEntropyWithLogitsInfo); +REGISTER(SquareInfo); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/activation_info.h b/mindspore/ccsrc/parallel/ops_info/activation_info.h index 8dca036f9e8..887be5ea33b 100644 --- a/mindspore/ccsrc/parallel/ops_info/activation_info.h +++ b/mindspore/ccsrc/parallel/ops_info/activation_info.h @@ -203,6 +203,14 @@ class SqueezeInfo : public ActivationOther { private: ValueTuplePtr axis_; }; + +class SquareInfo : public ActivationOther { + public: + SquareInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, + const PrimitiveAttrs& attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~SquareInfo() override = default; +}; } // namespace parallel } // namespace mindspore #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/parallel/ops_info/ops_utils.h index 0b52d8e83c7..89b174d6b0b 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_utils.h @@ -202,9 +202,10 @@ constexpr char SQRT[] = "Sqrt"; constexpr char ASSIGN[] = "Assign"; constexpr char GET_NEXT[] = "GetNext"; constexpr char SQUEEZE[] = "Squeeze"; -constexpr char Neg[] = "Neg"; +constexpr char NEG[] = "Neg"; constexpr char BATCH_MATMUL[] = "BatchMatMul"; constexpr char EXPAND_DIMS[] = "ExpandDims"; +constexpr char SQUARE[] = "Square"; // Parallel don't care constexpr char TUPLE_GETITEM[] = "tuple_getitem"; diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index 1d52eac82db..81aae04c739 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -104,13 +104,14 @@ std::vector splittable_op_ = {MATMUL, SQRT, GET_NEXT, CAST, - Neg, + NEG, + SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE}; -std::vector elementwise_op_ = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU, SQRT, - CAST, POW, EXP, LOG, COS, ACOS, LOGICALNOT}; +std::vector elementwise_op_ = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU, SQRT, CAST, + POW, EXP, LOG, COS, ACOS, LOGICALNOT, NEG, SQUARE}; bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { MS_EXCEPTION_IF_NULL(root); diff --git a/tests/ut/python/parallel/test_square.py b/tests/ut/python/parallel/test_square.py new file mode 100644 index 00000000000..e9c182a439a --- /dev/null +++ b/tests/ut/python/parallel/test_square.py @@ -0,0 +1,85 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.nn import Cell, TrainOneStepCell, Momentum +from mindspore.ops import operations as P +from mindspore.common.api import _executor + + +class Net(Cell): + def __init__(self, mul_weight, strategy1=None, strategy2=None): + super(Net, self).__init__() + self.mul = P.Mul().set_strategy(strategy1) + self.square = P.Square().set_strategy(strategy2) + self.mul2 = P.Mul().set_strategy(strategy1) + self.mul_weight = Parameter(mul_weight, "w1") + + def construct(self, x, b): + out = self.mul(x, self.mul_weight) + out = self.square(out) + out = self.mul2(out, b) + return out + + +_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) +_w1 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) +_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) + + +def compile_net(net): + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + _executor.compile(train_net, _x, _b) + context.reset_auto_parallel_context() + + +def test_square_data_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((16, 1, 1), (16, 1, 1)) + strategy2 = ((16, 1, 1), ) + net = Net(_w1, strategy1, strategy2) + compile_net(net) + + +def test_square_model_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((1, 1, 16), (1, 1, 16)) + strategy2 = ((1, 1, 16), ) + net = Net(_w1, strategy1, strategy2) + compile_net(net) + + +def test_square_hybrid_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((2, 2, 4), (2, 2, 4)) + strategy2 = ((2, 2, 4), ) + net = Net(_w1, strategy1, strategy2) + compile_net(net) + + +def test_square_auto_parallel(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0) + net = Net(_w1) + compile_net(net) + + +def test_square_repeat_calc(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((2, 2, 4), (2, 2, 4)) + strategy2 = ((1, 2, 2), ) + net = Net(_w1, strategy1, strategy2) + compile_net(net)