add parallel op for resize bilinear
This commit is contained in:
parent
7ebfbb0278
commit
846db9206f
|
@ -31,20 +31,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
namespace {
|
||||
ValuePtr MakeListValue(const std::vector<int64_t> &v) {
|
||||
std::vector<ValuePtr> list;
|
||||
(void)std::transform(v.begin(), v.end(), std::back_inserter(list), [](int64_t ele) { return MakeValue(ele); });
|
||||
return std::make_shared<ValueSequeue>(list);
|
||||
}
|
||||
|
||||
ValuePtr MakeTupleListValue(const Shapes &v) {
|
||||
std::vector<ValuePtr> tuple;
|
||||
(void)std::transform(v.begin(), v.end(), std::back_inserter(tuple),
|
||||
[](const std::vector<int64_t> &list) { return MakeListValue(list); });
|
||||
return std::make_shared<ValueTuple>(tuple);
|
||||
}
|
||||
} // namespace
|
||||
Status Conv2DInfo::GetAttrsBase() {
|
||||
// format
|
||||
format_ = GetStringAttr(FORMAT);
|
||||
|
@ -653,8 +639,11 @@ OperatorAttrs Conv2DInfo::CreateNeighborExchangeAttrs(const CNodePtr &cnode) {
|
|||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto dtype = tensor_type->element();
|
||||
MS_EXCEPTION_IF_NULL(dtype);
|
||||
Attr send_ranks = {SEND_RNAK_IDS, MakeListValue(send_rank_ids_)};
|
||||
Attr recv_ranks = {RECV_RNAK_IDS, MakeListValue(recv_rank_ids_)};
|
||||
|
||||
// the type of send_rank_ids, recv_rank_ids, send_shapes, recv_shapes is list, is not tuple, can not use MakeValue
|
||||
// the MakeValue(vector) return a tuple
|
||||
Attr send_ranks = {SEND_RANK_IDS, MakeListValue(send_rank_ids_)};
|
||||
Attr recv_ranks = {RECV_RANK_IDS, MakeListValue(recv_rank_ids_)};
|
||||
Attr send_shapes = {SEND_SHAPES, MakeTupleListValue(send_shapes_)};
|
||||
Attr recv_shapes = {RECV_SHAPES, MakeTupleListValue(recv_shapes_)};
|
||||
Attr recv_type = {RECV_TYPE, dtype};
|
||||
|
|
|
@ -1890,5 +1890,24 @@ std::vector<ValuePtr> GetValueSequeue(const ValuePtr &sequeue) {
|
|||
auto val = sequeue->cast<ValueListPtr>();
|
||||
return val->value();
|
||||
}
|
||||
|
||||
ValuePtr MakeListValue(const std::vector<int64_t> &v) {
|
||||
std::vector<ValuePtr> list;
|
||||
(void)std::transform(v.begin(), v.end(), std::back_inserter(list), [](int64_t ele) { return MakeValue(ele); });
|
||||
return std::make_shared<ValueSequeue>(list);
|
||||
}
|
||||
|
||||
ValuePtr MakeTupleListValue(const Shapes &v) {
|
||||
std::vector<ValuePtr> tuple;
|
||||
(void)std::transform(v.begin(), v.end(), std::back_inserter(tuple),
|
||||
[](const std::vector<int64_t> &list) { return MakeListValue(list); });
|
||||
return std::make_shared<ValueTuple>(tuple);
|
||||
}
|
||||
|
||||
AnfNodePtr CreateValueTupleAnfNodePtr(const std::vector<int64_t> &value_tuple) {
|
||||
auto value_ptr = MakeValue(value_tuple)->cast<ValueTuplePtr>();
|
||||
auto value_node = NewValueNode(value_ptr);
|
||||
return value_node->cast<AnfNodePtr>();
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -343,6 +343,9 @@ Status GenerateStrategiesWithBroadcast(int64_t stage_id, const Shapes &inputs_sh
|
|||
|
||||
Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph);
|
||||
std::vector<ValuePtr> GetValueSequeue(const ValuePtr &sequeue);
|
||||
ValuePtr MakeListValue(const std::vector<int64_t> &v);
|
||||
ValuePtr MakeTupleListValue(const Shapes &v);
|
||||
AnfNodePtr CreateValueTupleAnfNodePtr(const std::vector<int64_t> &value_tuple);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -172,11 +172,17 @@ constexpr char REPLACE[] = "replace";
|
|||
constexpr char CONNSYMBOL[] = "/";
|
||||
constexpr char INSTANCE_NAME[] = "instance_name";
|
||||
constexpr char SPLIT_SENS[] = "split_sens";
|
||||
constexpr char SEND_RNAK_IDS[] = "send_rank_ids";
|
||||
constexpr char RECV_RNAK_IDS[] = "recv_rank_ids";
|
||||
constexpr char SEND_RANK_IDS[] = "send_rank_ids";
|
||||
constexpr char RECV_RANK_IDS[] = "recv_rank_ids";
|
||||
constexpr char RECV_SHAPES[] = "recv_shapes";
|
||||
constexpr char SEND_SHAPES[] = "send_shapes";
|
||||
constexpr char RECV_TYPE[] = "recv_type";
|
||||
constexpr char SEND_LENS[] = "send_lens";
|
||||
constexpr char RECV_LENS[] = "recv_lens";
|
||||
constexpr char ORI_IMAGE_SIZE[] = "ori_image_size";
|
||||
constexpr char SPLIT_SIZE[] = "split_size";
|
||||
constexpr char SRC_START_W[] = "src_start_w";
|
||||
constexpr char DST_START_W[] = "dst_start_w";
|
||||
constexpr char SPLIT_TENSOR[] = "split_tensor";
|
||||
constexpr char DEV_MAT[] = "dev_mat";
|
||||
constexpr char TENSOR_MAP[] = "tensor_map";
|
||||
|
@ -238,6 +244,7 @@ constexpr char SPLIT[] = "Split";
|
|||
constexpr char ALL_TO_ALL[] = "AlltoAll";
|
||||
constexpr char NEIGHBOREXCHANGE[] = "NeighborExchange";
|
||||
constexpr char NEIGHBOREXCHANGEV2[] = "NeighborExchangeV2";
|
||||
constexpr char PARALLEL_RESIZE_BILINEAR[] = "ParallelResizeBilinear";
|
||||
constexpr char PERMUTE_BY_AXIS[] = "PermuteByAxis";
|
||||
constexpr char CONCAT_BY_AXIS[] = "ConcatByAxis";
|
||||
constexpr char SPLIT_BY_AXIS[] = "SplitByAxis";
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
|
@ -28,6 +29,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
// ResizeBilinear: support to split N/C/W
|
||||
// ResizeNearestNeighbor: support to split N/C/H/W if align_corners=False, support to split N/C if align_corners=True
|
||||
Status ResizeBilinearInfo::GetAttrs() {
|
||||
size_ = GetTupleIntAttr(SIZE);
|
||||
if (size_.size() != 2) {
|
||||
|
@ -43,8 +46,9 @@ Status ResizeBilinearInfo::GetAttrs() {
|
|||
|
||||
Status ResizeBilinearInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
MS_EXCEPTION_IF_NULL(strategy);
|
||||
need_exchange_overlap_ = false;
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
||||
MS_LOG(ERROR) << name_ << ": Check input strategy failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
|
@ -60,8 +64,19 @@ Status ResizeBilinearInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Do not support split from H or W";
|
||||
if (input_strategy[2] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Do not support split H dimension";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (input_strategy[3] != 1) {
|
||||
need_exchange_overlap_ = true;
|
||||
MS_LOG(INFO) << name_ << ": Split the w dimension";
|
||||
}
|
||||
|
||||
// check output strategy
|
||||
if (CheckStrategyValue(strategy, outputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Check output strategy failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
|
@ -87,6 +102,7 @@ Status ResizeBilinearInfo::InferDevMatrixShape() {
|
|||
slice_size_ = size_;
|
||||
slice_size_[0] = slice_size_[0] / dev_matrix_shape_[2];
|
||||
slice_size_[1] = slice_size_[1] / dev_matrix_shape_[3];
|
||||
w_dimension_shard_num_ = dev_matrix_shape_[3];
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -119,10 +135,282 @@ std::vector<StrategyPtr> ResizeBilinearInfo::GenerateOpStrategies(int64_t stage_
|
|||
}
|
||||
|
||||
void ResizeBilinearInfo::ReplaceNodeInputOrAttrs() {
|
||||
// if need exchange overlap, use replace_graph()
|
||||
if (need_exchange_overlap_) {
|
||||
return;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode_->input(0));
|
||||
prim->set_attr(SIZE, MakeValue(slice_size_));
|
||||
}
|
||||
|
||||
Status ResizeBilinearInfo::InferRankBias() {
|
||||
// the origin dev_matrix is [n, c, h, w]
|
||||
// if repeated calculation
|
||||
// 1) repeated num in the left of dev matrix, the dev_matrix is [repeated_num, n, c, h, w]
|
||||
// 2) repeated num in the right of dev matrix, the dev_matrix is [n, c, h, w, repeated_num]
|
||||
uint64_t w_index_in_dev_matrix = 3;
|
||||
if (repeated_calc_num_ > 1 && !repeated_num_in_dev_matrix_right_) {
|
||||
w_index_in_dev_matrix += 1;
|
||||
}
|
||||
|
||||
CheckGlobalDeviceManager();
|
||||
int64_t rank = g_device_manager->global_rank();
|
||||
DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
|
||||
RankList group_devices;
|
||||
if (dev_matrix.GetDevicesAlongDim(w_index_in_dev_matrix, &group_devices) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (group_devices.size() <= 1) {
|
||||
MS_LOG(INFO) << name_ << ": The devices' size of w dimension is " << group_devices.size()
|
||||
<< ", no need to infer rank bias";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
if (group_devices.size() != LongToSize(w_dimension_shard_num_)) {
|
||||
MS_LOG(ERROR) << name_ << ": The devices' size of w dimension is " << group_devices.size()
|
||||
<< ", but the shard num of w dimension is " << w_dimension_shard_num_;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
std::vector<int64_t>::iterator it = std::find(group_devices.begin(), group_devices.end(), rank);
|
||||
if (it == group_devices.end()) {
|
||||
MS_LOG(ERROR) << name_ << ": Can not find the current rank in device list of w dimension, the current rank is "
|
||||
<< rank << ", the device list is " << group_devices;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
rank_bias_ = std::distance(group_devices.begin(), it);
|
||||
if (it == group_devices.begin()) {
|
||||
// the current rank is on the left boundary
|
||||
left_rank_bias_ = -1;
|
||||
right_rank_bias_ = rank_bias_ + 1;
|
||||
|
||||
left_rank_id_ = -1;
|
||||
right_rank_id_ = *(it + 1);
|
||||
} else if (it == group_devices.end() - 1) {
|
||||
// the current rank is on the right boundary
|
||||
left_rank_bias_ = rank_bias_ - 1;
|
||||
right_rank_bias_ = -1;
|
||||
|
||||
left_rank_id_ = *(it - 1);
|
||||
right_rank_id_ = -1;
|
||||
} else {
|
||||
// the current rank is middle rank
|
||||
left_rank_bias_ = rank_bias_ - 1;
|
||||
right_rank_bias_ = rank_bias_ + 1;
|
||||
|
||||
left_rank_id_ = *(it - 1);
|
||||
right_rank_id_ = *(it + 1);
|
||||
}
|
||||
|
||||
Group group = g_device_manager->CreateGroup(group_devices);
|
||||
all_to_all_group_ = group.name();
|
||||
|
||||
MS_LOG(INFO) << name_ << ": The current rank is " << rank << ", the device list of w dimension is " << group_devices
|
||||
<< ", the rank bias is " << rank_bias_ << ", the left rank bias is " << left_rank_bias_
|
||||
<< ", the right rank bias is " << right_rank_bias_ << ", the left rank id is " << left_rank_id_
|
||||
<< ", the right rank id is " << right_rank_id_ << ", the all to all group is " << all_to_all_group_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void ResizeBilinearInfo::InferScale() {
|
||||
origin_in_w_shape_ = inputs_shape_[0][3];
|
||||
origin_out_w_shape_ = outputs_shape_[0][3];
|
||||
|
||||
if (origin_out_w_shape_ == 1) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Do not support that the w dimension of output shape is 1";
|
||||
}
|
||||
|
||||
if (align_corners_) {
|
||||
w_scale_ = LongToDouble(origin_in_w_shape_ - 1) / LongToDouble(origin_out_w_shape_ - 1);
|
||||
} else {
|
||||
w_scale_ = LongToDouble(origin_in_w_shape_) / LongToDouble(origin_out_w_shape_);
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << ": The scale is " << w_scale_;
|
||||
}
|
||||
|
||||
int64_t ResizeBilinearInfo::InferOverlapLeftSizeByRankBias(int64_t rank_bias) {
|
||||
// left_overlap_size = (rank * ori_in_w / w_shard) - floor(scale * rank * slice_w)
|
||||
int64_t map_left_boundary = DoubleToLong(std::floor(w_scale_ * rank_bias * slice_size_[1]));
|
||||
int64_t local_left_boundary = rank_bias * origin_in_w_shape_ / w_dimension_shard_num_;
|
||||
|
||||
if (map_left_boundary > local_left_boundary) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Invalid left overlap, the rank bias is " << rank_bias << ", the map boundary is "
|
||||
<< map_left_boundary << ", the local boundary is " << local_left_boundary;
|
||||
}
|
||||
return local_left_boundary - map_left_boundary;
|
||||
}
|
||||
|
||||
int64_t ResizeBilinearInfo::InferOverlapRightSizeByRankBias(int64_t rank_bias) {
|
||||
// right_overlap_size = ceil(scale * (rank + 1) * slice_w - 1) - ((rank + 1) * ori_in_w / w_shard - 1)
|
||||
int64_t map_right_boundary = DoubleToLong(std::ceil(w_scale_ * ((rank_bias + 1) * slice_size_[1] - 1)));
|
||||
int64_t local_right_boundary = (rank_bias + 1) * origin_in_w_shape_ / w_dimension_shard_num_ - 1;
|
||||
|
||||
// need to handle this special condition
|
||||
if (map_right_boundary > origin_in_w_shape_ - 1) {
|
||||
map_right_boundary = origin_in_w_shape_ - 1;
|
||||
}
|
||||
|
||||
if (map_right_boundary < local_right_boundary) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Invalid right overlap, the rank bias is " << rank_bias << ", the map boundary is "
|
||||
<< map_right_boundary << ", the local boundary is " << local_right_boundary;
|
||||
}
|
||||
|
||||
return map_right_boundary - local_right_boundary;
|
||||
}
|
||||
|
||||
void ResizeBilinearInfo::InferOverlapSize() {
|
||||
overlap_left_size_ = InferOverlapLeftSizeByRankBias(rank_bias_);
|
||||
overlap_right_size_ = InferOverlapRightSizeByRankBias(rank_bias_);
|
||||
|
||||
if (rank_bias_ == 0) {
|
||||
// it has not left rank
|
||||
left_rank_overlap_right_size_ = 0;
|
||||
right_rank_overlap_left_size_ = InferOverlapLeftSizeByRankBias(right_rank_bias_);
|
||||
} else if (rank_bias_ == w_dimension_shard_num_ - 1) {
|
||||
// it has not right rank
|
||||
left_rank_overlap_right_size_ = InferOverlapRightSizeByRankBias(left_rank_bias_);
|
||||
right_rank_overlap_left_size_ = 0;
|
||||
} else {
|
||||
// it has left rank and right rank
|
||||
left_rank_overlap_right_size_ = InferOverlapRightSizeByRankBias(left_rank_bias_);
|
||||
right_rank_overlap_left_size_ = InferOverlapLeftSizeByRankBias(right_rank_bias_);
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << ": the left overlap size of current rank is " << overlap_left_size_
|
||||
<< ", the right overlap size of current rank is " << overlap_right_size_
|
||||
<< ", the right overlap size of left rank is " << left_rank_overlap_right_size_
|
||||
<< ", the left overlap size of right rank is " << right_rank_overlap_left_size_;
|
||||
}
|
||||
|
||||
void ResizeBilinearInfo::InferCommunicationAttrs() {
|
||||
// send rank ids: [-1, -1, send_right_rank, -1, -1, -1, send_left_rank, -1]
|
||||
// recv rank ids: [-1, -1, recv_right_rank, -1, -1, -1, recv_left_rank, -1]
|
||||
// send lens: [0, 0, send_left_len, send_right_len]
|
||||
// recv lens: [0, 0, recv_left_len, recv_right_len]
|
||||
int64_t send_right_rank = -1, send_left_rank = -1, recv_right_rank = -1, recv_left_rank = -1;
|
||||
int64_t send_left_len = 0, send_right_len = 0, recv_left_len = 0, recv_right_len = 0;
|
||||
|
||||
if (rank_bias_ == 0) {
|
||||
// the first rank
|
||||
send_right_len = right_rank_overlap_left_size_;
|
||||
send_right_rank = send_right_len > 0 ? right_rank_id_ : -1;
|
||||
|
||||
recv_right_len = overlap_right_size_;
|
||||
recv_right_rank = recv_right_len > 0 ? right_rank_id_ : -1;
|
||||
} else if (rank_bias_ == w_dimension_shard_num_ - 1) {
|
||||
// the last rank
|
||||
send_left_len = left_rank_overlap_right_size_;
|
||||
send_left_rank = send_left_len > 0 ? left_rank_id_ : -1;
|
||||
|
||||
recv_left_len = overlap_left_size_;
|
||||
recv_left_rank = recv_left_len > 0 ? left_rank_id_ : -1;
|
||||
} else {
|
||||
// the middle rank
|
||||
send_right_len = right_rank_overlap_left_size_;
|
||||
send_right_rank = send_right_len > 0 ? right_rank_id_ : -1;
|
||||
|
||||
recv_right_len = overlap_right_size_;
|
||||
recv_right_rank = recv_right_len > 0 ? right_rank_id_ : -1;
|
||||
send_left_len = left_rank_overlap_right_size_;
|
||||
send_left_rank = send_left_len > 0 ? left_rank_id_ : -1;
|
||||
|
||||
recv_left_len = overlap_left_size_;
|
||||
recv_left_rank = recv_left_len > 0 ? left_rank_id_ : -1;
|
||||
}
|
||||
|
||||
send_rank_ids_ = {-1, -1, send_right_rank, -1, -1, -1, send_left_rank, -1};
|
||||
recv_rank_ids_ = {-1, -1, recv_right_rank, -1, -1, -1, recv_left_rank, -1};
|
||||
send_lens_ = {0, 0, send_left_len, send_right_len};
|
||||
recv_lens_ = {0, 0, recv_left_len, recv_right_len};
|
||||
MS_LOG(INFO) << name_ << ": The send rank ids is " << send_rank_ids_ << ", the send lens is " << send_lens_
|
||||
<< ", the recv rank ids is " << recv_rank_ids_ << ", the recv lens is " << recv_lens_;
|
||||
}
|
||||
|
||||
void ResizeBilinearInfo::InferResizeBilinearV2Attrs() {
|
||||
origin_image_size_ = {inputs_shape_[0][2], inputs_shape_[0][3]};
|
||||
src_start_w_ = DoubleToLong(std::floor(w_scale_ * rank_bias_ * slice_size_[1]));
|
||||
dst_start_w_ = rank_bias_ * slice_size_[1];
|
||||
|
||||
MS_LOG(INFO) << name_ << ": The origin image size is " << origin_image_size_ << ", src start index is "
|
||||
<< src_start_w_ << ", dst start index is " << dst_start_w_;
|
||||
}
|
||||
|
||||
void ResizeBilinearInfo::InferNewOperatorAttrs() {
|
||||
InferCommunicationAttrs();
|
||||
InferResizeBilinearV2Attrs();
|
||||
}
|
||||
|
||||
OperatorAttrs ResizeBilinearInfo::CreateNeighborExchangeV2Attrs() {
|
||||
// the type of send_rank_ids, recv_rank_ids, send_lens, recv_lens is list, is not tuple, can not use MakeValue
|
||||
// the MakeValue(vector) return a tuple
|
||||
Attr send_rank_ids = {SEND_RANK_IDS, MakeListValue(send_rank_ids_)};
|
||||
Attr send_lens = {SEND_LENS, MakeListValue(send_lens_)};
|
||||
Attr recv_rank_ids = {RECV_RANK_IDS, MakeListValue(recv_rank_ids_)};
|
||||
Attr recv_lens = {RECV_LENS, MakeListValue(recv_lens_)};
|
||||
Attr data_format = {DATA_FORMAT, MakeValue(NCHW)};
|
||||
Attr group = {GROUP, MakeValue(all_to_all_group_)};
|
||||
|
||||
OperatorAttrs attrs = {send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format, group};
|
||||
return attrs;
|
||||
}
|
||||
|
||||
OperatorAttrs ResizeBilinearInfo::CreateParallelResizeBilinearAttrs() {
|
||||
Attr ori_image_size = {ORI_IMAGE_SIZE, MakeValue(origin_image_size_)};
|
||||
Attr split_size = {SPLIT_SIZE, MakeValue(slice_size_)};
|
||||
Attr src_start_w = {SRC_START_W, MakeValue(src_start_w_)};
|
||||
Attr dst_start_w = {DST_START_W, MakeValue(dst_start_w_)};
|
||||
Attr align_corners = {ALIGN_CORNERS, MakeValue(align_corners_)};
|
||||
|
||||
OperatorAttrs attrs = {ori_image_size, split_size, src_start_w, dst_start_w, align_corners};
|
||||
return attrs;
|
||||
}
|
||||
|
||||
void ResizeBilinearInfo::InferReplaceGraph(const CNodePtr &cnode) {
|
||||
auto graph = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
||||
GenerateGraph gen_g = GenerateGraph(attrs_);
|
||||
if (gen_g.Init(cnode) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Init generator graph failed";
|
||||
}
|
||||
|
||||
auto neighbor_exchange_v2_attrs = CreateNeighborExchangeV2Attrs();
|
||||
auto neighbor_exchange_v2_node =
|
||||
gen_g.PushBack({gen_g.NewOpInst(NEIGHBOREXCHANGEV2, neighbor_exchange_v2_attrs), gen_g.virtual_input_node()});
|
||||
|
||||
auto size = CreateValueTupleAnfNodePtr(size_);
|
||||
auto parallel_resize_bilinear_attrs = CreateParallelResizeBilinearAttrs();
|
||||
auto parallel_resize_bilinear_node = gen_g.PushBack(
|
||||
{gen_g.NewOpInst(PARALLEL_RESIZE_BILINEAR, parallel_resize_bilinear_attrs), neighbor_exchange_v2_node, size});
|
||||
|
||||
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(neighbor_exchange_v2_node, 1)};
|
||||
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
|
||||
std::make_pair(input_nodes, parallel_resize_bilinear_node));
|
||||
}
|
||||
|
||||
ReplaceGraphPtr ResizeBilinearInfo::replace_graph(const CNodePtr &cnode) {
|
||||
if (!need_exchange_overlap_) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (InferRankBias() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": infer rank bias failed";
|
||||
}
|
||||
|
||||
InferScale();
|
||||
|
||||
InferOverlapSize();
|
||||
|
||||
InferNewOperatorAttrs();
|
||||
|
||||
InferReplaceGraph(cnode);
|
||||
|
||||
return replace_graph_;
|
||||
}
|
||||
|
||||
Status ResizeNearestNeighborInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
MS_EXCEPTION_IF_NULL(strategy);
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
|
||||
#include "utils/hash_map.h"
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
|
@ -46,10 +47,66 @@ class ResizeBilinearInfo : public OperatorInfo {
|
|||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
void ReplaceNodeInputOrAttrs() override;
|
||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||
|
||||
std::vector<int64_t> size_;
|
||||
std::vector<int64_t> slice_size_;
|
||||
bool align_corners_ = false;
|
||||
bool need_exchange_overlap_ = false;
|
||||
|
||||
private:
|
||||
Status InferRankBias();
|
||||
void InferOverlapSize();
|
||||
void InferScale();
|
||||
void InferNewOperatorAttrs();
|
||||
void InferCommunicationAttrs();
|
||||
void InferResizeBilinearV2Attrs();
|
||||
void InferReplaceGraph(const CNodePtr &cnode);
|
||||
int64_t InferOverlapLeftSizeByRankBias(int64_t rank_bias);
|
||||
int64_t InferOverlapRightSizeByRankBias(int64_t rank_bias);
|
||||
|
||||
OperatorAttrs CreateNeighborExchangeV2Attrs();
|
||||
OperatorAttrs CreateParallelResizeBilinearAttrs();
|
||||
|
||||
// rank_bias_ is the position of the current rank in the w dimension of the dev_matrix(have not split h dimension)
|
||||
int64_t rank_bias_ = 0;
|
||||
|
||||
int64_t left_rank_bias_ = -1;
|
||||
int64_t right_rank_bias_ = -1;
|
||||
int64_t left_rank_id_ = -1;
|
||||
int64_t right_rank_id_ = -1;
|
||||
int64_t overlap_left_size_ = 0;
|
||||
int64_t overlap_right_size_ = 0;
|
||||
int64_t left_rank_overlap_right_size_ = 0;
|
||||
int64_t right_rank_overlap_left_size_ = 0;
|
||||
|
||||
int64_t origin_in_w_shape_ = 1;
|
||||
int64_t origin_out_w_shape_ = 1;
|
||||
int64_t w_dimension_shard_num_ = 1;
|
||||
|
||||
// the send_rank_ids_ or recv_rank_ids is an array with 8 rank ids, the order of index in the array is organized in
|
||||
// the following format(the 'R' is current rank)
|
||||
// +++++++++++++
|
||||
// | 7 | 0 | 1 |
|
||||
// +++++++++++++
|
||||
// | 6 | R | 2 |
|
||||
// +++++++++++++
|
||||
// | 5 | 4 | 3 |
|
||||
// +++++++++++++
|
||||
std::vector<int64_t> send_rank_ids_; // 8 rank ids
|
||||
std::vector<int64_t> recv_rank_ids_; // 8 rank ids
|
||||
|
||||
// the send_lens_ or recv_lens_ is an array with 4 lens, the order in the array represents top, bottom, left, right
|
||||
std::vector<int64_t> send_lens_; // [top, bottom, left, right]
|
||||
std::vector<int64_t> recv_lens_; // [top, bottom, left, right]
|
||||
|
||||
std::string all_to_all_group_;
|
||||
|
||||
std::vector<int64_t> origin_image_size_; // [H, W]
|
||||
int64_t src_start_w_ = 0;
|
||||
int64_t dst_start_w_ = 0;
|
||||
|
||||
double w_scale_ = 1.0; // the scale in w dimension, now only support to split w dimension
|
||||
};
|
||||
|
||||
class ResizeNearestNeighborInfo : public ResizeBilinearInfo {
|
||||
|
|
|
@ -184,12 +184,6 @@ ReplaceGraphPtr SliceInfo::replace_graph(const CNodePtr &cnode) {
|
|||
return replace_graph_;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateValueTupleAndNodePtr(const std::vector<int64_t> &value_tuple) {
|
||||
auto value_ptr = MakeValue(value_tuple)->cast<ValueTuplePtr>();
|
||||
auto value_node = NewValueNode(value_ptr);
|
||||
return value_node->cast<AnfNodePtr>();
|
||||
}
|
||||
|
||||
Status SliceInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||
GenerateGraph gen_g = GenerateGraph(attrs_);
|
||||
if (gen_g.Init(cnode) != SUCCESS) {
|
||||
|
@ -207,8 +201,8 @@ Status SliceInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|||
sliced_size_shape_int.push_back(input_slice_shape[i]);
|
||||
}
|
||||
}
|
||||
auto new_begin = CreateValueTupleAndNodePtr(begin_);
|
||||
auto new_size = CreateValueTupleAndNodePtr(sliced_size_shape_int);
|
||||
auto new_begin = CreateValueTupleAnfNodePtr(begin_);
|
||||
auto new_size = CreateValueTupleAnfNodePtr(sliced_size_shape_int);
|
||||
|
||||
auto slice = gen_g.PushBack({gen_g.NewOpInst(SLICE), gen_g.virtual_input_node(), new_begin, new_size});
|
||||
|
||||
|
|
Loading…
Reference in New Issue