diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index b060eda0ff5..61988d5f7ed 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -144,6 +144,7 @@ #include "backend/optimizer/ascend/mindir/avg_pool_grad_unify_mindir.h" #include "backend/optimizer/ascend/mindir/bn_grad_unify_mindir.h" #include "backend/optimizer/ascend/mindir/all_to_all_unify_mindir.h" +#include "backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.h" #include "backend/optimizer/pass/adjust_depend_for_parallel_optimizer_recompute_all_gather.h" #include "backend/kernel_compiler/tbe/tbe_kernel_compile.h" #include "utils/ms_context.h" @@ -579,6 +580,8 @@ void AscendUnifyMindIR(const std::shared_ptr &graph) { unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); optimizer->AddPassManager(unify_mindir_pm); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.cc new file mode 100644 index 00000000000..15a86f0dc82 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.cc @@ -0,0 +1,900 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.h" +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/hccl_adapter/hccl_adapter.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kCNodePrimitiveIdx = 0; +constexpr size_t kNeighborExchangeV2InputIdx = 1; +constexpr size_t kLenTopIdx = 0; +constexpr size_t kLenBottomIdx = 1; +constexpr size_t kLenLeftIdx = 2; +constexpr size_t kLenRightIdx = 3; +constexpr size_t kHDim = 2; // dim of H in NCHW +constexpr size_t kWDim = 3; // dim of W in NCHW +constexpr int64_t kShapeSize = 4; +constexpr int64_t kRankIdZero = 0; +constexpr int64_t kRankIdOne = 1; +constexpr int64_t kRankIdTwo = 2; +constexpr int64_t kRankIdThree = 3; +constexpr int64_t kRankIdFour = 4; +constexpr int64_t kRankIdFive = 5; +constexpr int64_t kRankIdSix = 6; +constexpr int64_t kRankIdSeven = 7; +constexpr size_t kSizeFour = 4; +constexpr int64_t kInvalidId = -1; + +// cal split attrs size_splits, shapes and num_split +int64_t CalSplitAttrs(const std::vector &base_shape, const bool is_first, const bool is_last, + const int64_t split_dim, const std::vector &send_lens, std::vector *size_splits, + std::vector> *shapes) { + MS_EXCEPTION_IF_NULL(size_splits); + MS_EXCEPTION_IF_NULL(shapes); + if (SizeToLong(base_shape.size()) != kShapeSize) { + MS_LOG(EXCEPTION) << "Wrong base_shape size: " << base_shape.size() << ", it should be equal to 4."; + } + if (split_dim >= kShapeSize) { + MS_LOG(EXCEPTION) << "Wrong split_dim: " << split_dim << ", it should less than 4."; + } + int64_t num_split = 0; + int64_t split_middle_size = base_shape[split_dim]; + std::vector shape_tmp(base_shape); + // [top, bottom, left, right] + int64_t first_size = split_dim == kWDim ? send_lens[2] : send_lens[0]; + int64_t last_size = split_dim == kWDim ? send_lens[3] : send_lens[1]; + + if (is_first) { + // first + ++num_split; + size_splits->push_back(first_size); + split_middle_size -= first_size; + shape_tmp[split_dim] = static_cast(first_size); + shapes->push_back(shape_tmp); + } + if (is_last) { + // middle + ++num_split; + split_middle_size -= last_size; + size_splits->push_back(split_middle_size); + shape_tmp[split_dim] = static_cast(split_middle_size); + shapes->push_back(shape_tmp); + // last + ++num_split; + size_splits->push_back(last_size); + shape_tmp[split_dim] = static_cast(last_size); + shapes->push_back(shape_tmp); + } else { + ++num_split; + size_splits->push_back(split_middle_size); + shape_tmp[split_dim] = static_cast(split_middle_size); + shapes->push_back(shape_tmp); + } + return num_split; +} + +CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const std::vector &split_input, + const std::vector &base_shape, bool is_first, bool is_last, int64_t split_dim, + const std::vector &send_lens, TypeId input_dtype, int64_t *num_split) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(num_split); + if (split_input.empty()) { + MS_LOG(EXCEPTION) << "The input is empty, can not create splitv node."; + return nullptr; + } + auto split_v = graph->NewCNode(split_input); + MS_EXCEPTION_IF_NULL(split_v); + std::vector size_splits = {}; + std::vector> shapes = {}; + *num_split = CalSplitAttrs(base_shape, is_first, is_last, split_dim, send_lens, &size_splits, &shapes); + + std::vector dtypes(*num_split, input_dtype); + AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_v.get()); + AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(split_dim), split_v); + AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(*num_split), split_v); + AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue>(size_splits), split_v); + AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_v); + return split_v; +} + +std::vector> CalAllToAllvOutputShape(const std::vector &base_shape, + const std::vector &recv_lens, + const std::vector &recv_rank_ids) { + if (SizeToLong(base_shape.size()) != kShapeSize) { + MS_LOG(EXCEPTION) << "Wrong base_shape size: " << base_shape.size() << ", it should be equal to 4."; + } + std::vector> shapes = {}; + std::vector> ori_shapes = { + {base_shape[0], base_shape[1], static_cast(recv_lens[kLenTopIdx]), base_shape[kWDim]}, + {base_shape[0], base_shape[1], static_cast(recv_lens[kLenTopIdx]), + static_cast(recv_lens[kLenRightIdx])}, + {base_shape[0], base_shape[1], base_shape[kHDim], static_cast(recv_lens[kLenRightIdx])}, + {base_shape[0], base_shape[1], static_cast(recv_lens[kLenBottomIdx]), + static_cast(recv_lens[kLenRightIdx])}, + {base_shape[0], base_shape[1], static_cast(recv_lens[kLenBottomIdx]), base_shape[kWDim]}, + {base_shape[0], base_shape[1], static_cast(recv_lens[kLenBottomIdx]), + static_cast(recv_lens[kLenLeftIdx])}, + {base_shape[0], base_shape[1], base_shape[kHDim], static_cast(recv_lens[kLenLeftIdx])}, + {base_shape[0], base_shape[1], static_cast(recv_lens[kLenTopIdx]), + static_cast(recv_lens[kLenLeftIdx])}}; + + for (size_t idx = 0; idx < recv_rank_ids.size(); ++idx) { + if (recv_rank_ids[idx] != kInvalidId) { + shapes.push_back(ori_shapes[idx]); + } + } + + return shapes; +} + +// returns {top_bottom, left_right, top_corner, bottom_corner}, if no split, set it nullptr +std::vector CreateSplitNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2, + std::vector *split_num) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(neighbor_exchange_v2); + MS_EXCEPTION_IF_NULL(split_num); + std::vector send_rank_ids = + AnfAlgo::GetNodeAttr>(neighbor_exchange_v2, kAttrSendRankIds); + std::vector send_lens = AnfAlgo::GetNodeAttr>(neighbor_exchange_v2, kAttrSendLens); + + if (neighbor_exchange_v2->size() <= kNeighborExchangeV2InputIdx) { + MS_LOG(EXCEPTION) << "Invalid cnode " << neighbor_exchange_v2->DebugString() << " input size " + << neighbor_exchange_v2->size(); + } + std::vector split_nodes = {}; + + auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx); + + bool is_top = ((send_rank_ids[kRankIdZero] != kInvalidId) || (send_rank_ids[kRankIdOne] != kInvalidId) || + (send_rank_ids[kRankIdSeven] != kInvalidId)); + bool is_bottom = ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFour] != kInvalidId) || + (send_rank_ids[kRankIdFive] != kInvalidId)); + bool is_left = (send_rank_ids[kRankIdSix] != kInvalidId); + bool is_right = (send_rank_ids[kRankIdTwo] != kInvalidId); + + auto dtype = AnfAlgo::GetOutputInferDataType(neighbor_exchange_v2_input, 0); + auto shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0); + if (SizeToLong(shape.size()) != kShapeSize) { // only support NCHW now + MS_LOG(EXCEPTION) << "Invalid shape size " << shape.size() << ", only support NCHW input now!"; + } + + // splitv for top & bottom + int64_t num_split_h = 0; + CNodePtr split_v_top_bottom = nullptr; + if (is_top || is_bottom) { + std::vector split_input = {NewValueNode(std::make_shared(prim::kPrimSplitV->name())), + neighbor_exchange_v2_input}; + + split_v_top_bottom = + CreateSplitNode(graph, split_input, shape, is_top, is_bottom, kHDim, send_lens, dtype, &num_split_h); + } + split_nodes.emplace_back(split_v_top_bottom); + split_num->push_back(num_split_h); + + // splitv for left & right + int64_t num_split_w = 0; + CNodePtr split_v_left_right = nullptr; + if (is_left || is_right) { + std::vector split_input = {NewValueNode(std::make_shared(prim::kPrimSplitV->name())), + neighbor_exchange_v2_input}; + split_v_left_right = + CreateSplitNode(graph, split_input, shape, is_left, is_right, kWDim, send_lens, dtype, &num_split_w); + } + split_nodes.emplace_back(split_v_left_right); + split_num->push_back(num_split_w); + + // splitv for corner + if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId) || + (send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) { + // top_bottom_split outputs + std::vector split_outputs_top_bottom; + CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast((*split_num)[0]), + &split_outputs_top_bottom); + if (split_outputs_top_bottom.empty()) { + MS_LOG(EXCEPTION) << "The node " << split_nodes[0]->DebugString() + << " should have at least one output, but got 0."; + } + + // for top corner + if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId)) { + auto shape_tmp(shape); + shape_tmp[kHDim] = send_lens[0]; + bool is_first = (send_rank_ids[kRankIdSeven] != kInvalidId); + bool is_last = (send_rank_ids[kRankIdOne] != kInvalidId); + + std::vector split_v_corner_top_input = { + NewValueNode(std::make_shared(prim::kPrimSplitV->name()))}; + split_v_corner_top_input.insert(split_v_corner_top_input.end(), split_outputs_top_bottom.begin(), + split_outputs_top_bottom.begin() + 1); + int64_t num_split_top_corner = 0; + CNodePtr split_v_corner_top = CreateSplitNode(graph, split_v_corner_top_input, shape_tmp, is_first, is_last, + kWDim, send_lens, dtype, &num_split_top_corner); + + split_nodes.emplace_back(split_v_corner_top); + split_num->push_back(num_split_top_corner); + } else { + split_nodes.emplace_back(nullptr); + split_num->push_back(0); + } + + // for bottom corner + if ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) { + auto shape_tmp(shape); + shape_tmp[kHDim] = send_lens[1]; + bool is_first = (send_rank_ids[kRankIdFive] != kInvalidId); + bool is_last = (send_rank_ids[kRankIdThree] != kInvalidId); + + std::vector split_v_corner_bottom_input = { + NewValueNode(std::make_shared(prim::kPrimSplitV->name()))}; + split_v_corner_bottom_input.insert(split_v_corner_bottom_input.end(), split_outputs_top_bottom.end() - 1, + split_outputs_top_bottom.end()); + + int64_t num_split_bottom_corner = 0; + CNodePtr split_v_corner_bottom = CreateSplitNode(graph, split_v_corner_bottom_input, shape_tmp, is_first, is_last, + kWDim, send_lens, dtype, &num_split_bottom_corner); + split_nodes.emplace_back(split_v_corner_bottom); + split_num->push_back(num_split_bottom_corner); + } else { + split_nodes.emplace_back(nullptr); + split_num->push_back(0); + } + } else { + split_nodes.emplace_back(nullptr); + split_num->push_back(0); + split_nodes.emplace_back(nullptr); + split_num->push_back(0); + } + + return split_nodes; +} + +std::vector CreateAllToAllvInput(const std::vector> &split_outputs, + const std::vector &send_rank_ids) { + std::vector all_to_all_v_input = {NewValueNode(std::make_shared(kAllToAllVOpName))}; + std::vector split_idx = {0, 2, 1, 3, 0, 3, 1, 2}; + std::vector is_begin = {true, false, false, false, false, true, true, true}; + for (size_t idx = 0; idx < send_rank_ids.size(); ++idx) { + if (send_rank_ids[idx] != kInvalidId) { + if (is_begin[idx]) { + all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[split_idx[idx]].begin(), + split_outputs[split_idx[idx]].begin() + 1); + } else { + all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[split_idx[idx]].end() - 1, + split_outputs[split_idx[idx]].end()); + } + } + } + + return all_to_all_v_input; +} + +// get center of input for grad +AnfNodePtr GetCenter(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad, + const std::vector &split_nodes, const std::vector &split_num, + const std::vector &send_rank_ids) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad); + std::vector output; + if (split_nodes[kRankIdTwo] == nullptr) { + if (split_nodes[0] != nullptr) { + CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast(split_num[0]), &output); + if (output.size() < 2) { + MS_LOG(EXCEPTION) << "Wrong split output size: " << output.size() << ", except size >= 2."; + } + if (send_rank_ids[kRankIdZero] == kInvalidId) { + return output[0]; + } + return output[1]; + } else { + return neighbor_exchange_v2_grad->input(kNeighborExchangeV2InputIdx); + } + } else { + CreateMultipleOutputsOfAnfNode(graph, split_nodes[2], static_cast(split_num[2]), &output); + if (output.size() < 2) { + MS_LOG(EXCEPTION) << "Wrong split output size: " << output.size() << ", except size >= 2."; + } + if (send_rank_ids[kRankIdSix] == kInvalidId) { + return output[0]; + } + return output[1]; + } +} + +std::vector CreateAllToAllvInputForGrad(const std::vector &send_rank_ids, + const std::vector> &split_outputs, + const std::vector &split_nodes) { + if (send_rank_ids.size() != 8) { + MS_LOG(EXCEPTION) << "Wrong send_rank_ids size: " << send_rank_ids.size() << ", expect size: 8."; + } + if (split_outputs.size() != kSizeFour) { + MS_LOG(EXCEPTION) << "Wrong split_outputs size: " << split_outputs.size() << ", expect size: 4."; + } + if (split_nodes.size() != kSizeFour) { + MS_LOG(EXCEPTION) << "Wrong split_nodes size: " << split_nodes.size() << ", expect size: 4."; + } + std::vector all_to_all_v_input = {NewValueNode(std::make_shared(kAllToAllVOpName))}; + // only have top-bottom split + std::vector side_idx = {1, 2, 3, 5, 6, 7}; + bool no_send_side = std::all_of(side_idx.begin(), side_idx.end(), + [&send_rank_ids](int idx) { return send_rank_ids[idx] == kInvalidId; }); + if (no_send_side) { + if (send_rank_ids[kRankIdZero] != kInvalidId) { + all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[0].begin(), split_outputs[0].begin() + 1); + } + if (send_rank_ids[kRankIdFour] != kInvalidId) { + all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[0].end() - 1, split_outputs[0].end()); + } + return all_to_all_v_input; + } + // 0, 1 + if (split_nodes[1] != nullptr) { + if (send_rank_ids[kRankIdSeven] != kInvalidId) { + all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin() + 1, split_outputs[1].end()); + } else { + all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin(), split_outputs[1].end()); + } + } + // 2 + if (split_nodes[2] != nullptr && send_rank_ids[kRankIdTwo] != kInvalidId) { + all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[2].end() - 1, split_outputs[2].end()); + } + // 3, 4, 5 + if (split_nodes[3] != nullptr) { + all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[3].rbegin(), split_outputs[3].rend()); + } + // 6 + if (split_nodes[2] != nullptr && send_rank_ids[kRankIdSix] != kInvalidId) { + all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[2].begin(), split_outputs[2].begin() + 1); + } + // 7 + if (split_nodes[1] != nullptr && send_rank_ids[kRankIdSeven] != kInvalidId) { + all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin(), split_outputs[1].begin() + 1); + } + + return all_to_all_v_input; +} +// alltoallv for forward & grad +CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_or_grad, + const std::vector &split_nodes, const std::vector &split_num, + bool is_grad) { + MS_LOG(DEBUG) << "Start to create alltoallv node."; + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_or_grad); + std::vector send_rank_ids = + AnfAlgo::GetNodeAttr>(neighbor_exchange_v2_or_grad, kAttrSendRankIds); + std::vector recv_rank_ids = + AnfAlgo::GetNodeAttr>(neighbor_exchange_v2_or_grad, kAttrRecvRankIds); + std::vector recv_lens = + AnfAlgo::GetNodeAttr>(neighbor_exchange_v2_or_grad, kAttrRecvLens); + std::string group = AnfAlgo::GetNodeAttr(neighbor_exchange_v2_or_grad, kAttrGroup); + + // get split nodes output, split_outputs: [top_bottom, left_right, top_corner, bottom_corner] + std::vector> split_outputs; + for (size_t i = 0; i < split_nodes.size(); ++i) { + std::vector output; + if (split_nodes[i] != nullptr) { + CreateMultipleOutputsOfAnfNode(graph, split_nodes[i], static_cast(split_num[i]), &output); + if (output.empty()) { + MS_LOG(EXCEPTION) << "The node " << split_nodes[i]->DebugString() + << " should have at least one output, but got 0."; + } + } + split_outputs.emplace_back(output); + } + + // all_to_all_v input + std::vector all_to_all_v_input; + AnfNodePtr base_node = nullptr; + if (is_grad) { + all_to_all_v_input = CreateAllToAllvInputForGrad(send_rank_ids, split_outputs, split_nodes); + base_node = GetCenter(graph, neighbor_exchange_v2_or_grad, split_nodes, split_num, send_rank_ids); + } else { + all_to_all_v_input = CreateAllToAllvInput(split_outputs, send_rank_ids); + base_node = neighbor_exchange_v2_or_grad->input(kNeighborExchangeV2InputIdx); + } + + // output shapes and dtypes + auto base_dtype = AnfAlgo::GetOutputInferDataType(base_node, 0); + auto base_shape = AnfAlgo::GetOutputInferShape(base_node, 0); + if (SizeToLong(base_shape.size()) != kShapeSize) { + MS_LOG(EXCEPTION) << "Invalid shape size " << base_shape.size() << ", only support NCHW input now!"; + } + std::vector> shapes = CalAllToAllvOutputShape(base_shape, recv_lens, recv_rank_ids); + + // erase -1 in send_rank_ids + std::vector real_send_rank_ids(send_rank_ids.size()); + std::vector real_recv_rank_ids(recv_rank_ids.size()); + auto iter1 = std::copy_if(send_rank_ids.begin(), send_rank_ids.end(), real_send_rank_ids.begin(), + [](const int64_t item) { return item != kInvalidId; }); + auto iter2 = std::copy_if(recv_rank_ids.begin(), recv_rank_ids.end(), real_recv_rank_ids.begin(), + [](const int64_t item) { return item != kInvalidId; }); + real_send_rank_ids.resize(std::distance(real_send_rank_ids.begin(), iter1)); + real_recv_rank_ids.resize(std::distance(real_recv_rank_ids.begin(), iter2)); + + std::vector dtypes(real_recv_rank_ids.size(), base_dtype); + + // create alltoallv node + auto all_to_all_v = graph->NewCNode(all_to_all_v_input); + MS_EXCEPTION_IF_NULL(all_to_all_v); + AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, all_to_all_v.get()); + + AnfAlgo::SetNodeAttr(kAttrSendRankIds, MakeValue>(real_send_rank_ids), all_to_all_v); + AnfAlgo::SetNodeAttr(kAttrRecvRankIds, MakeValue>(real_recv_rank_ids), all_to_all_v); + AnfAlgo::SetNodeAttr(kAttrRecvType, TypeIdToType(base_dtype), all_to_all_v); + AnfAlgo::SetNodeAttr(kAttrGroup, MakeValue(group), all_to_all_v); + MS_LOG(INFO) << "Create AllToAllv success, send rank size " << send_rank_ids.size() << ", recv rank size " + << recv_rank_ids.size(); + return all_to_all_v; +} + +int64_t AllToAllRealIds(int64_t ids, const std::vector &recv_rank_ids) { + int64_t real_ids = 0; + for (auto i = 0; i < ids; ++i) { + if (recv_rank_ids[i] != kInvalidId) { + ++real_ids; + } + } + return real_ids; +} + +CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const std::vector &concat_input, + const std::vector> &output_shape, const std::vector &output_dtype, + int64_t axis, int64_t input_nums) { + MS_EXCEPTION_IF_NULL(graph); + auto concat = graph->NewCNode(concat_input); + MS_EXCEPTION_IF_NULL(concat); + AnfAlgo::SetOutputInferTypeAndShape(output_dtype, output_shape, concat.get()); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), concat); + AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(input_nums), concat); + std::vector dyn_input_size_empty{input_nums}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size_empty), concat); + return concat; +} + +CNodePtr CreateLeftRightConcat(const FuncGraphPtr &graph, const std::vector &all_to_all_v_outputs, + const std::vector &recv_rank_ids, const std::vector &recv_lens, + bool is_left) { + MS_EXCEPTION_IF_NULL(graph); + + std::vector concat_input = {NewValueNode(std::make_shared(kConcatOpName))}; + int64_t input_num = 1; + size_t first_ids = is_left ? 7 : 1; + size_t middle_ids = is_left ? 6 : 2; + size_t last_ids = is_left ? 5 : 3; + + auto single_shape = AnfAlgo::GetOutputInferShape(all_to_all_v_outputs[AllToAllRealIds(middle_ids, recv_rank_ids)], 0); + + if (recv_rank_ids[first_ids] != kInvalidId) { + ++input_num; + single_shape[2] += static_cast(recv_lens[0]); // H in NCHW + } + if (recv_rank_ids[last_ids] != kInvalidId) { + ++input_num; + single_shape[2] += static_cast(recv_lens[1]); // H in NCHW + } + if (is_left) { + concat_input.insert(concat_input.end(), all_to_all_v_outputs.rbegin(), all_to_all_v_outputs.rbegin() + input_num); + } else { + concat_input.insert(concat_input.end(), all_to_all_v_outputs.begin() + AllToAllRealIds(1, recv_rank_ids), + all_to_all_v_outputs.begin() + input_num + AllToAllRealIds(1, recv_rank_ids)); + } + + std::vector concat_output_dtype = { + AnfAlgo::GetOutputInferDataType(all_to_all_v_outputs[AllToAllRealIds(middle_ids, recv_rank_ids)], 0)}; + auto concat = CreateConcatNode(graph, concat_input, {single_shape}, concat_output_dtype, kHDim, input_num); + + return concat; +} + +CNodePtr CreateMiddleConcat(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2, + const std::vector &all_to_all_v_outputs, + const std::vector &recv_rank_ids, const std::vector &recv_lens, + int64_t concat_dim) { + std::vector concat_input_all = {NewValueNode(std::make_shared(kConcatOpName))}; + int64_t input_num_all = 0; + auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx); + auto single_shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0); + size_t first_idx = concat_dim == kWDim ? 6 : 0; + size_t last_idx = concat_dim == kWDim ? 2 : 4; + size_t first_len = concat_dim == kWDim ? static_cast(recv_lens[2]) : static_cast(recv_lens[0]); + size_t last_len = concat_dim == kWDim ? static_cast(recv_lens[3]) : static_cast(recv_lens[1]); + + // left + if (recv_rank_ids[first_idx] != kInvalidId) { + if (concat_dim == kWDim) { + concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.end() - 1, all_to_all_v_outputs.end()); + } else { + concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.begin() + 1); + } + + ++input_num_all; + single_shape[concat_dim] += first_len; + } + + concat_input_all.push_back(neighbor_exchange_v2_input); + ++input_num_all; + // right + if (recv_rank_ids[last_idx] != kInvalidId) { + if (concat_dim == kWDim) { + concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.begin() + 1); + } else { + int64_t bottom_num = AllToAllRealIds(4, recv_rank_ids); + concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin() + bottom_num, + all_to_all_v_outputs.begin() + bottom_num + 1); + } + + ++input_num_all; + single_shape[concat_dim] += last_len; + } + + std::vector concat_output_dtype = {AnfAlgo::GetOutputInferDataType(all_to_all_v_outputs[0], 0)}; + auto concat_all = + CreateConcatNode(graph, concat_input_all, {single_shape}, concat_output_dtype, concat_dim, input_num_all); + return concat_all; +} + +CNodePtr AllToAllvRecvEmpty(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2, + const CNodePtr &all_to_all_v) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(neighbor_exchange_v2); + MS_EXCEPTION_IF_NULL(all_to_all_v); + // add depend for input & alltoallv + auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx); + std::vector depend_input = {NewValueNode(std::make_shared(prim::kPrimDepend->name())), + neighbor_exchange_v2_input, all_to_all_v}; + auto depend = graph->NewCNode(depend_input); + MS_EXCEPTION_IF_NULL(depend); + depend->set_abstract(neighbor_exchange_v2_input->abstract()); + return depend; +} + +CNodePtr CreateConcatNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2, + const CNodePtr &all_to_all_v) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(neighbor_exchange_v2); + MS_EXCEPTION_IF_NULL(all_to_all_v); + std::vector recv_rank_ids = + AnfAlgo::GetNodeAttr>(neighbor_exchange_v2, kAttrRecvRankIds); + std::vector recv_lens = AnfAlgo::GetNodeAttr>(neighbor_exchange_v2, kAttrRecvLens); + + int64_t all_to_all_output_num = + std::count_if(recv_rank_ids.begin(), recv_rank_ids.end(), [](int64_t ids) { return ids != kInvalidId; }); + + if (all_to_all_output_num == 0) { + return AllToAllvRecvEmpty(graph, neighbor_exchange_v2, all_to_all_v); + } + + std::vector all_to_all_v_outputs; + CreateMultipleOutputsOfAnfNode(graph, all_to_all_v, static_cast(all_to_all_output_num), + &all_to_all_v_outputs); + if (all_to_all_v_outputs.empty()) { + MS_LOG(EXCEPTION) << "The node " << all_to_all_v->DebugString() << " should have at least one output, but got 0."; + } + + if (recv_rank_ids[kRankIdZero] == kInvalidId && recv_rank_ids[kRankIdFour] == kInvalidId) { + return CreateMiddleConcat(graph, neighbor_exchange_v2, all_to_all_v_outputs, recv_rank_ids, recv_lens, kWDim); + } + + // top or bottom + // middle concat + auto concat_middle = + CreateMiddleConcat(graph, neighbor_exchange_v2, all_to_all_v_outputs, recv_rank_ids, recv_lens, kHDim); + + bool is_left = recv_rank_ids[kRankIdSix] != kInvalidId || recv_rank_ids[kRankIdFive] != kInvalidId || + recv_rank_ids[kRankIdSeven] != kInvalidId; + bool is_right = recv_rank_ids[kRankIdOne] != kInvalidId || recv_rank_ids[kRankIdTwo] != kInvalidId || + recv_rank_ids[kRankIdThree] != kInvalidId; + if (!is_left && !is_right) { + return concat_middle; + } + + std::vector concat_input_all = {NewValueNode(std::make_shared(kConcatOpName))}; + auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx); + std::vector shape_all = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0); + shape_all[2] = + recv_rank_ids[kRankIdZero] != kInvalidId ? shape_all[2] + static_cast(recv_lens[0]) : shape_all[2]; + shape_all[2] = + recv_rank_ids[kRankIdFour] != kInvalidId ? shape_all[2] + static_cast(recv_lens[1]) : shape_all[2]; + int64_t input_nums_all = 0; + // left concat + if (is_left) { + auto concat_left = CreateLeftRightConcat(graph, all_to_all_v_outputs, recv_rank_ids, recv_lens, true); + + // connect to concat_all + std::vector concat_left_outputs; + CreateMultipleOutputsOfAnfNode(graph, concat_left, 1, &concat_left_outputs); + if (concat_left_outputs.empty()) { + MS_LOG(EXCEPTION) << "The node " << concat_left->DebugString() << " should have at least one output, but got 0."; + } + concat_input_all.insert(concat_input_all.end(), concat_left_outputs.begin(), concat_left_outputs.end()); + ++input_nums_all; + shape_all[3] += recv_lens[2]; + } + + // middle concat connect to concat_all + std::vector concat_middle_outputs; + CreateMultipleOutputsOfAnfNode(graph, concat_middle, 1, &concat_middle_outputs); + if (concat_middle_outputs.empty()) { + MS_LOG(EXCEPTION) << "The node " << concat_middle->DebugString() << " should have at least one output, but got 0."; + } + concat_input_all.insert(concat_input_all.end(), concat_middle_outputs.begin(), concat_middle_outputs.end()); + ++input_nums_all; + + if (is_right) { + auto concat_right = CreateLeftRightConcat(graph, all_to_all_v_outputs, recv_rank_ids, recv_lens, false); + + // connect to concat_all + std::vector concat_right_outputs; + CreateMultipleOutputsOfAnfNode(graph, concat_right, 1, &concat_right_outputs); + if (concat_right_outputs.empty()) { + MS_LOG(EXCEPTION) << "The node " << concat_right->DebugString() << " should have at least one output, but got 0."; + } + concat_input_all.insert(concat_input_all.end(), concat_right_outputs.begin(), concat_right_outputs.end()); + ++input_nums_all; + shape_all[3] += recv_lens[3]; + } + + std::vector concat_right_output_dtype = {AnfAlgo::GetOutputInferDataType(concat_input_all[1], 0)}; + auto concat_all = + CreateConcatNode(graph, concat_input_all, {shape_all}, concat_right_output_dtype, kWDim, input_nums_all); + return concat_all; +} + +// grad +// returns {top_bottom, left_right, top_corner, bottom_corner}, if no split, set it nullptr +std::vector CreateSplitNodesForGrad(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad, + std::vector *split_num) { + MS_LOG(DEBUG) << "Start create splitv nodes."; + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad); + MS_EXCEPTION_IF_NULL(split_num); + std::vector send_rank_ids = + AnfAlgo::GetNodeAttr>(neighbor_exchange_v2_grad, kAttrSendRankIds); + std::vector send_lens = AnfAlgo::GetNodeAttr>(neighbor_exchange_v2_grad, kAttrSendLens); + + if (neighbor_exchange_v2_grad->size() <= kNeighborExchangeV2InputIdx) { + MS_LOG(EXCEPTION) << "Invalid cnode " << neighbor_exchange_v2_grad->DebugString() << " input size " + << neighbor_exchange_v2_grad->size(); + } + + auto neighbor_exchange_v2_grad_input = neighbor_exchange_v2_grad->input(kNeighborExchangeV2InputIdx); + auto dtype = AnfAlgo::GetOutputInferDataType(neighbor_exchange_v2_grad_input, 0); + auto shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_grad_input, 0); + if (SizeToLong(shape.size()) != kShapeSize) { + MS_LOG(EXCEPTION) << "Invalid shape size " << shape.size() << ", only support NCHW input now!"; + } + + std::vector split_nodes = {}; + // splitv for top & bottom + bool is_top = ((send_rank_ids[kRankIdZero] != kInvalidId) || (send_rank_ids[kRankIdOne] != kInvalidId) || + (send_rank_ids[kRankIdSeven] != kInvalidId)); + bool is_bottom = ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFour] != kInvalidId) || + (send_rank_ids[kRankIdFive] != kInvalidId)); + CNodePtr split_v_top_bottom = nullptr; + int64_t num_split_h = 0; + if (is_top || is_bottom) { + std::vector split_input = {NewValueNode(std::make_shared(prim::kPrimSplitV->name())), + neighbor_exchange_v2_grad_input}; + split_v_top_bottom = + CreateSplitNode(graph, split_input, shape, is_top, is_bottom, kHDim, send_lens, dtype, &num_split_h); + } + split_nodes.emplace_back(split_v_top_bottom); + split_num->push_back(num_split_h); + + // splitvs for left & right + // inputs + std::vector split_outputs_top_bottom; + std::vector size_split_h; + if (split_nodes[0] != nullptr) { + CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast(num_split_h), &split_outputs_top_bottom); + if (split_outputs_top_bottom.empty()) { + MS_LOG(EXCEPTION) << "The node " << split_nodes[0]->DebugString() + << " should have at least one output, but got 0."; + } + size_split_h = AnfAlgo::GetNodeAttr>(split_nodes[0], kAttrSizeSplits); + } else { + // just middle + split_outputs_top_bottom.push_back(neighbor_exchange_v2_grad_input); + size_split_h.push_back(shape[kHDim]); + } + + // left_right splitv nodes from top to bottom + bool is_left = (send_rank_ids[kRankIdFive] != kInvalidId) || (send_rank_ids[kRankIdSix] != kInvalidId) || + (send_rank_ids[kRankIdSeven] != kInvalidId); + bool is_right = (send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdTwo] != kInvalidId) || + (send_rank_ids[kRankIdThree] != kInvalidId); + if (is_left || is_right) { + if (!is_top) { + split_nodes.push_back(nullptr); + split_num->push_back(0); + } + for (size_t i = 0; i < split_outputs_top_bottom.size(); ++i) { + std::vector split_input = {NewValueNode(std::make_shared(prim::kPrimSplitV->name())), + split_outputs_top_bottom[i]}; + + int64_t num_split_w = 0; + std::vector base_shape(shape); + base_shape[kHDim] = size_split_h[i]; + auto split_v_left_right = + CreateSplitNode(graph, split_input, base_shape, is_left, is_right, kWDim, send_lens, dtype, &num_split_w); + split_nodes.emplace_back(split_v_left_right); + split_num->push_back(num_split_w); + } + if (!is_bottom) { + split_nodes.push_back(nullptr); + split_num->push_back(0); + } + } else { + split_nodes.push_back(nullptr); + split_num->push_back(0); + split_nodes.push_back(nullptr); + split_num->push_back(0); + split_nodes.push_back(nullptr); + split_num->push_back(0); + } + MS_LOG(DEBUG) << "Create splitv nodes success."; + return split_nodes; +} + +CNodePtr CreatePadNode(const FuncGraphPtr &graph, const AnfNodePtr &input, const std::vector &begin, + const std::vector &size, const std::vector &shape, TypeId dtype) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(input); + std::vector pad_inputs = {NewValueNode(std::make_shared(kPadOpName)), input}; + auto pad = graph->NewCNode(pad_inputs); + std::vector> paddings; + for (size_t i = 0; i < shape.size(); ++i) { + paddings.emplace_back(std::vector{begin[i], static_cast(shape[i]) - begin[i] - size[i]}); + } + AnfAlgo::SetOutputInferTypeAndShape({dtype}, {shape}, pad.get()); + AnfAlgo::SetNodeAttr(kAttrPaddings, MakeValue(paddings), pad); + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(std::vector{"x"}), pad); + return pad; +} + +CNodePtr CreateSplitGradNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad, + const CNodePtr &all_to_all_v, const std::vector &split_nodes, + const std::vector &split_num) { + MS_LOG(DEBUG) << "Start create splitvs grad nodes."; + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad); + std::vector send_rank_ids = + AnfAlgo::GetNodeAttr>(neighbor_exchange_v2_grad, kAttrSendRankIds); + std::vector recv_rank_ids = + AnfAlgo::GetNodeAttr>(neighbor_exchange_v2_grad, kAttrRecvRankIds); + std::vector recv_lens = AnfAlgo::GetNodeAttr>(neighbor_exchange_v2_grad, kAttrRecvLens); + + auto centerx = GetCenter(graph, neighbor_exchange_v2_grad, split_nodes, split_num, send_rank_ids); + auto centerx_dtype = AnfAlgo::GetOutputInferDataType(centerx, 0); + auto centerx_shape = AnfAlgo::GetOutputInferShape(centerx, 0); + // empty + int64_t all_to_all_output_num = + std::count_if(recv_rank_ids.begin(), recv_rank_ids.end(), [](int64_t ids) { return ids != kInvalidId; }); + + if (all_to_all_output_num == 0) { + // add depend(alltoallv, centerx) + std::vector depend_input = {NewValueNode(std::make_shared(prim::kPrimDepend->name())), + centerx, all_to_all_v}; + auto depend = graph->NewCNode(depend_input); + MS_EXCEPTION_IF_NULL(depend); + depend->set_abstract(centerx->abstract()); + return depend; + } + // get alltoallv outputs + std::vector all_to_all_v_outputs; + CreateMultipleOutputsOfAnfNode(graph, all_to_all_v, static_cast(all_to_all_output_num), + &all_to_all_v_outputs); + if (all_to_all_v_outputs.empty()) { + MS_LOG(EXCEPTION) << "The node " << all_to_all_v->DebugString() << " should have at least one output, but got 0."; + } + // create pad nodes + // slice begin & size + std::vector> begins = {{0, 0, 0, 0}, + {0, 0, 0, static_cast(centerx_shape[3]) - recv_lens[3]}, + {0, 0, 0, static_cast(centerx_shape[3]) - recv_lens[3]}, + {0, 0, static_cast(centerx_shape[2]) - recv_lens[1], + static_cast(centerx_shape[3]) - recv_lens[3]}, + {0, 0, static_cast(centerx_shape[2]) - recv_lens[1], 0}, + {0, 0, static_cast(centerx_shape[2]) - recv_lens[1], 0}, + {0, 0, 0, 0}, + {0, 0, 0, 0}}; + std::vector> sizes = { + {static_cast(centerx_shape[0]), static_cast(centerx_shape[1]), recv_lens[0], + static_cast(centerx_shape[3])}, + {static_cast(centerx_shape[0]), static_cast(centerx_shape[1]), recv_lens[0], recv_lens[3]}, + {static_cast(centerx_shape[0]), static_cast(centerx_shape[1]), + static_cast(centerx_shape[2]), recv_lens[3]}, + {static_cast(centerx_shape[0]), static_cast(centerx_shape[1]), recv_lens[1], recv_lens[3]}, + {static_cast(centerx_shape[0]), static_cast(centerx_shape[1]), recv_lens[1], + static_cast(centerx_shape[3])}, + {static_cast(centerx_shape[0]), static_cast(centerx_shape[1]), recv_lens[1], recv_lens[2]}, + {static_cast(centerx_shape[0]), static_cast(centerx_shape[1]), + static_cast(centerx_shape[2]), recv_lens[2]}, + {static_cast(centerx_shape[0]), static_cast(centerx_shape[1]), recv_lens[0], recv_lens[2]}}; + std::vector pad_nodes; + size_t output_index = 0; + for (size_t i = 0; i < recv_rank_ids.size(); ++i) { + if (recv_rank_ids[i] != kInvalidId) { + auto pad = + CreatePadNode(graph, all_to_all_v_outputs[output_index], begins[i], sizes[i], centerx_shape, centerx_dtype); + ++output_index; + pad_nodes.emplace_back(pad); + } + } + + // create add node + std::vector addn_inputs = {NewValueNode(std::make_shared(kAddNOpName)), centerx}; + int64_t pad_num = 1; + for (auto pad : pad_nodes) { + std::vector pad_outputs; + CreateMultipleOutputsOfAnfNode(graph, pad, 1, &pad_outputs); + if (pad_outputs.empty()) { + MS_LOG(EXCEPTION) << "The node " << pad->DebugString() << " should have at least one output, but got 0."; + } + addn_inputs.insert(addn_inputs.end(), pad_outputs.begin(), pad_outputs.end()); + ++pad_num; + } + auto addn = graph->NewCNode(addn_inputs); + MS_EXCEPTION_IF_NULL(addn); + AnfAlgo::SetOutputInferTypeAndShape({centerx_dtype}, {centerx_shape}, addn.get()); + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue>({pad_num}), addn); + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(pad_num), addn); + MS_LOG(DEBUG) << "Create splitvs grad nodes success."; + return addn; +} +} // namespace + +const BaseRef NeighborExchangeV2UnifyMindIR::DefinePattern() const { + return VectorRef({prim::kPrimNeighborExchangeV2, std::make_shared()}); +} + +const AnfNodePtr NeighborExchangeV2UnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto neighbor_exchange_v2 = node->cast(); + MS_EXCEPTION_IF_NULL(neighbor_exchange_v2); + std::vector split_num; + auto split_nodes = CreateSplitNodes(graph, neighbor_exchange_v2, &split_num); + auto all_to_all_v = CreateAllToAllvNode(graph, neighbor_exchange_v2, split_nodes, split_num, false); + auto concat = CreateConcatNodes(graph, neighbor_exchange_v2, all_to_all_v); + return concat; +} + +const BaseRef NeighborExchangeV2GradUnifyMindIR::DefinePattern() const { + return VectorRef({prim::kPrimNeighborExchangeV2Grad, std::make_shared()}); +} +const AnfNodePtr NeighborExchangeV2GradUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto neighbor_exchange_v2_grad = node->cast(); + MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad); + std::vector split_num; + auto split_nodes = CreateSplitNodesForGrad(graph, neighbor_exchange_v2_grad, &split_num); + auto all_to_all_v = CreateAllToAllvNode(graph, neighbor_exchange_v2_grad, split_nodes, split_num, true); + auto add = CreateSplitGradNodes(graph, neighbor_exchange_v2_grad, all_to_all_v, split_nodes, split_num); + return add; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.h b/mindspore/ccsrc/backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.h new file mode 100644 index 00000000000..c9ed3c08631 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.h @@ -0,0 +1,46 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_NEIGHBOR_EXCHANGE_V2_UNIFY_MINDIR_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_NEIGHBOR_EXCHANGE_V2_UNIFY_MINDIR_H_ + +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +class NeighborExchangeV2UnifyMindIR : public PatternProcessPass { + public: + explicit NeighborExchangeV2UnifyMindIR(bool multigraph = true) + : PatternProcessPass("neighbor_exchange_v2_unify_mindir", multigraph) {} + ~NeighborExchangeV2UnifyMindIR() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; + +class NeighborExchangeV2GradUnifyMindIR : public PatternProcessPass { + public: + explicit NeighborExchangeV2GradUnifyMindIR(bool multigraph = true) + : PatternProcessPass("neighbor_exchange_v2_grad_unify_mindir", multigraph) {} + ~NeighborExchangeV2GradUnifyMindIR() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; + +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_NEIGHBOR_EXCHANGE_V2_UNIFY_MINDIR_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index fda7148b800..cc119a48a4a 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -235,6 +235,7 @@ constexpr char GET_TENSOR_SLICE[] = "_GetTensorSlice"; constexpr char SPLIT[] = "Split"; constexpr char ALL_TO_ALL[] = "AlltoAll"; constexpr char NEIGHBOREXCHANGE[] = "NeighborExchange"; +constexpr char NEIGHBOREXCHANGEV2[] = "NeighborExchangeV2"; constexpr char PERMUTE_BY_AXIS[] = "PermuteByAxis"; constexpr char CONCAT_BY_AXIS[] = "ConcatByAxis"; constexpr char SPLIT_BY_AXIS[] = "SplitByAxis"; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 3b980ef6ea2..23f4198e294 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -437,6 +437,8 @@ constexpr auto kAttrConcatDim = "concat_dim"; constexpr auto kAttrSplitCount = "split_count"; constexpr auto kAttrSendRankIds = "send_rank_ids"; constexpr auto kAttrRecvRankIds = "recv_rank_ids"; +constexpr auto kAttrSendLens = "send_lens"; +constexpr auto kAttrRecvLens = "recv_lens"; constexpr auto kAttrRankSize = "rank_size"; constexpr auto kAttrPadDimSize = "pad_dim_size"; constexpr auto kAttrPaddings = "paddings"; diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 549e102e35a..37fbf4eaf3e 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -437,6 +437,8 @@ inline const PrimitivePtr kPrimSend = std::make_shared("Send"); inline const PrimitivePtr kPrimReceive = std::make_shared("Receive"); inline const PrimitivePtr kPrimAllReduce = std::make_shared("AllReduce"); inline const PrimitivePtr kPrimNeighborExchange = std::make_shared("NeighborExchange"); +inline const PrimitivePtr kPrimNeighborExchangeV2 = std::make_shared("NeighborExchangeV2"); +inline const PrimitivePtr kPrimNeighborExchangeV2Grad = std::make_shared("NeighborExchangeV2Grad"); inline const PrimitivePtr kPrimAllToAll = std::make_shared("AlltoAll"); inline const PrimitivePtr kPrimAllToAllv = std::make_shared("AllToAllv"); inline const PrimitivePtr kPrimAllSwap = std::make_shared("AllSwap"); diff --git a/mindspore/core/ops/neighborexchangev2.cc b/mindspore/core/ops/neighborexchangev2.cc new file mode 100644 index 00000000000..8336afe0362 --- /dev/null +++ b/mindspore/core/ops/neighborexchangev2.cc @@ -0,0 +1,156 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/neighborexchangev2.h" +#include +#include "utils/check_convert_utils.h" +#include "abstract/primitive_infer_map.h" + +namespace mindspore { +namespace ops { +namespace { +constexpr auto kSendRankIds = "send_rank_ids"; +constexpr auto kSendLens = "send_lens"; +constexpr auto kRecvRankIds = "recv_rank_ids"; +constexpr auto kRecvLens = "recv_lens"; +constexpr auto kDataFormat = "format"; +constexpr auto kGroup = "group"; +constexpr size_t kRankIdsSize = 8; +constexpr size_t kLensSize = 4; + +std::vector CheckAttrSize(const PrimitivePtr &primitive, const std::string &attr_name, + const size_t attr_size) { + MS_EXCEPTION_IF_NULL(primitive); + // size of send/recv_rank_ids equal to size of send/recv_shapes + std::vector attr_value; + try { + auto attr = primitive->GetAttr(attr_name); + if (attr->cast() == nullptr) { + MS_EXCEPTION(TypeError); + } + attr_value = GetValue>(attr); + } catch (const std::exception &) { + MS_EXCEPTION(TypeError) << "Attr " << attr_name << " must be a list[int, int, ...]."; + } + + if (attr_value.size() != attr_size) { + MS_EXCEPTION(ValueError) << "Invalid " << primitive->name() << " attr " << attr_name << " size " + << attr_value.size() << " must be equal to size " << attr_size; + } + return attr_value; +} + +void CheckRecvCorner(std::vector recv_rank_ids, int64_t idx1, int64_t idx2, int64_t idx_corner) { + if (recv_rank_ids[idx1] != -1 && recv_rank_ids[idx2] != -1 && recv_rank_ids[idx_corner] == -1) { + MS_EXCEPTION(ValueError) << "Invalid recv_rank_ids, as recv_rank_ids[" << idx1 << "] = " << recv_rank_ids[idx1] + << ", recv_rank_ids[" << idx2 << "] = " << recv_rank_ids[idx2] << ", and recv_rank_ids[" + << idx_corner << "] = " << recv_rank_ids[idx_corner] << "."; + } + if ((recv_rank_ids[idx1] == -1 || recv_rank_ids[idx2] == -1) && recv_rank_ids[idx_corner] != -1) { + MS_EXCEPTION(ValueError) << "Invalid recv_rank_ids, as recv_rank_ids[" << idx1 << "] = " << recv_rank_ids[idx1] + << ", recv_rank_ids[" << idx2 << "] = " << recv_rank_ids[idx2] << ", and recv_rank_ids[" + << idx_corner << "] = " << recv_rank_ids[idx_corner] << "."; + } +} + +void Check(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + const int64_t input_num = 1; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); + + // check size of send_rank_ids, recv_rank_ids, send_lens, recv_lens + (void)CheckAttrSize(primitive, kSendRankIds, kRankIdsSize); + auto recv_rank_ids = CheckAttrSize(primitive, kRecvRankIds, kRankIdsSize); + (void)CheckAttrSize(primitive, kSendLens, kLensSize); + (void)CheckAttrSize(primitive, kRecvLens, kLensSize); + + // check recv rankids invalid cond + CheckRecvCorner(recv_rank_ids, 0, 2, 1); + CheckRecvCorner(recv_rank_ids, 2, 4, 3); + CheckRecvCorner(recv_rank_ids, 4, 6, 5); + CheckRecvCorner(recv_rank_ids, 6, 0, 7); + + // check data_format is NCHW + auto format = GetValue(primitive->GetAttr(kDataFormat)); + if (format != "NCHW") { + MS_EXCEPTION(ValueError) << "Attr data_format only support NCHW now."; + } + + // check group + auto group_attr = primitive->GetAttr(kGroup); + try { + MS_EXCEPTION_IF_NULL(group_attr); + (void)GetValue(group_attr); + } catch (const std::exception &) { + MS_EXCEPTION(TypeError) << "Attr " << kGroup << " should be a str."; + } +} + +abstract::BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto recv_rank_ids = primitive->GetAttr(kRecvRankIds); + MS_EXCEPTION_IF_NULL(recv_rank_ids); + auto recv_rank_ids_value = recv_rank_ids->cast(); + MS_EXCEPTION_IF_NULL(recv_rank_ids_value); + std::vector recv_rank_ids_v = GetValue>(recv_rank_ids_value); + auto recv_lens = primitive->GetAttr(kRecvLens); + MS_EXCEPTION_IF_NULL(recv_lens); + auto recv_lens_value = recv_lens->cast(); + MS_EXCEPTION_IF_NULL(recv_lens_value); + std::vector recv_lens_v = GetValue>(recv_lens_value); + + std::vector input_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + if (recv_rank_ids_v[0] != -1) { + input_shape[2] += recv_lens_v[0]; + } + if (recv_rank_ids_v[4] != -1) { + input_shape[2] += recv_lens_v[1]; + } + if (recv_rank_ids_v[6] != -1) { + input_shape[3] += recv_lens_v[2]; + } + if (recv_rank_ids_v[2] != -1) { + input_shape[3] += recv_lens_v[3]; + } + BaseShapePtr output_shape = std::make_shared(input_shape); + if (input_shape.empty()) { + return std::make_shared(); + } + return output_shape; +} + +TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + // recv type + TypePtr recv_type = input_args[0]->BuildType(); + if (recv_type == nullptr) { + return std::make_shared(); + } + return recv_type; +} +} // namespace +AbstractBasePtr NeighborExchangeV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + Check(primitive, input_args); + auto type = InferType(primitive, input_args); + auto shape = InferShape(primitive, input_args); + return abstract::MakeAbstract(shape, type); +} +REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchangeV2, prim::kPrimNeighborExchangeV2, NeighborExchangeV2Infer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/neighborexchangev2.h b/mindspore/core/ops/neighborexchangev2.h new file mode 100644 index 00000000000..0da412c1f9a --- /dev/null +++ b/mindspore/core/ops/neighborexchangev2.h @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_NEIGHBOREXCHANGEV2_H_ +#define MINDSPORE_CORE_OPS_NEIGHBOREXCHANGEV2_H_ +#include +#include +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameNeighborExchangeV2 = "NeighborExchangeV2"; +class MS_CORE_API NeighborExchangeV2 : public PrimitiveC { + public: + NeighborExchangeV2() : PrimitiveC(kNameNeighborExchangeV2) {} + ~NeighborExchangeV2() = default; + MS_DECLARE_PARENT(NeighborExchangeV2, PrimitiveC); + void Init() {} +}; +using kPrimNeighborExchangeV2Ptr = std::shared_ptr; + +AbstractBasePtr NeighborExchangeV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_NEIGHBOREXCHANGEV2_H_ diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index 202475d82f2..a29296e9bb8 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -22,12 +22,13 @@ from mindspore.parallel._utils import _get_enable_parallel_optimizer from .. import operations as P from ...common.tensor import RowTensor from ..composite.multitype_ops.zeros_like_impl import zeros_like -from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, NeighborExchange, AlltoAll, +from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, NeighborExchange, AlltoAll, NeighborExchangeV2, Broadcast, _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap, _VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator, _MicroStepAllGather) from .grad_base import bprop_getters from ..operations._inner_ops import Send, Receive +from ..operations import _grad_ops as G @bprop_getters.register(AllReduce) @@ -390,6 +391,24 @@ def get_bprop_all_to_all(self): return bprop +@bprop_getters.register(NeighborExchangeV2) +def get_bprop_neighborexchangev2(self): + """Generate bprop for NeighborExchangeV2.""" + group = self.group + send_rank_ids = self.recv_rank_ids + recv_rank_ids = self.send_rank_ids + send_lens = self.recv_lens + recv_lens = self.send_lens + data_format = self.data_format + neighborexchangev2_grad = G.NeighborExchangeV2Grad(send_rank_ids, send_lens, recv_rank_ids, + recv_lens, data_format, group) + + def bprop(x, out, dout): + return (neighborexchangev2_grad(dout),) + + return bprop + + @bprop_getters.register(_MirrorOperator) def get_bprop_mirror_operator(self): """ diff --git a/mindspore/ops/bprop_mindir/Broadcast_bprop.mindir b/mindspore/ops/bprop_mindir/Broadcast_bprop.mindir index 893d3235978..46d60c68d74 100644 --- a/mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +++ b/mindspore/ops/bprop_mindir/Broadcast_bprop.mindir @@ -5,4 +5,4 @@ e bprop.8:x* bprop.8:out* bprop.8:dout2 -bprop.8:[CNode]:1:@13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e8P \ No newline at end of file +bprop.8:[CNode]:1:@74787be4234cdeb03f214519cd8358a5f4ad2f5606dbeb494462cddc448eb4beP \ No newline at end of file diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index dc82f2bbab2..f7ad411431e 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted, TensorScatterMax, TensorScatterMin, TensorScatterSub, ScatterElements, ExtractVolumePatches) -from .comm_ops import (AllGather, AllReduce, NeighborExchange, AlltoAll, AllSwap, ReduceScatter, Broadcast, +from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, AllSwap, ReduceScatter, Broadcast, _MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset, _VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad, _HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather) @@ -486,6 +486,9 @@ __all__ = [ "Trunc", "Complex", "ExtractVolumePatches", + "NeighborExchangeV2", + "NeighborExchange", + "AlltoAll", ] __sponge__ = [ diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 7f3d46bd187..532f3a56062 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -23,6 +23,7 @@ from ..._checkparam import Validator as validator, Rel from .._utils import get_concat_offset from ...common import dtype as mstype from ... import context +from ...communication.management import GlobalComm class AbsGrad(PrimitiveWithInfer): @@ -723,6 +724,42 @@ class BNTrainingUpdateGrad(PrimitiveWithInfer): def infer_dtype(self, grads, x, batch_mean, batch_variance): return (batch_mean, batch_variance) +class NeighborExchangeV2Grad(PrimitiveWithInfer): + """"Gradients of NeighborExchangeV2 operation.""" + + @prim_attr_register + def __init__(self, send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format, + group=GlobalComm.WORLD_COMM_GROUP): + self.init_prim_io_names(inputs=['dy'], outputs=['dx']) + self.send_rank_ids = send_rank_ids + self.recv_rank_ids = recv_rank_ids + self.send_lens = send_lens + self.recv_lens = recv_lens + self.format = validator.check_string(data_format, ['NCHW'], 'format', self.name) + self.add_prim_attr('no_elimilate', True) + + def __infer__(self, dy): + dy_shape = dy['shape'] + validator.check(f'dy_shape.size()', len(dy_shape), f'4', 4, Rel.EQ, self.name) + if self.send_rank_ids[5] != -1 or self.send_rank_ids[6] != -1 or self.send_rank_ids[7] != -1: + dy_shape[3] -= self.send_lens[2] + + if self.send_rank_ids[1] != -1 or self.send_rank_ids[2] != -1 or self.send_rank_ids[3] != -1: + dy_shape[3] -= self.send_lens[3] + + if self.send_rank_ids[0] != -1 or self.send_rank_ids[1] != -1 or self.send_rank_ids[7] != -1: + dy_shape[2] -= self.send_lens[0] + + if self.send_rank_ids[3] != -1 or self.send_rank_ids[4] != -1 or self.send_rank_ids[5] != -1: + dy_shape[2] -= self.send_lens[1] + + return {'shape': dy_shape, + 'dtype': dy['dtype'], + 'value': None} + + def __call__(self, tensor): + raise NotImplementedError + class GeLUGrad(PrimitiveWithInfer): """Gradients of GeLU operation.""" diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 2a0dd145c93..9aab7e1183f 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -710,6 +710,39 @@ class AlltoAll(PrimitiveWithInfer): def __call__(self, tensor): raise NotImplementedError +class NeighborExchangeV2(Primitive): + """ + NeighborExchange is a collective operation. + + NeighborExchange sends data from the local rank to ranks in the send_rank_ids, + as while receive data from recv_rank_ids. + + Args: + send_rank_ids (list(int)): Ranks which the data is sent to. 8 rank_ids represents 8 directions, if one + direction is not send to , set it -1. + recv_rank_ids (list(int)): Ranks which the data is received from. 8 rank_ids represents 8 directions, + if one direction is not recv from , set it -1. + send_lens (list(int)): Data lens which send to the send_rank_ids, 4 numbers represent the lens of + [top, bottom, left, right]. + recv_lens (list(int)): Data lens which received from recv_rank_ids, 4 numbers represent the lens of + [top, bottom, left, right]. + data_format (str): Data format, only support NCHW now. + group (str): The communication group to work on. Default: "GlobalComm.WORLD_COMM_GROUP". + """ + + @prim_attr_register + def __init__(self, send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format, + group=GlobalComm.WORLD_COMM_GROUP): + self.init_prim_io_names(inputs=['x'], outputs=['output']) + self.send_rank_ids = send_rank_ids + self.recv_rank_ids = recv_rank_ids + self.send_lens = send_lens + self.recv_lens = recv_lens + self.format = data_format + self.add_prim_attr('no_elimilate', True) + + def __call__(self, tensor): + raise NotImplementedError class _MirrorOperator(PrimitiveWithInfer): """ diff --git a/tests/ut/python/parallel/test_neighborexchangev2.py b/tests/ut/python/parallel/test_neighborexchangev2.py new file mode 100644 index 00000000000..86c0bf8cf65 --- /dev/null +++ b/tests/ut/python/parallel/test_neighborexchangev2.py @@ -0,0 +1,542 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import pytest +import numpy as np +import mindspore as ms +import mindspore.context as context +from mindspore import Tensor +import mindspore.nn as nn +from mindspore.common.api import _cell_graph_executor +from mindspore.nn import TrainOneStepCell, Momentum +from mindspore.ops.operations.comm_ops import NeighborExchangeV2 + +_x1 = Tensor(np.ones([1, 1, 32, 16]), dtype=ms.float32) +_x2 = Tensor(np.ones([1, 1, 33, 16]), dtype=ms.float32) + + +def compile_net(net, x1, x2): + context.set_context(mode=context.GRAPH_MODE) + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_train() + _cell_graph_executor.compile(train_net, x1, x2) + + +def test_neighborexchangev2_single_input_success(): + """ + Feature: NeighborExchangeV2 + Description: one inputs and one outputs, with valid arguments + Expectation: success + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.linear = nn.Dense(16, 16) + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, 1, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=[0, 1, 0, 0], data_format="NCHW") + + def construct(self, x1, x2): + y = self.linear(x1) + y = self.neighborexchangev2(y) + y = y + x2 + return y + + net = Net() + compile_net(net, _x1, _x2) + + +def test_neighborexchangev2_empty_send_success(): + """ + Feature: NeighborExchangeV2 + Description: empty inputs, with valid arguments + Expectation: success + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.linear = nn.Dense(16, 16) + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, -1, -1, -1, -1], + send_lens=[1, 2, 3, 4], + recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=[0, 1, 0, 0], + data_format="NCHW") + + def construct(self, x1, x2): + y = self.linear(x1) + y = self.neighborexchangev2(y) + y = y + x2 + return y + + net = Net() + compile_net(net, _x1, _x2) + + +def test_neighborexchangev2_empty_recv_success(): + """ + Feature: NeighborExchangeV2 + Description: empty outputs, with valid arguments + Expectation: success + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.linear = nn.Dense(16, 16) + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, 1, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, -1, -1, -1, -1], + recv_lens=[1, 2, 3, 4], + data_format="NCHW") + + def construct(self, x1, x2): + y = self.linear(x1) + y = self.neighborexchangev2(y) + y = y + x2 + return y + + net = Net() + compile_net(net, _x1, _x1) + + +def test_neighborexchangev2_empty_send_empty_recv_success(): + """ + Feature: NeighborExchangeV2 + Description: empty inputs and empty outputs, with valid arguments + Expectation: success + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, -1, -1, -1, -1], + send_lens=[0, 1, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, -1, -1, -1, -1], + recv_lens=[1, 2, 3, 4], + data_format="NCHW") + + def construct(self, x1): + y = self.neighborexchangev2(x1) + return y + + net = Net() + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_invalid_dataformat_failed(): + """ + Feature: NeighborExchangeV2 + Description: data_format should be NCHW, but gives NHWC + Expectation: throw ValueError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, 1, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=[0, 1, 0, 0], + data_format="NHWC") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(ValueError): + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_invalid_send_rank_ids_size_failed(): + """ + Feature: NeighborExchangeV2 + Description: send_rank_ids size should be 8, but gives 5 + Expectation: throw ValueError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1], + send_lens=[0, 1, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=[0, 1, 0, 0], + data_format="NCHW") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(ValueError): + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_invalid_recv_rank_ids_size_failed(): + """ + Feature: NeighborExchangeV2 + Description: recv_rank_ids size should be 8, but gives 5 + Expectation: throw ValueError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, 1, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, 1], + recv_lens=[0, 1, 0, 0], + data_format="NCHW") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(ValueError): + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_invalid_send_lens_size_failed(): + """ + Feature: NeighborExchangeV2 + Description: send_lens size should be 4, but gives 5 + Expectation: throw ValueError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, 1, 0, 0, 2], + recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=[0, 1, 0, 0], + data_format="NCHW") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(ValueError): + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_invalid_recv_lens_size_failed(): + """ + Feature: NeighborExchangeV2 + Description: recv_lens size should be 4, but gives 5 + Expectation: throw ValueError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, 1, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=[0, 1, 0, 0, 2], + data_format="NCHW") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(ValueError): + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_invalid_input_size_failed(): + """ + Feature: NeighborExchangeV2 + Description: input should be one tensor, but gives 2 + Expectation: throw ValueError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, 1, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=[0, 1, 0, 0], + data_format="NCHW") + + def construct(self, x1, x2): + out = self.neighborexchangev2(x1, x2) + return out[0] + + net = Net() + with pytest.raises(ValueError): + _cell_graph_executor.compile(net, _x1, _x2) + +def test_neighborexchangev2_recv_rank_ids_invalid_value_failed(): + """ + Feature: NeighborExchangeV2 + Description: recv_rank_ids should can be concat, recv_rank_ids[3] and [4] is 1, [5] is -1 given + Expectation: throw Exception + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, 1, 0, 0], + recv_rank_ids=[-1, -1, -1, 1, 1, -1, -1, -1], + recv_lens=[0, 1, 0, 0], + data_format="NCHW") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(ValueError): + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_attr_check_send_rank_ids_is_tuple_failed(): + """ + Feature: NeighborExchangeV2 + Description: send_rank_ids should be list, but a tuple is given + Expectation: throw TypeError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=(-1, -1, -1, -1, 1, -1, -1, -1), + send_lens=[0, 1, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=[0, 1, 0, 0], + data_format="NCHW") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(TypeError): + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_attr_check_send_lens_is_tuple_failed(): + """ + Feature: NeighborExchangeV2 + Description: send_lens should be list, but a tuple is given + Expectation: throw TypeError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=(0, 1, 0, 0), + recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=[0, 1, 0, 0], + data_format="NCHW") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(TypeError): + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_attr_check_recv_rank_ids_is_tuple_failed(): + """ + Feature: NeighborExchangeV2 + Description: recv_rank_ids should be list, but a tuple is given + Expectation: throw TypeError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, 1, 0, 0], + recv_rank_ids=(-1, -1, -1, -1, 1, -1, -1, -1), + recv_lens=[0, 1, 0, 0], + data_format="NCHW") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(TypeError): + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_attr_check_recv_lens_is_tuple_failed(): + """ + Feature: NeighborExchangeV2 + Description: recv_lens should be list, but a tuple is given + Expectation: throw TypeError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, 1, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=(0, 1, 0, 0), + data_format="NCHW") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(TypeError): + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_attr_check_send_rank_ids_is_float_failed(): + """ + Feature: NeighborExchangeV2 + Description: send_rank_ids should be int, but float is given + Expectation: throw TypeError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1.0, -1, -1, -1], + send_lens=[0, 1, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=[0, 1, 0, 0], + data_format="NCHW") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(TypeError): + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_attr_check_send_lens_is_float_failed(): + """ + Feature: NeighborExchangeV2 + Description: send_lens should be int, but float is given + Expectation: throw TypeError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, 1.0, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=[0, 1, 0, 0], + data_format="NCHW") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(TypeError): + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_attr_check_recv_rank_ids_is_float_failed(): + """ + Feature: NeighborExchangeV2 + Description: send_rank_ids should be int, but float is given + Expectation: throw TypeError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, 1, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, 1.0, -1, -1, -1], + recv_lens=[0, 1, 0, 0], + data_format="NCHW") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(TypeError): + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_attr_check_recv_lens_is_float_failed(): + """ + Feature: NeighborExchangeV2 + Description: ids in send_rank_ids should be int, but float is given + Expectation: throw TypeError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, 1, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=[0, 1.0, 0, 0], + data_format="NCHW") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(TypeError): + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_group_is_tuple_failed(): + """ + Feature: NeighborExchangeV2 + Description: group should be a string, but tuple given + Expectation: throw TypeError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, 1, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=[0, 1, 0, 0], + data_format="NCHW", group=("str",)) + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(TypeError): + _cell_graph_executor.compile(net, _x1)