!27821 neighborexchangev2 fix bug

Merge pull request !27821 from TuDouNi/neighborexchangev2_fix_bug
This commit is contained in:
i-robot 2021-12-20 07:43:41 +00:00 committed by Gitee
commit ac69bac952
2 changed files with 68 additions and 86 deletions

View File

@ -164,7 +164,7 @@ std::vector<std::vector<size_t>> CalAllToAllvOutputShape(const std::vector<size_
std::vector<AnfNodePtr> CreateAllToAllvInput(const std::vector<std::vector<AnfNodePtr>> &split_outputs,
const std::vector<int64_t> &send_rank_ids) {
std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllVOpName))};
std::vector<size_t> split_idx = {0, 2, 1, 3, 0, 3, 1, 2};
std::vector<size_t> split_idx = {0, 5, 3, 7, 1, 6, 2, 4};
std::vector<bool> is_begin = {true, false, false, false, false, true, true, true};
for (size_t idx = 0; idx < send_rank_ids.size(); ++idx) {
if (send_rank_ids[idx] != kInvalidId) {
@ -374,11 +374,6 @@ std::vector<CNodePtr> NeighborExchangeV2UnifyMindIR::CreateSplitNodes(const Func
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
bool is_top = IsTop(send_rank_ids);
bool is_bottom = IsBottom(send_rank_ids);
bool is_left = (send_rank_ids[kRankIdSix] != kInvalidId);
bool is_right = (send_rank_ids[kRankIdTwo] != kInvalidId);
auto dtype = AnfAlgo::GetOutputInferDataType(neighbor_exchange_v2_input, 0);
auto shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0);
if (SizeToLong(shape.size()) != kShapeSize) { // only support NCHW now
@ -386,91 +381,68 @@ std::vector<CNodePtr> NeighborExchangeV2UnifyMindIR::CreateSplitNodes(const Func
<< trace::DumpSourceLines(neighbor_exchange_v2);
}
// splitv for top & bottom
int64_t num_split_h = 0;
CNodePtr split_v_top_bottom = nullptr;
if (is_top || is_bottom) {
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
neighbor_exchange_v2_input};
// splitv for 0, 4, 6, 2
bool is_top = IsTop(send_rank_ids);
bool is_bottom = IsBottom(send_rank_ids);
bool is_left = (send_rank_ids[kRankIdSix] != kInvalidId);
bool is_right = (send_rank_ids[kRankIdTwo] != kInvalidId);
std::vector<bool> splitvs_is_first = {true, false, true, false}; // is left or top
std::vector<bool> splitvs_is_exist = {is_top, is_bottom, is_left, is_right};
std::vector<size_t> splitvs_dim = {kHDim, kHDim, kWDim, kWDim};
for (size_t i = 0; i < splitvs_is_first.size(); ++i) {
int64_t num_split = 0;
CNodePtr split_v = nullptr;
if (splitvs_is_exist[i]) {
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
neighbor_exchange_v2_input};
split_v_top_bottom =
CreateSplitNode(graph, split_input, shape, is_top, is_bottom, kHDim, send_lens, dtype, &num_split_h, *this);
split_v = CreateSplitNode(graph, split_input, shape, splitvs_is_first[i], !splitvs_is_first[i], splitvs_dim[i],
send_lens, dtype, &num_split, *this);
}
split_nodes.emplace_back(split_v);
split_num->push_back(num_split);
}
split_nodes.emplace_back(split_v_top_bottom);
split_num->push_back(num_split_h);
// splitv for left & right
int64_t num_split_w = 0;
CNodePtr split_v_left_right = nullptr;
if (is_left || is_right) {
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
neighbor_exchange_v2_input};
split_v_left_right =
CreateSplitNode(graph, split_input, shape, is_left, is_right, kWDim, send_lens, dtype, &num_split_w, *this);
}
split_nodes.emplace_back(split_v_left_right);
split_num->push_back(num_split_w);
// splitv for corner
if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId) ||
(send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) {
// top_bottom_split outputs
std::vector<AnfNodePtr> split_outputs_top_bottom;
CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast<size_t>((*split_num)[0]),
&split_outputs_top_bottom);
if (split_outputs_top_bottom.empty()) {
// splitv for 7, 1, 5, 3
std::vector<bool> corner_splitvs_is_first = {true, false, true, false};
std::vector<bool> corner_splitvs_is_exist = {send_rank_ids[7] != kInvalidId, send_rank_ids[1] != kInvalidId,
send_rank_ids[5] != kInvalidId, send_rank_ids[3] != kInvalidId};
std::vector<bool> corner_splitvs_is_input_top = {true, true, false, false};
std::vector<AnfNodePtr> split_outputs_top;
std::vector<AnfNodePtr> split_outputs_bottom;
if (split_nodes[0] != nullptr) {
CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast<size_t>((*split_num)[0]), &split_outputs_top);
if (split_outputs_top.empty()) {
MS_LOG(EXCEPTION) << "The node " << split_nodes[0]->DebugString() << " should have at least one output, but got 0"
<< trace::DumpSourceLines(split_nodes[0]);
}
// for top corner
if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId)) {
auto shape_tmp(shape);
shape_tmp[kHDim] = send_lens[0];
bool is_first = (send_rank_ids[kRankIdSeven] != kInvalidId);
bool is_last = (send_rank_ids[kRankIdOne] != kInvalidId);
std::vector<AnfNodePtr> split_v_corner_top_input = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
split_v_corner_top_input.insert(split_v_corner_top_input.end(), split_outputs_top_bottom.begin(),
split_outputs_top_bottom.begin() + 1);
int64_t num_split_top_corner = 0;
CNodePtr split_v_corner_top = CreateSplitNode(graph, split_v_corner_top_input, shape_tmp, is_first, is_last,
kWDim, send_lens, dtype, &num_split_top_corner, *this);
split_nodes.emplace_back(split_v_corner_top);
split_num->push_back(num_split_top_corner);
} else {
split_nodes.emplace_back(nullptr);
split_num->push_back(0);
}
if (split_nodes[1] != nullptr) {
CreateMultipleOutputsOfAnfNode(graph, split_nodes[1], static_cast<size_t>((*split_num)[1]), &split_outputs_bottom);
if (split_outputs_bottom.empty()) {
MS_LOG(EXCEPTION) << "The node " << split_nodes[1]->DebugString() << " should have at least one output, but got 0"
<< trace::DumpSourceLines(split_nodes[1]);
}
}
for (size_t i = 0; i < corner_splitvs_is_first.size(); ++i) {
int64_t num_split = 0;
CNodePtr split_v = nullptr;
if (corner_splitvs_is_exist[i]) {
auto shape_tmp = shape;
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
if (corner_splitvs_is_input_top[i]) {
split_input.insert(split_input.end(), split_outputs_top.begin(), split_outputs_top.begin() + 1);
shape_tmp[kHDim] = send_lens[0];
} else {
split_input.insert(split_input.end(), split_outputs_bottom.end() - 1, split_outputs_bottom.end());
shape_tmp[kHDim] = send_lens[1];
}
// for bottom corner
if ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) {
auto shape_tmp(shape);
shape_tmp[kHDim] = send_lens[1];
bool is_first = (send_rank_ids[kRankIdFive] != kInvalidId);
bool is_last = (send_rank_ids[kRankIdThree] != kInvalidId);
std::vector<AnfNodePtr> split_v_corner_bottom_input = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
split_v_corner_bottom_input.insert(split_v_corner_bottom_input.end(), split_outputs_top_bottom.end() - 1,
split_outputs_top_bottom.end());
int64_t num_split_bottom_corner = 0;
CNodePtr split_v_corner_bottom = CreateSplitNode(graph, split_v_corner_bottom_input, shape_tmp, is_first, is_last,
kWDim, send_lens, dtype, &num_split_bottom_corner, *this);
split_nodes.emplace_back(split_v_corner_bottom);
split_num->push_back(num_split_bottom_corner);
} else {
split_nodes.emplace_back(nullptr);
split_num->push_back(0);
split_v = CreateSplitNode(graph, split_input, shape_tmp, corner_splitvs_is_first[i], !corner_splitvs_is_first[i],
kWDim, send_lens, dtype, &num_split, *this);
}
} else {
split_nodes.emplace_back(nullptr);
split_num->push_back(0);
split_nodes.emplace_back(nullptr);
split_num->push_back(0);
split_nodes.emplace_back(split_v);
split_num->push_back(num_split);
}
return split_nodes;

View File

@ -150,11 +150,21 @@ void Check(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &in
if (input_shape.size() != kInputSize) {
MS_EXCEPTION(ValueError) << "Input size is not 4, only support NCHW now.";
}
if (send_lens[kIdx0] + send_lens[kIdx1] > input_shape[kHDim]) {
MS_EXCEPTION(ValueError) << "send_lens in H dim is larger than input size.";
if (send_lens[kIdx0] > input_shape[kHDim]) {
MS_EXCEPTION(ValueError) << "Attr send_lens[0]: " << send_lens[kIdx0]
<< " is larger than input size in H dim: " << input_shape[kHDim] << ".";
}
if (send_lens[kIdx2] + send_lens[kIdx3] > input_shape[kWDim]) {
MS_EXCEPTION(ValueError) << "send_lens in W dim is larger than input size.";
if (send_lens[kIdx1] > input_shape[kHDim]) {
MS_EXCEPTION(ValueError) << "Attr send_lens[1]: " << send_lens[kIdx1]
<< " is larger than input size in H dim: " << input_shape[kHDim] << ".";
}
if (send_lens[kIdx2] > input_shape[kWDim]) {
MS_EXCEPTION(ValueError) << "Attr send_lens[2]: " << send_lens[kIdx2]
<< " is larger than input size in W dim: " << input_shape[kWDim] << ".";
}
if (send_lens[kIdx3] > input_shape[kWDim]) {
MS_EXCEPTION(ValueError) << "Attr send_lens[3]: " << send_lens[kIdx3]
<< " is larger than input size in W dim: " << input_shape[kWDim] << ".";
}
// check group