forked from mindspore-Ecosystem/mindspore
!17999 add parallel operator for conv2d
Merge pull request !17999 from yangzhenzhang/add-parallel-operator-for-conv2d
This commit is contained in:
commit
dc9720cbb2
|
@ -196,6 +196,7 @@ REGISTER(GatherNdInfo);
|
|||
REGISTER(TopKInfo);
|
||||
REGISTER(ScatterUpdateInfo);
|
||||
REGISTER(VirtualOutputInfo);
|
||||
REGISTER(Conv2DInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -0,0 +1,381 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/ops_info/conv2d_info.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
int64_t Conv2DInfo::GetIntAttr(const std::string &attr_name) {
|
||||
auto attr_iter = attrs_.find(attr_name);
|
||||
if (attr_iter == attrs_.end()) {
|
||||
MS_LOG(ERROR) << name_ << ": Can not find the attribution of " << attr_name;
|
||||
return -1;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(attr_iter->second);
|
||||
if (!attr_iter->second->isa<Int64Imm>()) {
|
||||
MS_LOG(ERROR) << name_ << ": The value of " << attr_name << " is not int";
|
||||
return -1;
|
||||
}
|
||||
|
||||
return attr_iter->second->cast<Int64ImmPtr>()->value();
|
||||
}
|
||||
|
||||
std::string Conv2DInfo::GetStringAttr(const std::string &attr_name) {
|
||||
std::string string_attr;
|
||||
auto attr_iter = attrs_.find(attr_name);
|
||||
if (attr_iter == attrs_.end()) {
|
||||
MS_LOG(ERROR) << name_ << ": Can not find the attribution of " << attr_name;
|
||||
return string_attr;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(attr_iter->second);
|
||||
if (!attr_iter->second->isa<StringImm>()) {
|
||||
MS_LOG(ERROR) << name_ << ": The value of " << attr_name << " is not string";
|
||||
return string_attr;
|
||||
}
|
||||
|
||||
string_attr = attr_iter->second->cast<StringImmPtr>()->value();
|
||||
return string_attr;
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2DInfo::GetTupleAttr(const std::string &attr_name) {
|
||||
std::vector<int64_t> tuple_attr;
|
||||
auto tuple_attr_iter = attrs_.find(attr_name);
|
||||
if (tuple_attr_iter == attrs_.end()) {
|
||||
MS_LOG(ERROR) << name_ << ": Can not find the attribution of " << attr_name;
|
||||
return tuple_attr;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(tuple_attr_iter->second);
|
||||
tuple_attr = GetValue<std::vector<int64_t>>(tuple_attr_iter->second);
|
||||
|
||||
return tuple_attr;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::GetAttrs() {
|
||||
// out_channel
|
||||
out_channel_ = GetIntAttr(OUT_CHANNEL);
|
||||
if (out_channel_ <= 0) {
|
||||
MS_LOG(ERROR) << name_ << ": The attr of out_channel is invalid";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// kernel_size
|
||||
auto kernel_size_iter = attrs_.find(KERNEL_SIZE);
|
||||
if (kernel_size_iter == attrs_.end()) {
|
||||
MS_LOG(ERROR) << name_ << ": Can not find the attribution of " << KERNEL_SIZE;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(kernel_size_iter->second);
|
||||
if (kernel_size_iter->second->isa<Int64Imm>()) {
|
||||
int64_t kernel_size = kernel_size_iter->second->cast<Int64ImmPtr>()->value();
|
||||
kernel_size_ = {kernel_size, kernel_size};
|
||||
} else if (kernel_size_iter->second->isa<ValueTuple>() || kernel_size_iter->second->isa<ValueList>()) {
|
||||
kernel_size_ = GetValue<std::vector<int64_t>>(kernel_size_iter->second);
|
||||
if (kernel_size_.size() != 2) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of kernel_size'tuple must be 2, but got " << kernel_size_.size();
|
||||
return FAILED;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": The kernel_size must be int or tuple";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// mode
|
||||
mode_ = GetIntAttr(MODE);
|
||||
if (mode_ != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The mode must be 1, but got " << mode_;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// pad_mode
|
||||
pad_mode_ = GetIntAttr(PAD_MODE);
|
||||
if (pad_mode_ < 0 || pad_mode_ > 2) {
|
||||
MS_LOG(ERROR) << name_ << ": The pad_mode must be in the range of [0, 2], but got " << pad_mode_;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// pad_list
|
||||
pad_list_ = GetTupleAttr(PAD_LIST);
|
||||
if (pad_list_.size() != 4) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of pad_list must be 4, but got " << pad_list_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// stride
|
||||
stride_ = GetTupleAttr(STRIDE);
|
||||
if (stride_.size() != 4) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of stride must be 4, but got " << stride_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (stride_[0] != 1 || stride_[1] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The first two elements of stride must be 1, but got (" << stride_[0] << ", "
|
||||
<< stride_[1] << ")";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// dilation
|
||||
dilation_ = GetTupleAttr(DILATION);
|
||||
if (dilation_.size() != 4) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of dilation must be 4, but got " << dilation_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// group
|
||||
group_ = GetIntAttr(GROUP);
|
||||
if (group_ != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The group must be 1, but got " << group_;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// format
|
||||
format_ = GetStringAttr(FORMAT);
|
||||
if (format_ != NCHW) {
|
||||
MS_LOG(ERROR) << name_ << ": The format must be 'NCHW', but got " << format_;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << ": The out channel is " << out_channel_ << ", kernel size is " << kernel_size_
|
||||
<< ", mode is " << mode_ << ", pad mode is " << pad_mode_ << ", pad list is " << pad_list_
|
||||
<< ", stride is " << stride_ << ", dilation is " << dilation_ << ", group is " << group_
|
||||
<< ", format is " << format_;
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
|
||||
if (pad_mode_ == 0) { // 'pad' mode
|
||||
MS_LOG(ERROR) << name_ << ": The 'pad' mode do not support to split H or W";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (pad_mode_ == 1) { // 'same' mode
|
||||
if ((kernel_size_[0] > stride_[2] || kernel_size_[1] > stride_[3]) && h_strategy > 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split H when kernel_size > stride";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (kernel_size_[0] <= stride_[2] || kernel_size_[1] <= stride_[3]) {
|
||||
int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy;
|
||||
int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy;
|
||||
if (h_slice_shape % stride_[2] != 0 || w_slice_shape % stride_[3] != 0) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": The 'same' mode do not support to split H or W when kernel_size <= stride but slice shape "
|
||||
"is not divisible by stride ";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (pad_mode_ == 2) { // 'valid' mode
|
||||
if ((kernel_size_[0] > stride_[2] && h_strategy > 1) || (kernel_size_[1] > stride_[3] && w_strategy > 1)) {
|
||||
MS_LOG(ERROR) << name_ << ": The 'valid' mode do not support to split H or W when kernel_size > stride";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (kernel_size_[0] <= stride_[2]) {
|
||||
int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy;
|
||||
if (h_slice_shape % stride_[2] != 0) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": The 'valid' mode do not support to split H when kernel_size <= stride but slice shape is "
|
||||
"not divisible by stride ";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
if (kernel_size_[1] <= stride_[3]) {
|
||||
int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy;
|
||||
if (w_slice_shape % stride_[3] != 0) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": The 'valid' mode do not support to split W when kernel_size <= stride but slice shape is "
|
||||
"not divisible by stride ";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
MS_EXCEPTION_IF_NULL(strategy);
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
std::vector<Dimensions> stra = strategy->GetInputDim();
|
||||
if (stra.size() != 2) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << stra.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Dimensions input_strategy = stra[0];
|
||||
Dimensions weight_strategy = stra[1];
|
||||
if (input_strategy.size() != 4 || weight_strategy.size() != 4) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": The size of input strategy or weight strategy must be 4, but the size of input strategy is "
|
||||
<< input_strategy.size() << ", the size of weight strategy is " << weight_strategy.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (input_strategy[1] != weight_strategy[1]) {
|
||||
MS_LOG(ERROR) << name_ << ": The shard num of c-in for input strategy is " << input_strategy[1]
|
||||
<< ", but the shard num of c-in for weight strategy is " << weight_strategy[1];
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (weight_strategy[2] != 1 || weight_strategy[3] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The kernel size can not be split, but the strategy for kernel size is ("
|
||||
<< weight_strategy[2] << ", " << weight_strategy[3] << ")";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
|
||||
if (CheckHWStrategy(input_strategy[2], input_strategy[3]) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
if (weight_strategy[0] > 1) {
|
||||
out_channel_shard_ = true;
|
||||
new_out_channel_ = out_channel_ / weight_strategy[1];
|
||||
} else {
|
||||
out_channel_shard_ = false;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::InferDevMatrixShape() {
|
||||
// the strategy is ((n, i, h, w), (o, i, 1, 1))
|
||||
// the dev matrix is (n, i, h, w, o)
|
||||
MS_EXCEPTION_IF_NULL(strategy_);
|
||||
std::vector<Dimensions> stra = strategy_->GetInputDim();
|
||||
if (stra.size() != 2) {
|
||||
MS_LOG(ERROR) << name_ << "The size of strategy must be 2, but got " << stra.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
dev_matrix_shape_ = stra[0];
|
||||
dev_matrix_shape_.push_back(stra[1][0]);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::InferTensorMap() {
|
||||
// input_strategy: ((n, i, h, w), (o, i, 1, 1))
|
||||
// output_strategy: ((n, o, h, w),)
|
||||
// dev_matrix: (n, i, h, w, o)
|
||||
TensorMap input_tensor_map = {4, 3, 2, 1};
|
||||
TensorMap weight_tensor_map = {0, 3, -1, -1};
|
||||
TensorMap output_tensor_map = {4, 0, 2, 1};
|
||||
|
||||
(void)inputs_tensor_map_.emplace_back(std::move(input_tensor_map));
|
||||
(void)inputs_tensor_map_.emplace_back(std::move(weight_tensor_map));
|
||||
(void)outputs_tensor_map_.emplace_back(std::move(output_tensor_map));
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
// if in channel is split, it need to insert all reduce
|
||||
Status Conv2DInfo::InferForwardCommunication() {
|
||||
forward_op_.clear();
|
||||
size_t relevant_dim_index = IN_CHANNEL_INDEX;
|
||||
if (repeated_calc_num_ > 1 && !repeated_num_in_dev_matrix_right_) {
|
||||
// if repeated calculation and repeated num in the left of dev matrix, the index of relevant dimension should add 1
|
||||
relevant_dim_index += 1;
|
||||
}
|
||||
|
||||
if (dev_matrix_shape_[relevant_dim_index] == MIN_SLICE_NUM) {
|
||||
MS_LOG(INFO) << name_ << ": Forward all reduce is not required";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<Group> group_list;
|
||||
if (CreateGroupByDim(relevant_dim_index, &group_list) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Create group failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (group_list.empty()) {
|
||||
MS_LOG(INFO) << name_ << ": Forward all reduce is not required";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Operator op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name());
|
||||
forward_op_.push_back(op);
|
||||
MS_LOG(INFO) << name_ << ": The group name of forward all reduce is " << group_list[0].name();
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) {
|
||||
if (!out_channel_shard_) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
prim->set_attr(OUT_CHANNEL, MakeValue(new_out_channel_));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Conv2DInfo::ReComputeBatchSplitFlagList() {
|
||||
split_flag_list_[0] = true;
|
||||
split_flag_list_[1] = false;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
std::vector<StrategyPtr> Conv2DInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
Strategys strategy = {{stage_device_size_, 1, 1, 1}, {1, 1, 1, 1}};
|
||||
StrategyPtr sp = std::make_shared<Strategy>(stage_id, strategy);
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
sp_vector.push_back(sp);
|
||||
return sp_vector;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::Init(const StrategyPtr &strategy) {
|
||||
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init failed.";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": Init success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::InitForCostModel(const StrategyPtr &strategy) {
|
||||
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << ": Init for cost model success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,75 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONV2D_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONV2D_INFO_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
class Conv2DInfo : public OperatorInfo {
|
||||
public:
|
||||
Conv2DInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()) {}
|
||||
~Conv2DInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
||||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override;
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
|
||||
Status InferForwardCommunication() override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
int64_t GetIntAttr(const std::string &attr_name);
|
||||
std::string GetStringAttr(const std::string &attr_name);
|
||||
std::vector<int64_t> GetTupleAttr(const std::string &attr_name);
|
||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||
|
||||
private:
|
||||
int64_t out_channel_ = 1;
|
||||
std::vector<int64_t> kernel_size_; // two integers
|
||||
int64_t mode_ = 1;
|
||||
int64_t pad_mode_ = 0; // "pad": 0; "same": 1; "valid": 2;
|
||||
std::vector<int64_t> pad_list_; // four integers
|
||||
std::vector<int64_t> stride_; // four integers
|
||||
std::vector<int64_t> dilation_; // four integers
|
||||
int64_t group_ = 1;
|
||||
std::string format_;
|
||||
bool out_channel_shard_ = false;
|
||||
int64_t new_out_channel_ = 1;
|
||||
};
|
||||
|
||||
constexpr size_t IN_CHANNEL_INDEX = 1;
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONV2D_INFO_H_
|
|
@ -1693,5 +1693,21 @@ Status OperatorInfo::GenerateStrategies(int64_t stage_id) {
|
|||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<ValuePtr> GetValueSequeue(const ValuePtr &sequeue) {
|
||||
MS_EXCEPTION_IF_NULL(sequeue);
|
||||
std::vector<ValuePtr> ret;
|
||||
if (!sequeue->isa<ValueTuple>() && !sequeue->isa<ValueList>()) {
|
||||
MS_LOG(ERROR) << "The arg is not value tuple or value list";
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (sequeue->isa<ValueTuple>()) {
|
||||
auto val_tuple = sequeue->cast<ValueTuplePtr>();
|
||||
return val_tuple->value();
|
||||
}
|
||||
auto val = sequeue->cast<ValueListPtr>();
|
||||
return val->value();
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -313,6 +313,7 @@ Status GenerateStrategiesWithBroadcast(int64_t stage_id, const Shapes &inputs_sh
|
|||
std::vector<StrategyPtr> *sp_vector);
|
||||
|
||||
Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph);
|
||||
std::vector<ValuePtr> GetValueSequeue(const ValuePtr &sequeue);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -55,5 +55,6 @@
|
|||
#include "frontend/parallel/ops_info/topk_info.h"
|
||||
#include "frontend/parallel/ops_info/scatter_update_info.h"
|
||||
#include "frontend/parallel/ops_info/virtual_output_info.h"
|
||||
#include "frontend/parallel/ops_info/conv2d_info.h"
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_
|
||||
|
|
|
@ -188,6 +188,16 @@ constexpr char DEVICE[] = "Device";
|
|||
constexpr char PARALLEL_OPTIMIZER_ALLGATHER[] = "parallel_optimizer_allgather";
|
||||
constexpr char CELLLIST_KEYWORD_PATTERN[] = "-CellList/(\\d+)-";
|
||||
|
||||
constexpr char OUT_CHANNEL[] = "out_channel";
|
||||
constexpr char KERNEL_SIZE[] = "kernel_size";
|
||||
constexpr char MODE[] = "mode";
|
||||
constexpr char PAD_MODE[] = "pad_mode";
|
||||
constexpr char PAD_LIST[] = "pad_list";
|
||||
constexpr char STRIDE[] = "stride";
|
||||
constexpr char DILATION[] = "dilation";
|
||||
constexpr char FORMAT[] = "format";
|
||||
constexpr char NCHW[] = "NCHW";
|
||||
|
||||
// Operator
|
||||
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
|
||||
constexpr char GET_TENSOR_SLICE[] = "_GetTensorSlice";
|
||||
|
|
|
@ -43,22 +43,6 @@ static std::string AxesToString(const std::vector<int32_t> &shape) {
|
|||
return str + "]";
|
||||
}
|
||||
|
||||
static std::vector<ValuePtr> GetValueSequeue(const ValuePtr &sequeue) {
|
||||
MS_EXCEPTION_IF_NULL(sequeue);
|
||||
std::vector<ValuePtr> ret;
|
||||
if (!sequeue->isa<ValueTuple>() && !sequeue->isa<ValueList>()) {
|
||||
MS_LOG(ERROR) << "The arg is not value tuple or value list";
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (sequeue->isa<ValueTuple>()) {
|
||||
auto val_tuple = sequeue->cast<ValueTuplePtr>();
|
||||
return val_tuple->value();
|
||||
}
|
||||
auto val = sequeue->cast<ValueListPtr>();
|
||||
return val->value();
|
||||
}
|
||||
|
||||
void TensorDotInfo::ShowAxes() {
|
||||
if (axes_tuple_.size()) {
|
||||
MS_LOG(INFO) << name_ << ": The axes tuple is " << AxesToString(axes_tuple_);
|
||||
|
|
|
@ -195,6 +195,7 @@ class ParallelStrategySearchNet(Cell):
|
|||
kernel_size=5, has_bias=True,
|
||||
weight_init='ones', bias_init='ones',
|
||||
pad_mode='valid')
|
||||
self.conv.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
|
||||
self.scalar = 0.5
|
||||
self.parameter = Parameter(
|
||||
initializer(0.5, test_size, dtype=mstype.float32),
|
||||
|
|
|
@ -81,12 +81,15 @@ class ResidualBlock(nn.Cell):
|
|||
|
||||
out_chls = out_channels // self.expansion
|
||||
self.conv1 = _conv1x1(in_channels, out_chls, stride=1)
|
||||
self.conv1.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
|
||||
self.bn1 = _fused_bn(out_chls, momentum=momentum)
|
||||
|
||||
self.conv2 = _conv3x3(out_chls, out_chls, stride=stride)
|
||||
self.conv2.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
|
||||
self.bn2 = _fused_bn(out_chls, momentum=momentum)
|
||||
|
||||
self.conv3 = _conv1x1(out_chls, out_channels, stride=1)
|
||||
self.conv3.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
|
||||
self.bn3 = _fused_bn(out_channels, momentum=momentum)
|
||||
|
||||
self.relu = P.ReLU()
|
||||
|
@ -95,6 +98,7 @@ class ResidualBlock(nn.Cell):
|
|||
if self.downsample:
|
||||
self.conv_down_sample = _conv1x1(in_channels, out_channels,
|
||||
stride=stride)
|
||||
self.conv_down_sample.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
|
||||
self.bn_down_sample = _fused_bn(out_channels, momentum=momentum)
|
||||
elif self.stride != 1:
|
||||
self.maxpool_down = nn.MaxPool2d(kernel_size=1, stride=2, pad_mode='same')
|
||||
|
@ -143,6 +147,7 @@ class ResNet(nn.Cell):
|
|||
"layer_num, inchannel, outchannel list must be 4!")
|
||||
|
||||
self.conv1 = _conv7x7(3, 64, stride=2)
|
||||
self.conv1.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
|
||||
self.bn1 = _fused_bn(64)
|
||||
self.relu = P.ReLU()
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
|
||||
|
|
|
@ -70,6 +70,7 @@ class Net(nn.Cell):
|
|||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=1, stride=1, pad_mode='valid',
|
||||
has_bias=True, weight_init='ones', bias_init='ones')
|
||||
self.conv.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
|
||||
self.reduce_mean = P.ReduceMean(keep_dims=False).shard(((1, 1, 1, 8),))
|
||||
self.flat = nn.Flatten()
|
||||
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
|
||||
strategy1=None, strategy2=None):
|
||||
super().__init__()
|
||||
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
|
||||
pad_mode=pad_mode, stride=stride).shard(strategy1)
|
||||
self.neg = P.Neg().shard(strategy2)
|
||||
self.conv2d_weight = Parameter(conv2d_weight, "w1")
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.conv2d(x, self.conv2d_weight)
|
||||
out = self.neg(out)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net):
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
train_net.set_train()
|
||||
_executor.compile(train_net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_conv2d_data_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
|
||||
strategy2 = ((8, 1, 1, 1),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_model_parallel1():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((8, 1, 1, 1),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_model_parallel2():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
|
||||
strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1))
|
||||
strategy2 = ((32, 1, 1, 1),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
|
@ -535,6 +535,7 @@ class ParallelReduceMeanNet(nn.Cell):
|
|||
self.conv = nn.Conv2d(in_channels=conv_in_channel, out_channels=conv_out_channel,
|
||||
kernel_size=1, stride=1, pad_mode='valid', has_bias=True,
|
||||
weight_init='ones', bias_init='ones')
|
||||
self.conv.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
|
||||
self.reduce_mean = P.ReduceMean(keep_dims=reducemean_keep_dims)
|
||||
self.flat = nn.Flatten()
|
||||
self.reducemean_axis = reducemean_axis
|
||||
|
|
|
@ -223,6 +223,7 @@ def test_reshape_unexpand_7():
|
|||
self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
|
||||
kernel_size=5, has_bias=True, weight_init='ones',
|
||||
bias_init='ones', pad_mode='valid')
|
||||
self.conv.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
|
||||
self.softmax = nn.Softmax(axis=axis)
|
||||
self.relu = nn.ReLU()
|
||||
self.reshape = P.Reshape()
|
||||
|
|
Loading…
Reference in New Issue