!25765 neighborExchangeV2 & grad

Merge pull request !25765 from TuDouNi/neighborExchangeV2
This commit is contained in:
i-robot 2021-11-05 09:31:03 +00:00 committed by Gitee
commit ded1c77bbf
14 changed files with 1789 additions and 3 deletions

View File

@ -144,6 +144,7 @@
#include "backend/optimizer/ascend/mindir/avg_pool_grad_unify_mindir.h"
#include "backend/optimizer/ascend/mindir/bn_grad_unify_mindir.h"
#include "backend/optimizer/ascend/mindir/all_to_all_unify_mindir.h"
#include "backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.h"
#include "backend/optimizer/pass/adjust_depend_for_parallel_optimizer_recompute_all_gather.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_compile.h"
#include "utils/ms_context.h"
@ -579,6 +580,8 @@ void AscendUnifyMindIR(const std::shared_ptr<session::KernelGraph> &graph) {
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::BatchNormGradUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::NeighborExchangeUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::NeighborExchangeV2UnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::NeighborExchangeV2GradUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::AllToAllUnifyMindIR>());
optimizer->AddPassManager(unify_mindir_pm);

View File

@ -0,0 +1,900 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.h"
#include <algorithm>
#include <vector>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "runtime/hccl_adapter/hccl_adapter.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kCNodePrimitiveIdx = 0;
constexpr size_t kNeighborExchangeV2InputIdx = 1;
constexpr size_t kLenTopIdx = 0;
constexpr size_t kLenBottomIdx = 1;
constexpr size_t kLenLeftIdx = 2;
constexpr size_t kLenRightIdx = 3;
constexpr size_t kHDim = 2; // dim of H in NCHW
constexpr size_t kWDim = 3; // dim of W in NCHW
constexpr int64_t kShapeSize = 4;
constexpr int64_t kRankIdZero = 0;
constexpr int64_t kRankIdOne = 1;
constexpr int64_t kRankIdTwo = 2;
constexpr int64_t kRankIdThree = 3;
constexpr int64_t kRankIdFour = 4;
constexpr int64_t kRankIdFive = 5;
constexpr int64_t kRankIdSix = 6;
constexpr int64_t kRankIdSeven = 7;
constexpr size_t kSizeFour = 4;
constexpr int64_t kInvalidId = -1;
// cal split attrs size_splits, shapes and num_split
int64_t CalSplitAttrs(const std::vector<size_t> &base_shape, const bool is_first, const bool is_last,
const int64_t split_dim, const std::vector<int64_t> &send_lens, std::vector<int64_t> *size_splits,
std::vector<std::vector<size_t>> *shapes) {
MS_EXCEPTION_IF_NULL(size_splits);
MS_EXCEPTION_IF_NULL(shapes);
if (SizeToLong(base_shape.size()) != kShapeSize) {
MS_LOG(EXCEPTION) << "Wrong base_shape size: " << base_shape.size() << ", it should be equal to 4.";
}
if (split_dim >= kShapeSize) {
MS_LOG(EXCEPTION) << "Wrong split_dim: " << split_dim << ", it should less than 4.";
}
int64_t num_split = 0;
int64_t split_middle_size = base_shape[split_dim];
std::vector<size_t> shape_tmp(base_shape);
// [top, bottom, left, right]
int64_t first_size = split_dim == kWDim ? send_lens[2] : send_lens[0];
int64_t last_size = split_dim == kWDim ? send_lens[3] : send_lens[1];
if (is_first) {
// first
++num_split;
size_splits->push_back(first_size);
split_middle_size -= first_size;
shape_tmp[split_dim] = static_cast<size_t>(first_size);
shapes->push_back(shape_tmp);
}
if (is_last) {
// middle
++num_split;
split_middle_size -= last_size;
size_splits->push_back(split_middle_size);
shape_tmp[split_dim] = static_cast<size_t>(split_middle_size);
shapes->push_back(shape_tmp);
// last
++num_split;
size_splits->push_back(last_size);
shape_tmp[split_dim] = static_cast<size_t>(last_size);
shapes->push_back(shape_tmp);
} else {
++num_split;
size_splits->push_back(split_middle_size);
shape_tmp[split_dim] = static_cast<size_t>(split_middle_size);
shapes->push_back(shape_tmp);
}
return num_split;
}
CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &split_input,
const std::vector<size_t> &base_shape, bool is_first, bool is_last, int64_t split_dim,
const std::vector<int64_t> &send_lens, TypeId input_dtype, int64_t *num_split) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(num_split);
if (split_input.empty()) {
MS_LOG(EXCEPTION) << "The input is empty, can not create splitv node.";
return nullptr;
}
auto split_v = graph->NewCNode(split_input);
MS_EXCEPTION_IF_NULL(split_v);
std::vector<int64_t> size_splits = {};
std::vector<std::vector<size_t>> shapes = {};
*num_split = CalSplitAttrs(base_shape, is_first, is_last, split_dim, send_lens, &size_splits, &shapes);
std::vector<TypeId> dtypes(*num_split, input_dtype);
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_v.get());
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue<int64_t>(split_dim), split_v);
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue<int64_t>(*num_split), split_v);
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue<std::vector<int64_t>>(size_splits), split_v);
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_v);
return split_v;
}
std::vector<std::vector<size_t>> CalAllToAllvOutputShape(const std::vector<size_t> &base_shape,
const std::vector<int64_t> &recv_lens,
const std::vector<int64_t> &recv_rank_ids) {
if (SizeToLong(base_shape.size()) != kShapeSize) {
MS_LOG(EXCEPTION) << "Wrong base_shape size: " << base_shape.size() << ", it should be equal to 4.";
}
std::vector<std::vector<size_t>> shapes = {};
std::vector<std::vector<size_t>> ori_shapes = {
{base_shape[0], base_shape[1], static_cast<size_t>(recv_lens[kLenTopIdx]), base_shape[kWDim]},
{base_shape[0], base_shape[1], static_cast<size_t>(recv_lens[kLenTopIdx]),
static_cast<size_t>(recv_lens[kLenRightIdx])},
{base_shape[0], base_shape[1], base_shape[kHDim], static_cast<size_t>(recv_lens[kLenRightIdx])},
{base_shape[0], base_shape[1], static_cast<size_t>(recv_lens[kLenBottomIdx]),
static_cast<size_t>(recv_lens[kLenRightIdx])},
{base_shape[0], base_shape[1], static_cast<size_t>(recv_lens[kLenBottomIdx]), base_shape[kWDim]},
{base_shape[0], base_shape[1], static_cast<size_t>(recv_lens[kLenBottomIdx]),
static_cast<size_t>(recv_lens[kLenLeftIdx])},
{base_shape[0], base_shape[1], base_shape[kHDim], static_cast<size_t>(recv_lens[kLenLeftIdx])},
{base_shape[0], base_shape[1], static_cast<size_t>(recv_lens[kLenTopIdx]),
static_cast<size_t>(recv_lens[kLenLeftIdx])}};
for (size_t idx = 0; idx < recv_rank_ids.size(); ++idx) {
if (recv_rank_ids[idx] != kInvalidId) {
shapes.push_back(ori_shapes[idx]);
}
}
return shapes;
}
// returns {top_bottom, left_right, top_corner, bottom_corner}, if no split, set it nullptr
std::vector<CNodePtr> CreateSplitNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
std::vector<int64_t> *split_num) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
MS_EXCEPTION_IF_NULL(split_num);
std::vector<int64_t> send_rank_ids =
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2, kAttrSendRankIds);
std::vector<int64_t> send_lens = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2, kAttrSendLens);
if (neighbor_exchange_v2->size() <= kNeighborExchangeV2InputIdx) {
MS_LOG(EXCEPTION) << "Invalid cnode " << neighbor_exchange_v2->DebugString() << " input size "
<< neighbor_exchange_v2->size();
}
std::vector<CNodePtr> split_nodes = {};
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
bool is_top = ((send_rank_ids[kRankIdZero] != kInvalidId) || (send_rank_ids[kRankIdOne] != kInvalidId) ||
(send_rank_ids[kRankIdSeven] != kInvalidId));
bool is_bottom = ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFour] != kInvalidId) ||
(send_rank_ids[kRankIdFive] != kInvalidId));
bool is_left = (send_rank_ids[kRankIdSix] != kInvalidId);
bool is_right = (send_rank_ids[kRankIdTwo] != kInvalidId);
auto dtype = AnfAlgo::GetOutputInferDataType(neighbor_exchange_v2_input, 0);
auto shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0);
if (SizeToLong(shape.size()) != kShapeSize) { // only support NCHW now
MS_LOG(EXCEPTION) << "Invalid shape size " << shape.size() << ", only support NCHW input now!";
}
// splitv for top & bottom
int64_t num_split_h = 0;
CNodePtr split_v_top_bottom = nullptr;
if (is_top || is_bottom) {
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
neighbor_exchange_v2_input};
split_v_top_bottom =
CreateSplitNode(graph, split_input, shape, is_top, is_bottom, kHDim, send_lens, dtype, &num_split_h);
}
split_nodes.emplace_back(split_v_top_bottom);
split_num->push_back(num_split_h);
// splitv for left & right
int64_t num_split_w = 0;
CNodePtr split_v_left_right = nullptr;
if (is_left || is_right) {
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
neighbor_exchange_v2_input};
split_v_left_right =
CreateSplitNode(graph, split_input, shape, is_left, is_right, kWDim, send_lens, dtype, &num_split_w);
}
split_nodes.emplace_back(split_v_left_right);
split_num->push_back(num_split_w);
// splitv for corner
if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId) ||
(send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) {
// top_bottom_split outputs
std::vector<AnfNodePtr> split_outputs_top_bottom;
CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast<size_t>((*split_num)[0]),
&split_outputs_top_bottom);
if (split_outputs_top_bottom.empty()) {
MS_LOG(EXCEPTION) << "The node " << split_nodes[0]->DebugString()
<< " should have at least one output, but got 0.";
}
// for top corner
if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId)) {
auto shape_tmp(shape);
shape_tmp[kHDim] = send_lens[0];
bool is_first = (send_rank_ids[kRankIdSeven] != kInvalidId);
bool is_last = (send_rank_ids[kRankIdOne] != kInvalidId);
std::vector<AnfNodePtr> split_v_corner_top_input = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
split_v_corner_top_input.insert(split_v_corner_top_input.end(), split_outputs_top_bottom.begin(),
split_outputs_top_bottom.begin() + 1);
int64_t num_split_top_corner = 0;
CNodePtr split_v_corner_top = CreateSplitNode(graph, split_v_corner_top_input, shape_tmp, is_first, is_last,
kWDim, send_lens, dtype, &num_split_top_corner);
split_nodes.emplace_back(split_v_corner_top);
split_num->push_back(num_split_top_corner);
} else {
split_nodes.emplace_back(nullptr);
split_num->push_back(0);
}
// for bottom corner
if ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) {
auto shape_tmp(shape);
shape_tmp[kHDim] = send_lens[1];
bool is_first = (send_rank_ids[kRankIdFive] != kInvalidId);
bool is_last = (send_rank_ids[kRankIdThree] != kInvalidId);
std::vector<AnfNodePtr> split_v_corner_bottom_input = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
split_v_corner_bottom_input.insert(split_v_corner_bottom_input.end(), split_outputs_top_bottom.end() - 1,
split_outputs_top_bottom.end());
int64_t num_split_bottom_corner = 0;
CNodePtr split_v_corner_bottom = CreateSplitNode(graph, split_v_corner_bottom_input, shape_tmp, is_first, is_last,
kWDim, send_lens, dtype, &num_split_bottom_corner);
split_nodes.emplace_back(split_v_corner_bottom);
split_num->push_back(num_split_bottom_corner);
} else {
split_nodes.emplace_back(nullptr);
split_num->push_back(0);
}
} else {
split_nodes.emplace_back(nullptr);
split_num->push_back(0);
split_nodes.emplace_back(nullptr);
split_num->push_back(0);
}
return split_nodes;
}
std::vector<AnfNodePtr> CreateAllToAllvInput(const std::vector<std::vector<AnfNodePtr>> &split_outputs,
const std::vector<int64_t> &send_rank_ids) {
std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllVOpName))};
std::vector<size_t> split_idx = {0, 2, 1, 3, 0, 3, 1, 2};
std::vector<bool> is_begin = {true, false, false, false, false, true, true, true};
for (size_t idx = 0; idx < send_rank_ids.size(); ++idx) {
if (send_rank_ids[idx] != kInvalidId) {
if (is_begin[idx]) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[split_idx[idx]].begin(),
split_outputs[split_idx[idx]].begin() + 1);
} else {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[split_idx[idx]].end() - 1,
split_outputs[split_idx[idx]].end());
}
}
}
return all_to_all_v_input;
}
// get center of input for grad
AnfNodePtr GetCenter(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad,
const std::vector<CNodePtr> &split_nodes, const std::vector<int64_t> &split_num,
const std::vector<int64_t> &send_rank_ids) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad);
std::vector<AnfNodePtr> output;
if (split_nodes[kRankIdTwo] == nullptr) {
if (split_nodes[0] != nullptr) {
CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast<size_t>(split_num[0]), &output);
if (output.size() < 2) {
MS_LOG(EXCEPTION) << "Wrong split output size: " << output.size() << ", except size >= 2.";
}
if (send_rank_ids[kRankIdZero] == kInvalidId) {
return output[0];
}
return output[1];
} else {
return neighbor_exchange_v2_grad->input(kNeighborExchangeV2InputIdx);
}
} else {
CreateMultipleOutputsOfAnfNode(graph, split_nodes[2], static_cast<size_t>(split_num[2]), &output);
if (output.size() < 2) {
MS_LOG(EXCEPTION) << "Wrong split output size: " << output.size() << ", except size >= 2.";
}
if (send_rank_ids[kRankIdSix] == kInvalidId) {
return output[0];
}
return output[1];
}
}
std::vector<AnfNodePtr> CreateAllToAllvInputForGrad(const std::vector<int64_t> &send_rank_ids,
const std::vector<std::vector<AnfNodePtr>> &split_outputs,
const std::vector<CNodePtr> &split_nodes) {
if (send_rank_ids.size() != 8) {
MS_LOG(EXCEPTION) << "Wrong send_rank_ids size: " << send_rank_ids.size() << ", expect size: 8.";
}
if (split_outputs.size() != kSizeFour) {
MS_LOG(EXCEPTION) << "Wrong split_outputs size: " << split_outputs.size() << ", expect size: 4.";
}
if (split_nodes.size() != kSizeFour) {
MS_LOG(EXCEPTION) << "Wrong split_nodes size: " << split_nodes.size() << ", expect size: 4.";
}
std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllVOpName))};
// only have top-bottom split
std::vector<size_t> side_idx = {1, 2, 3, 5, 6, 7};
bool no_send_side = std::all_of(side_idx.begin(), side_idx.end(),
[&send_rank_ids](int idx) { return send_rank_ids[idx] == kInvalidId; });
if (no_send_side) {
if (send_rank_ids[kRankIdZero] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[0].begin(), split_outputs[0].begin() + 1);
}
if (send_rank_ids[kRankIdFour] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[0].end() - 1, split_outputs[0].end());
}
return all_to_all_v_input;
}
// 0, 1
if (split_nodes[1] != nullptr) {
if (send_rank_ids[kRankIdSeven] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin() + 1, split_outputs[1].end());
} else {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin(), split_outputs[1].end());
}
}
// 2
if (split_nodes[2] != nullptr && send_rank_ids[kRankIdTwo] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[2].end() - 1, split_outputs[2].end());
}
// 3, 4, 5
if (split_nodes[3] != nullptr) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[3].rbegin(), split_outputs[3].rend());
}
// 6
if (split_nodes[2] != nullptr && send_rank_ids[kRankIdSix] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[2].begin(), split_outputs[2].begin() + 1);
}
// 7
if (split_nodes[1] != nullptr && send_rank_ids[kRankIdSeven] != kInvalidId) {
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin(), split_outputs[1].begin() + 1);
}
return all_to_all_v_input;
}
// alltoallv for forward & grad
CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_or_grad,
const std::vector<CNodePtr> &split_nodes, const std::vector<int64_t> &split_num,
bool is_grad) {
MS_LOG(DEBUG) << "Start to create alltoallv node.";
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_or_grad);
std::vector<int64_t> send_rank_ids =
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_or_grad, kAttrSendRankIds);
std::vector<int64_t> recv_rank_ids =
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_or_grad, kAttrRecvRankIds);
std::vector<int64_t> recv_lens =
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_or_grad, kAttrRecvLens);
std::string group = AnfAlgo::GetNodeAttr<std::string>(neighbor_exchange_v2_or_grad, kAttrGroup);
// get split nodes output, split_outputs: [top_bottom, left_right, top_corner, bottom_corner]
std::vector<std::vector<AnfNodePtr>> split_outputs;
for (size_t i = 0; i < split_nodes.size(); ++i) {
std::vector<AnfNodePtr> output;
if (split_nodes[i] != nullptr) {
CreateMultipleOutputsOfAnfNode(graph, split_nodes[i], static_cast<size_t>(split_num[i]), &output);
if (output.empty()) {
MS_LOG(EXCEPTION) << "The node " << split_nodes[i]->DebugString()
<< " should have at least one output, but got 0.";
}
}
split_outputs.emplace_back(output);
}
// all_to_all_v input
std::vector<AnfNodePtr> all_to_all_v_input;
AnfNodePtr base_node = nullptr;
if (is_grad) {
all_to_all_v_input = CreateAllToAllvInputForGrad(send_rank_ids, split_outputs, split_nodes);
base_node = GetCenter(graph, neighbor_exchange_v2_or_grad, split_nodes, split_num, send_rank_ids);
} else {
all_to_all_v_input = CreateAllToAllvInput(split_outputs, send_rank_ids);
base_node = neighbor_exchange_v2_or_grad->input(kNeighborExchangeV2InputIdx);
}
// output shapes and dtypes
auto base_dtype = AnfAlgo::GetOutputInferDataType(base_node, 0);
auto base_shape = AnfAlgo::GetOutputInferShape(base_node, 0);
if (SizeToLong(base_shape.size()) != kShapeSize) {
MS_LOG(EXCEPTION) << "Invalid shape size " << base_shape.size() << ", only support NCHW input now!";
}
std::vector<std::vector<size_t>> shapes = CalAllToAllvOutputShape(base_shape, recv_lens, recv_rank_ids);
// erase -1 in send_rank_ids
std::vector<int64_t> real_send_rank_ids(send_rank_ids.size());
std::vector<int64_t> real_recv_rank_ids(recv_rank_ids.size());
auto iter1 = std::copy_if(send_rank_ids.begin(), send_rank_ids.end(), real_send_rank_ids.begin(),
[](const int64_t item) { return item != kInvalidId; });
auto iter2 = std::copy_if(recv_rank_ids.begin(), recv_rank_ids.end(), real_recv_rank_ids.begin(),
[](const int64_t item) { return item != kInvalidId; });
real_send_rank_ids.resize(std::distance(real_send_rank_ids.begin(), iter1));
real_recv_rank_ids.resize(std::distance(real_recv_rank_ids.begin(), iter2));
std::vector<TypeId> dtypes(real_recv_rank_ids.size(), base_dtype);
// create alltoallv node
auto all_to_all_v = graph->NewCNode(all_to_all_v_input);
MS_EXCEPTION_IF_NULL(all_to_all_v);
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, all_to_all_v.get());
AnfAlgo::SetNodeAttr(kAttrSendRankIds, MakeValue<std::vector<int64_t>>(real_send_rank_ids), all_to_all_v);
AnfAlgo::SetNodeAttr(kAttrRecvRankIds, MakeValue<std::vector<int64_t>>(real_recv_rank_ids), all_to_all_v);
AnfAlgo::SetNodeAttr(kAttrRecvType, TypeIdToType(base_dtype), all_to_all_v);
AnfAlgo::SetNodeAttr(kAttrGroup, MakeValue<std::string>(group), all_to_all_v);
MS_LOG(INFO) << "Create AllToAllv success, send rank size " << send_rank_ids.size() << ", recv rank size "
<< recv_rank_ids.size();
return all_to_all_v;
}
int64_t AllToAllRealIds(int64_t ids, const std::vector<int64_t> &recv_rank_ids) {
int64_t real_ids = 0;
for (auto i = 0; i < ids; ++i) {
if (recv_rank_ids[i] != kInvalidId) {
++real_ids;
}
}
return real_ids;
}
CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &concat_input,
const std::vector<std::vector<size_t>> &output_shape, const std::vector<TypeId> &output_dtype,
int64_t axis, int64_t input_nums) {
MS_EXCEPTION_IF_NULL(graph);
auto concat = graph->NewCNode(concat_input);
MS_EXCEPTION_IF_NULL(concat);
AnfAlgo::SetOutputInferTypeAndShape(output_dtype, output_shape, concat.get());
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue<int64_t>(axis), concat);
AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(input_nums), concat);
std::vector<int64_t> dyn_input_size_empty{input_nums};
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size_empty), concat);
return concat;
}
CNodePtr CreateLeftRightConcat(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &all_to_all_v_outputs,
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens,
bool is_left) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> concat_input = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
int64_t input_num = 1;
size_t first_ids = is_left ? 7 : 1;
size_t middle_ids = is_left ? 6 : 2;
size_t last_ids = is_left ? 5 : 3;
auto single_shape = AnfAlgo::GetOutputInferShape(all_to_all_v_outputs[AllToAllRealIds(middle_ids, recv_rank_ids)], 0);
if (recv_rank_ids[first_ids] != kInvalidId) {
++input_num;
single_shape[2] += static_cast<size_t>(recv_lens[0]); // H in NCHW
}
if (recv_rank_ids[last_ids] != kInvalidId) {
++input_num;
single_shape[2] += static_cast<size_t>(recv_lens[1]); // H in NCHW
}
if (is_left) {
concat_input.insert(concat_input.end(), all_to_all_v_outputs.rbegin(), all_to_all_v_outputs.rbegin() + input_num);
} else {
concat_input.insert(concat_input.end(), all_to_all_v_outputs.begin() + AllToAllRealIds(1, recv_rank_ids),
all_to_all_v_outputs.begin() + input_num + AllToAllRealIds(1, recv_rank_ids));
}
std::vector<TypeId> concat_output_dtype = {
AnfAlgo::GetOutputInferDataType(all_to_all_v_outputs[AllToAllRealIds(middle_ids, recv_rank_ids)], 0)};
auto concat = CreateConcatNode(graph, concat_input, {single_shape}, concat_output_dtype, kHDim, input_num);
return concat;
}
CNodePtr CreateMiddleConcat(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
const std::vector<AnfNodePtr> &all_to_all_v_outputs,
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens,
int64_t concat_dim) {
std::vector<AnfNodePtr> concat_input_all = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
int64_t input_num_all = 0;
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
auto single_shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0);
size_t first_idx = concat_dim == kWDim ? 6 : 0;
size_t last_idx = concat_dim == kWDim ? 2 : 4;
size_t first_len = concat_dim == kWDim ? static_cast<size_t>(recv_lens[2]) : static_cast<size_t>(recv_lens[0]);
size_t last_len = concat_dim == kWDim ? static_cast<size_t>(recv_lens[3]) : static_cast<size_t>(recv_lens[1]);
// left
if (recv_rank_ids[first_idx] != kInvalidId) {
if (concat_dim == kWDim) {
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.end() - 1, all_to_all_v_outputs.end());
} else {
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.begin() + 1);
}
++input_num_all;
single_shape[concat_dim] += first_len;
}
concat_input_all.push_back(neighbor_exchange_v2_input);
++input_num_all;
// right
if (recv_rank_ids[last_idx] != kInvalidId) {
if (concat_dim == kWDim) {
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.begin() + 1);
} else {
int64_t bottom_num = AllToAllRealIds(4, recv_rank_ids);
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin() + bottom_num,
all_to_all_v_outputs.begin() + bottom_num + 1);
}
++input_num_all;
single_shape[concat_dim] += last_len;
}
std::vector<TypeId> concat_output_dtype = {AnfAlgo::GetOutputInferDataType(all_to_all_v_outputs[0], 0)};
auto concat_all =
CreateConcatNode(graph, concat_input_all, {single_shape}, concat_output_dtype, concat_dim, input_num_all);
return concat_all;
}
CNodePtr AllToAllvRecvEmpty(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
const CNodePtr &all_to_all_v) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
MS_EXCEPTION_IF_NULL(all_to_all_v);
// add depend for input & alltoallv
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
std::vector<AnfNodePtr> depend_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
neighbor_exchange_v2_input, all_to_all_v};
auto depend = graph->NewCNode(depend_input);
MS_EXCEPTION_IF_NULL(depend);
depend->set_abstract(neighbor_exchange_v2_input->abstract());
return depend;
}
CNodePtr CreateConcatNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
const CNodePtr &all_to_all_v) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
MS_EXCEPTION_IF_NULL(all_to_all_v);
std::vector<int64_t> recv_rank_ids =
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2, kAttrRecvRankIds);
std::vector<int64_t> recv_lens = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2, kAttrRecvLens);
int64_t all_to_all_output_num =
std::count_if(recv_rank_ids.begin(), recv_rank_ids.end(), [](int64_t ids) { return ids != kInvalidId; });
if (all_to_all_output_num == 0) {
return AllToAllvRecvEmpty(graph, neighbor_exchange_v2, all_to_all_v);
}
std::vector<AnfNodePtr> all_to_all_v_outputs;
CreateMultipleOutputsOfAnfNode(graph, all_to_all_v, static_cast<size_t>(all_to_all_output_num),
&all_to_all_v_outputs);
if (all_to_all_v_outputs.empty()) {
MS_LOG(EXCEPTION) << "The node " << all_to_all_v->DebugString() << " should have at least one output, but got 0.";
}
if (recv_rank_ids[kRankIdZero] == kInvalidId && recv_rank_ids[kRankIdFour] == kInvalidId) {
return CreateMiddleConcat(graph, neighbor_exchange_v2, all_to_all_v_outputs, recv_rank_ids, recv_lens, kWDim);
}
// top or bottom
// middle concat
auto concat_middle =
CreateMiddleConcat(graph, neighbor_exchange_v2, all_to_all_v_outputs, recv_rank_ids, recv_lens, kHDim);
bool is_left = recv_rank_ids[kRankIdSix] != kInvalidId || recv_rank_ids[kRankIdFive] != kInvalidId ||
recv_rank_ids[kRankIdSeven] != kInvalidId;
bool is_right = recv_rank_ids[kRankIdOne] != kInvalidId || recv_rank_ids[kRankIdTwo] != kInvalidId ||
recv_rank_ids[kRankIdThree] != kInvalidId;
if (!is_left && !is_right) {
return concat_middle;
}
std::vector<AnfNodePtr> concat_input_all = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
std::vector<size_t> shape_all = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0);
shape_all[2] =
recv_rank_ids[kRankIdZero] != kInvalidId ? shape_all[2] + static_cast<size_t>(recv_lens[0]) : shape_all[2];
shape_all[2] =
recv_rank_ids[kRankIdFour] != kInvalidId ? shape_all[2] + static_cast<size_t>(recv_lens[1]) : shape_all[2];
int64_t input_nums_all = 0;
// left concat
if (is_left) {
auto concat_left = CreateLeftRightConcat(graph, all_to_all_v_outputs, recv_rank_ids, recv_lens, true);
// connect to concat_all
std::vector<AnfNodePtr> concat_left_outputs;
CreateMultipleOutputsOfAnfNode(graph, concat_left, 1, &concat_left_outputs);
if (concat_left_outputs.empty()) {
MS_LOG(EXCEPTION) << "The node " << concat_left->DebugString() << " should have at least one output, but got 0.";
}
concat_input_all.insert(concat_input_all.end(), concat_left_outputs.begin(), concat_left_outputs.end());
++input_nums_all;
shape_all[3] += recv_lens[2];
}
// middle concat connect to concat_all
std::vector<AnfNodePtr> concat_middle_outputs;
CreateMultipleOutputsOfAnfNode(graph, concat_middle, 1, &concat_middle_outputs);
if (concat_middle_outputs.empty()) {
MS_LOG(EXCEPTION) << "The node " << concat_middle->DebugString() << " should have at least one output, but got 0.";
}
concat_input_all.insert(concat_input_all.end(), concat_middle_outputs.begin(), concat_middle_outputs.end());
++input_nums_all;
if (is_right) {
auto concat_right = CreateLeftRightConcat(graph, all_to_all_v_outputs, recv_rank_ids, recv_lens, false);
// connect to concat_all
std::vector<AnfNodePtr> concat_right_outputs;
CreateMultipleOutputsOfAnfNode(graph, concat_right, 1, &concat_right_outputs);
if (concat_right_outputs.empty()) {
MS_LOG(EXCEPTION) << "The node " << concat_right->DebugString() << " should have at least one output, but got 0.";
}
concat_input_all.insert(concat_input_all.end(), concat_right_outputs.begin(), concat_right_outputs.end());
++input_nums_all;
shape_all[3] += recv_lens[3];
}
std::vector<TypeId> concat_right_output_dtype = {AnfAlgo::GetOutputInferDataType(concat_input_all[1], 0)};
auto concat_all =
CreateConcatNode(graph, concat_input_all, {shape_all}, concat_right_output_dtype, kWDim, input_nums_all);
return concat_all;
}
// grad
// returns {top_bottom, left_right, top_corner, bottom_corner}, if no split, set it nullptr
std::vector<CNodePtr> CreateSplitNodesForGrad(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad,
std::vector<int64_t> *split_num) {
MS_LOG(DEBUG) << "Start create splitv nodes.";
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad);
MS_EXCEPTION_IF_NULL(split_num);
std::vector<int64_t> send_rank_ids =
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_grad, kAttrSendRankIds);
std::vector<int64_t> send_lens = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_grad, kAttrSendLens);
if (neighbor_exchange_v2_grad->size() <= kNeighborExchangeV2InputIdx) {
MS_LOG(EXCEPTION) << "Invalid cnode " << neighbor_exchange_v2_grad->DebugString() << " input size "
<< neighbor_exchange_v2_grad->size();
}
auto neighbor_exchange_v2_grad_input = neighbor_exchange_v2_grad->input(kNeighborExchangeV2InputIdx);
auto dtype = AnfAlgo::GetOutputInferDataType(neighbor_exchange_v2_grad_input, 0);
auto shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_grad_input, 0);
if (SizeToLong(shape.size()) != kShapeSize) {
MS_LOG(EXCEPTION) << "Invalid shape size " << shape.size() << ", only support NCHW input now!";
}
std::vector<CNodePtr> split_nodes = {};
// splitv for top & bottom
bool is_top = ((send_rank_ids[kRankIdZero] != kInvalidId) || (send_rank_ids[kRankIdOne] != kInvalidId) ||
(send_rank_ids[kRankIdSeven] != kInvalidId));
bool is_bottom = ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFour] != kInvalidId) ||
(send_rank_ids[kRankIdFive] != kInvalidId));
CNodePtr split_v_top_bottom = nullptr;
int64_t num_split_h = 0;
if (is_top || is_bottom) {
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
neighbor_exchange_v2_grad_input};
split_v_top_bottom =
CreateSplitNode(graph, split_input, shape, is_top, is_bottom, kHDim, send_lens, dtype, &num_split_h);
}
split_nodes.emplace_back(split_v_top_bottom);
split_num->push_back(num_split_h);
// splitvs for left & right
// inputs
std::vector<AnfNodePtr> split_outputs_top_bottom;
std::vector<int64_t> size_split_h;
if (split_nodes[0] != nullptr) {
CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast<size_t>(num_split_h), &split_outputs_top_bottom);
if (split_outputs_top_bottom.empty()) {
MS_LOG(EXCEPTION) << "The node " << split_nodes[0]->DebugString()
<< " should have at least one output, but got 0.";
}
size_split_h = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(split_nodes[0], kAttrSizeSplits);
} else {
// just middle
split_outputs_top_bottom.push_back(neighbor_exchange_v2_grad_input);
size_split_h.push_back(shape[kHDim]);
}
// left_right splitv nodes from top to bottom
bool is_left = (send_rank_ids[kRankIdFive] != kInvalidId) || (send_rank_ids[kRankIdSix] != kInvalidId) ||
(send_rank_ids[kRankIdSeven] != kInvalidId);
bool is_right = (send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdTwo] != kInvalidId) ||
(send_rank_ids[kRankIdThree] != kInvalidId);
if (is_left || is_right) {
if (!is_top) {
split_nodes.push_back(nullptr);
split_num->push_back(0);
}
for (size_t i = 0; i < split_outputs_top_bottom.size(); ++i) {
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
split_outputs_top_bottom[i]};
int64_t num_split_w = 0;
std::vector<size_t> base_shape(shape);
base_shape[kHDim] = size_split_h[i];
auto split_v_left_right =
CreateSplitNode(graph, split_input, base_shape, is_left, is_right, kWDim, send_lens, dtype, &num_split_w);
split_nodes.emplace_back(split_v_left_right);
split_num->push_back(num_split_w);
}
if (!is_bottom) {
split_nodes.push_back(nullptr);
split_num->push_back(0);
}
} else {
split_nodes.push_back(nullptr);
split_num->push_back(0);
split_nodes.push_back(nullptr);
split_num->push_back(0);
split_nodes.push_back(nullptr);
split_num->push_back(0);
}
MS_LOG(DEBUG) << "Create splitv nodes success.";
return split_nodes;
}
CNodePtr CreatePadNode(const FuncGraphPtr &graph, const AnfNodePtr &input, const std::vector<int64_t> &begin,
const std::vector<int64_t> &size, const std::vector<size_t> &shape, TypeId dtype) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(input);
std::vector<AnfNodePtr> pad_inputs = {NewValueNode(std::make_shared<Primitive>(kPadOpName)), input};
auto pad = graph->NewCNode(pad_inputs);
std::vector<std::vector<int64_t>> paddings;
for (size_t i = 0; i < shape.size(); ++i) {
paddings.emplace_back(std::vector<int64_t>{begin[i], static_cast<int64_t>(shape[i]) - begin[i] - size[i]});
}
AnfAlgo::SetOutputInferTypeAndShape({dtype}, {shape}, pad.get());
AnfAlgo::SetNodeAttr(kAttrPaddings, MakeValue(paddings), pad);
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(std::vector<std::string>{"x"}), pad);
return pad;
}
CNodePtr CreateSplitGradNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad,
const CNodePtr &all_to_all_v, const std::vector<CNodePtr> &split_nodes,
const std::vector<int64_t> &split_num) {
MS_LOG(DEBUG) << "Start create splitvs grad nodes.";
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad);
std::vector<int64_t> send_rank_ids =
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_grad, kAttrSendRankIds);
std::vector<int64_t> recv_rank_ids =
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_grad, kAttrRecvRankIds);
std::vector<int64_t> recv_lens = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_grad, kAttrRecvLens);
auto centerx = GetCenter(graph, neighbor_exchange_v2_grad, split_nodes, split_num, send_rank_ids);
auto centerx_dtype = AnfAlgo::GetOutputInferDataType(centerx, 0);
auto centerx_shape = AnfAlgo::GetOutputInferShape(centerx, 0);
// empty
int64_t all_to_all_output_num =
std::count_if(recv_rank_ids.begin(), recv_rank_ids.end(), [](int64_t ids) { return ids != kInvalidId; });
if (all_to_all_output_num == 0) {
// add depend(alltoallv, centerx)
std::vector<AnfNodePtr> depend_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
centerx, all_to_all_v};
auto depend = graph->NewCNode(depend_input);
MS_EXCEPTION_IF_NULL(depend);
depend->set_abstract(centerx->abstract());
return depend;
}
// get alltoallv outputs
std::vector<AnfNodePtr> all_to_all_v_outputs;
CreateMultipleOutputsOfAnfNode(graph, all_to_all_v, static_cast<size_t>(all_to_all_output_num),
&all_to_all_v_outputs);
if (all_to_all_v_outputs.empty()) {
MS_LOG(EXCEPTION) << "The node " << all_to_all_v->DebugString() << " should have at least one output, but got 0.";
}
// create pad nodes
// slice begin & size
std::vector<std::vector<int64_t>> begins = {{0, 0, 0, 0},
{0, 0, 0, static_cast<int64_t>(centerx_shape[3]) - recv_lens[3]},
{0, 0, 0, static_cast<int64_t>(centerx_shape[3]) - recv_lens[3]},
{0, 0, static_cast<int64_t>(centerx_shape[2]) - recv_lens[1],
static_cast<int64_t>(centerx_shape[3]) - recv_lens[3]},
{0, 0, static_cast<int64_t>(centerx_shape[2]) - recv_lens[1], 0},
{0, 0, static_cast<int64_t>(centerx_shape[2]) - recv_lens[1], 0},
{0, 0, 0, 0},
{0, 0, 0, 0}};
std::vector<std::vector<int64_t>> sizes = {
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[0],
static_cast<int64_t>(centerx_shape[3])},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[0], recv_lens[3]},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]),
static_cast<int64_t>(centerx_shape[2]), recv_lens[3]},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[1], recv_lens[3]},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[1],
static_cast<int64_t>(centerx_shape[3])},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[1], recv_lens[2]},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]),
static_cast<int64_t>(centerx_shape[2]), recv_lens[2]},
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[0], recv_lens[2]}};
std::vector<CNodePtr> pad_nodes;
size_t output_index = 0;
for (size_t i = 0; i < recv_rank_ids.size(); ++i) {
if (recv_rank_ids[i] != kInvalidId) {
auto pad =
CreatePadNode(graph, all_to_all_v_outputs[output_index], begins[i], sizes[i], centerx_shape, centerx_dtype);
++output_index;
pad_nodes.emplace_back(pad);
}
}
// create add node
std::vector<AnfNodePtr> addn_inputs = {NewValueNode(std::make_shared<Primitive>(kAddNOpName)), centerx};
int64_t pad_num = 1;
for (auto pad : pad_nodes) {
std::vector<AnfNodePtr> pad_outputs;
CreateMultipleOutputsOfAnfNode(graph, pad, 1, &pad_outputs);
if (pad_outputs.empty()) {
MS_LOG(EXCEPTION) << "The node " << pad->DebugString() << " should have at least one output, but got 0.";
}
addn_inputs.insert(addn_inputs.end(), pad_outputs.begin(), pad_outputs.end());
++pad_num;
}
auto addn = graph->NewCNode(addn_inputs);
MS_EXCEPTION_IF_NULL(addn);
AnfAlgo::SetOutputInferTypeAndShape({centerx_dtype}, {centerx_shape}, addn.get());
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue<std::vector<int64_t>>({pad_num}), addn);
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(pad_num), addn);
MS_LOG(DEBUG) << "Create splitvs grad nodes success.";
return addn;
}
} // namespace
const BaseRef NeighborExchangeV2UnifyMindIR::DefinePattern() const {
return VectorRef({prim::kPrimNeighborExchangeV2, std::make_shared<SeqVar>()});
}
const AnfNodePtr NeighborExchangeV2UnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto neighbor_exchange_v2 = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
std::vector<int64_t> split_num;
auto split_nodes = CreateSplitNodes(graph, neighbor_exchange_v2, &split_num);
auto all_to_all_v = CreateAllToAllvNode(graph, neighbor_exchange_v2, split_nodes, split_num, false);
auto concat = CreateConcatNodes(graph, neighbor_exchange_v2, all_to_all_v);
return concat;
}
const BaseRef NeighborExchangeV2GradUnifyMindIR::DefinePattern() const {
return VectorRef({prim::kPrimNeighborExchangeV2Grad, std::make_shared<SeqVar>()});
}
const AnfNodePtr NeighborExchangeV2GradUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto neighbor_exchange_v2_grad = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad);
std::vector<int64_t> split_num;
auto split_nodes = CreateSplitNodesForGrad(graph, neighbor_exchange_v2_grad, &split_num);
auto all_to_all_v = CreateAllToAllvNode(graph, neighbor_exchange_v2_grad, split_nodes, split_num, true);
auto add = CreateSplitGradNodes(graph, neighbor_exchange_v2_grad, all_to_all_v, split_nodes, split_num);
return add;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_NEIGHBOR_EXCHANGE_V2_UNIFY_MINDIR_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_NEIGHBOR_EXCHANGE_V2_UNIFY_MINDIR_H_
#include <memory>
#include <vector>
#include "backend/optimizer/common/optimizer.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
class NeighborExchangeV2UnifyMindIR : public PatternProcessPass {
public:
explicit NeighborExchangeV2UnifyMindIR(bool multigraph = true)
: PatternProcessPass("neighbor_exchange_v2_unify_mindir", multigraph) {}
~NeighborExchangeV2UnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
class NeighborExchangeV2GradUnifyMindIR : public PatternProcessPass {
public:
explicit NeighborExchangeV2GradUnifyMindIR(bool multigraph = true)
: PatternProcessPass("neighbor_exchange_v2_grad_unify_mindir", multigraph) {}
~NeighborExchangeV2GradUnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_NEIGHBOR_EXCHANGE_V2_UNIFY_MINDIR_H_

View File

@ -235,6 +235,7 @@ constexpr char GET_TENSOR_SLICE[] = "_GetTensorSlice";
constexpr char SPLIT[] = "Split";
constexpr char ALL_TO_ALL[] = "AlltoAll";
constexpr char NEIGHBOREXCHANGE[] = "NeighborExchange";
constexpr char NEIGHBOREXCHANGEV2[] = "NeighborExchangeV2";
constexpr char PERMUTE_BY_AXIS[] = "PermuteByAxis";
constexpr char CONCAT_BY_AXIS[] = "ConcatByAxis";
constexpr char SPLIT_BY_AXIS[] = "SplitByAxis";

View File

@ -437,6 +437,8 @@ constexpr auto kAttrConcatDim = "concat_dim";
constexpr auto kAttrSplitCount = "split_count";
constexpr auto kAttrSendRankIds = "send_rank_ids";
constexpr auto kAttrRecvRankIds = "recv_rank_ids";
constexpr auto kAttrSendLens = "send_lens";
constexpr auto kAttrRecvLens = "recv_lens";
constexpr auto kAttrRankSize = "rank_size";
constexpr auto kAttrPadDimSize = "pad_dim_size";
constexpr auto kAttrPaddings = "paddings";

View File

@ -437,6 +437,8 @@ inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send");
inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("Receive");
inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
inline const PrimitivePtr kPrimNeighborExchange = std::make_shared<Primitive>("NeighborExchange");
inline const PrimitivePtr kPrimNeighborExchangeV2 = std::make_shared<Primitive>("NeighborExchangeV2");
inline const PrimitivePtr kPrimNeighborExchangeV2Grad = std::make_shared<Primitive>("NeighborExchangeV2Grad");
inline const PrimitivePtr kPrimAllToAll = std::make_shared<Primitive>("AlltoAll");
inline const PrimitivePtr kPrimAllToAllv = std::make_shared<Primitive>("AllToAllv");
inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("AllSwap");

View File

@ -0,0 +1,156 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/neighborexchangev2.h"
#include <string>
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
constexpr auto kSendRankIds = "send_rank_ids";
constexpr auto kSendLens = "send_lens";
constexpr auto kRecvRankIds = "recv_rank_ids";
constexpr auto kRecvLens = "recv_lens";
constexpr auto kDataFormat = "format";
constexpr auto kGroup = "group";
constexpr size_t kRankIdsSize = 8;
constexpr size_t kLensSize = 4;
std::vector<int64_t> CheckAttrSize(const PrimitivePtr &primitive, const std::string &attr_name,
const size_t attr_size) {
MS_EXCEPTION_IF_NULL(primitive);
// size of send/recv_rank_ids equal to size of send/recv_shapes
std::vector<int64_t> attr_value;
try {
auto attr = primitive->GetAttr(attr_name);
if (attr->cast<ValueListPtr>() == nullptr) {
MS_EXCEPTION(TypeError);
}
attr_value = GetValue<std::vector<int64_t>>(attr);
} catch (const std::exception &) {
MS_EXCEPTION(TypeError) << "Attr " << attr_name << " must be a list[int, int, ...].";
}
if (attr_value.size() != attr_size) {
MS_EXCEPTION(ValueError) << "Invalid " << primitive->name() << " attr " << attr_name << " size "
<< attr_value.size() << " must be equal to size " << attr_size;
}
return attr_value;
}
void CheckRecvCorner(std::vector<int64_t> recv_rank_ids, int64_t idx1, int64_t idx2, int64_t idx_corner) {
if (recv_rank_ids[idx1] != -1 && recv_rank_ids[idx2] != -1 && recv_rank_ids[idx_corner] == -1) {
MS_EXCEPTION(ValueError) << "Invalid recv_rank_ids, as recv_rank_ids[" << idx1 << "] = " << recv_rank_ids[idx1]
<< ", recv_rank_ids[" << idx2 << "] = " << recv_rank_ids[idx2] << ", and recv_rank_ids["
<< idx_corner << "] = " << recv_rank_ids[idx_corner] << ".";
}
if ((recv_rank_ids[idx1] == -1 || recv_rank_ids[idx2] == -1) && recv_rank_ids[idx_corner] != -1) {
MS_EXCEPTION(ValueError) << "Invalid recv_rank_ids, as recv_rank_ids[" << idx1 << "] = " << recv_rank_ids[idx1]
<< ", recv_rank_ids[" << idx2 << "] = " << recv_rank_ids[idx2] << ", and recv_rank_ids["
<< idx_corner << "] = " << recv_rank_ids[idx_corner] << ".";
}
}
void Check(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t input_num = 1;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
// check size of send_rank_ids, recv_rank_ids, send_lens, recv_lens
(void)CheckAttrSize(primitive, kSendRankIds, kRankIdsSize);
auto recv_rank_ids = CheckAttrSize(primitive, kRecvRankIds, kRankIdsSize);
(void)CheckAttrSize(primitive, kSendLens, kLensSize);
(void)CheckAttrSize(primitive, kRecvLens, kLensSize);
// check recv rankids invalid cond
CheckRecvCorner(recv_rank_ids, 0, 2, 1);
CheckRecvCorner(recv_rank_ids, 2, 4, 3);
CheckRecvCorner(recv_rank_ids, 4, 6, 5);
CheckRecvCorner(recv_rank_ids, 6, 0, 7);
// check data_format is NCHW
auto format = GetValue<std::string>(primitive->GetAttr(kDataFormat));
if (format != "NCHW") {
MS_EXCEPTION(ValueError) << "Attr data_format only support NCHW now.";
}
// check group
auto group_attr = primitive->GetAttr(kGroup);
try {
MS_EXCEPTION_IF_NULL(group_attr);
(void)GetValue<std::string>(group_attr);
} catch (const std::exception &) {
MS_EXCEPTION(TypeError) << "Attr " << kGroup << " should be a str.";
}
}
abstract::BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto recv_rank_ids = primitive->GetAttr(kRecvRankIds);
MS_EXCEPTION_IF_NULL(recv_rank_ids);
auto recv_rank_ids_value = recv_rank_ids->cast<ValueSequeuePtr>();
MS_EXCEPTION_IF_NULL(recv_rank_ids_value);
std::vector<int64_t> recv_rank_ids_v = GetValue<std::vector<int64_t>>(recv_rank_ids_value);
auto recv_lens = primitive->GetAttr(kRecvLens);
MS_EXCEPTION_IF_NULL(recv_lens);
auto recv_lens_value = recv_lens->cast<ValueSequeuePtr>();
MS_EXCEPTION_IF_NULL(recv_lens_value);
std::vector<int64_t> recv_lens_v = GetValue<std::vector<int64_t>>(recv_lens_value);
std::vector<int64_t> input_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
if (recv_rank_ids_v[0] != -1) {
input_shape[2] += recv_lens_v[0];
}
if (recv_rank_ids_v[4] != -1) {
input_shape[2] += recv_lens_v[1];
}
if (recv_rank_ids_v[6] != -1) {
input_shape[3] += recv_lens_v[2];
}
if (recv_rank_ids_v[2] != -1) {
input_shape[3] += recv_lens_v[3];
}
BaseShapePtr output_shape = std::make_shared<abstract::Shape>(input_shape);
if (input_shape.empty()) {
return std::make_shared<abstract::Shape>();
}
return output_shape;
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
// recv type
TypePtr recv_type = input_args[0]->BuildType();
if (recv_type == nullptr) {
return std::make_shared<TypeNone>();
}
return recv_type;
}
} // namespace
AbstractBasePtr NeighborExchangeV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
Check(primitive, input_args);
auto type = InferType(primitive, input_args);
auto shape = InferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchangeV2, prim::kPrimNeighborExchangeV2, NeighborExchangeV2Infer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_NEIGHBOREXCHANGEV2_H_
#define MINDSPORE_CORE_OPS_NEIGHBOREXCHANGEV2_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameNeighborExchangeV2 = "NeighborExchangeV2";
class MS_CORE_API NeighborExchangeV2 : public PrimitiveC {
public:
NeighborExchangeV2() : PrimitiveC(kNameNeighborExchangeV2) {}
~NeighborExchangeV2() = default;
MS_DECLARE_PARENT(NeighborExchangeV2, PrimitiveC);
void Init() {}
};
using kPrimNeighborExchangeV2Ptr = std::shared_ptr<NeighborExchangeV2>;
AbstractBasePtr NeighborExchangeV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_NEIGHBOREXCHANGEV2_H_

View File

@ -22,12 +22,13 @@ from mindspore.parallel._utils import _get_enable_parallel_optimizer
from .. import operations as P
from ...common.tensor import RowTensor
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, NeighborExchange, AlltoAll,
from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, NeighborExchange, AlltoAll, NeighborExchangeV2,
Broadcast, _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap,
_VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator, _MicroStepAllGather)
from .grad_base import bprop_getters
from ..operations._inner_ops import Send, Receive
from ..operations import _grad_ops as G
@bprop_getters.register(AllReduce)
@ -390,6 +391,24 @@ def get_bprop_all_to_all(self):
return bprop
@bprop_getters.register(NeighborExchangeV2)
def get_bprop_neighborexchangev2(self):
"""Generate bprop for NeighborExchangeV2."""
group = self.group
send_rank_ids = self.recv_rank_ids
recv_rank_ids = self.send_rank_ids
send_lens = self.recv_lens
recv_lens = self.send_lens
data_format = self.data_format
neighborexchangev2_grad = G.NeighborExchangeV2Grad(send_rank_ids, send_lens, recv_rank_ids,
recv_lens, data_format, group)
def bprop(x, out, dout):
return (neighborexchangev2_grad(dout),)
return bprop
@bprop_getters.register(_MirrorOperator)
def get_bprop_mirror_operator(self):
"""

View File

@ -5,4 +5,4 @@ e
bprop.8:x*
bprop.8:out*
bprop.8:dout2
bprop.8:[CNode]:1:@13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e8P
bprop.8:[CNode]:1:@74787be4234cdeb03f214519cd8358a5f4ad2f5606dbeb494462cddc448eb4beP

View File

@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence,
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted,
TensorScatterMax, TensorScatterMin, TensorScatterSub, ScatterElements, ExtractVolumePatches)
from .comm_ops import (AllGather, AllReduce, NeighborExchange, AlltoAll, AllSwap, ReduceScatter, Broadcast,
from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad,
_HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather)
@ -486,6 +486,9 @@ __all__ = [
"Trunc",
"Complex",
"ExtractVolumePatches",
"NeighborExchangeV2",
"NeighborExchange",
"AlltoAll",
]
__sponge__ = [

View File

@ -23,6 +23,7 @@ from ..._checkparam import Validator as validator, Rel
from .._utils import get_concat_offset
from ...common import dtype as mstype
from ... import context
from ...communication.management import GlobalComm
class AbsGrad(PrimitiveWithInfer):
@ -723,6 +724,42 @@ class BNTrainingUpdateGrad(PrimitiveWithInfer):
def infer_dtype(self, grads, x, batch_mean, batch_variance):
return (batch_mean, batch_variance)
class NeighborExchangeV2Grad(PrimitiveWithInfer):
""""Gradients of NeighborExchangeV2 operation."""
@prim_attr_register
def __init__(self, send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format,
group=GlobalComm.WORLD_COMM_GROUP):
self.init_prim_io_names(inputs=['dy'], outputs=['dx'])
self.send_rank_ids = send_rank_ids
self.recv_rank_ids = recv_rank_ids
self.send_lens = send_lens
self.recv_lens = recv_lens
self.format = validator.check_string(data_format, ['NCHW'], 'format', self.name)
self.add_prim_attr('no_elimilate', True)
def __infer__(self, dy):
dy_shape = dy['shape']
validator.check(f'dy_shape.size()', len(dy_shape), f'4', 4, Rel.EQ, self.name)
if self.send_rank_ids[5] != -1 or self.send_rank_ids[6] != -1 or self.send_rank_ids[7] != -1:
dy_shape[3] -= self.send_lens[2]
if self.send_rank_ids[1] != -1 or self.send_rank_ids[2] != -1 or self.send_rank_ids[3] != -1:
dy_shape[3] -= self.send_lens[3]
if self.send_rank_ids[0] != -1 or self.send_rank_ids[1] != -1 or self.send_rank_ids[7] != -1:
dy_shape[2] -= self.send_lens[0]
if self.send_rank_ids[3] != -1 or self.send_rank_ids[4] != -1 or self.send_rank_ids[5] != -1:
dy_shape[2] -= self.send_lens[1]
return {'shape': dy_shape,
'dtype': dy['dtype'],
'value': None}
def __call__(self, tensor):
raise NotImplementedError
class GeLUGrad(PrimitiveWithInfer):
"""Gradients of GeLU operation."""

View File

@ -710,6 +710,39 @@ class AlltoAll(PrimitiveWithInfer):
def __call__(self, tensor):
raise NotImplementedError
class NeighborExchangeV2(Primitive):
"""
NeighborExchange is a collective operation.
NeighborExchange sends data from the local rank to ranks in the send_rank_ids,
as while receive data from recv_rank_ids.
Args:
send_rank_ids (list(int)): Ranks which the data is sent to. 8 rank_ids represents 8 directions, if one
direction is not send to , set it -1.
recv_rank_ids (list(int)): Ranks which the data is received from. 8 rank_ids represents 8 directions,
if one direction is not recv from , set it -1.
send_lens (list(int)): Data lens which send to the send_rank_ids, 4 numbers represent the lens of
[top, bottom, left, right].
recv_lens (list(int)): Data lens which received from recv_rank_ids, 4 numbers represent the lens of
[top, bottom, left, right].
data_format (str): Data format, only support NCHW now.
group (str): The communication group to work on. Default: "GlobalComm.WORLD_COMM_GROUP".
"""
@prim_attr_register
def __init__(self, send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format,
group=GlobalComm.WORLD_COMM_GROUP):
self.init_prim_io_names(inputs=['x'], outputs=['output'])
self.send_rank_ids = send_rank_ids
self.recv_rank_ids = recv_rank_ids
self.send_lens = send_lens
self.recv_lens = recv_lens
self.format = data_format
self.add_prim_attr('no_elimilate', True)
def __call__(self, tensor):
raise NotImplementedError
class _MirrorOperator(PrimitiveWithInfer):
"""

View File

@ -0,0 +1,542 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore as ms
import mindspore.context as context
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.common.api import _cell_graph_executor
from mindspore.nn import TrainOneStepCell, Momentum
from mindspore.ops.operations.comm_ops import NeighborExchangeV2
_x1 = Tensor(np.ones([1, 1, 32, 16]), dtype=ms.float32)
_x2 = Tensor(np.ones([1, 1, 33, 16]), dtype=ms.float32)
def compile_net(net, x1, x2):
context.set_context(mode=context.GRAPH_MODE)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_train()
_cell_graph_executor.compile(train_net, x1, x2)
def test_neighborexchangev2_single_input_success():
"""
Feature: NeighborExchangeV2
Description: one inputs and one outputs, with valid arguments
Expectation: success
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.linear = nn.Dense(16, 16)
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, 1, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=[0, 1, 0, 0], data_format="NCHW")
def construct(self, x1, x2):
y = self.linear(x1)
y = self.neighborexchangev2(y)
y = y + x2
return y
net = Net()
compile_net(net, _x1, _x2)
def test_neighborexchangev2_empty_send_success():
"""
Feature: NeighborExchangeV2
Description: empty inputs, with valid arguments
Expectation: success
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.linear = nn.Dense(16, 16)
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, -1, -1, -1, -1],
send_lens=[1, 2, 3, 4],
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=[0, 1, 0, 0],
data_format="NCHW")
def construct(self, x1, x2):
y = self.linear(x1)
y = self.neighborexchangev2(y)
y = y + x2
return y
net = Net()
compile_net(net, _x1, _x2)
def test_neighborexchangev2_empty_recv_success():
"""
Feature: NeighborExchangeV2
Description: empty outputs, with valid arguments
Expectation: success
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.linear = nn.Dense(16, 16)
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, 1, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, -1, -1, -1, -1],
recv_lens=[1, 2, 3, 4],
data_format="NCHW")
def construct(self, x1, x2):
y = self.linear(x1)
y = self.neighborexchangev2(y)
y = y + x2
return y
net = Net()
compile_net(net, _x1, _x1)
def test_neighborexchangev2_empty_send_empty_recv_success():
"""
Feature: NeighborExchangeV2
Description: empty inputs and empty outputs, with valid arguments
Expectation: success
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, -1, -1, -1, -1],
send_lens=[0, 1, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, -1, -1, -1, -1],
recv_lens=[1, 2, 3, 4],
data_format="NCHW")
def construct(self, x1):
y = self.neighborexchangev2(x1)
return y
net = Net()
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_invalid_dataformat_failed():
"""
Feature: NeighborExchangeV2
Description: data_format should be NCHW, but gives NHWC
Expectation: throw ValueError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, 1, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=[0, 1, 0, 0],
data_format="NHWC")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(ValueError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_invalid_send_rank_ids_size_failed():
"""
Feature: NeighborExchangeV2
Description: send_rank_ids size should be 8, but gives 5
Expectation: throw ValueError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1],
send_lens=[0, 1, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=[0, 1, 0, 0],
data_format="NCHW")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(ValueError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_invalid_recv_rank_ids_size_failed():
"""
Feature: NeighborExchangeV2
Description: recv_rank_ids size should be 8, but gives 5
Expectation: throw ValueError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, 1, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, 1],
recv_lens=[0, 1, 0, 0],
data_format="NCHW")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(ValueError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_invalid_send_lens_size_failed():
"""
Feature: NeighborExchangeV2
Description: send_lens size should be 4, but gives 5
Expectation: throw ValueError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, 1, 0, 0, 2],
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=[0, 1, 0, 0],
data_format="NCHW")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(ValueError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_invalid_recv_lens_size_failed():
"""
Feature: NeighborExchangeV2
Description: recv_lens size should be 4, but gives 5
Expectation: throw ValueError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, 1, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=[0, 1, 0, 0, 2],
data_format="NCHW")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(ValueError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_invalid_input_size_failed():
"""
Feature: NeighborExchangeV2
Description: input should be one tensor, but gives 2
Expectation: throw ValueError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, 1, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=[0, 1, 0, 0],
data_format="NCHW")
def construct(self, x1, x2):
out = self.neighborexchangev2(x1, x2)
return out[0]
net = Net()
with pytest.raises(ValueError):
_cell_graph_executor.compile(net, _x1, _x2)
def test_neighborexchangev2_recv_rank_ids_invalid_value_failed():
"""
Feature: NeighborExchangeV2
Description: recv_rank_ids should can be concat, recv_rank_ids[3] and [4] is 1, [5] is -1 given
Expectation: throw Exception
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, 1, 0, 0],
recv_rank_ids=[-1, -1, -1, 1, 1, -1, -1, -1],
recv_lens=[0, 1, 0, 0],
data_format="NCHW")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(ValueError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_attr_check_send_rank_ids_is_tuple_failed():
"""
Feature: NeighborExchangeV2
Description: send_rank_ids should be list, but a tuple is given
Expectation: throw TypeError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=(-1, -1, -1, -1, 1, -1, -1, -1),
send_lens=[0, 1, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=[0, 1, 0, 0],
data_format="NCHW")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(TypeError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_attr_check_send_lens_is_tuple_failed():
"""
Feature: NeighborExchangeV2
Description: send_lens should be list, but a tuple is given
Expectation: throw TypeError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=(0, 1, 0, 0),
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=[0, 1, 0, 0],
data_format="NCHW")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(TypeError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_attr_check_recv_rank_ids_is_tuple_failed():
"""
Feature: NeighborExchangeV2
Description: recv_rank_ids should be list, but a tuple is given
Expectation: throw TypeError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, 1, 0, 0],
recv_rank_ids=(-1, -1, -1, -1, 1, -1, -1, -1),
recv_lens=[0, 1, 0, 0],
data_format="NCHW")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(TypeError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_attr_check_recv_lens_is_tuple_failed():
"""
Feature: NeighborExchangeV2
Description: recv_lens should be list, but a tuple is given
Expectation: throw TypeError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, 1, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=(0, 1, 0, 0),
data_format="NCHW")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(TypeError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_attr_check_send_rank_ids_is_float_failed():
"""
Feature: NeighborExchangeV2
Description: send_rank_ids should be int, but float is given
Expectation: throw TypeError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1.0, -1, -1, -1],
send_lens=[0, 1, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=[0, 1, 0, 0],
data_format="NCHW")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(TypeError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_attr_check_send_lens_is_float_failed():
"""
Feature: NeighborExchangeV2
Description: send_lens should be int, but float is given
Expectation: throw TypeError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, 1.0, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=[0, 1, 0, 0],
data_format="NCHW")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(TypeError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_attr_check_recv_rank_ids_is_float_failed():
"""
Feature: NeighborExchangeV2
Description: send_rank_ids should be int, but float is given
Expectation: throw TypeError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, 1, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, 1.0, -1, -1, -1],
recv_lens=[0, 1, 0, 0],
data_format="NCHW")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(TypeError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_attr_check_recv_lens_is_float_failed():
"""
Feature: NeighborExchangeV2
Description: ids in send_rank_ids should be int, but float is given
Expectation: throw TypeError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, 1, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=[0, 1.0, 0, 0],
data_format="NCHW")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(TypeError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_group_is_tuple_failed():
"""
Feature: NeighborExchangeV2
Description: group should be a string, but tuple given
Expectation: throw TypeError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, 1, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=[0, 1, 0, 0],
data_format="NCHW", group=("str",))
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(TypeError):
_cell_graph_executor.compile(net, _x1)