From 148d645915fb62057fd40105c4124fa1e39ce096 Mon Sep 17 00:00:00 2001 From: Yi Huaijie Date: Thu, 29 Oct 2020 17:02:35 +0800 Subject: [PATCH] fix ReLUV2 error --- .../parallel/ops_info/activation_info.h | 8 - .../parallel/ops_info/ops_info_head_files.h | 1 + .../frontend/parallel/ops_info/reluv2_info.cc | 183 ++++++++++++++++++ .../frontend/parallel/ops_info/reluv2_info.h | 60 ++++++ tests/ut/python/parallel/test_reluv2.py | 76 ++++++++ 5 files changed, 320 insertions(+), 8 deletions(-) create mode 100644 mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.cc create mode 100644 mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.h create mode 100644 tests/ut/python/parallel/test_reluv2.py diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h index 10ba8475e95..a2978bfcfd3 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h @@ -154,14 +154,6 @@ class ReLU6Info : public ActivationOther { ~ReLU6Info() override = default; }; -class ReLUV2Info : public ActivationOther { - public: - ReLUV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~ReLUV2Info() override = default; -}; - class SoftsignInfo : public ActivationOther { public: SoftsignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h index a5c9b2e0b4b..30abd841ac0 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h @@ -45,5 +45,6 @@ #include "frontend/parallel/ops_info/pack_info.h" #include "frontend/parallel/ops_info/broadcast_to_info.h" #include "frontend/parallel/ops_info/unique_info.h" +#include "frontend/parallel/ops_info/reluv2_info.h" #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.cc new file mode 100644 index 00000000000..dcbd21f5f50 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.cc @@ -0,0 +1,183 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/reluv2_info.h" + +#include +#include +#include +#include +#include +#include + +#include "frontend/parallel/device_matrix.h" +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/costmodel.h" +#include "frontend/parallel/context.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +Status ReLUV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } + +Status ReLUV2Info::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); } + +Status ReLUV2Info::GetAttrs() { return SUCCESS; } + +Status ReLUV2Info::GenerateStrategies(int32_t stage_id) { + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status ReLUV2Info::InferDevMatrixShape() { + Strategys stra = strategy_->GetInputDim(); + Dimensions input_strategy = stra.at(0); + + dev_matrix_shape_ = input_strategy; + + return SUCCESS; +} + +Status ReLUV2Info::InferMirrorOps() { + mirror_ops_.clear(); + + Shape tensor_map = inputs_tensor_map_[0]; + std::vector group; + if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group failed."; + return FAILED; + } + + OperatorVector mirror_op; + if (group.empty()) { + MS_LOG(INFO) << name_ << " : The mirror ops is empty."; + return SUCCESS; + } else { + mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); + mirror_ops_.push_back(mirror_op); + std::string group_name = group[0].name(); + MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name; + } + + return SUCCESS; +} + +Status ReLUV2Info::InferForwardCommunication() { + // do nothing + return SUCCESS; +} + +Status ReLUV2Info::InferTensorMap() { + Shape tensor_map_index; + size_t size = inputs_shape_.at(0).size(); + // such as 4: tensor_map_index [3,2,1,0] + for (size_t i = 0; i < size; ++i) { + tensor_map_index.push_back((int64_t)(size - i - 1)); + } + + inputs_tensor_map_.push_back(tensor_map_index); + // output and mask + outputs_tensor_map_.push_back(tensor_map_index); + outputs_tensor_map_.push_back(tensor_map_index); + return SUCCESS; +} + +Status ReLUV2Info::InferTensorInfo() { + if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": Invalid args"; + return FAILED; + } + + TensorLayout input_layout, output_layout; + // infer tensor layout + if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed."; + return FAILED; + } + TensorInfo input_tensor_info(input_layout); + inputs_tensor_info_.push_back(input_tensor_info); + + if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed."; + return FAILED; + } + TensorInfo output_tensor_info(output_layout); + // output and mask + outputs_tensor_info_.push_back(output_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +Status ReLUV2Info::InferAsLossDivisor() { + if (!ParallelContext::GetInstance()->loss_repeated_mean()) { + as_loss_divisor_ = 1; + return SUCCESS; + } + + if (outputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": The outputs tensor map is empty."; + return FAILED; + } + + if (outputs_tensor_map_[0].empty()) { + as_loss_divisor_ = SizeToInt(global_device_list_.size()); + MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor."; + return SUCCESS; + } + + as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]); + MS_LOG(INFO) << name_ << ": the dev matrix shape is " << ShapeToString(dev_matrix_shape_) + << ", the output tensor map is " << ShapeToString(outputs_tensor_map_[0]) << ", loss divisor is " + << as_loss_divisor_; + return SUCCESS; +} + +Status ReLUV2Info::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init failed."; + return FAILED; + } + + MS_LOG(INFO) << name_ << " : Init success."; + return SUCCESS; +} + +Status ReLUV2Info::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init for cost model failed."; + return FAILED; + } + + MS_LOG(INFO) << name_ << " : Init for cost model success."; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.h new file mode 100644 index 00000000000..3de8747b2c2 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RELUV2_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RELUV2_INFO_H_ + +#include +#include +#include +#include +#include + +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +/* + * The input, output and mask have the same tensormap. + * And all dimensions of input are splitable. + */ +class ReLUV2Info : public OperatorInfo { + public: + ReLUV2Info(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~ReLUV2Info() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + + protected: + Status InferMirrorOps() override; + Status InferForwardCommunication() override; + Status InferTensorMap() override; + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status CheckStrategy(const StrategyPtr &strategy) override; + Status GetAttrs() override; + Status InferAsLossDivisor() override; +}; +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_RELUV2_INFO_H_ diff --git a/tests/ut/python/parallel/test_reluv2.py b/tests/ut/python/parallel/test_reluv2.py new file mode 100644 index 00000000000..2dcf5ca3fe6 --- /dev/null +++ b/tests/ut/python/parallel/test_reluv2.py @@ -0,0 +1,76 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore as ms +import mindspore.context as context +from mindspore import Tensor, Parameter +import mindspore.nn as nn +from mindspore.common.api import _executor +from mindspore.nn import TrainOneStepCell, Momentum +from mindspore.ops import operations as P + + +class Net(nn.Cell): + def __init__(self, mul_weight, strategy=None): + super(Net, self).__init__() + self.reluv2 = P.ReLUV2().shard(strategy) + self.mul = P.Mul() + self.weight = Parameter(mul_weight, "w1") + + def construct(self, x): + out = self.mul(x, self.weight) + output, _ = self.reluv2(out) + return output + + +_w1 = Tensor(np.ones([32, 16, 48, 64]), dtype=ms.float32) +_x = Tensor(np.ones([32, 16, 48, 64]), dtype=ms.float32) + + +def compile_net(net): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + train_net.set_train() + _executor.compile(train_net, _x) + context.reset_auto_parallel_context() + + +def test_reluv2(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy = ((2, 1, 2, 2),) + net = Net(_w1, strategy) + compile_net(net) + + +def test_reluv2_no_full(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy = ((2, 1, 2, 1),) + net = Net(_w1, strategy) + compile_net(net) + + +def test_reluv2_no_strategy(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy = None + net = Net(_w1, strategy) + compile_net(net) + + +def test_reluv2_auto_parallel(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) + net = Net(_w1) + compile_net(net)