!42918 add conv3d parallel op
Merge pull request !42918 from yangzhenzhang/add-conv3d-parallel-op
This commit is contained in:
commit
0b86281d0a
|
@ -33,14 +33,52 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status Conv2DInfo::GetAttrsBase() {
|
||||
// format
|
||||
format_ = GetStringAttr(FORMAT);
|
||||
Status Conv2DInfo::CheckAttrsBase() {
|
||||
if (format_ != NCHW) {
|
||||
MS_LOG(ERROR) << name_ << ": The format must be 'NCHW', but got " << format_;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (kernel_size_.size() != 2) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of kernel_size'tuple must be 2, but got " << kernel_size_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (pad_list_.size() != 4) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of pad_list must be 4, but got " << pad_list_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (stride_.size() != 4) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of stride must be 4, but got " << stride_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (stride_[0] != 1 || stride_[1] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The first two elements of stride must be 1, but got (" << stride_[0] << ", "
|
||||
<< stride_[1] << ")";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (dilation_.size() != 4) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of dilation must be 4, but got " << dilation_.size();
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": The out channel is " << out_channel_ << ", kernel size is " << kernel_size_
|
||||
<< ", mode is " << mode_ << ", pad mode is " << pad_mode_ << ", pad list is " << pad_list_
|
||||
<< ", stride is " << stride_ << ", dilation is " << dilation_ << ", group is " << group_
|
||||
<< ", format is " << format_ << ", the kernel size use dilation is " << kernel_size_use_dilation_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2DInfo::GetStrideAttr() { return GetTupleIntAttr(STRIDE); }
|
||||
|
||||
std::vector<int64_t> Conv2DInfo::GetDilationAttr() { return GetTupleIntAttr(DILATION); }
|
||||
|
||||
Status Conv2DInfo::GetAttrsBase() {
|
||||
// format
|
||||
format_ = GetStringAttr(FORMAT);
|
||||
|
||||
// out_channel
|
||||
out_channel_ = GetIntAttr(OUT_CHANNEL);
|
||||
if (out_channel_ <= 0) {
|
||||
|
@ -58,13 +96,9 @@ Status Conv2DInfo::GetAttrsBase() {
|
|||
MS_EXCEPTION_IF_NULL(kernel_size_iter->second);
|
||||
if (kernel_size_iter->second->isa<Int64Imm>()) {
|
||||
int64_t kernel_size = kernel_size_iter->second->cast<Int64ImmPtr>()->value();
|
||||
kernel_size_ = {kernel_size, kernel_size};
|
||||
kernel_size_ = Shape(inputs_shape_[1].size() - 2, kernel_size);
|
||||
} else if (kernel_size_iter->second->isa<ValueTuple>() || kernel_size_iter->second->isa<ValueList>()) {
|
||||
kernel_size_ = GetValue<std::vector<int64_t>>(kernel_size_iter->second);
|
||||
if (kernel_size_.size() != 2) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of kernel_size'tuple must be 2, but got " << kernel_size_.size();
|
||||
return FAILED;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": The kernel_size must be int or tuple";
|
||||
return FAILED;
|
||||
|
@ -86,30 +120,12 @@ Status Conv2DInfo::GetAttrsBase() {
|
|||
|
||||
// pad_list
|
||||
pad_list_ = GetTupleIntAttr(PAD_LIST);
|
||||
if (pad_list_.size() != 4) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of pad_list must be 4, but got " << pad_list_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// stride
|
||||
stride_ = GetTupleIntAttr(STRIDE);
|
||||
if (stride_.size() != 4) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of stride must be 4, but got " << stride_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (stride_[0] != 1 || stride_[1] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The first two elements of stride must be 1, but got (" << stride_[0] << ", "
|
||||
<< stride_[1] << ")";
|
||||
return FAILED;
|
||||
}
|
||||
stride_ = GetStrideAttr();
|
||||
|
||||
// dilation
|
||||
dilation_ = GetTupleIntAttr(DILATION);
|
||||
if (dilation_.size() != 4) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of dilation must be 4, but got " << dilation_.size();
|
||||
return FAILED;
|
||||
}
|
||||
dilation_ = GetDilationAttr();
|
||||
|
||||
for (size_t i = 0; i < kernel_size_.size(); ++i) {
|
||||
kernel_size_use_dilation_.push_back(dilation_[i + 2] * (kernel_size_[i] - 1) + 1);
|
||||
|
@ -118,28 +134,57 @@ Status Conv2DInfo::GetAttrsBase() {
|
|||
// group
|
||||
group_ = GetIntAttr(GROUP);
|
||||
|
||||
MS_LOG(INFO) << name_ << ": The out channel is " << out_channel_ << ", kernel size is " << kernel_size_
|
||||
<< ", mode is " << mode_ << ", pad mode is " << pad_mode_ << ", pad list is " << pad_list_
|
||||
<< ", stride is " << stride_ << ", dilation is " << dilation_ << ", group is " << group_
|
||||
<< ", format is " << format_ << ", the kernel size use dilation is " << kernel_size_use_dilation_;
|
||||
infer_strategy_mode_ = INDIVIDUAL_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::GetAttrs() { return GetAttrsBase(); }
|
||||
void Conv2DInfo::AdjustPadList() {
|
||||
// adjust the pad list for 'pad' mode
|
||||
// because the output_len = (in_len + pad_all - k) / s, so the useless_len = (in_len + pad_all - k) % s
|
||||
// and need to adjust the bottom_pad/right_pad if useless_len != 0
|
||||
if (pad_mode_ != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t useless_len_2th_dim = (inputs_shape_[0][2] + pad_list_[0] + pad_list_[1] - kernel_size_[1]) % stride_[2];
|
||||
int64_t useless_len_3th_dim = (inputs_shape_[0][3] + pad_list_[2] + pad_list_[3] - kernel_size_[2]) % stride_[3];
|
||||
if (useless_len_2th_dim == 0 && useless_len_3th_dim == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (useless_len_2th_dim > pad_list_[1]) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The useless len for 2th dim (" << useless_len_2th_dim
|
||||
<< ") can not larger than pad_list[1] (" << pad_list_[1] << ")";
|
||||
}
|
||||
if (useless_len_3th_dim > pad_list_[3]) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The useless len for 3th dim (" << useless_len_3th_dim
|
||||
<< ") can not larger than pad_list[3] (" << pad_list_[3] << ")";
|
||||
}
|
||||
pad_list_[1] -= useless_len_2th_dim;
|
||||
pad_list_[3] -= useless_len_3th_dim;
|
||||
MS_LOG(INFO) << name_ << ": After adjusting, the pad_list is " << pad_list_;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::GetAttrs() {
|
||||
if (GetAttrsBase() != SUCCESS || CheckAttrsBase() != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::CheckHWStrategyBase(int64_t h_strategy, int64_t w_strategy) const {
|
||||
if (outputs_shape_[0][2] % h_strategy != 0) {
|
||||
FILTER_LOG(is_auto_parallel_) << name_
|
||||
<< ": Do not support to split h dimension when out_shape of h dimension is not"
|
||||
" divisible by strategy of h dimension";
|
||||
<< ": Do not support to split 2th dimension when out_shape of 2th dimension is not"
|
||||
" divisible by strategy of 2th dimension";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (outputs_shape_[0][3] % w_strategy != 0) {
|
||||
FILTER_LOG(is_auto_parallel_) << name_
|
||||
<< ": Do not support to split w dimension when out_shape of w dimension is not"
|
||||
" divisible by strategy of w dimension";
|
||||
<< ": Do not support to split 3th dimension when out_shape of 3th dimension is not"
|
||||
" divisible by strategy of 3th dimension";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
|
@ -153,7 +198,7 @@ Status Conv2DInfo::CheckHWStrategyValidMode(int64_t h_strategy, int64_t w_strate
|
|||
if ((kernel_size_use_dilation_[0] > stride_[2] && h_strategy > 1) ||
|
||||
(kernel_size_use_dilation_[1] > stride_[3] && w_strategy > 1)) {
|
||||
FILTER_LOG(is_auto_parallel_) << name_
|
||||
<< ": The 'valid' mode do not support to split H or W when"
|
||||
<< ": The 'valid' mode do not support to split 2th or 3th dimension when"
|
||||
" kernel_size_use_dilation_ > stride";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -161,7 +206,7 @@ Status Conv2DInfo::CheckHWStrategyValidMode(int64_t h_strategy, int64_t w_strate
|
|||
if (kernel_size_use_dilation_[0] <= stride_[2] && h_slice_shape % stride_[2] != 0) {
|
||||
FILTER_LOG(is_auto_parallel_)
|
||||
<< name_
|
||||
<< ": The 'valid' mode do not support to split H when kernel_size_use_dilation_ <= stride but slice shape is "
|
||||
<< ": The 'valid' mode do not support to split 2th when kernel_size_use_dilation_ <= stride but slice shape is "
|
||||
"not divisible by stride ";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -169,7 +214,7 @@ Status Conv2DInfo::CheckHWStrategyValidMode(int64_t h_strategy, int64_t w_strate
|
|||
if (kernel_size_use_dilation_[1] <= stride_[3] && w_slice_shape % stride_[3] != 0) {
|
||||
FILTER_LOG(is_auto_parallel_)
|
||||
<< name_
|
||||
<< ": The 'valid' mode do not support to split W when kernel_size_use_dilation_ <= stride but slice shape is "
|
||||
<< ": The 'valid' mode do not support to split 3th when kernel_size_use_dilation_ <= stride but slice shape is "
|
||||
"not divisible by stride ";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -177,13 +222,13 @@ Status Conv2DInfo::CheckHWStrategyValidMode(int64_t h_strategy, int64_t w_strate
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::CheckHWStrategyPadModeByDimension(int64_t strategy, const std::string &dimension) {
|
||||
Status Conv2DInfo::CheckHWStrategyPadModeByDimension(int64_t strategy, int64_t dimension_id) {
|
||||
if (strategy == 1) {
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
int64_t h_or_w_input_shape = 0, h_or_w_output_shape = 0, h_or_w_kernel_size = 0, h_or_w_stride = 0, pad_all = 0;
|
||||
if (dimension == H_DIMENSION) {
|
||||
if (dimension_id == 2) {
|
||||
h_or_w_input_shape = inputs_shape_[0][2];
|
||||
h_or_w_output_shape = outputs_shape_[0][2];
|
||||
h_or_w_kernel_size = kernel_size_use_dilation_[0];
|
||||
|
@ -198,25 +243,27 @@ Status Conv2DInfo::CheckHWStrategyPadModeByDimension(int64_t strategy, const std
|
|||
}
|
||||
|
||||
if ((h_or_w_input_shape + pad_all - h_or_w_kernel_size) % h_or_w_stride != 0) {
|
||||
FILTER_LOG(is_auto_parallel_) << name_ << ": The 'pad' or 'same' mode do not support to split " << dimension
|
||||
<< " when input_shape + pad_all - k is not divisible by stride ";
|
||||
FILTER_LOG(is_auto_parallel_) << name_ << ": The 'pad' or 'same' mode do not support to split " << dimension_id
|
||||
<< "th dimension when input_shape + pad_all - k is not divisible by stride ";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if ((h_or_w_output_shape * h_or_w_stride - h_or_w_input_shape) % strategy != 0) {
|
||||
FILTER_LOG(is_auto_parallel_) << name_ << ": The 'pad' or 'same' mode do not support to split " << dimension
|
||||
<< " when output_shape * s - input_shape is not divisible by stride ";
|
||||
FILTER_LOG(is_auto_parallel_) << name_ << ": The 'pad' or 'same' mode do not support to split " << dimension_id
|
||||
<< "th dimension when output_shape * s - input_shape is not divisible by stride ";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::CheckHWStrategyPadMode(int64_t h_strategy, int64_t w_strategy) {
|
||||
if (CheckHWStrategyPadModeByDimension(h_strategy, H_DIMENSION) != SUCCESS) {
|
||||
AdjustPadList();
|
||||
|
||||
if (CheckHWStrategyPadModeByDimension(h_strategy, 2) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (CheckHWStrategyPadModeByDimension(w_strategy, W_DIMENSION) != SUCCESS) {
|
||||
if (CheckHWStrategyPadModeByDimension(w_strategy, 3) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
|
@ -251,21 +298,7 @@ Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
Dimensions input_strategy = stra[0];
|
||||
Dimensions weight_strategy = stra[1];
|
||||
if (input_strategy.size() != 4 || weight_strategy.size() != 4) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": The size of input strategy or weight strategy must be 4, but the size of input strategy is "
|
||||
<< input_strategy.size() << ", the size of weight strategy is " << weight_strategy.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (weight_strategy[2] != 1 || weight_strategy[3] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The kernel size can not be split, but the strategy for kernel size is ("
|
||||
<< weight_strategy[2] << ", " << weight_strategy[3] << ")";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (weight_strategy[0] > 1) {
|
||||
out_channel_shard_ = true;
|
||||
new_out_channel_ = out_channel_ / weight_strategy[0];
|
||||
|
@ -293,12 +326,25 @@ Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
std::vector<Dimensions> stra = strategy->GetInputDim();
|
||||
Dimensions input_strategy = stra[0];
|
||||
Dimensions weight_strategy = stra[1];
|
||||
if (input_strategy.size() != 4 || weight_strategy.size() != 4) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": The size of input strategy or weight strategy must be 4, but the size of input strategy is "
|
||||
<< input_strategy.size() << ", the size of weight strategy is " << weight_strategy.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (input_strategy[1] != weight_strategy[1]) {
|
||||
MS_LOG(ERROR) << name_ << ": The shard num of c-in for input strategy is " << input_strategy[1]
|
||||
<< ", but the shard num of c-in for weight strategy is " << weight_strategy[1];
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (weight_strategy[2] != 1 || weight_strategy[3] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The kernel size can not be split, but the strategy for kernel size is ("
|
||||
<< weight_strategy[2] << ", " << weight_strategy[3] << ")";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
|
||||
if (CheckHWStrategy(input_strategy[2], input_strategy[3]) != SUCCESS) {
|
||||
return FAILED;
|
||||
|
@ -317,8 +363,9 @@ Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
}
|
||||
|
||||
Status Conv2DInfo::InferDevMatrixShape() {
|
||||
// the strategy is ((n, i, h, w), (o, i, 1, 1))
|
||||
// the dev matrix is (n, i, h, w, o)
|
||||
// conv2d: the strategy is ((n, i, a, b), (o, i, 1, 1))
|
||||
// conv3d: the strategy is ((n, i, a, b, 1), (o, i, 1, 1, 1))
|
||||
// the dev matrix is (n, i, a, b, o)
|
||||
MS_EXCEPTION_IF_NULL(strategy_);
|
||||
std::vector<Dimensions> stra = strategy_->GetInputDim();
|
||||
if (stra.size() != 2) {
|
||||
|
@ -326,7 +373,7 @@ Status Conv2DInfo::InferDevMatrixShape() {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
dev_matrix_shape_ = stra[0];
|
||||
dev_matrix_shape_ = {stra[0][0], stra[0][1], stra[0][2], stra[0][3]};
|
||||
dev_matrix_shape_.push_back(stra[1][0]);
|
||||
h_dimension_shard_num_ = stra[0][2];
|
||||
w_dimension_shard_num_ = stra[0][3];
|
||||
|
@ -334,18 +381,18 @@ Status Conv2DInfo::InferDevMatrixShape() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2DInfo::GetAdjacentRankIdsAndBiases(int64_t rank_id, const std::string &dimension) {
|
||||
std::vector<int64_t> Conv2DInfo::GetAdjacentRankIdsAndBiases(int64_t rank_id, int64_t dimension) {
|
||||
std::vector<int64_t> ret;
|
||||
if (rank_id < 0) {
|
||||
ret = {-1, -1, -1, -1, -1};
|
||||
return ret;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << ": The rank id is " << rank_id << ", the dimension is " << dimension;
|
||||
MS_LOG(INFO) << name_ << ": The rank id is " << rank_id << ", the dimension is " << dimension << "th dimension";
|
||||
|
||||
uint64_t index_in_dev_matrix = 0;
|
||||
int64_t dimension_shard_num = 1;
|
||||
if (dimension == H_DIMENSION) {
|
||||
if (dimension == 2) {
|
||||
index_in_dev_matrix = 2;
|
||||
dimension_shard_num = h_dimension_shard_num_;
|
||||
} else {
|
||||
|
@ -364,21 +411,21 @@ std::vector<int64_t> Conv2DInfo::GetAdjacentRankIdsAndBiases(int64_t rank_id, co
|
|||
}
|
||||
|
||||
if (group_devices.size() <= 1) {
|
||||
MS_LOG(INFO) << name_ << ": The devices' size of " << dimension << " is " << group_devices.size()
|
||||
MS_LOG(INFO) << name_ << ": The devices' size of " << dimension << "th dimension is " << group_devices.size()
|
||||
<< ", no need to infer rank bias";
|
||||
ret = {-1, -1, -1, -1, -1};
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (group_devices.size() != LongToSize(dimension_shard_num)) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The devices' size of " << dimension << " is " << group_devices.size()
|
||||
<< ", but the shard num of w dimension is " << dimension_shard_num;
|
||||
MS_LOG(EXCEPTION) << name_ << ": The devices' size of " << dimension << "th dimension is " << group_devices.size()
|
||||
<< ", but the shard num of this dimension is " << dimension_shard_num;
|
||||
}
|
||||
|
||||
std::vector<int64_t>::iterator it = std::find(group_devices.begin(), group_devices.end(), rank_id);
|
||||
if (it == group_devices.end()) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Can not find the current rank in device list of " << dimension
|
||||
<< ", the current rank is " << rank_id << ", the device list is " << group_devices;
|
||||
<< "th dimension, the current rank is " << rank_id << ", the device list is " << group_devices;
|
||||
}
|
||||
|
||||
int64_t left_or_top_rank_id = -1;
|
||||
|
@ -425,25 +472,25 @@ void Conv2DInfo::InferAdjacentRankInfo() {
|
|||
|
||||
CheckGlobalDeviceManager();
|
||||
int64_t rank = g_device_manager->global_rank();
|
||||
std::vector<int64_t> h_dim_rank_info = GetAdjacentRankIdsAndBiases(rank, H_DIMENSION);
|
||||
std::vector<int64_t> h_dim_rank_info = GetAdjacentRankIdsAndBiases(rank, 2);
|
||||
top_rank_id_ = h_dim_rank_info[0];
|
||||
bottom_rank_id_ = h_dim_rank_info[1];
|
||||
top_rank_bias_ = h_dim_rank_info[2];
|
||||
bottom_rank_bias_ = h_dim_rank_info[3];
|
||||
h_rank_bias_ = h_dim_rank_info[4];
|
||||
|
||||
std::vector<int64_t> w_dim_rank_info = GetAdjacentRankIdsAndBiases(rank, W_DIMENSION);
|
||||
std::vector<int64_t> w_dim_rank_info = GetAdjacentRankIdsAndBiases(rank, 3);
|
||||
left_rank_id_ = w_dim_rank_info[0];
|
||||
right_rank_id_ = w_dim_rank_info[1];
|
||||
left_rank_bias_ = w_dim_rank_info[2];
|
||||
right_rank_bias_ = w_dim_rank_info[3];
|
||||
w_rank_bias_ = w_dim_rank_info[4];
|
||||
|
||||
std::vector<int64_t> top_w_dim_rank_info = GetAdjacentRankIdsAndBiases(top_rank_id_, W_DIMENSION);
|
||||
std::vector<int64_t> top_w_dim_rank_info = GetAdjacentRankIdsAndBiases(top_rank_id_, 3);
|
||||
top_left_rank_id_ = top_w_dim_rank_info[0];
|
||||
top_right_rank_id_ = top_w_dim_rank_info[1];
|
||||
|
||||
std::vector<int64_t> bottom_w_dim_rank_info = GetAdjacentRankIdsAndBiases(bottom_rank_id_, W_DIMENSION);
|
||||
std::vector<int64_t> bottom_w_dim_rank_info = GetAdjacentRankIdsAndBiases(bottom_rank_id_, 3);
|
||||
bottom_left_rank_id_ = bottom_w_dim_rank_info[0];
|
||||
bottom_right_rank_id_ = bottom_w_dim_rank_info[1];
|
||||
|
||||
|
@ -561,13 +608,13 @@ void Conv2DInfo::InferOverlapSizeForWDim() {
|
|||
|
||||
void Conv2DInfo::CheckHDimensionOverlapSizeNonNegative() {
|
||||
if (h_dimension_shard_num_ == 1) {
|
||||
MS_LOG(INFO) << name_ << ": The h dimension is not shard";
|
||||
MS_LOG(INFO) << name_ << ": The 2th dimension is not shard";
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t h_first_rank_bottom_size = ComputeOverlapBottomSizeByRankBias(0);
|
||||
if (h_first_rank_bottom_size < 0) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The bottom overlap size of h dimension rank bias 0 must be positive, but it is "
|
||||
MS_LOG(EXCEPTION) << name_ << ": The bottom overlap size of 2th dimension rank bias 0 must be positive, but it is "
|
||||
<< h_first_rank_bottom_size;
|
||||
}
|
||||
|
||||
|
@ -575,7 +622,7 @@ void Conv2DInfo::CheckHDimensionOverlapSizeNonNegative() {
|
|||
auto top_size = ComputeOverlapTopSizeByRankBias(h_rank_bias);
|
||||
auto bottom_size = ComputeOverlapBottomSizeByRankBias(h_rank_bias);
|
||||
if (top_size < 0 || bottom_size < 0) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The overlap size of h dimension rank bias " << h_rank_bias
|
||||
MS_LOG(EXCEPTION) << name_ << ": The overlap size of 2th dimension rank bias " << h_rank_bias
|
||||
<< " must be positive, but top overlap size is " << top_size << ", bottom overlap size is "
|
||||
<< bottom_size;
|
||||
}
|
||||
|
@ -583,19 +630,19 @@ void Conv2DInfo::CheckHDimensionOverlapSizeNonNegative() {
|
|||
|
||||
int64_t h_last_rank_top_size = ComputeOverlapTopSizeByRankBias(h_dimension_shard_num_ - 1);
|
||||
if (h_last_rank_top_size < 0) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The top overlap size of h dimension last rank bias must be positive, but it is "
|
||||
MS_LOG(EXCEPTION) << name_ << ": The top overlap size of 2th dimension last rank bias must be positive, but it is "
|
||||
<< h_last_rank_top_size;
|
||||
}
|
||||
}
|
||||
|
||||
void Conv2DInfo::CheckWDimensionOverlapSizeNonNegative() {
|
||||
if (w_dimension_shard_num_ == 1) {
|
||||
MS_LOG(INFO) << name_ << ": The w dimension is not shard";
|
||||
MS_LOG(INFO) << name_ << ": The 3th dimension is not shard";
|
||||
return;
|
||||
}
|
||||
int64_t w_first_rank_right_size = ComputeOverlapRightSizeByRankBias(0);
|
||||
if (w_first_rank_right_size < 0) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The right overlap size of w dimension rank bias 0 must be positive, but it is "
|
||||
MS_LOG(EXCEPTION) << name_ << ": The right overlap size of 3th dimension rank bias 0 must be positive, but it is "
|
||||
<< w_first_rank_right_size;
|
||||
}
|
||||
|
||||
|
@ -603,7 +650,7 @@ void Conv2DInfo::CheckWDimensionOverlapSizeNonNegative() {
|
|||
auto left_size = ComputeOverlapLeftSizeByRankBias(w_rank_bias);
|
||||
auto right_size = ComputeOverlapRightSizeByRankBias(w_rank_bias);
|
||||
if (left_size < 0 || right_size < 0) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The overlap size of w dimension rank bias " << w_rank_bias
|
||||
MS_LOG(EXCEPTION) << name_ << ": The overlap size of 3th dimension rank bias " << w_rank_bias
|
||||
<< " must be positive, but left overlap size is " << left_size << ", right overlap size is "
|
||||
<< right_size;
|
||||
}
|
||||
|
@ -611,7 +658,7 @@ void Conv2DInfo::CheckWDimensionOverlapSizeNonNegative() {
|
|||
|
||||
int64_t w_last_rank_left_size = ComputeOverlapLeftSizeByRankBias(w_dimension_shard_num_ - 1);
|
||||
if (w_last_rank_left_size < 0) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The left overlap size of w dimension last rank bias must be positive, but it is "
|
||||
MS_LOG(EXCEPTION) << name_ << ": The left overlap size of 3th dimension last rank bias must be positive, but it is "
|
||||
<< w_last_rank_left_size;
|
||||
}
|
||||
}
|
||||
|
@ -651,7 +698,7 @@ Status Conv2DInfo::InferTensorMap() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
// Conv2d: dev_matrix is (n, i, h, w, o), if in channel is split, it need to insert all reduce
|
||||
// Conv2d/Conv3d: dev_matrix is (n, i, h, w, o), if in channel is split, it need to insert all reduce
|
||||
// Conv2DBackpropInputInfo: dev_matrix is (n, o, h, w, i), if out channel is split, it need to insert all reduce
|
||||
Status Conv2DInfo::InferForwardCommunication() {
|
||||
forward_op_.clear();
|
||||
|
@ -785,13 +832,13 @@ void Conv2DInfo::InferCommunicationAttrs() {
|
|||
int64_t h_slice_shape = input_slice_shape_[2];
|
||||
if (send_top_len > h_slice_shape || send_bottom_len > h_slice_shape || recv_top_len > h_slice_shape ||
|
||||
recv_bottom_len > h_slice_shape) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The send or recv len larger than slice shape of h dimension " << h_slice_shape;
|
||||
MS_LOG(EXCEPTION) << name_ << ": The send or recv len larger than slice shape of 2th dimension " << h_slice_shape;
|
||||
}
|
||||
|
||||
int64_t w_slice_shape = input_slice_shape_[3];
|
||||
if (send_left_len > w_slice_shape || send_right_len > w_slice_shape || recv_left_len > w_slice_shape ||
|
||||
recv_right_len > w_slice_shape) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The send or recv len larger than slice shape of w dimension " << w_slice_shape;
|
||||
MS_LOG(EXCEPTION) << name_ << ": The send or recv len larger than slice shape of 3th dimension " << w_slice_shape;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -827,7 +874,7 @@ OperatorAttrs Conv2DInfo::CreateConv2DAttrs() {
|
|||
Attr data_format = {DATA_FORMAT, MakeValue(format_)};
|
||||
|
||||
OperatorAttrs attrs;
|
||||
if (name_.find(CONV2D_INFO) != std::string::npos) {
|
||||
if (name_.find(CONV2D_INFO) != std::string::npos || name_.find(CONV3D_INFO) != std::string::npos) {
|
||||
attrs = {out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format};
|
||||
} else { // Conv2DTranspose
|
||||
attrs = {out_channel, kernel_size, pad_mode, pad, pad, mode, stride, dilation, group, data_format};
|
||||
|
@ -937,7 +984,9 @@ std::vector<StrategyPtr> Conv2DInfo::GenerateOpStrategies(int64_t stage_id) {
|
|||
auto search_mode = parallel_context->strategy_search_mode();
|
||||
// generate data parallel strategy when the search mode is not sharding propagation
|
||||
if (parallel_mode == parallel::kAutoParallel && search_mode != parallel::kShardingPropagation) {
|
||||
Strategies strategy = {{stage_device_size_, 1, 1, 1}, {1, 1, 1, 1}};
|
||||
Shape input_strategy(inputs_shape_[0].size(), 1);
|
||||
input_strategy[0] = stage_device_size_;
|
||||
Strategies strategy = {input_strategy, Shape(inputs_shape_[1].size(), 1)};
|
||||
StrategyPtr data_parallel_sp = std::make_shared<Strategy>(stage_id, strategy);
|
||||
sp_vector.push_back(data_parallel_sp);
|
||||
return sp_vector;
|
||||
|
@ -966,11 +1015,19 @@ std::vector<StrategyPtr> Conv2DInfo::GenerateOpStrategies(int64_t stage_id) {
|
|||
if (tmp_strategy.size() != 5) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of first tmp strategy must be 5, but got " << tmp_strategy.size();
|
||||
}
|
||||
Dimensions input0_strategy = {tmp_strategy[0], tmp_strategy[1], tmp_strategy[2], tmp_strategy[3]};
|
||||
|
||||
Dimensions input0_strategy;
|
||||
Dimensions input1_strategy;
|
||||
if (name_.find(CONV2D_INFO) != std::string::npos) { // conv2d
|
||||
|
||||
if (name_.find(CONV2D_INFO) != std::string::npos) { // conv2d
|
||||
input0_strategy = {tmp_strategy[0], tmp_strategy[1], tmp_strategy[2], tmp_strategy[3]};
|
||||
input1_strategy = {tmp_strategy[4], tmp_strategy[1], 1, 1}; // (C-out, C-in, k1, k2), the k1/k2 can not be split
|
||||
} else { // conv2d-transpose
|
||||
} else if (name_.find(CONV3D_INFO) != std::string::npos) { // conv3d
|
||||
input0_strategy = {tmp_strategy[0], tmp_strategy[1], tmp_strategy[2], tmp_strategy[3], 1};
|
||||
input1_strategy = {tmp_strategy[4], tmp_strategy[1], 1, 1, 1};
|
||||
} else if (name_.find(CONV2D_TRANSPOSE) != std::string::npos ||
|
||||
name_.find(CONV2D_BACK_PROP_INPUT) != std::string::npos) { // conv2d-transpose
|
||||
input0_strategy = {tmp_strategy[0], tmp_strategy[1], tmp_strategy[2], tmp_strategy[3]};
|
||||
input1_strategy = {tmp_strategy[1], tmp_strategy[4], 1, 1};
|
||||
}
|
||||
replace_strategy.push_back(input0_strategy);
|
||||
|
@ -990,8 +1047,9 @@ Shapes Conv2DInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
|||
|
||||
Shape tmp_strategy;
|
||||
if (!in_strategy[0].empty()) {
|
||||
if (in_strategy[0].size() != 4) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[0] must be 4, but got " << in_strategy[0].size();
|
||||
if (in_strategy[0].size() != inputs_shape_[0].size()) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[0] must be " << inputs_shape_[0].size() << ", but got "
|
||||
<< in_strategy[0].size();
|
||||
}
|
||||
tmp_strategy = Shape(inputs_shape_[1].size(), 1);
|
||||
tmp_strategy[1] = in_strategy[0][1];
|
||||
|
@ -999,8 +1057,9 @@ Shapes Conv2DInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
|||
}
|
||||
|
||||
if (!in_strategy[1].empty()) {
|
||||
if (in_strategy[1].size() != 4) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[1] must be 4, but got " << in_strategy[1].size();
|
||||
if (in_strategy[1].size() != inputs_shape_[1].size()) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy[1] must be " << inputs_shape_[1].size() << ", but got "
|
||||
<< in_strategy[1].size();
|
||||
}
|
||||
tmp_strategy = Shape(inputs_shape_[0].size(), 1);
|
||||
tmp_strategy[1] = in_strategy[1][1];
|
||||
|
@ -1078,7 +1137,7 @@ Status Conv2DBackpropInputInfo::GetOutShape() {
|
|||
}
|
||||
|
||||
Status Conv2DBackpropInputInfo::GetAttrs() {
|
||||
if (GetAttrsBase() != SUCCESS) {
|
||||
if (GetAttrsBase() != SUCCESS || CheckAttrsBase() != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
|
|
|
@ -44,6 +44,9 @@ class Conv2DInfo : public OperatorInfo {
|
|||
|
||||
protected:
|
||||
Status GetAttrsBase();
|
||||
virtual Status CheckAttrsBase();
|
||||
virtual std::vector<int64_t> GetStrideAttr();
|
||||
virtual std::vector<int64_t> GetDilationAttr();
|
||||
Status GetAttrs() override;
|
||||
Status CheckStrategyBase(const StrategyPtr &strategy);
|
||||
Status CheckHWStrategyBase(int64_t h_strategy, int64_t w_strategy) const;
|
||||
|
@ -52,7 +55,7 @@ class Conv2DInfo : public OperatorInfo {
|
|||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
void InferAdjacentRankInfo();
|
||||
std::vector<int64_t> GetAdjacentRankIdsAndBiases(int64_t rank_id, const std::string &dimension);
|
||||
std::vector<int64_t> GetAdjacentRankIdsAndBiases(int64_t rank_id, int64_t dimension);
|
||||
void InferOverlapSize();
|
||||
void CheckHDimensionOverlapSizeNonNegative();
|
||||
void CheckWDimensionOverlapSizeNonNegative();
|
||||
|
@ -63,11 +66,11 @@ class Conv2DInfo : public OperatorInfo {
|
|||
void InferSendRankIds();
|
||||
void InferRecvRankIds();
|
||||
void InferCommunicationAttrs();
|
||||
std::string ReplaceNodeName() const;
|
||||
virtual std::string ReplaceNodeName() const;
|
||||
AnfNodePtr GenerateConv2DNode(const AnfNodePtr &new_input, const CNodePtr &cnode);
|
||||
OperatorAttrs CreateNeighborExchangeV2Attrs();
|
||||
OperatorAttrs CreateConv2DAttrs();
|
||||
void ComputeReplaceGraph(const CNodePtr &cnode);
|
||||
virtual void ComputeReplaceGraph(const CNodePtr &cnode);
|
||||
|
||||
int64_t out_channel_ = 1;
|
||||
std::vector<int64_t> kernel_size_; // two integers
|
||||
|
@ -147,8 +150,9 @@ class Conv2DInfo : public OperatorInfo {
|
|||
|
||||
private:
|
||||
Status CheckHWStrategyValidMode(int64_t h_strategy, int64_t w_strategy);
|
||||
Status CheckHWStrategyPadModeByDimension(int64_t strategy, const std::string &dimension);
|
||||
Status CheckHWStrategyPadModeByDimension(int64_t strategy, int64_t dimension_id);
|
||||
Status CheckHWStrategyPadMode(int64_t h_strategy, int64_t w_strategy);
|
||||
void AdjustPadList();
|
||||
};
|
||||
|
||||
class Conv2DBackpropInputInfo : public Conv2DInfo {
|
||||
|
|
|
@ -0,0 +1,230 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/ops_info/conv3d_info.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/dynamic_creator.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
#include "include/common/utils/parallel_context.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
std::vector<int64_t> Conv3DInfo::GetStrideAttr() { return GetTupleIntAttr(STRIDES); }
|
||||
|
||||
std::vector<int64_t> Conv3DInfo::GetDilationAttr() { return GetTupleIntAttr(DILATIONS); }
|
||||
|
||||
Status Conv3DInfo::CheckAttrsBase() {
|
||||
if (format_ != NCDHW) {
|
||||
MS_LOG(ERROR) << name_ << ": The format must be 'NCDHW', but got " << format_;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (kernel_size_.size() != 3) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of kernel_size'tuple must be 3, but got " << kernel_size_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (pad_list_.size() != 6) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of pad_list must be 6, but got " << pad_list_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (stride_.size() != 5) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of stride must be 5, but got " << stride_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (stride_[0] != 1 || stride_[1] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The first two elements of stride must be 1, but the stride is " << stride_;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (dilation_.size() != 5) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of dilation must be 5, but got " << dilation_.size();
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": The out channel is " << out_channel_ << ", kernel size is " << kernel_size_
|
||||
<< ", mode is " << mode_ << ", pad mode is " << pad_mode_ << ", pad list is " << pad_list_
|
||||
<< ", stride is " << stride_ << ", dilation is " << dilation_ << ", group is " << group_
|
||||
<< ", format is " << format_ << ", the kernel size use dilation is " << kernel_size_use_dilation_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv3DInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
h_dim_need_exchange_overlap_ = false;
|
||||
w_dim_need_exchange_overlap_ = false;
|
||||
if (CheckStrategyBase(strategy) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
std::vector<Dimensions> stra = strategy->GetInputDim();
|
||||
Dimensions input_strategy = stra[0];
|
||||
Dimensions weight_strategy = stra[1];
|
||||
if (input_strategy.size() != 5 || weight_strategy.size() != 5) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": The size of input strategy or weight strategy must be 5, but the size of input strategy is "
|
||||
<< input_strategy.size() << ", the size of weight strategy is " << weight_strategy.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (input_strategy[1] != weight_strategy[1]) {
|
||||
MS_LOG(ERROR) << name_ << ": The shard num of c-in for input strategy is " << input_strategy[1]
|
||||
<< ", but the shard num of c-in for weight strategy is " << weight_strategy[1];
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (weight_strategy[2] != 1 || weight_strategy[3] != 1 || weight_strategy[4] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The kernel size can not be split, but the strategy for kernel size is ("
|
||||
<< weight_strategy[2] << ", " << weight_strategy[3] << ", " << weight_strategy[4] << ")";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (input_strategy[4] != 1) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": Do not support to split the last dimension of input, but the strategy for this dimension is ("
|
||||
<< input_strategy[4];
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
|
||||
if (CheckHWStrategy(input_strategy[2], input_strategy[3]) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
// if the h/w dimension is split, and the pad mode is not "valid", need to exchange overlap
|
||||
if (input_strategy[2] > 1 && pad_mode_ != 2) {
|
||||
h_dim_need_exchange_overlap_ = true;
|
||||
}
|
||||
|
||||
if (input_strategy[3] > 1 && pad_mode_ != 2) {
|
||||
w_dim_need_exchange_overlap_ = true;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv3DInfo::InferTensorMap() {
|
||||
// input_strategy: ((n, i, a, b, 1), (o, i, 1, 1, 1))
|
||||
// output_strategy: ((n, o, a, b, 1),)
|
||||
// dev_matrix: (n, i, a, b, o)
|
||||
TensorMap input_tensor_map = {4, 3, 2, 1, -1};
|
||||
TensorMap weight_tensor_map = {0, 3, -1, -1, -1};
|
||||
TensorMap output_tensor_map = {4, 0, 2, 1, -1};
|
||||
|
||||
(void)inputs_tensor_map_.emplace_back(std::move(input_tensor_map));
|
||||
(void)inputs_tensor_map_.emplace_back(std::move(weight_tensor_map));
|
||||
(void)outputs_tensor_map_.emplace_back(std::move(output_tensor_map));
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::string Conv3DInfo::ReplaceNodeName() const {
|
||||
if (name_.find(CONV3D_INFO) != std::string::npos) {
|
||||
return CONV3D;
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "Invalid name: " << name_;
|
||||
}
|
||||
|
||||
OperatorAttrs Conv3DInfo::CreateConv3DAttrs() {
|
||||
auto node_stride = stride_;
|
||||
(void)node_stride.erase(node_stride.begin(), node_stride.begin() + 2);
|
||||
auto node_dilition = dilation_;
|
||||
(void)node_dilition.erase(node_dilition.begin(), node_dilition.begin() + 2);
|
||||
|
||||
Attr out_channel = {OUT_CHANNEL, MakeValue(new_out_channel_)};
|
||||
Attr kernel_size = {KERNEL_SIZE, MakeValue(kernel_size_)};
|
||||
Attr mode = {MODE, MakeValue(mode_)};
|
||||
Attr pad_mode = {PAD_MODE, MakeValue("pad")};
|
||||
Attr pad = {PAD, MakeValue(new_pad_list_)};
|
||||
Attr stride = {STRIDE, MakeValue(node_stride)};
|
||||
Attr dilation = {DILATION, MakeValue(node_dilition)};
|
||||
Attr group = {GROUP, MakeValue(group_)};
|
||||
Attr data_format = {DATA_FORMAT, MakeValue(format_)};
|
||||
|
||||
OperatorAttrs attrs;
|
||||
attrs = {out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format};
|
||||
return attrs;
|
||||
}
|
||||
|
||||
AnfNodePtr Conv3DInfo::GenerateConv3DNode(const AnfNodePtr &new_input, const CNodePtr &cnode) {
|
||||
auto conv3d_attrs = CreateConv3DAttrs();
|
||||
auto node_name = ReplaceNodeName();
|
||||
|
||||
if (cnode->size() < 3) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of cnode is invalid: " << cnode->size();
|
||||
}
|
||||
return gen_g_.PushBack({gen_g_.NewOpInst(node_name, conv3d_attrs), new_input, cnode->input(2)});
|
||||
}
|
||||
|
||||
void Conv3DInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||
// Because the NeighborExchangeV2 only support the 4-dim input, and it only exchange the last 2-dim of input, but the
|
||||
// input of conv3d is 5-dim, and need to exchange 3/4th-dim of input, so here use some operators to build the graph:
|
||||
// slice input (ncdhw) -> transpose(in, (4, 0, 1, 2, 3)) -> reshape(in, (w*n, c, d, h)) -> neighborexchangev2(in)
|
||||
// -> reshape(in, (w, n, c, d', h')) -> transpose(in, (1, 2, 3, 4, 0)) -> conv3d
|
||||
auto graph = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
||||
if (gen_g_.Init(cnode) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": GenerateGraph Init failed";
|
||||
}
|
||||
|
||||
// transpose-1
|
||||
std::vector<int64_t> t1 = {4, 0, 1, 2, 3};
|
||||
auto transpose_1 = gen_g_.PushBack({gen_g_.NewOpInst(TRANSPOSE), gen_g_.virtual_input_node(), CreateTuple(t1)});
|
||||
|
||||
// reshape-1
|
||||
auto s = input_slice_shape_;
|
||||
if (s.size() != 5) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of input slice shape must be 5, but got " << s.size();
|
||||
}
|
||||
Shape s1 = {s[4] * s[0], s[1], s[2], s[3]};
|
||||
auto reshape_1 = gen_g_.PushBack({gen_g_.NewOpInst(RESHAPE), transpose_1, CreateTuple(s1)});
|
||||
|
||||
// neighborexchangev2
|
||||
auto neighbor_exchange_v2_attrs = CreateNeighborExchangeV2Attrs();
|
||||
auto neighbor_exchange_v2 =
|
||||
gen_g_.PushBack({gen_g_.NewOpInst(NEIGHBOREXCHANGEV2, neighbor_exchange_v2_attrs), reshape_1});
|
||||
|
||||
// reshape-2
|
||||
Shape s2 = {s[4], s[0], s[1], s[2] + recv_lens_[0] + recv_lens_[1], s[3] + recv_lens_[2] + recv_lens_[3]};
|
||||
auto reshape_2 = gen_g_.PushBack({gen_g_.NewOpInst(RESHAPE), neighbor_exchange_v2, CreateTuple(s2)});
|
||||
|
||||
// transopse-2
|
||||
std::vector<int64_t> t2 = {1, 2, 3, 4, 0};
|
||||
auto transpose_2 = gen_g_.PushBack({gen_g_.NewOpInst(TRANSPOSE), reshape_2, CreateTuple(t2)});
|
||||
|
||||
// conv3d
|
||||
auto conv3d = GenerateConv3DNode(transpose_2, cnode);
|
||||
|
||||
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(transpose_1, 1)};
|
||||
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
|
||||
std::make_pair(input_nodes, conv3d));
|
||||
}
|
||||
|
||||
REGISTER(Conv3DInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONV3D_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONV3D_INFO_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/ops_info/conv2d_info.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
class Conv3DInfo : public Conv2DInfo {
|
||||
public:
|
||||
Conv3DInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: Conv2DInfo(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~Conv3DInfo() override = default;
|
||||
|
||||
protected:
|
||||
Status CheckAttrsBase() override;
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferTensorMap() override;
|
||||
std::string ReplaceNodeName() const override;
|
||||
AnfNodePtr GenerateConv3DNode(const AnfNodePtr &new_input, const CNodePtr &cnode);
|
||||
void ComputeReplaceGraph(const CNodePtr &cnode) override;
|
||||
std::vector<int64_t> GetStrideAttr() override;
|
||||
std::vector<int64_t> GetDilationAttr() override;
|
||||
OperatorAttrs CreateConv3DAttrs();
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONV3D_INFO_H_
|
|
@ -248,8 +248,10 @@ constexpr char PAD[] = "pad";
|
|||
constexpr char DATA_FORMAT[] = "data_format";
|
||||
constexpr char STRIDE[] = "stride";
|
||||
constexpr char DILATION[] = "dilation";
|
||||
constexpr char DILATIONS[] = "dilations";
|
||||
constexpr char FORMAT[] = "format";
|
||||
constexpr char NCHW[] = "NCHW";
|
||||
constexpr char NCDHW[] = "NCDHW";
|
||||
constexpr char H_DIMENSION[] = "h_dimension";
|
||||
constexpr char W_DIMENSION[] = "w_dimension";
|
||||
constexpr char IS_TRAINING[] = "is_training";
|
||||
|
@ -350,9 +352,11 @@ constexpr char ARGMAX[] = "Argmax";
|
|||
constexpr char ARGMIN[] = "Argmin";
|
||||
constexpr char ARGMINV2[] = "ArgminV2";
|
||||
constexpr char CONV2D[] = "Conv2D";
|
||||
constexpr char CONV3D[] = "Conv3D";
|
||||
constexpr char CONV2D_BACK_PROP_INPUT[] = "Conv2DBackpropInput";
|
||||
constexpr char CONV2D_TRANSPOSE[] = "Conv2DTranspose";
|
||||
constexpr char CONV2D_INFO[] = "Conv2DInfo";
|
||||
constexpr char CONV3D_INFO[] = "Conv3DInfo";
|
||||
constexpr char CONV2D_BACK_PROP_INPUT_INFO[] = "Conv2DBackpropInputInfo";
|
||||
constexpr char CONV2D_TRANSPOSE_INFO[] = "Conv2DTransposeInfo";
|
||||
constexpr char FUSE_BATCH_NORM[] = "FusedBatchNorm";
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops import operations as P
|
||||
from parallel.utils.utils import ParallelValidator, compile_net
|
||||
|
||||
|
||||
def setup_function():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, conv3d_weight, out_channel, kernel_size, pad_mode, stride, dilation=1, group=1, pad=0,
|
||||
strategy1=None, strategy2=None):
|
||||
super().__init__()
|
||||
self.conv3d = P.Conv3D(out_channel=out_channel, kernel_size=kernel_size, pad_mode=pad_mode, pad=pad,
|
||||
stride=stride, dilation=dilation, group=group).shard(strategy1)
|
||||
self.neg = P.Neg().shard(strategy2)
|
||||
self.conv3d_weight = Parameter(conv3d_weight, "w1")
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.conv3d(x, self.conv3d_weight)
|
||||
out = self.neg(out)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([32, 16, 8, 8, 8]), dtype=ms.float32)
|
||||
_x3 = Tensor(np.ones([32, 16, 16, 16, 16]), dtype=ms.float32)
|
||||
_x4 = Tensor(np.ones([2, 16, 56, 56, 24]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([8, 16, 2, 2, 2]), dtype=ms.float32)
|
||||
_w2 = Tensor(np.ones([8, 16, 3, 3, 3]), dtype=ms.float32)
|
||||
_w5 = Tensor(np.ones([8, 16, 4, 4, 4]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([32, 16, 8, 8, 8]), dtype=ms.float32)
|
||||
|
||||
|
||||
def test_conv3d_data_parallel():
|
||||
"""
|
||||
Feature: test conv3d data parallel
|
||||
Description: shard n dimension
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((8, 1, 1, 1, 1), (1, 1, 1, 1, 1))
|
||||
strategy2 = ((8, 1, 1, 1, 1),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
|
||||
phase = compile_net(net, _x, _b)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_node_inputs('Neg-0', ['Conv3D-0'])
|
||||
|
||||
|
||||
def test_conv3d_pad_mode_overlap_is_negative():
|
||||
"""
|
||||
Feature: test conv3d pad mode and overlap is negative
|
||||
Description: shard d/h
|
||||
Expectation: compile failed
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((1, 1, 4, 4, 1), (1, 1, 1, 1, 1))
|
||||
strategy2 = ((1, 1, 1, 1, 1),)
|
||||
net = Net(_w5, out_channel=8, kernel_size=4, pad_mode="pad", stride=5, pad=(3, 0, 3, 0, 3, 0),
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, _x3, _b)
|
||||
|
||||
|
||||
def test_conv3d_pad_mode():
|
||||
"""
|
||||
Feature: test pad mode and overlap is non-negative
|
||||
Description: shard d/h
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 1, 2, 4, 1), (1, 1, 1, 1, 1))
|
||||
strategy2 = ((1, 1, 2, 4, 1),)
|
||||
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="pad", stride=1, pad=(3, 3, 3, 3, 3, 3),
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
phase = compile_net(net, _x3, _b)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_node_inputs('Neg-0', ['Conv3D-0'])
|
||||
assert validator.check_node_inputs('Conv3D-0', ['Transpose-1', 'Load-0'])
|
||||
assert validator.check_node_inputs_fuzzy_match('Transpose-0', ['StridedSlice', '(4, 0, 1, 2, 3)'])
|
||||
assert validator.check_node_inputs_fuzzy_match('Transpose-1', ['Reshape', '(1, 2, 3, 4, 0)'])
|
||||
assert validator.check_node_inputs_fuzzy_match('Reshape-0', ['Transpose', '(512, 16, 8, 4)'])
|
||||
assert validator.check_node_inputs_fuzzy_match('Reshape-1', ['NeighborExchangeV2', '(16, 32, 16, 9, 4)'])
|
||||
assert validator.check_node_attrs('NeighborExchangeV2-0', {'send_lens': '[0, 1, 0, 2]'})
|
||||
|
||||
|
||||
def test_conv3d_pad_mode_unet_3d_rank0():
|
||||
"""
|
||||
Feature: test pad mode unet 3d
|
||||
Description: shard d/h
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 1, 2, 4, 1), (1, 1, 1, 1, 1))
|
||||
strategy2 = ((1, 1, 2, 4, 1),)
|
||||
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="pad", stride=2, pad=1,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
phase = compile_net(net, _x4, _b)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_node_attrs('NeighborExchangeV2-0', {'send_lens': '[0, 1, 0, 1]'})
|
||||
assert validator.check_node_attrs('NeighborExchangeV2-0', {'recv_lens': '[0, 0, 0, 0]'})
|
||||
|
||||
|
||||
def test_conv3d_pad_mode_unet_3d_rank1():
|
||||
"""
|
||||
Feature: test pad mode unet 3d
|
||||
Description: shard d/h
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1)
|
||||
strategy1 = ((1, 1, 2, 4, 1), (1, 1, 1, 1, 1))
|
||||
strategy2 = ((1, 1, 2, 4, 1),)
|
||||
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="pad", stride=2, pad=1,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
phase = compile_net(net, _x4, _b)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_node_attrs('NeighborExchangeV2-0', {'send_lens': '[0, 1, 0, 1]'})
|
||||
assert validator.check_node_attrs('NeighborExchangeV2-0', {'recv_lens': '[0, 0, 1, 0]'})
|
||||
|
||||
|
||||
def test_conv3d_pad_mode_unet_3d_rank7():
|
||||
"""
|
||||
Feature: test pad mode unet 3d
|
||||
Description: shard d/h
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=7)
|
||||
strategy1 = ((1, 1, 2, 4, 1), (1, 1, 1, 1, 1))
|
||||
strategy2 = ((1, 1, 2, 4, 1),)
|
||||
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="pad", stride=2, pad=1,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
phase = compile_net(net, _x4, _b)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_node_attrs('NeighborExchangeV2-0', {'send_lens': '[0, 0, 0, 0]'})
|
||||
assert validator.check_node_attrs('NeighborExchangeV2-0', {'recv_lens': '[1, 0, 1, 0]'})
|
||||
|
||||
|
||||
def test_conv3d_valid_mode_output_shape_cannot_div_by_strategy():
|
||||
"""
|
||||
Feature: test valid mode, and output shape can not div by strategy
|
||||
Description: shard d
|
||||
Expectation: compile failed
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 1, 1, 8, 1), (1, 1, 1, 1, 1))
|
||||
strategy2 = ((1, 1, 1, 1, 1),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="valid", stride=4,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, _x3, _b)
|
Loading…
Reference in New Issue