!17999 add parallel operator for conv2d

Merge pull request !17999 from yangzhenzhang/add-parallel-operator-for-conv2d
This commit is contained in:
i-robot 2021-06-10 20:05:45 +08:00 committed by Gitee
commit dc9720cbb2
14 changed files with 568 additions and 16 deletions

View File

@ -196,6 +196,7 @@ REGISTER(GatherNdInfo);
REGISTER(TopKInfo);
REGISTER(ScatterUpdateInfo);
REGISTER(VirtualOutputInfo);
REGISTER(Conv2DInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,381 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/ops_info/conv2d_info.h"
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "pipeline/jit/resource.h"
namespace mindspore {
namespace parallel {
int64_t Conv2DInfo::GetIntAttr(const std::string &attr_name) {
auto attr_iter = attrs_.find(attr_name);
if (attr_iter == attrs_.end()) {
MS_LOG(ERROR) << name_ << ": Can not find the attribution of " << attr_name;
return -1;
}
MS_EXCEPTION_IF_NULL(attr_iter->second);
if (!attr_iter->second->isa<Int64Imm>()) {
MS_LOG(ERROR) << name_ << ": The value of " << attr_name << " is not int";
return -1;
}
return attr_iter->second->cast<Int64ImmPtr>()->value();
}
std::string Conv2DInfo::GetStringAttr(const std::string &attr_name) {
std::string string_attr;
auto attr_iter = attrs_.find(attr_name);
if (attr_iter == attrs_.end()) {
MS_LOG(ERROR) << name_ << ": Can not find the attribution of " << attr_name;
return string_attr;
}
MS_EXCEPTION_IF_NULL(attr_iter->second);
if (!attr_iter->second->isa<StringImm>()) {
MS_LOG(ERROR) << name_ << ": The value of " << attr_name << " is not string";
return string_attr;
}
string_attr = attr_iter->second->cast<StringImmPtr>()->value();
return string_attr;
}
std::vector<int64_t> Conv2DInfo::GetTupleAttr(const std::string &attr_name) {
std::vector<int64_t> tuple_attr;
auto tuple_attr_iter = attrs_.find(attr_name);
if (tuple_attr_iter == attrs_.end()) {
MS_LOG(ERROR) << name_ << ": Can not find the attribution of " << attr_name;
return tuple_attr;
}
MS_EXCEPTION_IF_NULL(tuple_attr_iter->second);
tuple_attr = GetValue<std::vector<int64_t>>(tuple_attr_iter->second);
return tuple_attr;
}
Status Conv2DInfo::GetAttrs() {
// out_channel
out_channel_ = GetIntAttr(OUT_CHANNEL);
if (out_channel_ <= 0) {
MS_LOG(ERROR) << name_ << ": The attr of out_channel is invalid";
return FAILED;
}
// kernel_size
auto kernel_size_iter = attrs_.find(KERNEL_SIZE);
if (kernel_size_iter == attrs_.end()) {
MS_LOG(ERROR) << name_ << ": Can not find the attribution of " << KERNEL_SIZE;
return FAILED;
}
MS_EXCEPTION_IF_NULL(kernel_size_iter->second);
if (kernel_size_iter->second->isa<Int64Imm>()) {
int64_t kernel_size = kernel_size_iter->second->cast<Int64ImmPtr>()->value();
kernel_size_ = {kernel_size, kernel_size};
} else if (kernel_size_iter->second->isa<ValueTuple>() || kernel_size_iter->second->isa<ValueList>()) {
kernel_size_ = GetValue<std::vector<int64_t>>(kernel_size_iter->second);
if (kernel_size_.size() != 2) {
MS_LOG(ERROR) << name_ << ": The size of kernel_size'tuple must be 2, but got " << kernel_size_.size();
return FAILED;
}
} else {
MS_LOG(ERROR) << name_ << ": The kernel_size must be int or tuple";
return FAILED;
}
// mode
mode_ = GetIntAttr(MODE);
if (mode_ != 1) {
MS_LOG(ERROR) << name_ << ": The mode must be 1, but got " << mode_;
return FAILED;
}
// pad_mode
pad_mode_ = GetIntAttr(PAD_MODE);
if (pad_mode_ < 0 || pad_mode_ > 2) {
MS_LOG(ERROR) << name_ << ": The pad_mode must be in the range of [0, 2], but got " << pad_mode_;
return FAILED;
}
// pad_list
pad_list_ = GetTupleAttr(PAD_LIST);
if (pad_list_.size() != 4) {
MS_LOG(ERROR) << name_ << ": The size of pad_list must be 4, but got " << pad_list_.size();
return FAILED;
}
// stride
stride_ = GetTupleAttr(STRIDE);
if (stride_.size() != 4) {
MS_LOG(ERROR) << name_ << ": The size of stride must be 4, but got " << stride_.size();
return FAILED;
}
if (stride_[0] != 1 || stride_[1] != 1) {
MS_LOG(ERROR) << name_ << ": The first two elements of stride must be 1, but got (" << stride_[0] << ", "
<< stride_[1] << ")";
return FAILED;
}
// dilation
dilation_ = GetTupleAttr(DILATION);
if (dilation_.size() != 4) {
MS_LOG(ERROR) << name_ << ": The size of dilation must be 4, but got " << dilation_.size();
return FAILED;
}
// group
group_ = GetIntAttr(GROUP);
if (group_ != 1) {
MS_LOG(ERROR) << name_ << ": The group must be 1, but got " << group_;
return FAILED;
}
// format
format_ = GetStringAttr(FORMAT);
if (format_ != NCHW) {
MS_LOG(ERROR) << name_ << ": The format must be 'NCHW', but got " << format_;
return FAILED;
}
MS_LOG(INFO) << name_ << ": The out channel is " << out_channel_ << ", kernel size is " << kernel_size_
<< ", mode is " << mode_ << ", pad mode is " << pad_mode_ << ", pad list is " << pad_list_
<< ", stride is " << stride_ << ", dilation is " << dilation_ << ", group is " << group_
<< ", format is " << format_;
return SUCCESS;
}
Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
if (pad_mode_ == 0) { // 'pad' mode
MS_LOG(ERROR) << name_ << ": The 'pad' mode do not support to split H or W";
return FAILED;
}
if (pad_mode_ == 1) { // 'same' mode
if ((kernel_size_[0] > stride_[2] || kernel_size_[1] > stride_[3]) && h_strategy > 1) {
MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split H when kernel_size > stride";
return FAILED;
}
if (kernel_size_[0] <= stride_[2] || kernel_size_[1] <= stride_[3]) {
int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy;
int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy;
if (h_slice_shape % stride_[2] != 0 || w_slice_shape % stride_[3] != 0) {
MS_LOG(ERROR) << name_
<< ": The 'same' mode do not support to split H or W when kernel_size <= stride but slice shape "
"is not divisible by stride ";
return FAILED;
}
}
}
if (pad_mode_ == 2) { // 'valid' mode
if ((kernel_size_[0] > stride_[2] && h_strategy > 1) || (kernel_size_[1] > stride_[3] && w_strategy > 1)) {
MS_LOG(ERROR) << name_ << ": The 'valid' mode do not support to split H or W when kernel_size > stride";
return FAILED;
}
if (kernel_size_[0] <= stride_[2]) {
int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy;
if (h_slice_shape % stride_[2] != 0) {
MS_LOG(ERROR) << name_
<< ": The 'valid' mode do not support to split H when kernel_size <= stride but slice shape is "
"not divisible by stride ";
return FAILED;
}
}
if (kernel_size_[1] <= stride_[3]) {
int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy;
if (w_slice_shape % stride_[3] != 0) {
MS_LOG(ERROR) << name_
<< ": The 'valid' mode do not support to split W when kernel_size <= stride but slice shape is "
"not divisible by stride ";
return FAILED;
}
}
}
return SUCCESS;
}
Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy);
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy";
return FAILED;
}
std::vector<Dimensions> stra = strategy->GetInputDim();
if (stra.size() != 2) {
MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << stra.size();
return FAILED;
}
Dimensions input_strategy = stra[0];
Dimensions weight_strategy = stra[1];
if (input_strategy.size() != 4 || weight_strategy.size() != 4) {
MS_LOG(ERROR) << name_
<< ": The size of input strategy or weight strategy must be 4, but the size of input strategy is "
<< input_strategy.size() << ", the size of weight strategy is " << weight_strategy.size();
return FAILED;
}
if (input_strategy[1] != weight_strategy[1]) {
MS_LOG(ERROR) << name_ << ": The shard num of c-in for input strategy is " << input_strategy[1]
<< ", but the shard num of c-in for weight strategy is " << weight_strategy[1];
return FAILED;
}
if (weight_strategy[2] != 1 || weight_strategy[3] != 1) {
MS_LOG(ERROR) << name_ << ": The kernel size can not be split, but the strategy for kernel size is ("
<< weight_strategy[2] << ", " << weight_strategy[3] << ")";
return FAILED;
}
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
if (CheckHWStrategy(input_strategy[2], input_strategy[3]) != SUCCESS) {
return FAILED;
}
}
if (weight_strategy[0] > 1) {
out_channel_shard_ = true;
new_out_channel_ = out_channel_ / weight_strategy[1];
} else {
out_channel_shard_ = false;
}
return SUCCESS;
}
Status Conv2DInfo::InferDevMatrixShape() {
// the strategy is ((n, i, h, w), (o, i, 1, 1))
// the dev matrix is (n, i, h, w, o)
MS_EXCEPTION_IF_NULL(strategy_);
std::vector<Dimensions> stra = strategy_->GetInputDim();
if (stra.size() != 2) {
MS_LOG(ERROR) << name_ << "The size of strategy must be 2, but got " << stra.size();
return FAILED;
}
dev_matrix_shape_ = stra[0];
dev_matrix_shape_.push_back(stra[1][0]);
return SUCCESS;
}
Status Conv2DInfo::InferTensorMap() {
// input_strategy: ((n, i, h, w), (o, i, 1, 1))
// output_strategy: ((n, o, h, w),)
// dev_matrix: (n, i, h, w, o)
TensorMap input_tensor_map = {4, 3, 2, 1};
TensorMap weight_tensor_map = {0, 3, -1, -1};
TensorMap output_tensor_map = {4, 0, 2, 1};
(void)inputs_tensor_map_.emplace_back(std::move(input_tensor_map));
(void)inputs_tensor_map_.emplace_back(std::move(weight_tensor_map));
(void)outputs_tensor_map_.emplace_back(std::move(output_tensor_map));
return SUCCESS;
}
// if in channel is split, it need to insert all reduce
Status Conv2DInfo::InferForwardCommunication() {
forward_op_.clear();
size_t relevant_dim_index = IN_CHANNEL_INDEX;
if (repeated_calc_num_ > 1 && !repeated_num_in_dev_matrix_right_) {
// if repeated calculation and repeated num in the left of dev matrix, the index of relevant dimension should add 1
relevant_dim_index += 1;
}
if (dev_matrix_shape_[relevant_dim_index] == MIN_SLICE_NUM) {
MS_LOG(INFO) << name_ << ": Forward all reduce is not required";
return SUCCESS;
}
std::vector<Group> group_list;
if (CreateGroupByDim(relevant_dim_index, &group_list) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Create group failed";
return FAILED;
}
if (group_list.empty()) {
MS_LOG(INFO) << name_ << ": Forward all reduce is not required";
return SUCCESS;
}
Operator op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name());
forward_op_.push_back(op);
MS_LOG(INFO) << name_ << ": The group name of forward all reduce is " << group_list[0].name();
return SUCCESS;
}
ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) {
if (!out_channel_shard_) {
return nullptr;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
prim->set_attr(OUT_CHANNEL, MakeValue(new_out_channel_));
return nullptr;
}
void Conv2DInfo::ReComputeBatchSplitFlagList() {
split_flag_list_[0] = true;
split_flag_list_[1] = false;
}
Status Conv2DInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
std::vector<StrategyPtr> Conv2DInfo::GenerateOpStrategies(int64_t stage_id) {
Strategys strategy = {{stage_device_size_, 1, 1, 1}, {1, 1, 1, 1}};
StrategyPtr sp = std::make_shared<Strategy>(stage_id, strategy);
std::vector<StrategyPtr> sp_vector;
sp_vector.push_back(sp);
return sp_vector;
}
Status Conv2DInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}
Status Conv2DInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,75 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONV2D_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONV2D_INFO_H_
#include <string>
#include <memory>
#include <unordered_map>
#include <vector>
#include "ir/value.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
namespace mindspore {
namespace parallel {
class Conv2DInfo : public OperatorInfo {
public:
Conv2DInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()) {}
~Conv2DInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
void ReComputeBatchSplitFlagList() override;
protected:
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
Status InferForwardCommunication() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
int64_t GetIntAttr(const std::string &attr_name);
std::string GetStringAttr(const std::string &attr_name);
std::vector<int64_t> GetTupleAttr(const std::string &attr_name);
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
private:
int64_t out_channel_ = 1;
std::vector<int64_t> kernel_size_; // two integers
int64_t mode_ = 1;
int64_t pad_mode_ = 0; // "pad": 0; "same": 1; "valid": 2;
std::vector<int64_t> pad_list_; // four integers
std::vector<int64_t> stride_; // four integers
std::vector<int64_t> dilation_; // four integers
int64_t group_ = 1;
std::string format_;
bool out_channel_shard_ = false;
int64_t new_out_channel_ = 1;
};
constexpr size_t IN_CHANNEL_INDEX = 1;
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONV2D_INFO_H_

View File

@ -1693,5 +1693,21 @@ Status OperatorInfo::GenerateStrategies(int64_t stage_id) {
}
return SUCCESS;
}
std::vector<ValuePtr> GetValueSequeue(const ValuePtr &sequeue) {
MS_EXCEPTION_IF_NULL(sequeue);
std::vector<ValuePtr> ret;
if (!sequeue->isa<ValueTuple>() && !sequeue->isa<ValueList>()) {
MS_LOG(ERROR) << "The arg is not value tuple or value list";
return ret;
}
if (sequeue->isa<ValueTuple>()) {
auto val_tuple = sequeue->cast<ValueTuplePtr>();
return val_tuple->value();
}
auto val = sequeue->cast<ValueListPtr>();
return val->value();
}
} // namespace parallel
} // namespace mindspore

View File

@ -313,6 +313,7 @@ Status GenerateStrategiesWithBroadcast(int64_t stage_id, const Shapes &inputs_sh
std::vector<StrategyPtr> *sp_vector);
Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph);
std::vector<ValuePtr> GetValueSequeue(const ValuePtr &sequeue);
} // namespace parallel
} // namespace mindspore

View File

@ -55,5 +55,6 @@
#include "frontend/parallel/ops_info/topk_info.h"
#include "frontend/parallel/ops_info/scatter_update_info.h"
#include "frontend/parallel/ops_info/virtual_output_info.h"
#include "frontend/parallel/ops_info/conv2d_info.h"
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_

View File

@ -188,6 +188,16 @@ constexpr char DEVICE[] = "Device";
constexpr char PARALLEL_OPTIMIZER_ALLGATHER[] = "parallel_optimizer_allgather";
constexpr char CELLLIST_KEYWORD_PATTERN[] = "-CellList/(\\d+)-";
constexpr char OUT_CHANNEL[] = "out_channel";
constexpr char KERNEL_SIZE[] = "kernel_size";
constexpr char MODE[] = "mode";
constexpr char PAD_MODE[] = "pad_mode";
constexpr char PAD_LIST[] = "pad_list";
constexpr char STRIDE[] = "stride";
constexpr char DILATION[] = "dilation";
constexpr char FORMAT[] = "format";
constexpr char NCHW[] = "NCHW";
// Operator
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
constexpr char GET_TENSOR_SLICE[] = "_GetTensorSlice";

View File

@ -43,22 +43,6 @@ static std::string AxesToString(const std::vector<int32_t> &shape) {
return str + "]";
}
static std::vector<ValuePtr> GetValueSequeue(const ValuePtr &sequeue) {
MS_EXCEPTION_IF_NULL(sequeue);
std::vector<ValuePtr> ret;
if (!sequeue->isa<ValueTuple>() && !sequeue->isa<ValueList>()) {
MS_LOG(ERROR) << "The arg is not value tuple or value list";
return ret;
}
if (sequeue->isa<ValueTuple>()) {
auto val_tuple = sequeue->cast<ValueTuplePtr>();
return val_tuple->value();
}
auto val = sequeue->cast<ValueListPtr>();
return val->value();
}
void TensorDotInfo::ShowAxes() {
if (axes_tuple_.size()) {
MS_LOG(INFO) << name_ << ": The axes tuple is " << AxesToString(axes_tuple_);

View File

@ -195,6 +195,7 @@ class ParallelStrategySearchNet(Cell):
kernel_size=5, has_bias=True,
weight_init='ones', bias_init='ones',
pad_mode='valid')
self.conv.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
self.scalar = 0.5
self.parameter = Parameter(
initializer(0.5, test_size, dtype=mstype.float32),

View File

@ -81,12 +81,15 @@ class ResidualBlock(nn.Cell):
out_chls = out_channels // self.expansion
self.conv1 = _conv1x1(in_channels, out_chls, stride=1)
self.conv1.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
self.bn1 = _fused_bn(out_chls, momentum=momentum)
self.conv2 = _conv3x3(out_chls, out_chls, stride=stride)
self.conv2.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
self.bn2 = _fused_bn(out_chls, momentum=momentum)
self.conv3 = _conv1x1(out_chls, out_channels, stride=1)
self.conv3.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
self.bn3 = _fused_bn(out_channels, momentum=momentum)
self.relu = P.ReLU()
@ -95,6 +98,7 @@ class ResidualBlock(nn.Cell):
if self.downsample:
self.conv_down_sample = _conv1x1(in_channels, out_channels,
stride=stride)
self.conv_down_sample.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
self.bn_down_sample = _fused_bn(out_channels, momentum=momentum)
elif self.stride != 1:
self.maxpool_down = nn.MaxPool2d(kernel_size=1, stride=2, pad_mode='same')
@ -143,6 +147,7 @@ class ResNet(nn.Cell):
"layer_num, inchannel, outchannel list must be 4!")
self.conv1 = _conv7x7(3, 64, stride=2)
self.conv1.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
self.bn1 = _fused_bn(64)
self.relu = P.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')

View File

@ -70,6 +70,7 @@ class Net(nn.Cell):
super().__init__()
self.conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=1, stride=1, pad_mode='valid',
has_bias=True, weight_init='ones', bias_init='ones')
self.conv.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
self.reduce_mean = P.ReduceMean(keep_dims=False).shard(((1, 1, 1, 8),))
self.flat = nn.Flatten()

View File

@ -0,0 +1,74 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _executor
from mindspore.nn import Cell, TrainOneStepCell, Momentum
from mindspore.ops import operations as P
class Net(Cell):
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
strategy1=None, strategy2=None):
super().__init__()
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
pad_mode=pad_mode, stride=stride).shard(strategy1)
self.neg = P.Neg().shard(strategy2)
self.conv2d_weight = Parameter(conv2d_weight, "w1")
def construct(self, x, b):
out = self.conv2d(x, self.conv2d_weight)
out = self.neg(out)
return out
_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
def compile_net(net):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_executor.compile(train_net, _x, _b)
context.reset_auto_parallel_context()
def test_conv2d_data_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_conv2d_model_parallel1():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((8, 1, 1, 1),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_conv2d_model_parallel2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1))
strategy2 = ((32, 1, 1, 1),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
compile_net(net)

View File

@ -535,6 +535,7 @@ class ParallelReduceMeanNet(nn.Cell):
self.conv = nn.Conv2d(in_channels=conv_in_channel, out_channels=conv_out_channel,
kernel_size=1, stride=1, pad_mode='valid', has_bias=True,
weight_init='ones', bias_init='ones')
self.conv.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
self.reduce_mean = P.ReduceMean(keep_dims=reducemean_keep_dims)
self.flat = nn.Flatten()
self.reducemean_axis = reducemean_axis

View File

@ -223,6 +223,7 @@ def test_reshape_unexpand_7():
self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=5, has_bias=True, weight_init='ones',
bias_init='ones', pad_mode='valid')
self.conv.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
self.softmax = nn.Softmax(axis=axis)
self.relu = nn.ReLU()
self.reshape = P.Reshape()