From af0d28de487f55a8512bf55c9a34ebf60c608864 Mon Sep 17 00:00:00 2001 From: yangzhenzhang Date: Wed, 16 Jun 2021 14:57:56 +0800 Subject: [PATCH] add parallel op for batchnorm --- .../ccsrc/frontend/parallel/dynamic_creator.h | 1 + .../parallel/ops_info/batchnorm_info.cc | 274 ++++++++++++++++++ .../parallel/ops_info/batchnorm_info.h | 64 ++++ .../frontend/parallel/ops_info/conv2d_info.cc | 56 +--- .../frontend/parallel/ops_info/conv2d_info.h | 3 - .../parallel/ops_info/operator_info.cc | 71 +++++ .../parallel/ops_info/operator_info.h | 5 + .../parallel/ops_info/ops_info_head_files.h | 1 + .../frontend/parallel/ops_info/ops_utils.h | 5 + .../ccsrc/frontend/parallel/step_parallel.cc | 6 + tests/ut/python/parallel/test_batchnorm.py | 125 ++++++++ 11 files changed, 556 insertions(+), 55 deletions(-) create mode 100644 mindspore/ccsrc/frontend/parallel/ops_info/batchnorm_info.cc create mode 100644 mindspore/ccsrc/frontend/parallel/ops_info/batchnorm_info.h create mode 100644 tests/ut/python/parallel/test_batchnorm.py diff --git a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h index c4810aa10b7..34251a1e135 100644 --- a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h @@ -197,6 +197,7 @@ REGISTER(TopKInfo); REGISTER(ScatterUpdateInfo); REGISTER(VirtualOutputInfo); REGISTER(Conv2DInfo); +REGISTER(BatchNormInfo); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/batchnorm_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/batchnorm_info.cc new file mode 100644 index 00000000000..489232fab3c --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/batchnorm_info.cc @@ -0,0 +1,274 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/batchnorm_info.h" + +#include +#include +#include +#include + +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" +#include "pipeline/jit/resource.h" + +namespace mindspore { +namespace parallel { +Status BatchNormInfo::GetAttrs() { + is_training_ = GetBoolAttr(IS_TRAINING); + + epsilon_ = GetFloatAttr(EPSILON); + + momentum_ = GetFloatAttr(MOMENTUM); + + format_ = GetStringAttr(FORMAT); + if (format_ != NCHW) { + MS_LOG(ERROR) << name_ << ": The data format must be 'NCHW', but got " << format_; + return FAILED; + } + + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + + if (inputs_shape_[0].size() == 2) { + input_is_4d_ = false; + } else if (inputs_shape_[0].size() == 4) { + input_is_4d_ = true; + } else { + MS_LOG(ERROR) << name_ << ": The size of input[0]'shape must be 2 or 4, but got " << inputs_shape_[0].size(); + return FAILED; + } + + MS_LOG(INFO) << name_ << ": The is_traing is " << is_training_ << ", epsilon is " << epsilon_ << ", momentum is " + << momentum_ << ", data format is " << format_; + + return SUCCESS; +} + +Status BatchNormInfo::CheckStrategy(const StrategyPtr &strategy) { + MS_EXCEPTION_IF_NULL(strategy); + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Invalid strategy"; + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + if (stra.size() != 5) { + MS_LOG(ERROR) << name_ << ": The size of strategy must be 5, but got " << stra.size(); + return FAILED; + } + + if ((stra[0].size() != 4) && (stra[0].size() != 2)) { + MS_LOG(ERROR) << name_ << ": The size of strategy[0] must be 4 or 2, but got " << stra[0].size(); + return FAILED; + } + + for (size_t i = 1; i < 5; ++i) { + if (stra[i].empty()) { + MS_LOG(ERROR) << name_ << ": The strategy can not be empty, the index is " << i; + return FAILED; + } + if (stra[0][1] != stra[i][0]) { + MS_LOG(ERROR) << name_ << ": Invalid strategy, the index is " << i << ", it must be equal to " << stra[0][1] + << ", but got " << stra[i][0]; + return FAILED; + } + } + + return SUCCESS; +} + +Status BatchNormInfo::InferDevMatrixShape() { + MS_EXCEPTION_IF_NULL(strategy_); + std::vector stra = strategy_->GetInputDim(); + if (stra.empty()) { + MS_LOG(ERROR) << name_ << ": The strategy can not be empty"; + return FAILED; + } + + dev_matrix_shape_ = stra[0]; + return SUCCESS; +} + +Status BatchNormInfo::InferTensorMap() { + TensorMap input_tensor_map; + TensorMap in_other_tensor_map; + + if (input_is_4d_) { + // if input is 4d: + // input_strategy: ((n, c, h, w), (c), (c), (c), (c)) + // output_strategy: ((n, c, h, w), (c), (c), (c), (c)) + // dev_matrix: (n, c, h, w) + input_tensor_map = {3, 2, 1, 0}; + in_other_tensor_map = {2}; + } else { + // if input is 2d: + // input_strategy: ((n, c), (c), (c), (c), (c)) + // output_strategy: ((n, c), (c), (c), (c), (c)) + // dev_matrix: (n, c) + input_tensor_map = {1, 0}; + in_other_tensor_map = {0}; + } + + inputs_tensor_map_.push_back(input_tensor_map); // input + inputs_tensor_map_.push_back(in_other_tensor_map); // scale + inputs_tensor_map_.push_back(in_other_tensor_map); // bias + inputs_tensor_map_.push_back(in_other_tensor_map); // mean + inputs_tensor_map_.push_back(in_other_tensor_map); // variance + + outputs_tensor_map_ = inputs_tensor_map_; + return SUCCESS; +} + +Status BatchNormInfo::InferForwardCommunication() { + // if it is not training, no need forward allreduce + if (!is_training_) { + MS_LOG(INFO) << name_ << ": It is not training, no need forward allreduce"; + return SUCCESS; + } + + TensorMap tmp_map; + if (input_is_4d_) { + // input is 4d: + // if has not repeated calculation, the dev matirx is [n, c, h, w] + // if repeated calculation and repeated num in the left of dev matrix, the dev matrix is [repeated_num, n, c, h, w] + // if repeated calculation and repeated num in the right of dev matrix, the dev matrix is [n, c, h, w, repeated_num] + // and the forward allreduce need to use the dimensions of n/h/w + if (repeated_calc_num_ == 1) { + // has not repeated calculation + tmp_map = {-1, 2, -1, -1}; + } else if (!repeated_num_in_dev_matrix_right_) { + // repeated calculation and repeated num in the left of dev matrix + tmp_map = {4, -1, 2, -1, -1}; + } else { + // repeated calculation and repeated num in the right of dev matrix + tmp_map = {-1, 3, -1, -1, 0}; + } + } else { + // input is 2d: + // if has not repeated calculation, the dev matirx is [n, c] + // if repeated calculation and repeated num in the left of dev matrix, the dev matrix is [repeated_num, n, c] + // if repeated calculation and repeated num in the right of dev matrix, the dev matrix is [n, c, repeated_num] + // and the forward allreduce need to use the dimensions of n + if (repeated_calc_num_ == 1) { + // has not repeated calculation + tmp_map = {-1, 0}; + } else if (!repeated_num_in_dev_matrix_right_) { + // repeated calculation and repeated num in the left of dev matrix + tmp_map = {2, -1, 0}; + } else { + // repeated calculation and repeated num in the right of dev matrix + tmp_map = {-1, 1, 0}; + } + } + + std::vector group_list; + if (CreateGroupByTensorMap(tmp_map, &group_list) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Create group failed"; + return FAILED; + } + + if (group_list.empty()) { + MS_LOG(INFO) << name_ << ": Forward all reduce is not required"; + return SUCCESS; + } else { + MS_LOG(INFO) << name_ << ": The group name of forward all reduce is " << group_list[0].name(); + } + + forward_allreduce_group_ = group_list; + return SUCCESS; +} + +Status BatchNormInfo::InferReplaceOps() { + if (!is_training_) { + MS_LOG(INFO) << name_ << ": It is not training, no need to replace op"; + return SUCCESS; + } + + if (forward_allreduce_group_.empty()) { + MS_LOG(INFO) << name_ << ": The forward allreduce group is empty, no need to replace op"; + return SUCCESS; + } + + ValuePtr epsilon = MakeValue(epsilon_); + ValuePtr momentum = MakeValue(momentum_); + ValuePtr group = MakeValue(forward_allreduce_group_[0].name()); + ValuePtr device_num = MakeValue(forward_allreduce_group_[0].GetDevNum()); + + Attr attr_epsilon = std::make_pair(EPSILON, epsilon); + Attr attr_momentum = std::make_pair(MOMENTUM, momentum); + Attr attr_group = std::make_pair(GROUP, group); + Attr attr_device_num = std::make_pair(DEVICE_NUM, device_num); + + OperatorAttrs attrs = {attr_epsilon, attr_momentum, attr_group, attr_device_num}; + OperatorParams params; + OperatorArgs args = std::make_pair(attrs, params); + replace_op_ = {std::make_pair(SYNC_BATCH_NORM, args)}; + return SUCCESS; +} + +Status BatchNormInfo::InferAsLossDivisor() { + if (outputs_tensor_map_.size() != 5) { + MS_LOG(ERROR) << name_ << ": The size of outputs tensor map must be 5, but got " << outputs_tensor_map_.size(); + return FAILED; + } + as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]); + MS_LOG(INFO) << name_ << " : The dev matrix shape is " << ShapeToString(dev_matrix_shape_) + << ", the output[0]'s tensor map is " << ShapeToString(outputs_tensor_map_[0]) + << ", as_loss_divisor_ is " << as_loss_divisor_; + return SUCCESS; +} + +Status BatchNormInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } + +std::vector BatchNormInfo::GenerateOpStrategies(int64_t stage_id) { + Strategys strategy; + if (input_is_4d_) { + strategy = {{stage_device_size_, 1, 1, 1}, {1}, {1}, {1}, {1}}; + } else { + strategy = {{stage_device_size_, 1}, {1}, {1}, {1}, {1}}; + } + StrategyPtr sp = std::make_shared(stage_id, strategy); + std::vector sp_vector; + sp_vector.push_back(sp); + return sp_vector; +} + +Status BatchNormInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + + (void)InferReplaceOps(); + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status BatchNormInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/batchnorm_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/batchnorm_info.h new file mode 100644 index 00000000000..9437c236ff5 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/batchnorm_info.h @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_BATCHNORM_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_BATCHNORM_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class BatchNormInfo : public OperatorInfo { + public: + BatchNormInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + ~BatchNormInfo() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + std::vector GenerateOpStrategies(int64_t) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; + + protected: + Status GetAttrs() override; + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferForwardCommunication() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status InferReplaceOps(); + Status InferAsLossDivisor() override; + + private: + bool is_training_ = false; + float epsilon_ = 0.00001; + float momentum_ = 0.1; + bool input_is_4d_ = true; + std::string format_; + std::vector forward_allreduce_group_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_BATCHNORM_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc index a75b481088e..d82f05329b0 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc @@ -28,54 +28,6 @@ namespace mindspore { namespace parallel { -int64_t Conv2DInfo::GetIntAttr(const std::string &attr_name) { - auto attr_iter = attrs_.find(attr_name); - if (attr_iter == attrs_.end()) { - MS_LOG(ERROR) << name_ << ": Can not find the attribution of " << attr_name; - return -1; - } - - MS_EXCEPTION_IF_NULL(attr_iter->second); - if (!attr_iter->second->isa()) { - MS_LOG(ERROR) << name_ << ": The value of " << attr_name << " is not int"; - return -1; - } - - return attr_iter->second->cast()->value(); -} - -std::string Conv2DInfo::GetStringAttr(const std::string &attr_name) { - std::string string_attr; - auto attr_iter = attrs_.find(attr_name); - if (attr_iter == attrs_.end()) { - MS_LOG(ERROR) << name_ << ": Can not find the attribution of " << attr_name; - return string_attr; - } - - MS_EXCEPTION_IF_NULL(attr_iter->second); - if (!attr_iter->second->isa()) { - MS_LOG(ERROR) << name_ << ": The value of " << attr_name << " is not string"; - return string_attr; - } - - string_attr = attr_iter->second->cast()->value(); - return string_attr; -} - -std::vector Conv2DInfo::GetTupleAttr(const std::string &attr_name) { - std::vector tuple_attr; - auto tuple_attr_iter = attrs_.find(attr_name); - if (tuple_attr_iter == attrs_.end()) { - MS_LOG(ERROR) << name_ << ": Can not find the attribution of " << attr_name; - return tuple_attr; - } - - MS_EXCEPTION_IF_NULL(tuple_attr_iter->second); - tuple_attr = GetValue>(tuple_attr_iter->second); - - return tuple_attr; -} - Status Conv2DInfo::GetAttrs() { // out_channel out_channel_ = GetIntAttr(OUT_CHANNEL); @@ -121,14 +73,14 @@ Status Conv2DInfo::GetAttrs() { } // pad_list - pad_list_ = GetTupleAttr(PAD_LIST); + pad_list_ = GetTupleIntAttr(PAD_LIST); if (pad_list_.size() != 4) { MS_LOG(ERROR) << name_ << ": The size of pad_list must be 4, but got " << pad_list_.size(); return FAILED; } // stride - stride_ = GetTupleAttr(STRIDE); + stride_ = GetTupleIntAttr(STRIDE); if (stride_.size() != 4) { MS_LOG(ERROR) << name_ << ": The size of stride must be 4, but got " << stride_.size(); return FAILED; @@ -141,7 +93,7 @@ Status Conv2DInfo::GetAttrs() { } // dilation - dilation_ = GetTupleAttr(DILATION); + dilation_ = GetTupleIntAttr(DILATION); if (dilation_.size() != 4) { MS_LOG(ERROR) << name_ << ": The size of dilation must be 4, but got " << dilation_.size(); return FAILED; @@ -279,7 +231,7 @@ Status Conv2DInfo::InferDevMatrixShape() { MS_EXCEPTION_IF_NULL(strategy_); std::vector stra = strategy_->GetInputDim(); if (stra.size() != 2) { - MS_LOG(ERROR) << name_ << "The size of strategy must be 2, but got " << stra.size(); + MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << stra.size(); return FAILED; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h index 64c12c01b1e..1e0179200c3 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h @@ -49,9 +49,6 @@ class Conv2DInfo : public OperatorInfo { Status InferForwardCommunication() override; Status InferDevMatrixShape() override; Status InferTensorMap() override; - int64_t GetIntAttr(const std::string &attr_name); - std::string GetStringAttr(const std::string &attr_name); - std::vector GetTupleAttr(const std::string &attr_name); ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; private: diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index 6de3a990e9d..052a5a808be 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -1714,6 +1714,77 @@ Status OperatorInfo::GenerateStrategies(int64_t stage_id) { return SUCCESS; } +int64_t OperatorInfo::GetIntAttr(const std::string &attr_name) { + auto attr_iter = attrs_.find(attr_name); + if (attr_iter == attrs_.end()) { + MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name; + } + + MS_EXCEPTION_IF_NULL(attr_iter->second); + if (!attr_iter->second->isa()) { + MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not int"; + } + + return attr_iter->second->cast()->value(); +} + +bool OperatorInfo::GetBoolAttr(const std::string &attr_name) { + auto attr_iter = attrs_.find(attr_name); + if (attr_iter == attrs_.end()) { + MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name; + } + + MS_EXCEPTION_IF_NULL(attr_iter->second); + if (!attr_iter->second->isa()) { + MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not int"; + } + + return attr_iter->second->cast()->value(); +} + +std::string OperatorInfo::GetStringAttr(const std::string &attr_name) { + std::string string_attr; + auto attr_iter = attrs_.find(attr_name); + if (attr_iter == attrs_.end()) { + MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name; + } + + MS_EXCEPTION_IF_NULL(attr_iter->second); + if (!attr_iter->second->isa()) { + MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not string"; + } + + string_attr = attr_iter->second->cast()->value(); + return string_attr; +} + +std::vector OperatorInfo::GetTupleIntAttr(const std::string &attr_name) { + std::vector tuple_attr; + auto tuple_attr_iter = attrs_.find(attr_name); + if (tuple_attr_iter == attrs_.end()) { + MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name; + } + + MS_EXCEPTION_IF_NULL(tuple_attr_iter->second); + tuple_attr = GetValue>(tuple_attr_iter->second); + + return tuple_attr; +} + +float OperatorInfo::GetFloatAttr(const std::string &attr_name) { + auto attr_iter = attrs_.find(attr_name); + if (attr_iter == attrs_.end()) { + MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name; + } + + MS_EXCEPTION_IF_NULL(attr_iter->second); + if (!attr_iter->second->isa()) { + MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not float"; + } + + return attr_iter->second->cast()->value(); +} + std::vector GetValueSequeue(const ValuePtr &sequeue) { MS_EXCEPTION_IF_NULL(sequeue); std::vector ret; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index 5cc3b7db78c..4583de949a8 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -214,6 +214,11 @@ class OperatorInfo { Status InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy, Shapes *inputs_slice_shape, Shapes *outputs_slice_shape); void BreakingTiesForPerferringDataParallel(const StrategyPtr &, const CostPtr &); + int64_t GetIntAttr(const std::string &attr_name); + bool GetBoolAttr(const std::string &attr_name); + float GetFloatAttr(const std::string &attr_name); + std::string GetStringAttr(const std::string &attr_name); + std::vector GetTupleIntAttr(const std::string &attr_name); std::string name_; Shapes inputs_shape_; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h index 044eddad0a4..dfda1bd9505 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h @@ -56,5 +56,6 @@ #include "frontend/parallel/ops_info/scatter_update_info.h" #include "frontend/parallel/ops_info/virtual_output_info.h" #include "frontend/parallel/ops_info/conv2d_info.h" +#include "frontend/parallel/ops_info/batchnorm_info.h" #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 3b7b62c9b43..a51210fb955 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -197,6 +197,10 @@ constexpr char STRIDE[] = "stride"; constexpr char DILATION[] = "dilation"; constexpr char FORMAT[] = "format"; constexpr char NCHW[] = "NCHW"; +constexpr char IS_TRAINING[] = "is_training"; +constexpr char EPSILON[] = "epsilon"; +constexpr char MOMENTUM[] = "momentum"; +constexpr char DEVICE_NUM[] = "device_num"; // Operator constexpr char VIRTUAL_DIV[] = "_VirtualDiv"; @@ -267,6 +271,7 @@ constexpr char CONV2D[] = "Conv2D"; constexpr char FUSE_BATCH_NORM[] = "FusedBatchNorm"; constexpr char FUSE_BATCH_NORM_EX[] = "FusedBatchNormEx"; constexpr char BATCH_NORM[] = "BatchNorm"; +constexpr char SYNC_BATCH_NORM[] = "SyncBatchNorm"; constexpr char LAYER_NORM[] = "LayerNorm"; constexpr char POOLING[] = "Pooling"; constexpr char CAST[] = "Cast"; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 351a4205cfe..3ebfaf7728b 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -787,9 +787,11 @@ std::vector ReplaceOpInput(const Operator &replace_op, const std::st MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2"; } std::vector replace_input = {NewValueNode(pyop_instance), node->input(1)}; + if (replace_op.first == EMBEDDING_LOOKUP) { replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; } + if (!params.empty()) { Param param_first = *(params.begin()); int64_t first_position = param_first.second; @@ -804,6 +806,10 @@ std::vector ReplaceOpInput(const Operator &replace_op, const std::st int64_t position = param.second; (void)replace_input.insert(replace_input.begin() + position, val); } + } else if (replace_op.first == SYNC_BATCH_NORM) { + for (size_t i = 2; i < node->inputs().size(); ++i) { + replace_input.push_back(node->input(i)); + } } return replace_input; diff --git a/tests/ut/python/parallel/test_batchnorm.py b/tests/ut/python/parallel/test_batchnorm.py new file mode 100644 index 00000000000..f9a26f742ef --- /dev/null +++ b/tests/ut/python/parallel/test_batchnorm.py @@ -0,0 +1,125 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.common.api import _executor +from mindspore.nn import Cell, TrainOneStepCell, Momentum, BatchNorm2d, BatchNorm1d +from mindspore.ops import operations as P + + +class Net(Cell): + def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, + strategy1=None, strategy2=None): + super().__init__() + self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size, + pad_mode=pad_mode, stride=stride).shard(strategy1) + self.conv2d_weight = Parameter(conv2d_weight, "w1") + self.bn = BatchNorm2d(8) + self.bn.bn_train.shard(strategy2) + + def construct(self, x, b): + out = self.conv2d(x, self.conv2d_weight) + out = self.bn(out) + return out + + +_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) +_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32) +_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) + + +def compile_net(net): + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + train_net.set_train() + _executor.compile(train_net, _x, _b) + context.reset_auto_parallel_context() + + +def test_batchnorm_data_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) + strategy2 = ((8, 1, 1, 1), (1,), (1,), (1,), (1,)) + net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2) + compile_net(net) + + +def test_batchnorm_model_parallel1(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1)) + strategy2 = ((2, 1, 2, 2), (1,), (1,), (1,), (1,)) + net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2) + compile_net(net) + + +def test_batchnorm_model_parallel2(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0) + strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1)) + strategy2 = ((1, 8, 1, 1), (8,), (8,), (8,), (8,)) + net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2) + compile_net(net) + + +class Net2(Cell): + def __init__(self, strategy1=None, strategy2=None): + super().__init__() + self.bn = BatchNorm1d(8) + self.bn.bn_train.shard(strategy1) + self.relu = P.ReLU().shard(strategy2) + + def construct(self, x, b): + out = self.bn(x) + out = self.relu(out) + return out + + +_x1 = Tensor(np.ones([32, 8]), dtype=ms.float32) +_b1 = Tensor(np.ones([32, 8]), dtype=ms.float32) + + +def compile_net2(net): + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + train_net.set_train() + _executor.compile(train_net, _x1, _b1) + context.reset_auto_parallel_context() + + +def test_batchnorm1d_data_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((8, 1), (1,), (1,), (1,), (1,)) + strategy2 = ((8, 1),) + net = Net2(strategy1=strategy1, strategy2=strategy2) + compile_net2(net) + + +def test_batchnorm1d_model_parallel1(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 8), (8,), (8,), (8,), (8,)) + strategy2 = ((1, 8),) + net = Net2(strategy1=strategy1, strategy2=strategy2) + compile_net2(net) + + +def test_batchnorm1d_model_parallel2(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0) + strategy1 = ((2, 4), (4,), (4,), (4,), (4,)) + strategy2 = ((2, 4),) + net = Net2(strategy1=strategy1, strategy2=strategy2) + compile_net2(net)