forked from mindspore-Ecosystem/mindspore
add parallel op for batchnorm
This commit is contained in:
parent
a0c5b56f5f
commit
af0d28de48
|
@ -197,6 +197,7 @@ REGISTER(TopKInfo);
|
||||||
REGISTER(ScatterUpdateInfo);
|
REGISTER(ScatterUpdateInfo);
|
||||||
REGISTER(VirtualOutputInfo);
|
REGISTER(VirtualOutputInfo);
|
||||||
REGISTER(Conv2DInfo);
|
REGISTER(Conv2DInfo);
|
||||||
|
REGISTER(BatchNormInfo);
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,274 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "frontend/parallel/ops_info/batchnorm_info.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "frontend/parallel/device_matrix.h"
|
||||||
|
#include "frontend/parallel/strategy.h"
|
||||||
|
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||||
|
#include "pipeline/jit/resource.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
Status BatchNormInfo::GetAttrs() {
|
||||||
|
is_training_ = GetBoolAttr(IS_TRAINING);
|
||||||
|
|
||||||
|
epsilon_ = GetFloatAttr(EPSILON);
|
||||||
|
|
||||||
|
momentum_ = GetFloatAttr(MOMENTUM);
|
||||||
|
|
||||||
|
format_ = GetStringAttr(FORMAT);
|
||||||
|
if (format_ != NCHW) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The data format must be 'NCHW', but got " << format_;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inputs_shape_.empty()) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inputs_shape_[0].size() == 2) {
|
||||||
|
input_is_4d_ = false;
|
||||||
|
} else if (inputs_shape_[0].size() == 4) {
|
||||||
|
input_is_4d_ = true;
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The size of input[0]'shape must be 2 or 4, but got " << inputs_shape_[0].size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << name_ << ": The is_traing is " << is_training_ << ", epsilon is " << epsilon_ << ", momentum is "
|
||||||
|
<< momentum_ << ", data format is " << format_;
|
||||||
|
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status BatchNormInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
|
MS_EXCEPTION_IF_NULL(strategy);
|
||||||
|
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Dimensions> stra = strategy->GetInputDim();
|
||||||
|
if (stra.size() != 5) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The size of strategy must be 5, but got " << stra.size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((stra[0].size() != 4) && (stra[0].size() != 2)) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The size of strategy[0] must be 4 or 2, but got " << stra[0].size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 1; i < 5; ++i) {
|
||||||
|
if (stra[i].empty()) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The strategy can not be empty, the index is " << i;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (stra[0][1] != stra[i][0]) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Invalid strategy, the index is " << i << ", it must be equal to " << stra[0][1]
|
||||||
|
<< ", but got " << stra[i][0];
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status BatchNormInfo::InferDevMatrixShape() {
|
||||||
|
MS_EXCEPTION_IF_NULL(strategy_);
|
||||||
|
std::vector<Dimensions> stra = strategy_->GetInputDim();
|
||||||
|
if (stra.empty()) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The strategy can not be empty";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
dev_matrix_shape_ = stra[0];
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status BatchNormInfo::InferTensorMap() {
|
||||||
|
TensorMap input_tensor_map;
|
||||||
|
TensorMap in_other_tensor_map;
|
||||||
|
|
||||||
|
if (input_is_4d_) {
|
||||||
|
// if input is 4d:
|
||||||
|
// input_strategy: ((n, c, h, w), (c), (c), (c), (c))
|
||||||
|
// output_strategy: ((n, c, h, w), (c), (c), (c), (c))
|
||||||
|
// dev_matrix: (n, c, h, w)
|
||||||
|
input_tensor_map = {3, 2, 1, 0};
|
||||||
|
in_other_tensor_map = {2};
|
||||||
|
} else {
|
||||||
|
// if input is 2d:
|
||||||
|
// input_strategy: ((n, c), (c), (c), (c), (c))
|
||||||
|
// output_strategy: ((n, c), (c), (c), (c), (c))
|
||||||
|
// dev_matrix: (n, c)
|
||||||
|
input_tensor_map = {1, 0};
|
||||||
|
in_other_tensor_map = {0};
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs_tensor_map_.push_back(input_tensor_map); // input
|
||||||
|
inputs_tensor_map_.push_back(in_other_tensor_map); // scale
|
||||||
|
inputs_tensor_map_.push_back(in_other_tensor_map); // bias
|
||||||
|
inputs_tensor_map_.push_back(in_other_tensor_map); // mean
|
||||||
|
inputs_tensor_map_.push_back(in_other_tensor_map); // variance
|
||||||
|
|
||||||
|
outputs_tensor_map_ = inputs_tensor_map_;
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status BatchNormInfo::InferForwardCommunication() {
|
||||||
|
// if it is not training, no need forward allreduce
|
||||||
|
if (!is_training_) {
|
||||||
|
MS_LOG(INFO) << name_ << ": It is not training, no need forward allreduce";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorMap tmp_map;
|
||||||
|
if (input_is_4d_) {
|
||||||
|
// input is 4d:
|
||||||
|
// if has not repeated calculation, the dev matirx is [n, c, h, w]
|
||||||
|
// if repeated calculation and repeated num in the left of dev matrix, the dev matrix is [repeated_num, n, c, h, w]
|
||||||
|
// if repeated calculation and repeated num in the right of dev matrix, the dev matrix is [n, c, h, w, repeated_num]
|
||||||
|
// and the forward allreduce need to use the dimensions of n/h/w
|
||||||
|
if (repeated_calc_num_ == 1) {
|
||||||
|
// has not repeated calculation
|
||||||
|
tmp_map = {-1, 2, -1, -1};
|
||||||
|
} else if (!repeated_num_in_dev_matrix_right_) {
|
||||||
|
// repeated calculation and repeated num in the left of dev matrix
|
||||||
|
tmp_map = {4, -1, 2, -1, -1};
|
||||||
|
} else {
|
||||||
|
// repeated calculation and repeated num in the right of dev matrix
|
||||||
|
tmp_map = {-1, 3, -1, -1, 0};
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// input is 2d:
|
||||||
|
// if has not repeated calculation, the dev matirx is [n, c]
|
||||||
|
// if repeated calculation and repeated num in the left of dev matrix, the dev matrix is [repeated_num, n, c]
|
||||||
|
// if repeated calculation and repeated num in the right of dev matrix, the dev matrix is [n, c, repeated_num]
|
||||||
|
// and the forward allreduce need to use the dimensions of n
|
||||||
|
if (repeated_calc_num_ == 1) {
|
||||||
|
// has not repeated calculation
|
||||||
|
tmp_map = {-1, 0};
|
||||||
|
} else if (!repeated_num_in_dev_matrix_right_) {
|
||||||
|
// repeated calculation and repeated num in the left of dev matrix
|
||||||
|
tmp_map = {2, -1, 0};
|
||||||
|
} else {
|
||||||
|
// repeated calculation and repeated num in the right of dev matrix
|
||||||
|
tmp_map = {-1, 1, 0};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Group> group_list;
|
||||||
|
if (CreateGroupByTensorMap(tmp_map, &group_list) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Create group failed";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (group_list.empty()) {
|
||||||
|
MS_LOG(INFO) << name_ << ": Forward all reduce is not required";
|
||||||
|
return SUCCESS;
|
||||||
|
} else {
|
||||||
|
MS_LOG(INFO) << name_ << ": The group name of forward all reduce is " << group_list[0].name();
|
||||||
|
}
|
||||||
|
|
||||||
|
forward_allreduce_group_ = group_list;
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status BatchNormInfo::InferReplaceOps() {
|
||||||
|
if (!is_training_) {
|
||||||
|
MS_LOG(INFO) << name_ << ": It is not training, no need to replace op";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (forward_allreduce_group_.empty()) {
|
||||||
|
MS_LOG(INFO) << name_ << ": The forward allreduce group is empty, no need to replace op";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
ValuePtr epsilon = MakeValue(epsilon_);
|
||||||
|
ValuePtr momentum = MakeValue(momentum_);
|
||||||
|
ValuePtr group = MakeValue(forward_allreduce_group_[0].name());
|
||||||
|
ValuePtr device_num = MakeValue(forward_allreduce_group_[0].GetDevNum());
|
||||||
|
|
||||||
|
Attr attr_epsilon = std::make_pair(EPSILON, epsilon);
|
||||||
|
Attr attr_momentum = std::make_pair(MOMENTUM, momentum);
|
||||||
|
Attr attr_group = std::make_pair(GROUP, group);
|
||||||
|
Attr attr_device_num = std::make_pair(DEVICE_NUM, device_num);
|
||||||
|
|
||||||
|
OperatorAttrs attrs = {attr_epsilon, attr_momentum, attr_group, attr_device_num};
|
||||||
|
OperatorParams params;
|
||||||
|
OperatorArgs args = std::make_pair(attrs, params);
|
||||||
|
replace_op_ = {std::make_pair(SYNC_BATCH_NORM, args)};
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status BatchNormInfo::InferAsLossDivisor() {
|
||||||
|
if (outputs_tensor_map_.size() != 5) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The size of outputs tensor map must be 5, but got " << outputs_tensor_map_.size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]);
|
||||||
|
MS_LOG(INFO) << name_ << " : The dev matrix shape is " << ShapeToString(dev_matrix_shape_)
|
||||||
|
<< ", the output[0]'s tensor map is " << ShapeToString(outputs_tensor_map_[0])
|
||||||
|
<< ", as_loss_divisor_ is " << as_loss_divisor_;
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status BatchNormInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||||
|
|
||||||
|
std::vector<StrategyPtr> BatchNormInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
|
Strategys strategy;
|
||||||
|
if (input_is_4d_) {
|
||||||
|
strategy = {{stage_device_size_, 1, 1, 1}, {1}, {1}, {1}, {1}};
|
||||||
|
} else {
|
||||||
|
strategy = {{stage_device_size_, 1}, {1}, {1}, {1}, {1}};
|
||||||
|
}
|
||||||
|
StrategyPtr sp = std::make_shared<Strategy>(stage_id, strategy);
|
||||||
|
std::vector<StrategyPtr> sp_vector;
|
||||||
|
sp_vector.push_back(sp);
|
||||||
|
return sp_vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status BatchNormInfo::Init(const StrategyPtr &strategy) {
|
||||||
|
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Init failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
(void)InferReplaceOps();
|
||||||
|
MS_LOG(INFO) << name_ << ": Init success.";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status BatchNormInfo::InitForCostModel(const StrategyPtr &strategy) {
|
||||||
|
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << name_ << ": Init for cost model success.";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,64 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_BATCHNORM_INFO_H_
|
||||||
|
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_BATCHNORM_INFO_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ir/value.h"
|
||||||
|
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||||
|
#include "frontend/parallel/ops_info/operator_info.h"
|
||||||
|
#include "frontend/parallel/strategy.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
class BatchNormInfo : public OperatorInfo {
|
||||||
|
public:
|
||||||
|
BatchNormInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
|
const PrimitiveAttrs &attrs)
|
||||||
|
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()) {}
|
||||||
|
~BatchNormInfo() override = default;
|
||||||
|
|
||||||
|
Status Init(const StrategyPtr &strategy) override;
|
||||||
|
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||||
|
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
||||||
|
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status GetAttrs() override;
|
||||||
|
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||||
|
Status InferForwardCommunication() override;
|
||||||
|
Status InferDevMatrixShape() override;
|
||||||
|
Status InferTensorMap() override;
|
||||||
|
Status InferReplaceOps();
|
||||||
|
Status InferAsLossDivisor() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool is_training_ = false;
|
||||||
|
float epsilon_ = 0.00001;
|
||||||
|
float momentum_ = 0.1;
|
||||||
|
bool input_is_4d_ = true;
|
||||||
|
std::string format_;
|
||||||
|
std::vector<Group> forward_allreduce_group_;
|
||||||
|
};
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_BATCHNORM_INFO_H_
|
|
@ -28,54 +28,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
int64_t Conv2DInfo::GetIntAttr(const std::string &attr_name) {
|
|
||||||
auto attr_iter = attrs_.find(attr_name);
|
|
||||||
if (attr_iter == attrs_.end()) {
|
|
||||||
MS_LOG(ERROR) << name_ << ": Can not find the attribution of " << attr_name;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
MS_EXCEPTION_IF_NULL(attr_iter->second);
|
|
||||||
if (!attr_iter->second->isa<Int64Imm>()) {
|
|
||||||
MS_LOG(ERROR) << name_ << ": The value of " << attr_name << " is not int";
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
return attr_iter->second->cast<Int64ImmPtr>()->value();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string Conv2DInfo::GetStringAttr(const std::string &attr_name) {
|
|
||||||
std::string string_attr;
|
|
||||||
auto attr_iter = attrs_.find(attr_name);
|
|
||||||
if (attr_iter == attrs_.end()) {
|
|
||||||
MS_LOG(ERROR) << name_ << ": Can not find the attribution of " << attr_name;
|
|
||||||
return string_attr;
|
|
||||||
}
|
|
||||||
|
|
||||||
MS_EXCEPTION_IF_NULL(attr_iter->second);
|
|
||||||
if (!attr_iter->second->isa<StringImm>()) {
|
|
||||||
MS_LOG(ERROR) << name_ << ": The value of " << attr_name << " is not string";
|
|
||||||
return string_attr;
|
|
||||||
}
|
|
||||||
|
|
||||||
string_attr = attr_iter->second->cast<StringImmPtr>()->value();
|
|
||||||
return string_attr;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int64_t> Conv2DInfo::GetTupleAttr(const std::string &attr_name) {
|
|
||||||
std::vector<int64_t> tuple_attr;
|
|
||||||
auto tuple_attr_iter = attrs_.find(attr_name);
|
|
||||||
if (tuple_attr_iter == attrs_.end()) {
|
|
||||||
MS_LOG(ERROR) << name_ << ": Can not find the attribution of " << attr_name;
|
|
||||||
return tuple_attr;
|
|
||||||
}
|
|
||||||
|
|
||||||
MS_EXCEPTION_IF_NULL(tuple_attr_iter->second);
|
|
||||||
tuple_attr = GetValue<std::vector<int64_t>>(tuple_attr_iter->second);
|
|
||||||
|
|
||||||
return tuple_attr;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Conv2DInfo::GetAttrs() {
|
Status Conv2DInfo::GetAttrs() {
|
||||||
// out_channel
|
// out_channel
|
||||||
out_channel_ = GetIntAttr(OUT_CHANNEL);
|
out_channel_ = GetIntAttr(OUT_CHANNEL);
|
||||||
|
@ -121,14 +73,14 @@ Status Conv2DInfo::GetAttrs() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// pad_list
|
// pad_list
|
||||||
pad_list_ = GetTupleAttr(PAD_LIST);
|
pad_list_ = GetTupleIntAttr(PAD_LIST);
|
||||||
if (pad_list_.size() != 4) {
|
if (pad_list_.size() != 4) {
|
||||||
MS_LOG(ERROR) << name_ << ": The size of pad_list must be 4, but got " << pad_list_.size();
|
MS_LOG(ERROR) << name_ << ": The size of pad_list must be 4, but got " << pad_list_.size();
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
// stride
|
// stride
|
||||||
stride_ = GetTupleAttr(STRIDE);
|
stride_ = GetTupleIntAttr(STRIDE);
|
||||||
if (stride_.size() != 4) {
|
if (stride_.size() != 4) {
|
||||||
MS_LOG(ERROR) << name_ << ": The size of stride must be 4, but got " << stride_.size();
|
MS_LOG(ERROR) << name_ << ": The size of stride must be 4, but got " << stride_.size();
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -141,7 +93,7 @@ Status Conv2DInfo::GetAttrs() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// dilation
|
// dilation
|
||||||
dilation_ = GetTupleAttr(DILATION);
|
dilation_ = GetTupleIntAttr(DILATION);
|
||||||
if (dilation_.size() != 4) {
|
if (dilation_.size() != 4) {
|
||||||
MS_LOG(ERROR) << name_ << ": The size of dilation must be 4, but got " << dilation_.size();
|
MS_LOG(ERROR) << name_ << ": The size of dilation must be 4, but got " << dilation_.size();
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -279,7 +231,7 @@ Status Conv2DInfo::InferDevMatrixShape() {
|
||||||
MS_EXCEPTION_IF_NULL(strategy_);
|
MS_EXCEPTION_IF_NULL(strategy_);
|
||||||
std::vector<Dimensions> stra = strategy_->GetInputDim();
|
std::vector<Dimensions> stra = strategy_->GetInputDim();
|
||||||
if (stra.size() != 2) {
|
if (stra.size() != 2) {
|
||||||
MS_LOG(ERROR) << name_ << "The size of strategy must be 2, but got " << stra.size();
|
MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << stra.size();
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -49,9 +49,6 @@ class Conv2DInfo : public OperatorInfo {
|
||||||
Status InferForwardCommunication() override;
|
Status InferForwardCommunication() override;
|
||||||
Status InferDevMatrixShape() override;
|
Status InferDevMatrixShape() override;
|
||||||
Status InferTensorMap() override;
|
Status InferTensorMap() override;
|
||||||
int64_t GetIntAttr(const std::string &attr_name);
|
|
||||||
std::string GetStringAttr(const std::string &attr_name);
|
|
||||||
std::vector<int64_t> GetTupleAttr(const std::string &attr_name);
|
|
||||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -1714,6 +1714,77 @@ Status OperatorInfo::GenerateStrategies(int64_t stage_id) {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64_t OperatorInfo::GetIntAttr(const std::string &attr_name) {
|
||||||
|
auto attr_iter = attrs_.find(attr_name);
|
||||||
|
if (attr_iter == attrs_.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(attr_iter->second);
|
||||||
|
if (!attr_iter->second->isa<Int64Imm>()) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not int";
|
||||||
|
}
|
||||||
|
|
||||||
|
return attr_iter->second->cast<Int64ImmPtr>()->value();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OperatorInfo::GetBoolAttr(const std::string &attr_name) {
|
||||||
|
auto attr_iter = attrs_.find(attr_name);
|
||||||
|
if (attr_iter == attrs_.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(attr_iter->second);
|
||||||
|
if (!attr_iter->second->isa<BoolImm>()) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not int";
|
||||||
|
}
|
||||||
|
|
||||||
|
return attr_iter->second->cast<BoolImmPtr>()->value();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string OperatorInfo::GetStringAttr(const std::string &attr_name) {
|
||||||
|
std::string string_attr;
|
||||||
|
auto attr_iter = attrs_.find(attr_name);
|
||||||
|
if (attr_iter == attrs_.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(attr_iter->second);
|
||||||
|
if (!attr_iter->second->isa<StringImm>()) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not string";
|
||||||
|
}
|
||||||
|
|
||||||
|
string_attr = attr_iter->second->cast<StringImmPtr>()->value();
|
||||||
|
return string_attr;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> OperatorInfo::GetTupleIntAttr(const std::string &attr_name) {
|
||||||
|
std::vector<int64_t> tuple_attr;
|
||||||
|
auto tuple_attr_iter = attrs_.find(attr_name);
|
||||||
|
if (tuple_attr_iter == attrs_.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(tuple_attr_iter->second);
|
||||||
|
tuple_attr = GetValue<std::vector<int64_t>>(tuple_attr_iter->second);
|
||||||
|
|
||||||
|
return tuple_attr;
|
||||||
|
}
|
||||||
|
|
||||||
|
float OperatorInfo::GetFloatAttr(const std::string &attr_name) {
|
||||||
|
auto attr_iter = attrs_.find(attr_name);
|
||||||
|
if (attr_iter == attrs_.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(attr_iter->second);
|
||||||
|
if (!attr_iter->second->isa<FP32Imm>()) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not float";
|
||||||
|
}
|
||||||
|
|
||||||
|
return attr_iter->second->cast<FP32ImmPtr>()->value();
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<ValuePtr> GetValueSequeue(const ValuePtr &sequeue) {
|
std::vector<ValuePtr> GetValueSequeue(const ValuePtr &sequeue) {
|
||||||
MS_EXCEPTION_IF_NULL(sequeue);
|
MS_EXCEPTION_IF_NULL(sequeue);
|
||||||
std::vector<ValuePtr> ret;
|
std::vector<ValuePtr> ret;
|
||||||
|
|
|
@ -214,6 +214,11 @@ class OperatorInfo {
|
||||||
Status InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy,
|
Status InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy,
|
||||||
Shapes *inputs_slice_shape, Shapes *outputs_slice_shape);
|
Shapes *inputs_slice_shape, Shapes *outputs_slice_shape);
|
||||||
void BreakingTiesForPerferringDataParallel(const StrategyPtr &, const CostPtr &);
|
void BreakingTiesForPerferringDataParallel(const StrategyPtr &, const CostPtr &);
|
||||||
|
int64_t GetIntAttr(const std::string &attr_name);
|
||||||
|
bool GetBoolAttr(const std::string &attr_name);
|
||||||
|
float GetFloatAttr(const std::string &attr_name);
|
||||||
|
std::string GetStringAttr(const std::string &attr_name);
|
||||||
|
std::vector<int64_t> GetTupleIntAttr(const std::string &attr_name);
|
||||||
|
|
||||||
std::string name_;
|
std::string name_;
|
||||||
Shapes inputs_shape_;
|
Shapes inputs_shape_;
|
||||||
|
|
|
@ -56,5 +56,6 @@
|
||||||
#include "frontend/parallel/ops_info/scatter_update_info.h"
|
#include "frontend/parallel/ops_info/scatter_update_info.h"
|
||||||
#include "frontend/parallel/ops_info/virtual_output_info.h"
|
#include "frontend/parallel/ops_info/virtual_output_info.h"
|
||||||
#include "frontend/parallel/ops_info/conv2d_info.h"
|
#include "frontend/parallel/ops_info/conv2d_info.h"
|
||||||
|
#include "frontend/parallel/ops_info/batchnorm_info.h"
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_
|
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_
|
||||||
|
|
|
@ -197,6 +197,10 @@ constexpr char STRIDE[] = "stride";
|
||||||
constexpr char DILATION[] = "dilation";
|
constexpr char DILATION[] = "dilation";
|
||||||
constexpr char FORMAT[] = "format";
|
constexpr char FORMAT[] = "format";
|
||||||
constexpr char NCHW[] = "NCHW";
|
constexpr char NCHW[] = "NCHW";
|
||||||
|
constexpr char IS_TRAINING[] = "is_training";
|
||||||
|
constexpr char EPSILON[] = "epsilon";
|
||||||
|
constexpr char MOMENTUM[] = "momentum";
|
||||||
|
constexpr char DEVICE_NUM[] = "device_num";
|
||||||
|
|
||||||
// Operator
|
// Operator
|
||||||
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
|
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
|
||||||
|
@ -267,6 +271,7 @@ constexpr char CONV2D[] = "Conv2D";
|
||||||
constexpr char FUSE_BATCH_NORM[] = "FusedBatchNorm";
|
constexpr char FUSE_BATCH_NORM[] = "FusedBatchNorm";
|
||||||
constexpr char FUSE_BATCH_NORM_EX[] = "FusedBatchNormEx";
|
constexpr char FUSE_BATCH_NORM_EX[] = "FusedBatchNormEx";
|
||||||
constexpr char BATCH_NORM[] = "BatchNorm";
|
constexpr char BATCH_NORM[] = "BatchNorm";
|
||||||
|
constexpr char SYNC_BATCH_NORM[] = "SyncBatchNorm";
|
||||||
constexpr char LAYER_NORM[] = "LayerNorm";
|
constexpr char LAYER_NORM[] = "LayerNorm";
|
||||||
constexpr char POOLING[] = "Pooling";
|
constexpr char POOLING[] = "Pooling";
|
||||||
constexpr char CAST[] = "Cast";
|
constexpr char CAST[] = "Cast";
|
||||||
|
|
|
@ -787,9 +787,11 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
|
||||||
MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2";
|
MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2";
|
||||||
}
|
}
|
||||||
std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
|
std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
|
||||||
|
|
||||||
if (replace_op.first == EMBEDDING_LOOKUP) {
|
if (replace_op.first == EMBEDDING_LOOKUP) {
|
||||||
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)};
|
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!params.empty()) {
|
if (!params.empty()) {
|
||||||
Param param_first = *(params.begin());
|
Param param_first = *(params.begin());
|
||||||
int64_t first_position = param_first.second;
|
int64_t first_position = param_first.second;
|
||||||
|
@ -804,6 +806,10 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
|
||||||
int64_t position = param.second;
|
int64_t position = param.second;
|
||||||
(void)replace_input.insert(replace_input.begin() + position, val);
|
(void)replace_input.insert(replace_input.begin() + position, val);
|
||||||
}
|
}
|
||||||
|
} else if (replace_op.first == SYNC_BATCH_NORM) {
|
||||||
|
for (size_t i = 2; i < node->inputs().size(); ++i) {
|
||||||
|
replace_input.push_back(node->input(i));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return replace_input;
|
return replace_input;
|
||||||
|
|
|
@ -0,0 +1,125 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore as ms
|
||||||
|
from mindspore import context, Tensor, Parameter
|
||||||
|
from mindspore.common.api import _executor
|
||||||
|
from mindspore.nn import Cell, TrainOneStepCell, Momentum, BatchNorm2d, BatchNorm1d
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
|
class Net(Cell):
|
||||||
|
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
|
||||||
|
strategy1=None, strategy2=None):
|
||||||
|
super().__init__()
|
||||||
|
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
|
||||||
|
pad_mode=pad_mode, stride=stride).shard(strategy1)
|
||||||
|
self.conv2d_weight = Parameter(conv2d_weight, "w1")
|
||||||
|
self.bn = BatchNorm2d(8)
|
||||||
|
self.bn.bn_train.shard(strategy2)
|
||||||
|
|
||||||
|
def construct(self, x, b):
|
||||||
|
out = self.conv2d(x, self.conv2d_weight)
|
||||||
|
out = self.bn(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
|
||||||
|
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
|
||||||
|
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def compile_net(net):
|
||||||
|
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||||
|
train_net = TrainOneStepCell(net, optimizer)
|
||||||
|
train_net.set_auto_parallel()
|
||||||
|
train_net.set_train()
|
||||||
|
_executor.compile(train_net, _x, _b)
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
|
||||||
|
|
||||||
|
def test_batchnorm_data_parallel():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
|
||||||
|
strategy2 = ((8, 1, 1, 1), (1,), (1,), (1,), (1,))
|
||||||
|
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_batchnorm_model_parallel1():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||||
|
strategy2 = ((2, 1, 2, 2), (1,), (1,), (1,), (1,))
|
||||||
|
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_batchnorm_model_parallel2():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
|
||||||
|
strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1))
|
||||||
|
strategy2 = ((1, 8, 1, 1), (8,), (8,), (8,), (8,))
|
||||||
|
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
class Net2(Cell):
|
||||||
|
def __init__(self, strategy1=None, strategy2=None):
|
||||||
|
super().__init__()
|
||||||
|
self.bn = BatchNorm1d(8)
|
||||||
|
self.bn.bn_train.shard(strategy1)
|
||||||
|
self.relu = P.ReLU().shard(strategy2)
|
||||||
|
|
||||||
|
def construct(self, x, b):
|
||||||
|
out = self.bn(x)
|
||||||
|
out = self.relu(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
_x1 = Tensor(np.ones([32, 8]), dtype=ms.float32)
|
||||||
|
_b1 = Tensor(np.ones([32, 8]), dtype=ms.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def compile_net2(net):
|
||||||
|
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||||
|
train_net = TrainOneStepCell(net, optimizer)
|
||||||
|
train_net.set_auto_parallel()
|
||||||
|
train_net.set_train()
|
||||||
|
_executor.compile(train_net, _x1, _b1)
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
|
||||||
|
|
||||||
|
def test_batchnorm1d_data_parallel():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
strategy1 = ((8, 1), (1,), (1,), (1,), (1,))
|
||||||
|
strategy2 = ((8, 1),)
|
||||||
|
net = Net2(strategy1=strategy1, strategy2=strategy2)
|
||||||
|
compile_net2(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_batchnorm1d_model_parallel1():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
strategy1 = ((1, 8), (8,), (8,), (8,), (8,))
|
||||||
|
strategy2 = ((1, 8),)
|
||||||
|
net = Net2(strategy1=strategy1, strategy2=strategy2)
|
||||||
|
compile_net2(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_batchnorm1d_model_parallel2():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
|
||||||
|
strategy1 = ((2, 4), (4,), (4,), (4,), (4,))
|
||||||
|
strategy2 = ((2, 4),)
|
||||||
|
net = Net2(strategy1=strategy1, strategy2=strategy2)
|
||||||
|
compile_net2(net)
|
Loading…
Reference in New Issue