diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc index 7d9cf452273..b3bba806704 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc @@ -31,20 +31,6 @@ namespace mindspore { namespace parallel { -namespace { -ValuePtr MakeListValue(const std::vector &v) { - std::vector list; - (void)std::transform(v.begin(), v.end(), std::back_inserter(list), [](int64_t ele) { return MakeValue(ele); }); - return std::make_shared(list); -} - -ValuePtr MakeTupleListValue(const Shapes &v) { - std::vector tuple; - (void)std::transform(v.begin(), v.end(), std::back_inserter(tuple), - [](const std::vector &list) { return MakeListValue(list); }); - return std::make_shared(tuple); -} -} // namespace Status Conv2DInfo::GetAttrsBase() { // format format_ = GetStringAttr(FORMAT); @@ -653,8 +639,11 @@ OperatorAttrs Conv2DInfo::CreateNeighborExchangeAttrs(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(tensor_type); auto dtype = tensor_type->element(); MS_EXCEPTION_IF_NULL(dtype); - Attr send_ranks = {SEND_RNAK_IDS, MakeListValue(send_rank_ids_)}; - Attr recv_ranks = {RECV_RNAK_IDS, MakeListValue(recv_rank_ids_)}; + + // the type of send_rank_ids, recv_rank_ids, send_shapes, recv_shapes is list, is not tuple, can not use MakeValue + // the MakeValue(vector) return a tuple + Attr send_ranks = {SEND_RANK_IDS, MakeListValue(send_rank_ids_)}; + Attr recv_ranks = {RECV_RANK_IDS, MakeListValue(recv_rank_ids_)}; Attr send_shapes = {SEND_SHAPES, MakeTupleListValue(send_shapes_)}; Attr recv_shapes = {RECV_SHAPES, MakeTupleListValue(recv_shapes_)}; Attr recv_type = {RECV_TYPE, dtype}; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index e203b1829d2..26174cfbe22 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -1890,5 +1890,24 @@ std::vector GetValueSequeue(const ValuePtr &sequeue) { auto val = sequeue->cast(); return val->value(); } + +ValuePtr MakeListValue(const std::vector &v) { + std::vector list; + (void)std::transform(v.begin(), v.end(), std::back_inserter(list), [](int64_t ele) { return MakeValue(ele); }); + return std::make_shared(list); +} + +ValuePtr MakeTupleListValue(const Shapes &v) { + std::vector tuple; + (void)std::transform(v.begin(), v.end(), std::back_inserter(tuple), + [](const std::vector &list) { return MakeListValue(list); }); + return std::make_shared(tuple); +} + +AnfNodePtr CreateValueTupleAnfNodePtr(const std::vector &value_tuple) { + auto value_ptr = MakeValue(value_tuple)->cast(); + auto value_node = NewValueNode(value_ptr); + return value_node->cast(); +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index a4f64b82fde..5ce0aa74c6a 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -343,6 +343,9 @@ Status GenerateStrategiesWithBroadcast(int64_t stage_id, const Shapes &inputs_sh Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph); std::vector GetValueSequeue(const ValuePtr &sequeue); +ValuePtr MakeListValue(const std::vector &v); +ValuePtr MakeTupleListValue(const Shapes &v); +AnfNodePtr CreateValueTupleAnfNodePtr(const std::vector &value_tuple); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 98d85bdafa6..278cb314559 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -172,11 +172,17 @@ constexpr char REPLACE[] = "replace"; constexpr char CONNSYMBOL[] = "/"; constexpr char INSTANCE_NAME[] = "instance_name"; constexpr char SPLIT_SENS[] = "split_sens"; -constexpr char SEND_RNAK_IDS[] = "send_rank_ids"; -constexpr char RECV_RNAK_IDS[] = "recv_rank_ids"; +constexpr char SEND_RANK_IDS[] = "send_rank_ids"; +constexpr char RECV_RANK_IDS[] = "recv_rank_ids"; constexpr char RECV_SHAPES[] = "recv_shapes"; constexpr char SEND_SHAPES[] = "send_shapes"; constexpr char RECV_TYPE[] = "recv_type"; +constexpr char SEND_LENS[] = "send_lens"; +constexpr char RECV_LENS[] = "recv_lens"; +constexpr char ORI_IMAGE_SIZE[] = "ori_image_size"; +constexpr char SPLIT_SIZE[] = "split_size"; +constexpr char SRC_START_W[] = "src_start_w"; +constexpr char DST_START_W[] = "dst_start_w"; constexpr char SPLIT_TENSOR[] = "split_tensor"; constexpr char DEV_MAT[] = "dev_mat"; constexpr char TENSOR_MAP[] = "tensor_map"; @@ -238,6 +244,7 @@ constexpr char SPLIT[] = "Split"; constexpr char ALL_TO_ALL[] = "AlltoAll"; constexpr char NEIGHBOREXCHANGE[] = "NeighborExchange"; constexpr char NEIGHBOREXCHANGEV2[] = "NeighborExchangeV2"; +constexpr char PARALLEL_RESIZE_BILINEAR[] = "ParallelResizeBilinear"; constexpr char PERMUTE_BY_AXIS[] = "PermuteByAxis"; constexpr char CONCAT_BY_AXIS[] = "ConcatByAxis"; constexpr char SPLIT_BY_AXIS[] = "SplitByAxis"; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/resizebilinear_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/resizebilinear_info.cc index 8031c2cad09..4931c767cd2 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/resizebilinear_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/resizebilinear_info.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include "frontend/parallel/device_matrix.h" #include "frontend/parallel/strategy.h" @@ -28,6 +29,8 @@ namespace mindspore { namespace parallel { +// ResizeBilinear: support to split N/C/W +// ResizeNearestNeighbor: support to split N/C/H/W if align_corners=False, support to split N/C if align_corners=True Status ResizeBilinearInfo::GetAttrs() { size_ = GetTupleIntAttr(SIZE); if (size_.size() != 2) { @@ -43,8 +46,9 @@ Status ResizeBilinearInfo::GetAttrs() { Status ResizeBilinearInfo::CheckStrategy(const StrategyPtr &strategy) { MS_EXCEPTION_IF_NULL(strategy); + need_exchange_overlap_ = false; if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Invalid strategy"; + MS_LOG(ERROR) << name_ << ": Check input strategy failed"; return FAILED; } @@ -60,8 +64,19 @@ Status ResizeBilinearInfo::CheckStrategy(const StrategyPtr &strategy) { return FAILED; } - if (input_strategy[2] != 1 || input_strategy[3] != 1) { - MS_LOG(ERROR) << name_ << ": Do not support split from H or W"; + if (input_strategy[2] != 1) { + MS_LOG(ERROR) << name_ << ": Do not support split H dimension"; + return FAILED; + } + + if (input_strategy[3] != 1) { + need_exchange_overlap_ = true; + MS_LOG(INFO) << name_ << ": Split the w dimension"; + } + + // check output strategy + if (CheckStrategyValue(strategy, outputs_shape_) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Check output strategy failed"; return FAILED; } @@ -87,6 +102,7 @@ Status ResizeBilinearInfo::InferDevMatrixShape() { slice_size_ = size_; slice_size_[0] = slice_size_[0] / dev_matrix_shape_[2]; slice_size_[1] = slice_size_[1] / dev_matrix_shape_[3]; + w_dimension_shard_num_ = dev_matrix_shape_[3]; return SUCCESS; } @@ -119,10 +135,282 @@ std::vector ResizeBilinearInfo::GenerateOpStrategies(int64_t stage_ } void ResizeBilinearInfo::ReplaceNodeInputOrAttrs() { + // if need exchange overlap, use replace_graph() + if (need_exchange_overlap_) { + return; + } auto prim = GetValueNode(cnode_->input(0)); prim->set_attr(SIZE, MakeValue(slice_size_)); } +Status ResizeBilinearInfo::InferRankBias() { + // the origin dev_matrix is [n, c, h, w] + // if repeated calculation + // 1) repeated num in the left of dev matrix, the dev_matrix is [repeated_num, n, c, h, w] + // 2) repeated num in the right of dev matrix, the dev_matrix is [n, c, h, w, repeated_num] + uint64_t w_index_in_dev_matrix = 3; + if (repeated_calc_num_ > 1 && !repeated_num_in_dev_matrix_right_) { + w_index_in_dev_matrix += 1; + } + + CheckGlobalDeviceManager(); + int64_t rank = g_device_manager->global_rank(); + DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_); + RankList group_devices; + if (dev_matrix.GetDevicesAlongDim(w_index_in_dev_matrix, &group_devices) != SUCCESS) { + return FAILED; + } + + if (group_devices.size() <= 1) { + MS_LOG(INFO) << name_ << ": The devices' size of w dimension is " << group_devices.size() + << ", no need to infer rank bias"; + return SUCCESS; + } + + if (group_devices.size() != LongToSize(w_dimension_shard_num_)) { + MS_LOG(ERROR) << name_ << ": The devices' size of w dimension is " << group_devices.size() + << ", but the shard num of w dimension is " << w_dimension_shard_num_; + return FAILED; + } + + std::vector::iterator it = std::find(group_devices.begin(), group_devices.end(), rank); + if (it == group_devices.end()) { + MS_LOG(ERROR) << name_ << ": Can not find the current rank in device list of w dimension, the current rank is " + << rank << ", the device list is " << group_devices; + return FAILED; + } + + rank_bias_ = std::distance(group_devices.begin(), it); + if (it == group_devices.begin()) { + // the current rank is on the left boundary + left_rank_bias_ = -1; + right_rank_bias_ = rank_bias_ + 1; + + left_rank_id_ = -1; + right_rank_id_ = *(it + 1); + } else if (it == group_devices.end() - 1) { + // the current rank is on the right boundary + left_rank_bias_ = rank_bias_ - 1; + right_rank_bias_ = -1; + + left_rank_id_ = *(it - 1); + right_rank_id_ = -1; + } else { + // the current rank is middle rank + left_rank_bias_ = rank_bias_ - 1; + right_rank_bias_ = rank_bias_ + 1; + + left_rank_id_ = *(it - 1); + right_rank_id_ = *(it + 1); + } + + Group group = g_device_manager->CreateGroup(group_devices); + all_to_all_group_ = group.name(); + + MS_LOG(INFO) << name_ << ": The current rank is " << rank << ", the device list of w dimension is " << group_devices + << ", the rank bias is " << rank_bias_ << ", the left rank bias is " << left_rank_bias_ + << ", the right rank bias is " << right_rank_bias_ << ", the left rank id is " << left_rank_id_ + << ", the right rank id is " << right_rank_id_ << ", the all to all group is " << all_to_all_group_; + return SUCCESS; +} + +void ResizeBilinearInfo::InferScale() { + origin_in_w_shape_ = inputs_shape_[0][3]; + origin_out_w_shape_ = outputs_shape_[0][3]; + + if (origin_out_w_shape_ == 1) { + MS_LOG(EXCEPTION) << name_ << ": Do not support that the w dimension of output shape is 1"; + } + + if (align_corners_) { + w_scale_ = LongToDouble(origin_in_w_shape_ - 1) / LongToDouble(origin_out_w_shape_ - 1); + } else { + w_scale_ = LongToDouble(origin_in_w_shape_) / LongToDouble(origin_out_w_shape_); + } + + MS_LOG(INFO) << name_ << ": The scale is " << w_scale_; +} + +int64_t ResizeBilinearInfo::InferOverlapLeftSizeByRankBias(int64_t rank_bias) { + // left_overlap_size = (rank * ori_in_w / w_shard) - floor(scale * rank * slice_w) + int64_t map_left_boundary = DoubleToLong(std::floor(w_scale_ * rank_bias * slice_size_[1])); + int64_t local_left_boundary = rank_bias * origin_in_w_shape_ / w_dimension_shard_num_; + + if (map_left_boundary > local_left_boundary) { + MS_LOG(EXCEPTION) << name_ << ": Invalid left overlap, the rank bias is " << rank_bias << ", the map boundary is " + << map_left_boundary << ", the local boundary is " << local_left_boundary; + } + return local_left_boundary - map_left_boundary; +} + +int64_t ResizeBilinearInfo::InferOverlapRightSizeByRankBias(int64_t rank_bias) { + // right_overlap_size = ceil(scale * (rank + 1) * slice_w - 1) - ((rank + 1) * ori_in_w / w_shard - 1) + int64_t map_right_boundary = DoubleToLong(std::ceil(w_scale_ * ((rank_bias + 1) * slice_size_[1] - 1))); + int64_t local_right_boundary = (rank_bias + 1) * origin_in_w_shape_ / w_dimension_shard_num_ - 1; + + // need to handle this special condition + if (map_right_boundary > origin_in_w_shape_ - 1) { + map_right_boundary = origin_in_w_shape_ - 1; + } + + if (map_right_boundary < local_right_boundary) { + MS_LOG(EXCEPTION) << name_ << ": Invalid right overlap, the rank bias is " << rank_bias << ", the map boundary is " + << map_right_boundary << ", the local boundary is " << local_right_boundary; + } + + return map_right_boundary - local_right_boundary; +} + +void ResizeBilinearInfo::InferOverlapSize() { + overlap_left_size_ = InferOverlapLeftSizeByRankBias(rank_bias_); + overlap_right_size_ = InferOverlapRightSizeByRankBias(rank_bias_); + + if (rank_bias_ == 0) { + // it has not left rank + left_rank_overlap_right_size_ = 0; + right_rank_overlap_left_size_ = InferOverlapLeftSizeByRankBias(right_rank_bias_); + } else if (rank_bias_ == w_dimension_shard_num_ - 1) { + // it has not right rank + left_rank_overlap_right_size_ = InferOverlapRightSizeByRankBias(left_rank_bias_); + right_rank_overlap_left_size_ = 0; + } else { + // it has left rank and right rank + left_rank_overlap_right_size_ = InferOverlapRightSizeByRankBias(left_rank_bias_); + right_rank_overlap_left_size_ = InferOverlapLeftSizeByRankBias(right_rank_bias_); + } + + MS_LOG(INFO) << name_ << ": the left overlap size of current rank is " << overlap_left_size_ + << ", the right overlap size of current rank is " << overlap_right_size_ + << ", the right overlap size of left rank is " << left_rank_overlap_right_size_ + << ", the left overlap size of right rank is " << right_rank_overlap_left_size_; +} + +void ResizeBilinearInfo::InferCommunicationAttrs() { + // send rank ids: [-1, -1, send_right_rank, -1, -1, -1, send_left_rank, -1] + // recv rank ids: [-1, -1, recv_right_rank, -1, -1, -1, recv_left_rank, -1] + // send lens: [0, 0, send_left_len, send_right_len] + // recv lens: [0, 0, recv_left_len, recv_right_len] + int64_t send_right_rank = -1, send_left_rank = -1, recv_right_rank = -1, recv_left_rank = -1; + int64_t send_left_len = 0, send_right_len = 0, recv_left_len = 0, recv_right_len = 0; + + if (rank_bias_ == 0) { + // the first rank + send_right_len = right_rank_overlap_left_size_; + send_right_rank = send_right_len > 0 ? right_rank_id_ : -1; + + recv_right_len = overlap_right_size_; + recv_right_rank = recv_right_len > 0 ? right_rank_id_ : -1; + } else if (rank_bias_ == w_dimension_shard_num_ - 1) { + // the last rank + send_left_len = left_rank_overlap_right_size_; + send_left_rank = send_left_len > 0 ? left_rank_id_ : -1; + + recv_left_len = overlap_left_size_; + recv_left_rank = recv_left_len > 0 ? left_rank_id_ : -1; + } else { + // the middle rank + send_right_len = right_rank_overlap_left_size_; + send_right_rank = send_right_len > 0 ? right_rank_id_ : -1; + + recv_right_len = overlap_right_size_; + recv_right_rank = recv_right_len > 0 ? right_rank_id_ : -1; + send_left_len = left_rank_overlap_right_size_; + send_left_rank = send_left_len > 0 ? left_rank_id_ : -1; + + recv_left_len = overlap_left_size_; + recv_left_rank = recv_left_len > 0 ? left_rank_id_ : -1; + } + + send_rank_ids_ = {-1, -1, send_right_rank, -1, -1, -1, send_left_rank, -1}; + recv_rank_ids_ = {-1, -1, recv_right_rank, -1, -1, -1, recv_left_rank, -1}; + send_lens_ = {0, 0, send_left_len, send_right_len}; + recv_lens_ = {0, 0, recv_left_len, recv_right_len}; + MS_LOG(INFO) << name_ << ": The send rank ids is " << send_rank_ids_ << ", the send lens is " << send_lens_ + << ", the recv rank ids is " << recv_rank_ids_ << ", the recv lens is " << recv_lens_; +} + +void ResizeBilinearInfo::InferResizeBilinearV2Attrs() { + origin_image_size_ = {inputs_shape_[0][2], inputs_shape_[0][3]}; + src_start_w_ = DoubleToLong(std::floor(w_scale_ * rank_bias_ * slice_size_[1])); + dst_start_w_ = rank_bias_ * slice_size_[1]; + + MS_LOG(INFO) << name_ << ": The origin image size is " << origin_image_size_ << ", src start index is " + << src_start_w_ << ", dst start index is " << dst_start_w_; +} + +void ResizeBilinearInfo::InferNewOperatorAttrs() { + InferCommunicationAttrs(); + InferResizeBilinearV2Attrs(); +} + +OperatorAttrs ResizeBilinearInfo::CreateNeighborExchangeV2Attrs() { + // the type of send_rank_ids, recv_rank_ids, send_lens, recv_lens is list, is not tuple, can not use MakeValue + // the MakeValue(vector) return a tuple + Attr send_rank_ids = {SEND_RANK_IDS, MakeListValue(send_rank_ids_)}; + Attr send_lens = {SEND_LENS, MakeListValue(send_lens_)}; + Attr recv_rank_ids = {RECV_RANK_IDS, MakeListValue(recv_rank_ids_)}; + Attr recv_lens = {RECV_LENS, MakeListValue(recv_lens_)}; + Attr data_format = {DATA_FORMAT, MakeValue(NCHW)}; + Attr group = {GROUP, MakeValue(all_to_all_group_)}; + + OperatorAttrs attrs = {send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format, group}; + return attrs; +} + +OperatorAttrs ResizeBilinearInfo::CreateParallelResizeBilinearAttrs() { + Attr ori_image_size = {ORI_IMAGE_SIZE, MakeValue(origin_image_size_)}; + Attr split_size = {SPLIT_SIZE, MakeValue(slice_size_)}; + Attr src_start_w = {SRC_START_W, MakeValue(src_start_w_)}; + Attr dst_start_w = {DST_START_W, MakeValue(dst_start_w_)}; + Attr align_corners = {ALIGN_CORNERS, MakeValue(align_corners_)}; + + OperatorAttrs attrs = {ori_image_size, split_size, src_start_w, dst_start_w, align_corners}; + return attrs; +} + +void ResizeBilinearInfo::InferReplaceGraph(const CNodePtr &cnode) { + auto graph = cnode->func_graph(); + MS_EXCEPTION_IF_NULL(graph); + + GenerateGraph gen_g = GenerateGraph(attrs_); + if (gen_g.Init(cnode) != SUCCESS) { + MS_LOG(EXCEPTION) << name_ << ": Init generator graph failed"; + } + + auto neighbor_exchange_v2_attrs = CreateNeighborExchangeV2Attrs(); + auto neighbor_exchange_v2_node = + gen_g.PushBack({gen_g.NewOpInst(NEIGHBOREXCHANGEV2, neighbor_exchange_v2_attrs), gen_g.virtual_input_node()}); + + auto size = CreateValueTupleAnfNodePtr(size_); + auto parallel_resize_bilinear_attrs = CreateParallelResizeBilinearAttrs(); + auto parallel_resize_bilinear_node = gen_g.PushBack( + {gen_g.NewOpInst(PARALLEL_RESIZE_BILINEAR, parallel_resize_bilinear_attrs), neighbor_exchange_v2_node, size}); + + std::vector> input_nodes = {std::make_pair(neighbor_exchange_v2_node, 1)}; + replace_graph_ = std::make_shared>, AnfNodePtr>>( + std::make_pair(input_nodes, parallel_resize_bilinear_node)); +} + +ReplaceGraphPtr ResizeBilinearInfo::replace_graph(const CNodePtr &cnode) { + if (!need_exchange_overlap_) { + return nullptr; + } + + if (InferRankBias() != SUCCESS) { + MS_LOG(EXCEPTION) << name_ << ": infer rank bias failed"; + } + + InferScale(); + + InferOverlapSize(); + + InferNewOperatorAttrs(); + + InferReplaceGraph(cnode); + + return replace_graph_; +} + Status ResizeNearestNeighborInfo::CheckStrategy(const StrategyPtr &strategy) { MS_EXCEPTION_IF_NULL(strategy); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/resizebilinear_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/resizebilinear_info.h index 173004b84eb..9370fbd5365 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/resizebilinear_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/resizebilinear_info.h @@ -23,6 +23,7 @@ #include "utils/hash_map.h" #include "ir/value.h" +#include "frontend/parallel/graph_util/generate_graph.h" #include "frontend/parallel/auto_parallel/operator_costmodel.h" #include "frontend/parallel/ops_info/operator_info.h" #include "frontend/parallel/strategy.h" @@ -46,10 +47,66 @@ class ResizeBilinearInfo : public OperatorInfo { Status InferDevMatrixShape() override; Status InferTensorMap() override; void ReplaceNodeInputOrAttrs() override; + ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; std::vector size_; std::vector slice_size_; bool align_corners_ = false; + bool need_exchange_overlap_ = false; + + private: + Status InferRankBias(); + void InferOverlapSize(); + void InferScale(); + void InferNewOperatorAttrs(); + void InferCommunicationAttrs(); + void InferResizeBilinearV2Attrs(); + void InferReplaceGraph(const CNodePtr &cnode); + int64_t InferOverlapLeftSizeByRankBias(int64_t rank_bias); + int64_t InferOverlapRightSizeByRankBias(int64_t rank_bias); + + OperatorAttrs CreateNeighborExchangeV2Attrs(); + OperatorAttrs CreateParallelResizeBilinearAttrs(); + + // rank_bias_ is the position of the current rank in the w dimension of the dev_matrix(have not split h dimension) + int64_t rank_bias_ = 0; + + int64_t left_rank_bias_ = -1; + int64_t right_rank_bias_ = -1; + int64_t left_rank_id_ = -1; + int64_t right_rank_id_ = -1; + int64_t overlap_left_size_ = 0; + int64_t overlap_right_size_ = 0; + int64_t left_rank_overlap_right_size_ = 0; + int64_t right_rank_overlap_left_size_ = 0; + + int64_t origin_in_w_shape_ = 1; + int64_t origin_out_w_shape_ = 1; + int64_t w_dimension_shard_num_ = 1; + + // the send_rank_ids_ or recv_rank_ids is an array with 8 rank ids, the order of index in the array is organized in + // the following format(the 'R' is current rank) + // +++++++++++++ + // | 7 | 0 | 1 | + // +++++++++++++ + // | 6 | R | 2 | + // +++++++++++++ + // | 5 | 4 | 3 | + // +++++++++++++ + std::vector send_rank_ids_; // 8 rank ids + std::vector recv_rank_ids_; // 8 rank ids + + // the send_lens_ or recv_lens_ is an array with 4 lens, the order in the array represents top, bottom, left, right + std::vector send_lens_; // [top, bottom, left, right] + std::vector recv_lens_; // [top, bottom, left, right] + + std::string all_to_all_group_; + + std::vector origin_image_size_; // [H, W] + int64_t src_start_w_ = 0; + int64_t dst_start_w_ = 0; + + double w_scale_ = 1.0; // the scale in w dimension, now only support to split w dimension }; class ResizeNearestNeighborInfo : public ResizeBilinearInfo { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/slice_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/slice_info.cc index 3d91cc7e9e3..d2d73c6871f 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/slice_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/slice_info.cc @@ -184,12 +184,6 @@ ReplaceGraphPtr SliceInfo::replace_graph(const CNodePtr &cnode) { return replace_graph_; } -AnfNodePtr CreateValueTupleAndNodePtr(const std::vector &value_tuple) { - auto value_ptr = MakeValue(value_tuple)->cast(); - auto value_node = NewValueNode(value_ptr); - return value_node->cast(); -} - Status SliceInfo::ComputeReplaceGraph(const CNodePtr &cnode) { GenerateGraph gen_g = GenerateGraph(attrs_); if (gen_g.Init(cnode) != SUCCESS) { @@ -207,8 +201,8 @@ Status SliceInfo::ComputeReplaceGraph(const CNodePtr &cnode) { sliced_size_shape_int.push_back(input_slice_shape[i]); } } - auto new_begin = CreateValueTupleAndNodePtr(begin_); - auto new_size = CreateValueTupleAndNodePtr(sliced_size_shape_int); + auto new_begin = CreateValueTupleAnfNodePtr(begin_); + auto new_size = CreateValueTupleAnfNodePtr(sliced_size_shape_int); auto slice = gen_g.PushBack({gen_g.NewOpInst(SLICE), gen_g.virtual_input_node(), new_begin, new_size});