add parallel op for batchnorm

This commit is contained in:
yangzhenzhang 2021-06-16 14:57:56 +08:00
parent a0c5b56f5f
commit af0d28de48
11 changed files with 556 additions and 55 deletions

View File

@ -197,6 +197,7 @@ REGISTER(TopKInfo);
REGISTER(ScatterUpdateInfo);
REGISTER(VirtualOutputInfo);
REGISTER(Conv2DInfo);
REGISTER(BatchNormInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,274 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/ops_info/batchnorm_info.h"
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "pipeline/jit/resource.h"
namespace mindspore {
namespace parallel {
Status BatchNormInfo::GetAttrs() {
is_training_ = GetBoolAttr(IS_TRAINING);
epsilon_ = GetFloatAttr(EPSILON);
momentum_ = GetFloatAttr(MOMENTUM);
format_ = GetStringAttr(FORMAT);
if (format_ != NCHW) {
MS_LOG(ERROR) << name_ << ": The data format must be 'NCHW', but got " << format_;
return FAILED;
}
if (inputs_shape_.empty()) {
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
return FAILED;
}
if (inputs_shape_[0].size() == 2) {
input_is_4d_ = false;
} else if (inputs_shape_[0].size() == 4) {
input_is_4d_ = true;
} else {
MS_LOG(ERROR) << name_ << ": The size of input[0]'shape must be 2 or 4, but got " << inputs_shape_[0].size();
return FAILED;
}
MS_LOG(INFO) << name_ << ": The is_traing is " << is_training_ << ", epsilon is " << epsilon_ << ", momentum is "
<< momentum_ << ", data format is " << format_;
return SUCCESS;
}
Status BatchNormInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy);
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy";
return FAILED;
}
std::vector<Dimensions> stra = strategy->GetInputDim();
if (stra.size() != 5) {
MS_LOG(ERROR) << name_ << ": The size of strategy must be 5, but got " << stra.size();
return FAILED;
}
if ((stra[0].size() != 4) && (stra[0].size() != 2)) {
MS_LOG(ERROR) << name_ << ": The size of strategy[0] must be 4 or 2, but got " << stra[0].size();
return FAILED;
}
for (size_t i = 1; i < 5; ++i) {
if (stra[i].empty()) {
MS_LOG(ERROR) << name_ << ": The strategy can not be empty, the index is " << i;
return FAILED;
}
if (stra[0][1] != stra[i][0]) {
MS_LOG(ERROR) << name_ << ": Invalid strategy, the index is " << i << ", it must be equal to " << stra[0][1]
<< ", but got " << stra[i][0];
return FAILED;
}
}
return SUCCESS;
}
Status BatchNormInfo::InferDevMatrixShape() {
MS_EXCEPTION_IF_NULL(strategy_);
std::vector<Dimensions> stra = strategy_->GetInputDim();
if (stra.empty()) {
MS_LOG(ERROR) << name_ << ": The strategy can not be empty";
return FAILED;
}
dev_matrix_shape_ = stra[0];
return SUCCESS;
}
Status BatchNormInfo::InferTensorMap() {
TensorMap input_tensor_map;
TensorMap in_other_tensor_map;
if (input_is_4d_) {
// if input is 4d:
// input_strategy: ((n, c, h, w), (c), (c), (c), (c))
// output_strategy: ((n, c, h, w), (c), (c), (c), (c))
// dev_matrix: (n, c, h, w)
input_tensor_map = {3, 2, 1, 0};
in_other_tensor_map = {2};
} else {
// if input is 2d:
// input_strategy: ((n, c), (c), (c), (c), (c))
// output_strategy: ((n, c), (c), (c), (c), (c))
// dev_matrix: (n, c)
input_tensor_map = {1, 0};
in_other_tensor_map = {0};
}
inputs_tensor_map_.push_back(input_tensor_map); // input
inputs_tensor_map_.push_back(in_other_tensor_map); // scale
inputs_tensor_map_.push_back(in_other_tensor_map); // bias
inputs_tensor_map_.push_back(in_other_tensor_map); // mean
inputs_tensor_map_.push_back(in_other_tensor_map); // variance
outputs_tensor_map_ = inputs_tensor_map_;
return SUCCESS;
}
Status BatchNormInfo::InferForwardCommunication() {
// if it is not training, no need forward allreduce
if (!is_training_) {
MS_LOG(INFO) << name_ << ": It is not training, no need forward allreduce";
return SUCCESS;
}
TensorMap tmp_map;
if (input_is_4d_) {
// input is 4d:
// if has not repeated calculation, the dev matirx is [n, c, h, w]
// if repeated calculation and repeated num in the left of dev matrix, the dev matrix is [repeated_num, n, c, h, w]
// if repeated calculation and repeated num in the right of dev matrix, the dev matrix is [n, c, h, w, repeated_num]
// and the forward allreduce need to use the dimensions of n/h/w
if (repeated_calc_num_ == 1) {
// has not repeated calculation
tmp_map = {-1, 2, -1, -1};
} else if (!repeated_num_in_dev_matrix_right_) {
// repeated calculation and repeated num in the left of dev matrix
tmp_map = {4, -1, 2, -1, -1};
} else {
// repeated calculation and repeated num in the right of dev matrix
tmp_map = {-1, 3, -1, -1, 0};
}
} else {
// input is 2d:
// if has not repeated calculation, the dev matirx is [n, c]
// if repeated calculation and repeated num in the left of dev matrix, the dev matrix is [repeated_num, n, c]
// if repeated calculation and repeated num in the right of dev matrix, the dev matrix is [n, c, repeated_num]
// and the forward allreduce need to use the dimensions of n
if (repeated_calc_num_ == 1) {
// has not repeated calculation
tmp_map = {-1, 0};
} else if (!repeated_num_in_dev_matrix_right_) {
// repeated calculation and repeated num in the left of dev matrix
tmp_map = {2, -1, 0};
} else {
// repeated calculation and repeated num in the right of dev matrix
tmp_map = {-1, 1, 0};
}
}
std::vector<Group> group_list;
if (CreateGroupByTensorMap(tmp_map, &group_list) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Create group failed";
return FAILED;
}
if (group_list.empty()) {
MS_LOG(INFO) << name_ << ": Forward all reduce is not required";
return SUCCESS;
} else {
MS_LOG(INFO) << name_ << ": The group name of forward all reduce is " << group_list[0].name();
}
forward_allreduce_group_ = group_list;
return SUCCESS;
}
Status BatchNormInfo::InferReplaceOps() {
if (!is_training_) {
MS_LOG(INFO) << name_ << ": It is not training, no need to replace op";
return SUCCESS;
}
if (forward_allreduce_group_.empty()) {
MS_LOG(INFO) << name_ << ": The forward allreduce group is empty, no need to replace op";
return SUCCESS;
}
ValuePtr epsilon = MakeValue(epsilon_);
ValuePtr momentum = MakeValue(momentum_);
ValuePtr group = MakeValue(forward_allreduce_group_[0].name());
ValuePtr device_num = MakeValue(forward_allreduce_group_[0].GetDevNum());
Attr attr_epsilon = std::make_pair(EPSILON, epsilon);
Attr attr_momentum = std::make_pair(MOMENTUM, momentum);
Attr attr_group = std::make_pair(GROUP, group);
Attr attr_device_num = std::make_pair(DEVICE_NUM, device_num);
OperatorAttrs attrs = {attr_epsilon, attr_momentum, attr_group, attr_device_num};
OperatorParams params;
OperatorArgs args = std::make_pair(attrs, params);
replace_op_ = {std::make_pair(SYNC_BATCH_NORM, args)};
return SUCCESS;
}
Status BatchNormInfo::InferAsLossDivisor() {
if (outputs_tensor_map_.size() != 5) {
MS_LOG(ERROR) << name_ << ": The size of outputs tensor map must be 5, but got " << outputs_tensor_map_.size();
return FAILED;
}
as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]);
MS_LOG(INFO) << name_ << " : The dev matrix shape is " << ShapeToString(dev_matrix_shape_)
<< ", the output[0]'s tensor map is " << ShapeToString(outputs_tensor_map_[0])
<< ", as_loss_divisor_ is " << as_loss_divisor_;
return SUCCESS;
}
Status BatchNormInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
std::vector<StrategyPtr> BatchNormInfo::GenerateOpStrategies(int64_t stage_id) {
Strategys strategy;
if (input_is_4d_) {
strategy = {{stage_device_size_, 1, 1, 1}, {1}, {1}, {1}, {1}};
} else {
strategy = {{stage_device_size_, 1}, {1}, {1}, {1}, {1}};
}
StrategyPtr sp = std::make_shared<Strategy>(stage_id, strategy);
std::vector<StrategyPtr> sp_vector;
sp_vector.push_back(sp);
return sp_vector;
}
Status BatchNormInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
(void)InferReplaceOps();
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}
Status BatchNormInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,64 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_BATCHNORM_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_BATCHNORM_INFO_H_
#include <string>
#include <memory>
#include <unordered_map>
#include <vector>
#include "ir/value.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
namespace mindspore {
namespace parallel {
class BatchNormInfo : public OperatorInfo {
public:
BatchNormInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()) {}
~BatchNormInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
protected:
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferReplaceOps();
Status InferAsLossDivisor() override;
private:
bool is_training_ = false;
float epsilon_ = 0.00001;
float momentum_ = 0.1;
bool input_is_4d_ = true;
std::string format_;
std::vector<Group> forward_allreduce_group_;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_BATCHNORM_INFO_H_

View File

@ -28,54 +28,6 @@
namespace mindspore {
namespace parallel {
int64_t Conv2DInfo::GetIntAttr(const std::string &attr_name) {
auto attr_iter = attrs_.find(attr_name);
if (attr_iter == attrs_.end()) {
MS_LOG(ERROR) << name_ << ": Can not find the attribution of " << attr_name;
return -1;
}
MS_EXCEPTION_IF_NULL(attr_iter->second);
if (!attr_iter->second->isa<Int64Imm>()) {
MS_LOG(ERROR) << name_ << ": The value of " << attr_name << " is not int";
return -1;
}
return attr_iter->second->cast<Int64ImmPtr>()->value();
}
std::string Conv2DInfo::GetStringAttr(const std::string &attr_name) {
std::string string_attr;
auto attr_iter = attrs_.find(attr_name);
if (attr_iter == attrs_.end()) {
MS_LOG(ERROR) << name_ << ": Can not find the attribution of " << attr_name;
return string_attr;
}
MS_EXCEPTION_IF_NULL(attr_iter->second);
if (!attr_iter->second->isa<StringImm>()) {
MS_LOG(ERROR) << name_ << ": The value of " << attr_name << " is not string";
return string_attr;
}
string_attr = attr_iter->second->cast<StringImmPtr>()->value();
return string_attr;
}
std::vector<int64_t> Conv2DInfo::GetTupleAttr(const std::string &attr_name) {
std::vector<int64_t> tuple_attr;
auto tuple_attr_iter = attrs_.find(attr_name);
if (tuple_attr_iter == attrs_.end()) {
MS_LOG(ERROR) << name_ << ": Can not find the attribution of " << attr_name;
return tuple_attr;
}
MS_EXCEPTION_IF_NULL(tuple_attr_iter->second);
tuple_attr = GetValue<std::vector<int64_t>>(tuple_attr_iter->second);
return tuple_attr;
}
Status Conv2DInfo::GetAttrs() {
// out_channel
out_channel_ = GetIntAttr(OUT_CHANNEL);
@ -121,14 +73,14 @@ Status Conv2DInfo::GetAttrs() {
}
// pad_list
pad_list_ = GetTupleAttr(PAD_LIST);
pad_list_ = GetTupleIntAttr(PAD_LIST);
if (pad_list_.size() != 4) {
MS_LOG(ERROR) << name_ << ": The size of pad_list must be 4, but got " << pad_list_.size();
return FAILED;
}
// stride
stride_ = GetTupleAttr(STRIDE);
stride_ = GetTupleIntAttr(STRIDE);
if (stride_.size() != 4) {
MS_LOG(ERROR) << name_ << ": The size of stride must be 4, but got " << stride_.size();
return FAILED;
@ -141,7 +93,7 @@ Status Conv2DInfo::GetAttrs() {
}
// dilation
dilation_ = GetTupleAttr(DILATION);
dilation_ = GetTupleIntAttr(DILATION);
if (dilation_.size() != 4) {
MS_LOG(ERROR) << name_ << ": The size of dilation must be 4, but got " << dilation_.size();
return FAILED;
@ -279,7 +231,7 @@ Status Conv2DInfo::InferDevMatrixShape() {
MS_EXCEPTION_IF_NULL(strategy_);
std::vector<Dimensions> stra = strategy_->GetInputDim();
if (stra.size() != 2) {
MS_LOG(ERROR) << name_ << "The size of strategy must be 2, but got " << stra.size();
MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << stra.size();
return FAILED;
}

View File

@ -49,9 +49,6 @@ class Conv2DInfo : public OperatorInfo {
Status InferForwardCommunication() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
int64_t GetIntAttr(const std::string &attr_name);
std::string GetStringAttr(const std::string &attr_name);
std::vector<int64_t> GetTupleAttr(const std::string &attr_name);
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
private:

View File

@ -1714,6 +1714,77 @@ Status OperatorInfo::GenerateStrategies(int64_t stage_id) {
return SUCCESS;
}
int64_t OperatorInfo::GetIntAttr(const std::string &attr_name) {
auto attr_iter = attrs_.find(attr_name);
if (attr_iter == attrs_.end()) {
MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
}
MS_EXCEPTION_IF_NULL(attr_iter->second);
if (!attr_iter->second->isa<Int64Imm>()) {
MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not int";
}
return attr_iter->second->cast<Int64ImmPtr>()->value();
}
bool OperatorInfo::GetBoolAttr(const std::string &attr_name) {
auto attr_iter = attrs_.find(attr_name);
if (attr_iter == attrs_.end()) {
MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
}
MS_EXCEPTION_IF_NULL(attr_iter->second);
if (!attr_iter->second->isa<BoolImm>()) {
MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not int";
}
return attr_iter->second->cast<BoolImmPtr>()->value();
}
std::string OperatorInfo::GetStringAttr(const std::string &attr_name) {
std::string string_attr;
auto attr_iter = attrs_.find(attr_name);
if (attr_iter == attrs_.end()) {
MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
}
MS_EXCEPTION_IF_NULL(attr_iter->second);
if (!attr_iter->second->isa<StringImm>()) {
MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not string";
}
string_attr = attr_iter->second->cast<StringImmPtr>()->value();
return string_attr;
}
std::vector<int64_t> OperatorInfo::GetTupleIntAttr(const std::string &attr_name) {
std::vector<int64_t> tuple_attr;
auto tuple_attr_iter = attrs_.find(attr_name);
if (tuple_attr_iter == attrs_.end()) {
MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
}
MS_EXCEPTION_IF_NULL(tuple_attr_iter->second);
tuple_attr = GetValue<std::vector<int64_t>>(tuple_attr_iter->second);
return tuple_attr;
}
float OperatorInfo::GetFloatAttr(const std::string &attr_name) {
auto attr_iter = attrs_.find(attr_name);
if (attr_iter == attrs_.end()) {
MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
}
MS_EXCEPTION_IF_NULL(attr_iter->second);
if (!attr_iter->second->isa<FP32Imm>()) {
MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not float";
}
return attr_iter->second->cast<FP32ImmPtr>()->value();
}
std::vector<ValuePtr> GetValueSequeue(const ValuePtr &sequeue) {
MS_EXCEPTION_IF_NULL(sequeue);
std::vector<ValuePtr> ret;

View File

@ -214,6 +214,11 @@ class OperatorInfo {
Status InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy,
Shapes *inputs_slice_shape, Shapes *outputs_slice_shape);
void BreakingTiesForPerferringDataParallel(const StrategyPtr &, const CostPtr &);
int64_t GetIntAttr(const std::string &attr_name);
bool GetBoolAttr(const std::string &attr_name);
float GetFloatAttr(const std::string &attr_name);
std::string GetStringAttr(const std::string &attr_name);
std::vector<int64_t> GetTupleIntAttr(const std::string &attr_name);
std::string name_;
Shapes inputs_shape_;

View File

@ -56,5 +56,6 @@
#include "frontend/parallel/ops_info/scatter_update_info.h"
#include "frontend/parallel/ops_info/virtual_output_info.h"
#include "frontend/parallel/ops_info/conv2d_info.h"
#include "frontend/parallel/ops_info/batchnorm_info.h"
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_

View File

@ -197,6 +197,10 @@ constexpr char STRIDE[] = "stride";
constexpr char DILATION[] = "dilation";
constexpr char FORMAT[] = "format";
constexpr char NCHW[] = "NCHW";
constexpr char IS_TRAINING[] = "is_training";
constexpr char EPSILON[] = "epsilon";
constexpr char MOMENTUM[] = "momentum";
constexpr char DEVICE_NUM[] = "device_num";
// Operator
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
@ -267,6 +271,7 @@ constexpr char CONV2D[] = "Conv2D";
constexpr char FUSE_BATCH_NORM[] = "FusedBatchNorm";
constexpr char FUSE_BATCH_NORM_EX[] = "FusedBatchNormEx";
constexpr char BATCH_NORM[] = "BatchNorm";
constexpr char SYNC_BATCH_NORM[] = "SyncBatchNorm";
constexpr char LAYER_NORM[] = "LayerNorm";
constexpr char POOLING[] = "Pooling";
constexpr char CAST[] = "Cast";

View File

@ -787,9 +787,11 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2";
}
std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
if (replace_op.first == EMBEDDING_LOOKUP) {
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)};
}
if (!params.empty()) {
Param param_first = *(params.begin());
int64_t first_position = param_first.second;
@ -804,6 +806,10 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
int64_t position = param.second;
(void)replace_input.insert(replace_input.begin() + position, val);
}
} else if (replace_op.first == SYNC_BATCH_NORM) {
for (size_t i = 2; i < node->inputs().size(); ++i) {
replace_input.push_back(node->input(i));
}
}
return replace_input;

View File

@ -0,0 +1,125 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _executor
from mindspore.nn import Cell, TrainOneStepCell, Momentum, BatchNorm2d, BatchNorm1d
from mindspore.ops import operations as P
class Net(Cell):
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
strategy1=None, strategy2=None):
super().__init__()
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
pad_mode=pad_mode, stride=stride).shard(strategy1)
self.conv2d_weight = Parameter(conv2d_weight, "w1")
self.bn = BatchNorm2d(8)
self.bn.bn_train.shard(strategy2)
def construct(self, x, b):
out = self.conv2d(x, self.conv2d_weight)
out = self.bn(out)
return out
_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
def compile_net(net):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_executor.compile(train_net, _x, _b)
context.reset_auto_parallel_context()
def test_batchnorm_data_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1), (1,), (1,), (1,), (1,))
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_batchnorm_model_parallel1():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((2, 1, 2, 2), (1,), (1,), (1,), (1,))
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_batchnorm_model_parallel2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1))
strategy2 = ((1, 8, 1, 1), (8,), (8,), (8,), (8,))
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
compile_net(net)
class Net2(Cell):
def __init__(self, strategy1=None, strategy2=None):
super().__init__()
self.bn = BatchNorm1d(8)
self.bn.bn_train.shard(strategy1)
self.relu = P.ReLU().shard(strategy2)
def construct(self, x, b):
out = self.bn(x)
out = self.relu(out)
return out
_x1 = Tensor(np.ones([32, 8]), dtype=ms.float32)
_b1 = Tensor(np.ones([32, 8]), dtype=ms.float32)
def compile_net2(net):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_executor.compile(train_net, _x1, _b1)
context.reset_auto_parallel_context()
def test_batchnorm1d_data_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1), (1,), (1,), (1,), (1,))
strategy2 = ((8, 1),)
net = Net2(strategy1=strategy1, strategy2=strategy2)
compile_net2(net)
def test_batchnorm1d_model_parallel1():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 8), (8,), (8,), (8,), (8,))
strategy2 = ((1, 8),)
net = Net2(strategy1=strategy1, strategy2=strategy2)
compile_net2(net)
def test_batchnorm1d_model_parallel2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
strategy1 = ((2, 4), (4,), (4,), (4,), (4,))
strategy2 = ((2, 4),)
net = Net2(strategy1=strategy1, strategy2=strategy2)
compile_net2(net)