forked from mindspore-Ecosystem/mindspore
!27821 neighborexchangev2 fix bug
Merge pull request !27821 from TuDouNi/neighborexchangev2_fix_bug
This commit is contained in:
commit
ac69bac952
|
@ -164,7 +164,7 @@ std::vector<std::vector<size_t>> CalAllToAllvOutputShape(const std::vector<size_
|
|||
std::vector<AnfNodePtr> CreateAllToAllvInput(const std::vector<std::vector<AnfNodePtr>> &split_outputs,
|
||||
const std::vector<int64_t> &send_rank_ids) {
|
||||
std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllVOpName))};
|
||||
std::vector<size_t> split_idx = {0, 2, 1, 3, 0, 3, 1, 2};
|
||||
std::vector<size_t> split_idx = {0, 5, 3, 7, 1, 6, 2, 4};
|
||||
std::vector<bool> is_begin = {true, false, false, false, false, true, true, true};
|
||||
for (size_t idx = 0; idx < send_rank_ids.size(); ++idx) {
|
||||
if (send_rank_ids[idx] != kInvalidId) {
|
||||
|
@ -374,11 +374,6 @@ std::vector<CNodePtr> NeighborExchangeV2UnifyMindIR::CreateSplitNodes(const Func
|
|||
|
||||
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
|
||||
|
||||
bool is_top = IsTop(send_rank_ids);
|
||||
bool is_bottom = IsBottom(send_rank_ids);
|
||||
bool is_left = (send_rank_ids[kRankIdSix] != kInvalidId);
|
||||
bool is_right = (send_rank_ids[kRankIdTwo] != kInvalidId);
|
||||
|
||||
auto dtype = AnfAlgo::GetOutputInferDataType(neighbor_exchange_v2_input, 0);
|
||||
auto shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0);
|
||||
if (SizeToLong(shape.size()) != kShapeSize) { // only support NCHW now
|
||||
|
@ -386,91 +381,68 @@ std::vector<CNodePtr> NeighborExchangeV2UnifyMindIR::CreateSplitNodes(const Func
|
|||
<< trace::DumpSourceLines(neighbor_exchange_v2);
|
||||
}
|
||||
|
||||
// splitv for top & bottom
|
||||
int64_t num_split_h = 0;
|
||||
CNodePtr split_v_top_bottom = nullptr;
|
||||
if (is_top || is_bottom) {
|
||||
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
neighbor_exchange_v2_input};
|
||||
// splitv for 0, 4, 6, 2
|
||||
bool is_top = IsTop(send_rank_ids);
|
||||
bool is_bottom = IsBottom(send_rank_ids);
|
||||
bool is_left = (send_rank_ids[kRankIdSix] != kInvalidId);
|
||||
bool is_right = (send_rank_ids[kRankIdTwo] != kInvalidId);
|
||||
std::vector<bool> splitvs_is_first = {true, false, true, false}; // is left or top
|
||||
std::vector<bool> splitvs_is_exist = {is_top, is_bottom, is_left, is_right};
|
||||
std::vector<size_t> splitvs_dim = {kHDim, kHDim, kWDim, kWDim};
|
||||
for (size_t i = 0; i < splitvs_is_first.size(); ++i) {
|
||||
int64_t num_split = 0;
|
||||
CNodePtr split_v = nullptr;
|
||||
if (splitvs_is_exist[i]) {
|
||||
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
neighbor_exchange_v2_input};
|
||||
|
||||
split_v_top_bottom =
|
||||
CreateSplitNode(graph, split_input, shape, is_top, is_bottom, kHDim, send_lens, dtype, &num_split_h, *this);
|
||||
split_v = CreateSplitNode(graph, split_input, shape, splitvs_is_first[i], !splitvs_is_first[i], splitvs_dim[i],
|
||||
send_lens, dtype, &num_split, *this);
|
||||
}
|
||||
split_nodes.emplace_back(split_v);
|
||||
split_num->push_back(num_split);
|
||||
}
|
||||
split_nodes.emplace_back(split_v_top_bottom);
|
||||
split_num->push_back(num_split_h);
|
||||
|
||||
// splitv for left & right
|
||||
int64_t num_split_w = 0;
|
||||
CNodePtr split_v_left_right = nullptr;
|
||||
if (is_left || is_right) {
|
||||
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
neighbor_exchange_v2_input};
|
||||
split_v_left_right =
|
||||
CreateSplitNode(graph, split_input, shape, is_left, is_right, kWDim, send_lens, dtype, &num_split_w, *this);
|
||||
}
|
||||
split_nodes.emplace_back(split_v_left_right);
|
||||
split_num->push_back(num_split_w);
|
||||
|
||||
// splitv for corner
|
||||
if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId) ||
|
||||
(send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) {
|
||||
// top_bottom_split outputs
|
||||
std::vector<AnfNodePtr> split_outputs_top_bottom;
|
||||
CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast<size_t>((*split_num)[0]),
|
||||
&split_outputs_top_bottom);
|
||||
if (split_outputs_top_bottom.empty()) {
|
||||
// splitv for 7, 1, 5, 3
|
||||
std::vector<bool> corner_splitvs_is_first = {true, false, true, false};
|
||||
std::vector<bool> corner_splitvs_is_exist = {send_rank_ids[7] != kInvalidId, send_rank_ids[1] != kInvalidId,
|
||||
send_rank_ids[5] != kInvalidId, send_rank_ids[3] != kInvalidId};
|
||||
std::vector<bool> corner_splitvs_is_input_top = {true, true, false, false};
|
||||
std::vector<AnfNodePtr> split_outputs_top;
|
||||
std::vector<AnfNodePtr> split_outputs_bottom;
|
||||
if (split_nodes[0] != nullptr) {
|
||||
CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast<size_t>((*split_num)[0]), &split_outputs_top);
|
||||
if (split_outputs_top.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The node " << split_nodes[0]->DebugString() << " should have at least one output, but got 0"
|
||||
<< trace::DumpSourceLines(split_nodes[0]);
|
||||
}
|
||||
|
||||
// for top corner
|
||||
if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId)) {
|
||||
auto shape_tmp(shape);
|
||||
shape_tmp[kHDim] = send_lens[0];
|
||||
bool is_first = (send_rank_ids[kRankIdSeven] != kInvalidId);
|
||||
bool is_last = (send_rank_ids[kRankIdOne] != kInvalidId);
|
||||
|
||||
std::vector<AnfNodePtr> split_v_corner_top_input = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
||||
split_v_corner_top_input.insert(split_v_corner_top_input.end(), split_outputs_top_bottom.begin(),
|
||||
split_outputs_top_bottom.begin() + 1);
|
||||
int64_t num_split_top_corner = 0;
|
||||
CNodePtr split_v_corner_top = CreateSplitNode(graph, split_v_corner_top_input, shape_tmp, is_first, is_last,
|
||||
kWDim, send_lens, dtype, &num_split_top_corner, *this);
|
||||
|
||||
split_nodes.emplace_back(split_v_corner_top);
|
||||
split_num->push_back(num_split_top_corner);
|
||||
} else {
|
||||
split_nodes.emplace_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
}
|
||||
if (split_nodes[1] != nullptr) {
|
||||
CreateMultipleOutputsOfAnfNode(graph, split_nodes[1], static_cast<size_t>((*split_num)[1]), &split_outputs_bottom);
|
||||
if (split_outputs_bottom.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The node " << split_nodes[1]->DebugString() << " should have at least one output, but got 0"
|
||||
<< trace::DumpSourceLines(split_nodes[1]);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < corner_splitvs_is_first.size(); ++i) {
|
||||
int64_t num_split = 0;
|
||||
CNodePtr split_v = nullptr;
|
||||
if (corner_splitvs_is_exist[i]) {
|
||||
auto shape_tmp = shape;
|
||||
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
||||
if (corner_splitvs_is_input_top[i]) {
|
||||
split_input.insert(split_input.end(), split_outputs_top.begin(), split_outputs_top.begin() + 1);
|
||||
shape_tmp[kHDim] = send_lens[0];
|
||||
} else {
|
||||
split_input.insert(split_input.end(), split_outputs_bottom.end() - 1, split_outputs_bottom.end());
|
||||
shape_tmp[kHDim] = send_lens[1];
|
||||
}
|
||||
|
||||
// for bottom corner
|
||||
if ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) {
|
||||
auto shape_tmp(shape);
|
||||
shape_tmp[kHDim] = send_lens[1];
|
||||
bool is_first = (send_rank_ids[kRankIdFive] != kInvalidId);
|
||||
bool is_last = (send_rank_ids[kRankIdThree] != kInvalidId);
|
||||
|
||||
std::vector<AnfNodePtr> split_v_corner_bottom_input = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
||||
split_v_corner_bottom_input.insert(split_v_corner_bottom_input.end(), split_outputs_top_bottom.end() - 1,
|
||||
split_outputs_top_bottom.end());
|
||||
|
||||
int64_t num_split_bottom_corner = 0;
|
||||
CNodePtr split_v_corner_bottom = CreateSplitNode(graph, split_v_corner_bottom_input, shape_tmp, is_first, is_last,
|
||||
kWDim, send_lens, dtype, &num_split_bottom_corner, *this);
|
||||
split_nodes.emplace_back(split_v_corner_bottom);
|
||||
split_num->push_back(num_split_bottom_corner);
|
||||
} else {
|
||||
split_nodes.emplace_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
split_v = CreateSplitNode(graph, split_input, shape_tmp, corner_splitvs_is_first[i], !corner_splitvs_is_first[i],
|
||||
kWDim, send_lens, dtype, &num_split, *this);
|
||||
}
|
||||
} else {
|
||||
split_nodes.emplace_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
split_nodes.emplace_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
split_nodes.emplace_back(split_v);
|
||||
split_num->push_back(num_split);
|
||||
}
|
||||
|
||||
return split_nodes;
|
||||
|
|
|
@ -150,11 +150,21 @@ void Check(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &in
|
|||
if (input_shape.size() != kInputSize) {
|
||||
MS_EXCEPTION(ValueError) << "Input size is not 4, only support NCHW now.";
|
||||
}
|
||||
if (send_lens[kIdx0] + send_lens[kIdx1] > input_shape[kHDim]) {
|
||||
MS_EXCEPTION(ValueError) << "send_lens in H dim is larger than input size.";
|
||||
if (send_lens[kIdx0] > input_shape[kHDim]) {
|
||||
MS_EXCEPTION(ValueError) << "Attr send_lens[0]: " << send_lens[kIdx0]
|
||||
<< " is larger than input size in H dim: " << input_shape[kHDim] << ".";
|
||||
}
|
||||
if (send_lens[kIdx2] + send_lens[kIdx3] > input_shape[kWDim]) {
|
||||
MS_EXCEPTION(ValueError) << "send_lens in W dim is larger than input size.";
|
||||
if (send_lens[kIdx1] > input_shape[kHDim]) {
|
||||
MS_EXCEPTION(ValueError) << "Attr send_lens[1]: " << send_lens[kIdx1]
|
||||
<< " is larger than input size in H dim: " << input_shape[kHDim] << ".";
|
||||
}
|
||||
if (send_lens[kIdx2] > input_shape[kWDim]) {
|
||||
MS_EXCEPTION(ValueError) << "Attr send_lens[2]: " << send_lens[kIdx2]
|
||||
<< " is larger than input size in W dim: " << input_shape[kWDim] << ".";
|
||||
}
|
||||
if (send_lens[kIdx3] > input_shape[kWDim]) {
|
||||
MS_EXCEPTION(ValueError) << "Attr send_lens[3]: " << send_lens[kIdx3]
|
||||
<< " is larger than input size in W dim: " << input_shape[kWDim] << ".";
|
||||
}
|
||||
|
||||
// check group
|
||||
|
|
Loading…
Reference in New Issue