diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.cc index babf9c3de8b..7934a3af53b 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.cc @@ -164,7 +164,7 @@ std::vector> CalAllToAllvOutputShape(const std::vector CreateAllToAllvInput(const std::vector> &split_outputs, const std::vector &send_rank_ids) { std::vector all_to_all_v_input = {NewValueNode(std::make_shared(kAllToAllVOpName))}; - std::vector split_idx = {0, 2, 1, 3, 0, 3, 1, 2}; + std::vector split_idx = {0, 5, 3, 7, 1, 6, 2, 4}; std::vector is_begin = {true, false, false, false, false, true, true, true}; for (size_t idx = 0; idx < send_rank_ids.size(); ++idx) { if (send_rank_ids[idx] != kInvalidId) { @@ -374,11 +374,6 @@ std::vector NeighborExchangeV2UnifyMindIR::CreateSplitNodes(const Func auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx); - bool is_top = IsTop(send_rank_ids); - bool is_bottom = IsBottom(send_rank_ids); - bool is_left = (send_rank_ids[kRankIdSix] != kInvalidId); - bool is_right = (send_rank_ids[kRankIdTwo] != kInvalidId); - auto dtype = AnfAlgo::GetOutputInferDataType(neighbor_exchange_v2_input, 0); auto shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0); if (SizeToLong(shape.size()) != kShapeSize) { // only support NCHW now @@ -386,91 +381,68 @@ std::vector NeighborExchangeV2UnifyMindIR::CreateSplitNodes(const Func << trace::DumpSourceLines(neighbor_exchange_v2); } - // splitv for top & bottom - int64_t num_split_h = 0; - CNodePtr split_v_top_bottom = nullptr; - if (is_top || is_bottom) { - std::vector split_input = {NewValueNode(std::make_shared(prim::kPrimSplitV->name())), - neighbor_exchange_v2_input}; + // splitv for 0, 4, 6, 2 + bool is_top = IsTop(send_rank_ids); + bool is_bottom = IsBottom(send_rank_ids); + bool is_left = (send_rank_ids[kRankIdSix] != kInvalidId); + bool is_right = (send_rank_ids[kRankIdTwo] != kInvalidId); + std::vector splitvs_is_first = {true, false, true, false}; // is left or top + std::vector splitvs_is_exist = {is_top, is_bottom, is_left, is_right}; + std::vector splitvs_dim = {kHDim, kHDim, kWDim, kWDim}; + for (size_t i = 0; i < splitvs_is_first.size(); ++i) { + int64_t num_split = 0; + CNodePtr split_v = nullptr; + if (splitvs_is_exist[i]) { + std::vector split_input = {NewValueNode(std::make_shared(prim::kPrimSplitV->name())), + neighbor_exchange_v2_input}; - split_v_top_bottom = - CreateSplitNode(graph, split_input, shape, is_top, is_bottom, kHDim, send_lens, dtype, &num_split_h, *this); + split_v = CreateSplitNode(graph, split_input, shape, splitvs_is_first[i], !splitvs_is_first[i], splitvs_dim[i], + send_lens, dtype, &num_split, *this); + } + split_nodes.emplace_back(split_v); + split_num->push_back(num_split); } - split_nodes.emplace_back(split_v_top_bottom); - split_num->push_back(num_split_h); - // splitv for left & right - int64_t num_split_w = 0; - CNodePtr split_v_left_right = nullptr; - if (is_left || is_right) { - std::vector split_input = {NewValueNode(std::make_shared(prim::kPrimSplitV->name())), - neighbor_exchange_v2_input}; - split_v_left_right = - CreateSplitNode(graph, split_input, shape, is_left, is_right, kWDim, send_lens, dtype, &num_split_w, *this); - } - split_nodes.emplace_back(split_v_left_right); - split_num->push_back(num_split_w); - - // splitv for corner - if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId) || - (send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) { - // top_bottom_split outputs - std::vector split_outputs_top_bottom; - CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast((*split_num)[0]), - &split_outputs_top_bottom); - if (split_outputs_top_bottom.empty()) { + // splitv for 7, 1, 5, 3 + std::vector corner_splitvs_is_first = {true, false, true, false}; + std::vector corner_splitvs_is_exist = {send_rank_ids[7] != kInvalidId, send_rank_ids[1] != kInvalidId, + send_rank_ids[5] != kInvalidId, send_rank_ids[3] != kInvalidId}; + std::vector corner_splitvs_is_input_top = {true, true, false, false}; + std::vector split_outputs_top; + std::vector split_outputs_bottom; + if (split_nodes[0] != nullptr) { + CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast((*split_num)[0]), &split_outputs_top); + if (split_outputs_top.empty()) { MS_LOG(EXCEPTION) << "The node " << split_nodes[0]->DebugString() << " should have at least one output, but got 0" << trace::DumpSourceLines(split_nodes[0]); } - - // for top corner - if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId)) { - auto shape_tmp(shape); - shape_tmp[kHDim] = send_lens[0]; - bool is_first = (send_rank_ids[kRankIdSeven] != kInvalidId); - bool is_last = (send_rank_ids[kRankIdOne] != kInvalidId); - - std::vector split_v_corner_top_input = { - NewValueNode(std::make_shared(prim::kPrimSplitV->name()))}; - split_v_corner_top_input.insert(split_v_corner_top_input.end(), split_outputs_top_bottom.begin(), - split_outputs_top_bottom.begin() + 1); - int64_t num_split_top_corner = 0; - CNodePtr split_v_corner_top = CreateSplitNode(graph, split_v_corner_top_input, shape_tmp, is_first, is_last, - kWDim, send_lens, dtype, &num_split_top_corner, *this); - - split_nodes.emplace_back(split_v_corner_top); - split_num->push_back(num_split_top_corner); - } else { - split_nodes.emplace_back(nullptr); - split_num->push_back(0); + } + if (split_nodes[1] != nullptr) { + CreateMultipleOutputsOfAnfNode(graph, split_nodes[1], static_cast((*split_num)[1]), &split_outputs_bottom); + if (split_outputs_bottom.empty()) { + MS_LOG(EXCEPTION) << "The node " << split_nodes[1]->DebugString() << " should have at least one output, but got 0" + << trace::DumpSourceLines(split_nodes[1]); } + } + for (size_t i = 0; i < corner_splitvs_is_first.size(); ++i) { + int64_t num_split = 0; + CNodePtr split_v = nullptr; + if (corner_splitvs_is_exist[i]) { + auto shape_tmp = shape; + std::vector split_input = {NewValueNode(std::make_shared(prim::kPrimSplitV->name()))}; + if (corner_splitvs_is_input_top[i]) { + split_input.insert(split_input.end(), split_outputs_top.begin(), split_outputs_top.begin() + 1); + shape_tmp[kHDim] = send_lens[0]; + } else { + split_input.insert(split_input.end(), split_outputs_bottom.end() - 1, split_outputs_bottom.end()); + shape_tmp[kHDim] = send_lens[1]; + } - // for bottom corner - if ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) { - auto shape_tmp(shape); - shape_tmp[kHDim] = send_lens[1]; - bool is_first = (send_rank_ids[kRankIdFive] != kInvalidId); - bool is_last = (send_rank_ids[kRankIdThree] != kInvalidId); - - std::vector split_v_corner_bottom_input = { - NewValueNode(std::make_shared(prim::kPrimSplitV->name()))}; - split_v_corner_bottom_input.insert(split_v_corner_bottom_input.end(), split_outputs_top_bottom.end() - 1, - split_outputs_top_bottom.end()); - - int64_t num_split_bottom_corner = 0; - CNodePtr split_v_corner_bottom = CreateSplitNode(graph, split_v_corner_bottom_input, shape_tmp, is_first, is_last, - kWDim, send_lens, dtype, &num_split_bottom_corner, *this); - split_nodes.emplace_back(split_v_corner_bottom); - split_num->push_back(num_split_bottom_corner); - } else { - split_nodes.emplace_back(nullptr); - split_num->push_back(0); + split_v = CreateSplitNode(graph, split_input, shape_tmp, corner_splitvs_is_first[i], !corner_splitvs_is_first[i], + kWDim, send_lens, dtype, &num_split, *this); } - } else { - split_nodes.emplace_back(nullptr); - split_num->push_back(0); - split_nodes.emplace_back(nullptr); - split_num->push_back(0); + split_nodes.emplace_back(split_v); + split_num->push_back(num_split); } return split_nodes; diff --git a/mindspore/core/ops/neighborexchangev2.cc b/mindspore/core/ops/neighborexchangev2.cc index 473cbdc8889..6cfacac0e67 100644 --- a/mindspore/core/ops/neighborexchangev2.cc +++ b/mindspore/core/ops/neighborexchangev2.cc @@ -150,11 +150,21 @@ void Check(const PrimitivePtr &primitive, const std::vector &in if (input_shape.size() != kInputSize) { MS_EXCEPTION(ValueError) << "Input size is not 4, only support NCHW now."; } - if (send_lens[kIdx0] + send_lens[kIdx1] > input_shape[kHDim]) { - MS_EXCEPTION(ValueError) << "send_lens in H dim is larger than input size."; + if (send_lens[kIdx0] > input_shape[kHDim]) { + MS_EXCEPTION(ValueError) << "Attr send_lens[0]: " << send_lens[kIdx0] + << " is larger than input size in H dim: " << input_shape[kHDim] << "."; } - if (send_lens[kIdx2] + send_lens[kIdx3] > input_shape[kWDim]) { - MS_EXCEPTION(ValueError) << "send_lens in W dim is larger than input size."; + if (send_lens[kIdx1] > input_shape[kHDim]) { + MS_EXCEPTION(ValueError) << "Attr send_lens[1]: " << send_lens[kIdx1] + << " is larger than input size in H dim: " << input_shape[kHDim] << "."; + } + if (send_lens[kIdx2] > input_shape[kWDim]) { + MS_EXCEPTION(ValueError) << "Attr send_lens[2]: " << send_lens[kIdx2] + << " is larger than input size in W dim: " << input_shape[kWDim] << "."; + } + if (send_lens[kIdx3] > input_shape[kWDim]) { + MS_EXCEPTION(ValueError) << "Attr send_lens[3]: " << send_lens[kIdx3] + << " is larger than input size in W dim: " << input_shape[kWDim] << "."; } // check group