!25765 neighborExchangeV2 & grad
Merge pull request !25765 from TuDouNi/neighborExchangeV2
This commit is contained in:
commit
ded1c77bbf
|
@ -144,6 +144,7 @@
|
|||
#include "backend/optimizer/ascend/mindir/avg_pool_grad_unify_mindir.h"
|
||||
#include "backend/optimizer/ascend/mindir/bn_grad_unify_mindir.h"
|
||||
#include "backend/optimizer/ascend/mindir/all_to_all_unify_mindir.h"
|
||||
#include "backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.h"
|
||||
#include "backend/optimizer/pass/adjust_depend_for_parallel_optimizer_recompute_all_gather.h"
|
||||
#include "backend/kernel_compiler/tbe/tbe_kernel_compile.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
@ -579,6 +580,8 @@ void AscendUnifyMindIR(const std::shared_ptr<session::KernelGraph> &graph) {
|
|||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::BatchNormGradUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::NeighborExchangeUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::NeighborExchangeV2UnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::NeighborExchangeV2GradUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::AllToAllUnifyMindIR>());
|
||||
|
||||
optimizer->AddPassManager(unify_mindir_pm);
|
||||
|
|
|
@ -0,0 +1,900 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.h"
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "runtime/hccl_adapter/hccl_adapter.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kCNodePrimitiveIdx = 0;
|
||||
constexpr size_t kNeighborExchangeV2InputIdx = 1;
|
||||
constexpr size_t kLenTopIdx = 0;
|
||||
constexpr size_t kLenBottomIdx = 1;
|
||||
constexpr size_t kLenLeftIdx = 2;
|
||||
constexpr size_t kLenRightIdx = 3;
|
||||
constexpr size_t kHDim = 2; // dim of H in NCHW
|
||||
constexpr size_t kWDim = 3; // dim of W in NCHW
|
||||
constexpr int64_t kShapeSize = 4;
|
||||
constexpr int64_t kRankIdZero = 0;
|
||||
constexpr int64_t kRankIdOne = 1;
|
||||
constexpr int64_t kRankIdTwo = 2;
|
||||
constexpr int64_t kRankIdThree = 3;
|
||||
constexpr int64_t kRankIdFour = 4;
|
||||
constexpr int64_t kRankIdFive = 5;
|
||||
constexpr int64_t kRankIdSix = 6;
|
||||
constexpr int64_t kRankIdSeven = 7;
|
||||
constexpr size_t kSizeFour = 4;
|
||||
constexpr int64_t kInvalidId = -1;
|
||||
|
||||
// cal split attrs size_splits, shapes and num_split
|
||||
int64_t CalSplitAttrs(const std::vector<size_t> &base_shape, const bool is_first, const bool is_last,
|
||||
const int64_t split_dim, const std::vector<int64_t> &send_lens, std::vector<int64_t> *size_splits,
|
||||
std::vector<std::vector<size_t>> *shapes) {
|
||||
MS_EXCEPTION_IF_NULL(size_splits);
|
||||
MS_EXCEPTION_IF_NULL(shapes);
|
||||
if (SizeToLong(base_shape.size()) != kShapeSize) {
|
||||
MS_LOG(EXCEPTION) << "Wrong base_shape size: " << base_shape.size() << ", it should be equal to 4.";
|
||||
}
|
||||
if (split_dim >= kShapeSize) {
|
||||
MS_LOG(EXCEPTION) << "Wrong split_dim: " << split_dim << ", it should less than 4.";
|
||||
}
|
||||
int64_t num_split = 0;
|
||||
int64_t split_middle_size = base_shape[split_dim];
|
||||
std::vector<size_t> shape_tmp(base_shape);
|
||||
// [top, bottom, left, right]
|
||||
int64_t first_size = split_dim == kWDim ? send_lens[2] : send_lens[0];
|
||||
int64_t last_size = split_dim == kWDim ? send_lens[3] : send_lens[1];
|
||||
|
||||
if (is_first) {
|
||||
// first
|
||||
++num_split;
|
||||
size_splits->push_back(first_size);
|
||||
split_middle_size -= first_size;
|
||||
shape_tmp[split_dim] = static_cast<size_t>(first_size);
|
||||
shapes->push_back(shape_tmp);
|
||||
}
|
||||
if (is_last) {
|
||||
// middle
|
||||
++num_split;
|
||||
split_middle_size -= last_size;
|
||||
size_splits->push_back(split_middle_size);
|
||||
shape_tmp[split_dim] = static_cast<size_t>(split_middle_size);
|
||||
shapes->push_back(shape_tmp);
|
||||
// last
|
||||
++num_split;
|
||||
size_splits->push_back(last_size);
|
||||
shape_tmp[split_dim] = static_cast<size_t>(last_size);
|
||||
shapes->push_back(shape_tmp);
|
||||
} else {
|
||||
++num_split;
|
||||
size_splits->push_back(split_middle_size);
|
||||
shape_tmp[split_dim] = static_cast<size_t>(split_middle_size);
|
||||
shapes->push_back(shape_tmp);
|
||||
}
|
||||
return num_split;
|
||||
}
|
||||
|
||||
CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &split_input,
|
||||
const std::vector<size_t> &base_shape, bool is_first, bool is_last, int64_t split_dim,
|
||||
const std::vector<int64_t> &send_lens, TypeId input_dtype, int64_t *num_split) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(num_split);
|
||||
if (split_input.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The input is empty, can not create splitv node.";
|
||||
return nullptr;
|
||||
}
|
||||
auto split_v = graph->NewCNode(split_input);
|
||||
MS_EXCEPTION_IF_NULL(split_v);
|
||||
std::vector<int64_t> size_splits = {};
|
||||
std::vector<std::vector<size_t>> shapes = {};
|
||||
*num_split = CalSplitAttrs(base_shape, is_first, is_last, split_dim, send_lens, &size_splits, &shapes);
|
||||
|
||||
std::vector<TypeId> dtypes(*num_split, input_dtype);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_v.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue<int64_t>(split_dim), split_v);
|
||||
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue<int64_t>(*num_split), split_v);
|
||||
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue<std::vector<int64_t>>(size_splits), split_v);
|
||||
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_v);
|
||||
return split_v;
|
||||
}
|
||||
|
||||
std::vector<std::vector<size_t>> CalAllToAllvOutputShape(const std::vector<size_t> &base_shape,
|
||||
const std::vector<int64_t> &recv_lens,
|
||||
const std::vector<int64_t> &recv_rank_ids) {
|
||||
if (SizeToLong(base_shape.size()) != kShapeSize) {
|
||||
MS_LOG(EXCEPTION) << "Wrong base_shape size: " << base_shape.size() << ", it should be equal to 4.";
|
||||
}
|
||||
std::vector<std::vector<size_t>> shapes = {};
|
||||
std::vector<std::vector<size_t>> ori_shapes = {
|
||||
{base_shape[0], base_shape[1], static_cast<size_t>(recv_lens[kLenTopIdx]), base_shape[kWDim]},
|
||||
{base_shape[0], base_shape[1], static_cast<size_t>(recv_lens[kLenTopIdx]),
|
||||
static_cast<size_t>(recv_lens[kLenRightIdx])},
|
||||
{base_shape[0], base_shape[1], base_shape[kHDim], static_cast<size_t>(recv_lens[kLenRightIdx])},
|
||||
{base_shape[0], base_shape[1], static_cast<size_t>(recv_lens[kLenBottomIdx]),
|
||||
static_cast<size_t>(recv_lens[kLenRightIdx])},
|
||||
{base_shape[0], base_shape[1], static_cast<size_t>(recv_lens[kLenBottomIdx]), base_shape[kWDim]},
|
||||
{base_shape[0], base_shape[1], static_cast<size_t>(recv_lens[kLenBottomIdx]),
|
||||
static_cast<size_t>(recv_lens[kLenLeftIdx])},
|
||||
{base_shape[0], base_shape[1], base_shape[kHDim], static_cast<size_t>(recv_lens[kLenLeftIdx])},
|
||||
{base_shape[0], base_shape[1], static_cast<size_t>(recv_lens[kLenTopIdx]),
|
||||
static_cast<size_t>(recv_lens[kLenLeftIdx])}};
|
||||
|
||||
for (size_t idx = 0; idx < recv_rank_ids.size(); ++idx) {
|
||||
if (recv_rank_ids[idx] != kInvalidId) {
|
||||
shapes.push_back(ori_shapes[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
return shapes;
|
||||
}
|
||||
|
||||
// returns {top_bottom, left_right, top_corner, bottom_corner}, if no split, set it nullptr
|
||||
std::vector<CNodePtr> CreateSplitNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
|
||||
std::vector<int64_t> *split_num) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
|
||||
MS_EXCEPTION_IF_NULL(split_num);
|
||||
std::vector<int64_t> send_rank_ids =
|
||||
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2, kAttrSendRankIds);
|
||||
std::vector<int64_t> send_lens = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2, kAttrSendLens);
|
||||
|
||||
if (neighbor_exchange_v2->size() <= kNeighborExchangeV2InputIdx) {
|
||||
MS_LOG(EXCEPTION) << "Invalid cnode " << neighbor_exchange_v2->DebugString() << " input size "
|
||||
<< neighbor_exchange_v2->size();
|
||||
}
|
||||
std::vector<CNodePtr> split_nodes = {};
|
||||
|
||||
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
|
||||
|
||||
bool is_top = ((send_rank_ids[kRankIdZero] != kInvalidId) || (send_rank_ids[kRankIdOne] != kInvalidId) ||
|
||||
(send_rank_ids[kRankIdSeven] != kInvalidId));
|
||||
bool is_bottom = ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFour] != kInvalidId) ||
|
||||
(send_rank_ids[kRankIdFive] != kInvalidId));
|
||||
bool is_left = (send_rank_ids[kRankIdSix] != kInvalidId);
|
||||
bool is_right = (send_rank_ids[kRankIdTwo] != kInvalidId);
|
||||
|
||||
auto dtype = AnfAlgo::GetOutputInferDataType(neighbor_exchange_v2_input, 0);
|
||||
auto shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0);
|
||||
if (SizeToLong(shape.size()) != kShapeSize) { // only support NCHW now
|
||||
MS_LOG(EXCEPTION) << "Invalid shape size " << shape.size() << ", only support NCHW input now!";
|
||||
}
|
||||
|
||||
// splitv for top & bottom
|
||||
int64_t num_split_h = 0;
|
||||
CNodePtr split_v_top_bottom = nullptr;
|
||||
if (is_top || is_bottom) {
|
||||
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
neighbor_exchange_v2_input};
|
||||
|
||||
split_v_top_bottom =
|
||||
CreateSplitNode(graph, split_input, shape, is_top, is_bottom, kHDim, send_lens, dtype, &num_split_h);
|
||||
}
|
||||
split_nodes.emplace_back(split_v_top_bottom);
|
||||
split_num->push_back(num_split_h);
|
||||
|
||||
// splitv for left & right
|
||||
int64_t num_split_w = 0;
|
||||
CNodePtr split_v_left_right = nullptr;
|
||||
if (is_left || is_right) {
|
||||
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
neighbor_exchange_v2_input};
|
||||
split_v_left_right =
|
||||
CreateSplitNode(graph, split_input, shape, is_left, is_right, kWDim, send_lens, dtype, &num_split_w);
|
||||
}
|
||||
split_nodes.emplace_back(split_v_left_right);
|
||||
split_num->push_back(num_split_w);
|
||||
|
||||
// splitv for corner
|
||||
if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId) ||
|
||||
(send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) {
|
||||
// top_bottom_split outputs
|
||||
std::vector<AnfNodePtr> split_outputs_top_bottom;
|
||||
CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast<size_t>((*split_num)[0]),
|
||||
&split_outputs_top_bottom);
|
||||
if (split_outputs_top_bottom.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The node " << split_nodes[0]->DebugString()
|
||||
<< " should have at least one output, but got 0.";
|
||||
}
|
||||
|
||||
// for top corner
|
||||
if ((send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdSeven] != kInvalidId)) {
|
||||
auto shape_tmp(shape);
|
||||
shape_tmp[kHDim] = send_lens[0];
|
||||
bool is_first = (send_rank_ids[kRankIdSeven] != kInvalidId);
|
||||
bool is_last = (send_rank_ids[kRankIdOne] != kInvalidId);
|
||||
|
||||
std::vector<AnfNodePtr> split_v_corner_top_input = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
||||
split_v_corner_top_input.insert(split_v_corner_top_input.end(), split_outputs_top_bottom.begin(),
|
||||
split_outputs_top_bottom.begin() + 1);
|
||||
int64_t num_split_top_corner = 0;
|
||||
CNodePtr split_v_corner_top = CreateSplitNode(graph, split_v_corner_top_input, shape_tmp, is_first, is_last,
|
||||
kWDim, send_lens, dtype, &num_split_top_corner);
|
||||
|
||||
split_nodes.emplace_back(split_v_corner_top);
|
||||
split_num->push_back(num_split_top_corner);
|
||||
} else {
|
||||
split_nodes.emplace_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
}
|
||||
|
||||
// for bottom corner
|
||||
if ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFive] != kInvalidId)) {
|
||||
auto shape_tmp(shape);
|
||||
shape_tmp[kHDim] = send_lens[1];
|
||||
bool is_first = (send_rank_ids[kRankIdFive] != kInvalidId);
|
||||
bool is_last = (send_rank_ids[kRankIdThree] != kInvalidId);
|
||||
|
||||
std::vector<AnfNodePtr> split_v_corner_bottom_input = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
||||
split_v_corner_bottom_input.insert(split_v_corner_bottom_input.end(), split_outputs_top_bottom.end() - 1,
|
||||
split_outputs_top_bottom.end());
|
||||
|
||||
int64_t num_split_bottom_corner = 0;
|
||||
CNodePtr split_v_corner_bottom = CreateSplitNode(graph, split_v_corner_bottom_input, shape_tmp, is_first, is_last,
|
||||
kWDim, send_lens, dtype, &num_split_bottom_corner);
|
||||
split_nodes.emplace_back(split_v_corner_bottom);
|
||||
split_num->push_back(num_split_bottom_corner);
|
||||
} else {
|
||||
split_nodes.emplace_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
}
|
||||
} else {
|
||||
split_nodes.emplace_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
split_nodes.emplace_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
}
|
||||
|
||||
return split_nodes;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> CreateAllToAllvInput(const std::vector<std::vector<AnfNodePtr>> &split_outputs,
|
||||
const std::vector<int64_t> &send_rank_ids) {
|
||||
std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllVOpName))};
|
||||
std::vector<size_t> split_idx = {0, 2, 1, 3, 0, 3, 1, 2};
|
||||
std::vector<bool> is_begin = {true, false, false, false, false, true, true, true};
|
||||
for (size_t idx = 0; idx < send_rank_ids.size(); ++idx) {
|
||||
if (send_rank_ids[idx] != kInvalidId) {
|
||||
if (is_begin[idx]) {
|
||||
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[split_idx[idx]].begin(),
|
||||
split_outputs[split_idx[idx]].begin() + 1);
|
||||
} else {
|
||||
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[split_idx[idx]].end() - 1,
|
||||
split_outputs[split_idx[idx]].end());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return all_to_all_v_input;
|
||||
}
|
||||
|
||||
// get center of input for grad
|
||||
AnfNodePtr GetCenter(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad,
|
||||
const std::vector<CNodePtr> &split_nodes, const std::vector<int64_t> &split_num,
|
||||
const std::vector<int64_t> &send_rank_ids) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad);
|
||||
std::vector<AnfNodePtr> output;
|
||||
if (split_nodes[kRankIdTwo] == nullptr) {
|
||||
if (split_nodes[0] != nullptr) {
|
||||
CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast<size_t>(split_num[0]), &output);
|
||||
if (output.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Wrong split output size: " << output.size() << ", except size >= 2.";
|
||||
}
|
||||
if (send_rank_ids[kRankIdZero] == kInvalidId) {
|
||||
return output[0];
|
||||
}
|
||||
return output[1];
|
||||
} else {
|
||||
return neighbor_exchange_v2_grad->input(kNeighborExchangeV2InputIdx);
|
||||
}
|
||||
} else {
|
||||
CreateMultipleOutputsOfAnfNode(graph, split_nodes[2], static_cast<size_t>(split_num[2]), &output);
|
||||
if (output.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Wrong split output size: " << output.size() << ", except size >= 2.";
|
||||
}
|
||||
if (send_rank_ids[kRankIdSix] == kInvalidId) {
|
||||
return output[0];
|
||||
}
|
||||
return output[1];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> CreateAllToAllvInputForGrad(const std::vector<int64_t> &send_rank_ids,
|
||||
const std::vector<std::vector<AnfNodePtr>> &split_outputs,
|
||||
const std::vector<CNodePtr> &split_nodes) {
|
||||
if (send_rank_ids.size() != 8) {
|
||||
MS_LOG(EXCEPTION) << "Wrong send_rank_ids size: " << send_rank_ids.size() << ", expect size: 8.";
|
||||
}
|
||||
if (split_outputs.size() != kSizeFour) {
|
||||
MS_LOG(EXCEPTION) << "Wrong split_outputs size: " << split_outputs.size() << ", expect size: 4.";
|
||||
}
|
||||
if (split_nodes.size() != kSizeFour) {
|
||||
MS_LOG(EXCEPTION) << "Wrong split_nodes size: " << split_nodes.size() << ", expect size: 4.";
|
||||
}
|
||||
std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllVOpName))};
|
||||
// only have top-bottom split
|
||||
std::vector<size_t> side_idx = {1, 2, 3, 5, 6, 7};
|
||||
bool no_send_side = std::all_of(side_idx.begin(), side_idx.end(),
|
||||
[&send_rank_ids](int idx) { return send_rank_ids[idx] == kInvalidId; });
|
||||
if (no_send_side) {
|
||||
if (send_rank_ids[kRankIdZero] != kInvalidId) {
|
||||
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[0].begin(), split_outputs[0].begin() + 1);
|
||||
}
|
||||
if (send_rank_ids[kRankIdFour] != kInvalidId) {
|
||||
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[0].end() - 1, split_outputs[0].end());
|
||||
}
|
||||
return all_to_all_v_input;
|
||||
}
|
||||
// 0, 1
|
||||
if (split_nodes[1] != nullptr) {
|
||||
if (send_rank_ids[kRankIdSeven] != kInvalidId) {
|
||||
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin() + 1, split_outputs[1].end());
|
||||
} else {
|
||||
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin(), split_outputs[1].end());
|
||||
}
|
||||
}
|
||||
// 2
|
||||
if (split_nodes[2] != nullptr && send_rank_ids[kRankIdTwo] != kInvalidId) {
|
||||
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[2].end() - 1, split_outputs[2].end());
|
||||
}
|
||||
// 3, 4, 5
|
||||
if (split_nodes[3] != nullptr) {
|
||||
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[3].rbegin(), split_outputs[3].rend());
|
||||
}
|
||||
// 6
|
||||
if (split_nodes[2] != nullptr && send_rank_ids[kRankIdSix] != kInvalidId) {
|
||||
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[2].begin(), split_outputs[2].begin() + 1);
|
||||
}
|
||||
// 7
|
||||
if (split_nodes[1] != nullptr && send_rank_ids[kRankIdSeven] != kInvalidId) {
|
||||
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs[1].begin(), split_outputs[1].begin() + 1);
|
||||
}
|
||||
|
||||
return all_to_all_v_input;
|
||||
}
|
||||
// alltoallv for forward & grad
|
||||
CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_or_grad,
|
||||
const std::vector<CNodePtr> &split_nodes, const std::vector<int64_t> &split_num,
|
||||
bool is_grad) {
|
||||
MS_LOG(DEBUG) << "Start to create alltoallv node.";
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_or_grad);
|
||||
std::vector<int64_t> send_rank_ids =
|
||||
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_or_grad, kAttrSendRankIds);
|
||||
std::vector<int64_t> recv_rank_ids =
|
||||
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_or_grad, kAttrRecvRankIds);
|
||||
std::vector<int64_t> recv_lens =
|
||||
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_or_grad, kAttrRecvLens);
|
||||
std::string group = AnfAlgo::GetNodeAttr<std::string>(neighbor_exchange_v2_or_grad, kAttrGroup);
|
||||
|
||||
// get split nodes output, split_outputs: [top_bottom, left_right, top_corner, bottom_corner]
|
||||
std::vector<std::vector<AnfNodePtr>> split_outputs;
|
||||
for (size_t i = 0; i < split_nodes.size(); ++i) {
|
||||
std::vector<AnfNodePtr> output;
|
||||
if (split_nodes[i] != nullptr) {
|
||||
CreateMultipleOutputsOfAnfNode(graph, split_nodes[i], static_cast<size_t>(split_num[i]), &output);
|
||||
if (output.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The node " << split_nodes[i]->DebugString()
|
||||
<< " should have at least one output, but got 0.";
|
||||
}
|
||||
}
|
||||
split_outputs.emplace_back(output);
|
||||
}
|
||||
|
||||
// all_to_all_v input
|
||||
std::vector<AnfNodePtr> all_to_all_v_input;
|
||||
AnfNodePtr base_node = nullptr;
|
||||
if (is_grad) {
|
||||
all_to_all_v_input = CreateAllToAllvInputForGrad(send_rank_ids, split_outputs, split_nodes);
|
||||
base_node = GetCenter(graph, neighbor_exchange_v2_or_grad, split_nodes, split_num, send_rank_ids);
|
||||
} else {
|
||||
all_to_all_v_input = CreateAllToAllvInput(split_outputs, send_rank_ids);
|
||||
base_node = neighbor_exchange_v2_or_grad->input(kNeighborExchangeV2InputIdx);
|
||||
}
|
||||
|
||||
// output shapes and dtypes
|
||||
auto base_dtype = AnfAlgo::GetOutputInferDataType(base_node, 0);
|
||||
auto base_shape = AnfAlgo::GetOutputInferShape(base_node, 0);
|
||||
if (SizeToLong(base_shape.size()) != kShapeSize) {
|
||||
MS_LOG(EXCEPTION) << "Invalid shape size " << base_shape.size() << ", only support NCHW input now!";
|
||||
}
|
||||
std::vector<std::vector<size_t>> shapes = CalAllToAllvOutputShape(base_shape, recv_lens, recv_rank_ids);
|
||||
|
||||
// erase -1 in send_rank_ids
|
||||
std::vector<int64_t> real_send_rank_ids(send_rank_ids.size());
|
||||
std::vector<int64_t> real_recv_rank_ids(recv_rank_ids.size());
|
||||
auto iter1 = std::copy_if(send_rank_ids.begin(), send_rank_ids.end(), real_send_rank_ids.begin(),
|
||||
[](const int64_t item) { return item != kInvalidId; });
|
||||
auto iter2 = std::copy_if(recv_rank_ids.begin(), recv_rank_ids.end(), real_recv_rank_ids.begin(),
|
||||
[](const int64_t item) { return item != kInvalidId; });
|
||||
real_send_rank_ids.resize(std::distance(real_send_rank_ids.begin(), iter1));
|
||||
real_recv_rank_ids.resize(std::distance(real_recv_rank_ids.begin(), iter2));
|
||||
|
||||
std::vector<TypeId> dtypes(real_recv_rank_ids.size(), base_dtype);
|
||||
|
||||
// create alltoallv node
|
||||
auto all_to_all_v = graph->NewCNode(all_to_all_v_input);
|
||||
MS_EXCEPTION_IF_NULL(all_to_all_v);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, all_to_all_v.get());
|
||||
|
||||
AnfAlgo::SetNodeAttr(kAttrSendRankIds, MakeValue<std::vector<int64_t>>(real_send_rank_ids), all_to_all_v);
|
||||
AnfAlgo::SetNodeAttr(kAttrRecvRankIds, MakeValue<std::vector<int64_t>>(real_recv_rank_ids), all_to_all_v);
|
||||
AnfAlgo::SetNodeAttr(kAttrRecvType, TypeIdToType(base_dtype), all_to_all_v);
|
||||
AnfAlgo::SetNodeAttr(kAttrGroup, MakeValue<std::string>(group), all_to_all_v);
|
||||
MS_LOG(INFO) << "Create AllToAllv success, send rank size " << send_rank_ids.size() << ", recv rank size "
|
||||
<< recv_rank_ids.size();
|
||||
return all_to_all_v;
|
||||
}
|
||||
|
||||
int64_t AllToAllRealIds(int64_t ids, const std::vector<int64_t> &recv_rank_ids) {
|
||||
int64_t real_ids = 0;
|
||||
for (auto i = 0; i < ids; ++i) {
|
||||
if (recv_rank_ids[i] != kInvalidId) {
|
||||
++real_ids;
|
||||
}
|
||||
}
|
||||
return real_ids;
|
||||
}
|
||||
|
||||
CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &concat_input,
|
||||
const std::vector<std::vector<size_t>> &output_shape, const std::vector<TypeId> &output_dtype,
|
||||
int64_t axis, int64_t input_nums) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto concat = graph->NewCNode(concat_input);
|
||||
MS_EXCEPTION_IF_NULL(concat);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(output_dtype, output_shape, concat.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue<int64_t>(axis), concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(input_nums), concat);
|
||||
std::vector<int64_t> dyn_input_size_empty{input_nums};
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size_empty), concat);
|
||||
return concat;
|
||||
}
|
||||
|
||||
CNodePtr CreateLeftRightConcat(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &all_to_all_v_outputs,
|
||||
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens,
|
||||
bool is_left) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
||||
std::vector<AnfNodePtr> concat_input = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
|
||||
int64_t input_num = 1;
|
||||
size_t first_ids = is_left ? 7 : 1;
|
||||
size_t middle_ids = is_left ? 6 : 2;
|
||||
size_t last_ids = is_left ? 5 : 3;
|
||||
|
||||
auto single_shape = AnfAlgo::GetOutputInferShape(all_to_all_v_outputs[AllToAllRealIds(middle_ids, recv_rank_ids)], 0);
|
||||
|
||||
if (recv_rank_ids[first_ids] != kInvalidId) {
|
||||
++input_num;
|
||||
single_shape[2] += static_cast<size_t>(recv_lens[0]); // H in NCHW
|
||||
}
|
||||
if (recv_rank_ids[last_ids] != kInvalidId) {
|
||||
++input_num;
|
||||
single_shape[2] += static_cast<size_t>(recv_lens[1]); // H in NCHW
|
||||
}
|
||||
if (is_left) {
|
||||
concat_input.insert(concat_input.end(), all_to_all_v_outputs.rbegin(), all_to_all_v_outputs.rbegin() + input_num);
|
||||
} else {
|
||||
concat_input.insert(concat_input.end(), all_to_all_v_outputs.begin() + AllToAllRealIds(1, recv_rank_ids),
|
||||
all_to_all_v_outputs.begin() + input_num + AllToAllRealIds(1, recv_rank_ids));
|
||||
}
|
||||
|
||||
std::vector<TypeId> concat_output_dtype = {
|
||||
AnfAlgo::GetOutputInferDataType(all_to_all_v_outputs[AllToAllRealIds(middle_ids, recv_rank_ids)], 0)};
|
||||
auto concat = CreateConcatNode(graph, concat_input, {single_shape}, concat_output_dtype, kHDim, input_num);
|
||||
|
||||
return concat;
|
||||
}
|
||||
|
||||
CNodePtr CreateMiddleConcat(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
|
||||
const std::vector<AnfNodePtr> &all_to_all_v_outputs,
|
||||
const std::vector<int64_t> &recv_rank_ids, const std::vector<int64_t> &recv_lens,
|
||||
int64_t concat_dim) {
|
||||
std::vector<AnfNodePtr> concat_input_all = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
|
||||
int64_t input_num_all = 0;
|
||||
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
|
||||
auto single_shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0);
|
||||
size_t first_idx = concat_dim == kWDim ? 6 : 0;
|
||||
size_t last_idx = concat_dim == kWDim ? 2 : 4;
|
||||
size_t first_len = concat_dim == kWDim ? static_cast<size_t>(recv_lens[2]) : static_cast<size_t>(recv_lens[0]);
|
||||
size_t last_len = concat_dim == kWDim ? static_cast<size_t>(recv_lens[3]) : static_cast<size_t>(recv_lens[1]);
|
||||
|
||||
// left
|
||||
if (recv_rank_ids[first_idx] != kInvalidId) {
|
||||
if (concat_dim == kWDim) {
|
||||
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.end() - 1, all_to_all_v_outputs.end());
|
||||
} else {
|
||||
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.begin() + 1);
|
||||
}
|
||||
|
||||
++input_num_all;
|
||||
single_shape[concat_dim] += first_len;
|
||||
}
|
||||
|
||||
concat_input_all.push_back(neighbor_exchange_v2_input);
|
||||
++input_num_all;
|
||||
// right
|
||||
if (recv_rank_ids[last_idx] != kInvalidId) {
|
||||
if (concat_dim == kWDim) {
|
||||
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.begin() + 1);
|
||||
} else {
|
||||
int64_t bottom_num = AllToAllRealIds(4, recv_rank_ids);
|
||||
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin() + bottom_num,
|
||||
all_to_all_v_outputs.begin() + bottom_num + 1);
|
||||
}
|
||||
|
||||
++input_num_all;
|
||||
single_shape[concat_dim] += last_len;
|
||||
}
|
||||
|
||||
std::vector<TypeId> concat_output_dtype = {AnfAlgo::GetOutputInferDataType(all_to_all_v_outputs[0], 0)};
|
||||
auto concat_all =
|
||||
CreateConcatNode(graph, concat_input_all, {single_shape}, concat_output_dtype, concat_dim, input_num_all);
|
||||
return concat_all;
|
||||
}
|
||||
|
||||
CNodePtr AllToAllvRecvEmpty(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
|
||||
const CNodePtr &all_to_all_v) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
|
||||
MS_EXCEPTION_IF_NULL(all_to_all_v);
|
||||
// add depend for input & alltoallv
|
||||
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
|
||||
std::vector<AnfNodePtr> depend_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
|
||||
neighbor_exchange_v2_input, all_to_all_v};
|
||||
auto depend = graph->NewCNode(depend_input);
|
||||
MS_EXCEPTION_IF_NULL(depend);
|
||||
depend->set_abstract(neighbor_exchange_v2_input->abstract());
|
||||
return depend;
|
||||
}
|
||||
|
||||
CNodePtr CreateConcatNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
|
||||
const CNodePtr &all_to_all_v) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
|
||||
MS_EXCEPTION_IF_NULL(all_to_all_v);
|
||||
std::vector<int64_t> recv_rank_ids =
|
||||
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2, kAttrRecvRankIds);
|
||||
std::vector<int64_t> recv_lens = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2, kAttrRecvLens);
|
||||
|
||||
int64_t all_to_all_output_num =
|
||||
std::count_if(recv_rank_ids.begin(), recv_rank_ids.end(), [](int64_t ids) { return ids != kInvalidId; });
|
||||
|
||||
if (all_to_all_output_num == 0) {
|
||||
return AllToAllvRecvEmpty(graph, neighbor_exchange_v2, all_to_all_v);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> all_to_all_v_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, all_to_all_v, static_cast<size_t>(all_to_all_output_num),
|
||||
&all_to_all_v_outputs);
|
||||
if (all_to_all_v_outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The node " << all_to_all_v->DebugString() << " should have at least one output, but got 0.";
|
||||
}
|
||||
|
||||
if (recv_rank_ids[kRankIdZero] == kInvalidId && recv_rank_ids[kRankIdFour] == kInvalidId) {
|
||||
return CreateMiddleConcat(graph, neighbor_exchange_v2, all_to_all_v_outputs, recv_rank_ids, recv_lens, kWDim);
|
||||
}
|
||||
|
||||
// top or bottom
|
||||
// middle concat
|
||||
auto concat_middle =
|
||||
CreateMiddleConcat(graph, neighbor_exchange_v2, all_to_all_v_outputs, recv_rank_ids, recv_lens, kHDim);
|
||||
|
||||
bool is_left = recv_rank_ids[kRankIdSix] != kInvalidId || recv_rank_ids[kRankIdFive] != kInvalidId ||
|
||||
recv_rank_ids[kRankIdSeven] != kInvalidId;
|
||||
bool is_right = recv_rank_ids[kRankIdOne] != kInvalidId || recv_rank_ids[kRankIdTwo] != kInvalidId ||
|
||||
recv_rank_ids[kRankIdThree] != kInvalidId;
|
||||
if (!is_left && !is_right) {
|
||||
return concat_middle;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> concat_input_all = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
|
||||
auto neighbor_exchange_v2_input = neighbor_exchange_v2->input(kNeighborExchangeV2InputIdx);
|
||||
std::vector<size_t> shape_all = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_input, 0);
|
||||
shape_all[2] =
|
||||
recv_rank_ids[kRankIdZero] != kInvalidId ? shape_all[2] + static_cast<size_t>(recv_lens[0]) : shape_all[2];
|
||||
shape_all[2] =
|
||||
recv_rank_ids[kRankIdFour] != kInvalidId ? shape_all[2] + static_cast<size_t>(recv_lens[1]) : shape_all[2];
|
||||
int64_t input_nums_all = 0;
|
||||
// left concat
|
||||
if (is_left) {
|
||||
auto concat_left = CreateLeftRightConcat(graph, all_to_all_v_outputs, recv_rank_ids, recv_lens, true);
|
||||
|
||||
// connect to concat_all
|
||||
std::vector<AnfNodePtr> concat_left_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, concat_left, 1, &concat_left_outputs);
|
||||
if (concat_left_outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The node " << concat_left->DebugString() << " should have at least one output, but got 0.";
|
||||
}
|
||||
concat_input_all.insert(concat_input_all.end(), concat_left_outputs.begin(), concat_left_outputs.end());
|
||||
++input_nums_all;
|
||||
shape_all[3] += recv_lens[2];
|
||||
}
|
||||
|
||||
// middle concat connect to concat_all
|
||||
std::vector<AnfNodePtr> concat_middle_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, concat_middle, 1, &concat_middle_outputs);
|
||||
if (concat_middle_outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The node " << concat_middle->DebugString() << " should have at least one output, but got 0.";
|
||||
}
|
||||
concat_input_all.insert(concat_input_all.end(), concat_middle_outputs.begin(), concat_middle_outputs.end());
|
||||
++input_nums_all;
|
||||
|
||||
if (is_right) {
|
||||
auto concat_right = CreateLeftRightConcat(graph, all_to_all_v_outputs, recv_rank_ids, recv_lens, false);
|
||||
|
||||
// connect to concat_all
|
||||
std::vector<AnfNodePtr> concat_right_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, concat_right, 1, &concat_right_outputs);
|
||||
if (concat_right_outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The node " << concat_right->DebugString() << " should have at least one output, but got 0.";
|
||||
}
|
||||
concat_input_all.insert(concat_input_all.end(), concat_right_outputs.begin(), concat_right_outputs.end());
|
||||
++input_nums_all;
|
||||
shape_all[3] += recv_lens[3];
|
||||
}
|
||||
|
||||
std::vector<TypeId> concat_right_output_dtype = {AnfAlgo::GetOutputInferDataType(concat_input_all[1], 0)};
|
||||
auto concat_all =
|
||||
CreateConcatNode(graph, concat_input_all, {shape_all}, concat_right_output_dtype, kWDim, input_nums_all);
|
||||
return concat_all;
|
||||
}
|
||||
|
||||
// grad
|
||||
// returns {top_bottom, left_right, top_corner, bottom_corner}, if no split, set it nullptr
|
||||
std::vector<CNodePtr> CreateSplitNodesForGrad(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad,
|
||||
std::vector<int64_t> *split_num) {
|
||||
MS_LOG(DEBUG) << "Start create splitv nodes.";
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad);
|
||||
MS_EXCEPTION_IF_NULL(split_num);
|
||||
std::vector<int64_t> send_rank_ids =
|
||||
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_grad, kAttrSendRankIds);
|
||||
std::vector<int64_t> send_lens = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_grad, kAttrSendLens);
|
||||
|
||||
if (neighbor_exchange_v2_grad->size() <= kNeighborExchangeV2InputIdx) {
|
||||
MS_LOG(EXCEPTION) << "Invalid cnode " << neighbor_exchange_v2_grad->DebugString() << " input size "
|
||||
<< neighbor_exchange_v2_grad->size();
|
||||
}
|
||||
|
||||
auto neighbor_exchange_v2_grad_input = neighbor_exchange_v2_grad->input(kNeighborExchangeV2InputIdx);
|
||||
auto dtype = AnfAlgo::GetOutputInferDataType(neighbor_exchange_v2_grad_input, 0);
|
||||
auto shape = AnfAlgo::GetOutputInferShape(neighbor_exchange_v2_grad_input, 0);
|
||||
if (SizeToLong(shape.size()) != kShapeSize) {
|
||||
MS_LOG(EXCEPTION) << "Invalid shape size " << shape.size() << ", only support NCHW input now!";
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> split_nodes = {};
|
||||
// splitv for top & bottom
|
||||
bool is_top = ((send_rank_ids[kRankIdZero] != kInvalidId) || (send_rank_ids[kRankIdOne] != kInvalidId) ||
|
||||
(send_rank_ids[kRankIdSeven] != kInvalidId));
|
||||
bool is_bottom = ((send_rank_ids[kRankIdThree] != kInvalidId) || (send_rank_ids[kRankIdFour] != kInvalidId) ||
|
||||
(send_rank_ids[kRankIdFive] != kInvalidId));
|
||||
CNodePtr split_v_top_bottom = nullptr;
|
||||
int64_t num_split_h = 0;
|
||||
if (is_top || is_bottom) {
|
||||
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
neighbor_exchange_v2_grad_input};
|
||||
split_v_top_bottom =
|
||||
CreateSplitNode(graph, split_input, shape, is_top, is_bottom, kHDim, send_lens, dtype, &num_split_h);
|
||||
}
|
||||
split_nodes.emplace_back(split_v_top_bottom);
|
||||
split_num->push_back(num_split_h);
|
||||
|
||||
// splitvs for left & right
|
||||
// inputs
|
||||
std::vector<AnfNodePtr> split_outputs_top_bottom;
|
||||
std::vector<int64_t> size_split_h;
|
||||
if (split_nodes[0] != nullptr) {
|
||||
CreateMultipleOutputsOfAnfNode(graph, split_nodes[0], static_cast<size_t>(num_split_h), &split_outputs_top_bottom);
|
||||
if (split_outputs_top_bottom.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The node " << split_nodes[0]->DebugString()
|
||||
<< " should have at least one output, but got 0.";
|
||||
}
|
||||
size_split_h = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(split_nodes[0], kAttrSizeSplits);
|
||||
} else {
|
||||
// just middle
|
||||
split_outputs_top_bottom.push_back(neighbor_exchange_v2_grad_input);
|
||||
size_split_h.push_back(shape[kHDim]);
|
||||
}
|
||||
|
||||
// left_right splitv nodes from top to bottom
|
||||
bool is_left = (send_rank_ids[kRankIdFive] != kInvalidId) || (send_rank_ids[kRankIdSix] != kInvalidId) ||
|
||||
(send_rank_ids[kRankIdSeven] != kInvalidId);
|
||||
bool is_right = (send_rank_ids[kRankIdOne] != kInvalidId) || (send_rank_ids[kRankIdTwo] != kInvalidId) ||
|
||||
(send_rank_ids[kRankIdThree] != kInvalidId);
|
||||
if (is_left || is_right) {
|
||||
if (!is_top) {
|
||||
split_nodes.push_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
}
|
||||
for (size_t i = 0; i < split_outputs_top_bottom.size(); ++i) {
|
||||
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
||||
split_outputs_top_bottom[i]};
|
||||
|
||||
int64_t num_split_w = 0;
|
||||
std::vector<size_t> base_shape(shape);
|
||||
base_shape[kHDim] = size_split_h[i];
|
||||
auto split_v_left_right =
|
||||
CreateSplitNode(graph, split_input, base_shape, is_left, is_right, kWDim, send_lens, dtype, &num_split_w);
|
||||
split_nodes.emplace_back(split_v_left_right);
|
||||
split_num->push_back(num_split_w);
|
||||
}
|
||||
if (!is_bottom) {
|
||||
split_nodes.push_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
}
|
||||
} else {
|
||||
split_nodes.push_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
split_nodes.push_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
split_nodes.push_back(nullptr);
|
||||
split_num->push_back(0);
|
||||
}
|
||||
MS_LOG(DEBUG) << "Create splitv nodes success.";
|
||||
return split_nodes;
|
||||
}
|
||||
|
||||
CNodePtr CreatePadNode(const FuncGraphPtr &graph, const AnfNodePtr &input, const std::vector<int64_t> &begin,
|
||||
const std::vector<int64_t> &size, const std::vector<size_t> &shape, TypeId dtype) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
std::vector<AnfNodePtr> pad_inputs = {NewValueNode(std::make_shared<Primitive>(kPadOpName)), input};
|
||||
auto pad = graph->NewCNode(pad_inputs);
|
||||
std::vector<std::vector<int64_t>> paddings;
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
paddings.emplace_back(std::vector<int64_t>{begin[i], static_cast<int64_t>(shape[i]) - begin[i] - size[i]});
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape({dtype}, {shape}, pad.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrPaddings, MakeValue(paddings), pad);
|
||||
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(std::vector<std::string>{"x"}), pad);
|
||||
return pad;
|
||||
}
|
||||
|
||||
CNodePtr CreateSplitGradNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad,
|
||||
const CNodePtr &all_to_all_v, const std::vector<CNodePtr> &split_nodes,
|
||||
const std::vector<int64_t> &split_num) {
|
||||
MS_LOG(DEBUG) << "Start create splitvs grad nodes.";
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad);
|
||||
std::vector<int64_t> send_rank_ids =
|
||||
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_grad, kAttrSendRankIds);
|
||||
std::vector<int64_t> recv_rank_ids =
|
||||
AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_grad, kAttrRecvRankIds);
|
||||
std::vector<int64_t> recv_lens = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(neighbor_exchange_v2_grad, kAttrRecvLens);
|
||||
|
||||
auto centerx = GetCenter(graph, neighbor_exchange_v2_grad, split_nodes, split_num, send_rank_ids);
|
||||
auto centerx_dtype = AnfAlgo::GetOutputInferDataType(centerx, 0);
|
||||
auto centerx_shape = AnfAlgo::GetOutputInferShape(centerx, 0);
|
||||
// empty
|
||||
int64_t all_to_all_output_num =
|
||||
std::count_if(recv_rank_ids.begin(), recv_rank_ids.end(), [](int64_t ids) { return ids != kInvalidId; });
|
||||
|
||||
if (all_to_all_output_num == 0) {
|
||||
// add depend(alltoallv, centerx)
|
||||
std::vector<AnfNodePtr> depend_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
|
||||
centerx, all_to_all_v};
|
||||
auto depend = graph->NewCNode(depend_input);
|
||||
MS_EXCEPTION_IF_NULL(depend);
|
||||
depend->set_abstract(centerx->abstract());
|
||||
return depend;
|
||||
}
|
||||
// get alltoallv outputs
|
||||
std::vector<AnfNodePtr> all_to_all_v_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, all_to_all_v, static_cast<size_t>(all_to_all_output_num),
|
||||
&all_to_all_v_outputs);
|
||||
if (all_to_all_v_outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The node " << all_to_all_v->DebugString() << " should have at least one output, but got 0.";
|
||||
}
|
||||
// create pad nodes
|
||||
// slice begin & size
|
||||
std::vector<std::vector<int64_t>> begins = {{0, 0, 0, 0},
|
||||
{0, 0, 0, static_cast<int64_t>(centerx_shape[3]) - recv_lens[3]},
|
||||
{0, 0, 0, static_cast<int64_t>(centerx_shape[3]) - recv_lens[3]},
|
||||
{0, 0, static_cast<int64_t>(centerx_shape[2]) - recv_lens[1],
|
||||
static_cast<int64_t>(centerx_shape[3]) - recv_lens[3]},
|
||||
{0, 0, static_cast<int64_t>(centerx_shape[2]) - recv_lens[1], 0},
|
||||
{0, 0, static_cast<int64_t>(centerx_shape[2]) - recv_lens[1], 0},
|
||||
{0, 0, 0, 0},
|
||||
{0, 0, 0, 0}};
|
||||
std::vector<std::vector<int64_t>> sizes = {
|
||||
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[0],
|
||||
static_cast<int64_t>(centerx_shape[3])},
|
||||
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[0], recv_lens[3]},
|
||||
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]),
|
||||
static_cast<int64_t>(centerx_shape[2]), recv_lens[3]},
|
||||
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[1], recv_lens[3]},
|
||||
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[1],
|
||||
static_cast<int64_t>(centerx_shape[3])},
|
||||
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[1], recv_lens[2]},
|
||||
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]),
|
||||
static_cast<int64_t>(centerx_shape[2]), recv_lens[2]},
|
||||
{static_cast<int64_t>(centerx_shape[0]), static_cast<int64_t>(centerx_shape[1]), recv_lens[0], recv_lens[2]}};
|
||||
std::vector<CNodePtr> pad_nodes;
|
||||
size_t output_index = 0;
|
||||
for (size_t i = 0; i < recv_rank_ids.size(); ++i) {
|
||||
if (recv_rank_ids[i] != kInvalidId) {
|
||||
auto pad =
|
||||
CreatePadNode(graph, all_to_all_v_outputs[output_index], begins[i], sizes[i], centerx_shape, centerx_dtype);
|
||||
++output_index;
|
||||
pad_nodes.emplace_back(pad);
|
||||
}
|
||||
}
|
||||
|
||||
// create add node
|
||||
std::vector<AnfNodePtr> addn_inputs = {NewValueNode(std::make_shared<Primitive>(kAddNOpName)), centerx};
|
||||
int64_t pad_num = 1;
|
||||
for (auto pad : pad_nodes) {
|
||||
std::vector<AnfNodePtr> pad_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, pad, 1, &pad_outputs);
|
||||
if (pad_outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The node " << pad->DebugString() << " should have at least one output, but got 0.";
|
||||
}
|
||||
addn_inputs.insert(addn_inputs.end(), pad_outputs.begin(), pad_outputs.end());
|
||||
++pad_num;
|
||||
}
|
||||
auto addn = graph->NewCNode(addn_inputs);
|
||||
MS_EXCEPTION_IF_NULL(addn);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({centerx_dtype}, {centerx_shape}, addn.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue<std::vector<int64_t>>({pad_num}), addn);
|
||||
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(pad_num), addn);
|
||||
MS_LOG(DEBUG) << "Create splitvs grad nodes success.";
|
||||
return addn;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef NeighborExchangeV2UnifyMindIR::DefinePattern() const {
|
||||
return VectorRef({prim::kPrimNeighborExchangeV2, std::make_shared<SeqVar>()});
|
||||
}
|
||||
|
||||
const AnfNodePtr NeighborExchangeV2UnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto neighbor_exchange_v2 = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2);
|
||||
std::vector<int64_t> split_num;
|
||||
auto split_nodes = CreateSplitNodes(graph, neighbor_exchange_v2, &split_num);
|
||||
auto all_to_all_v = CreateAllToAllvNode(graph, neighbor_exchange_v2, split_nodes, split_num, false);
|
||||
auto concat = CreateConcatNodes(graph, neighbor_exchange_v2, all_to_all_v);
|
||||
return concat;
|
||||
}
|
||||
|
||||
const BaseRef NeighborExchangeV2GradUnifyMindIR::DefinePattern() const {
|
||||
return VectorRef({prim::kPrimNeighborExchangeV2Grad, std::make_shared<SeqVar>()});
|
||||
}
|
||||
const AnfNodePtr NeighborExchangeV2GradUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto neighbor_exchange_v2_grad = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(neighbor_exchange_v2_grad);
|
||||
std::vector<int64_t> split_num;
|
||||
auto split_nodes = CreateSplitNodesForGrad(graph, neighbor_exchange_v2_grad, &split_num);
|
||||
auto all_to_all_v = CreateAllToAllvNode(graph, neighbor_exchange_v2_grad, split_nodes, split_num, true);
|
||||
auto add = CreateSplitGradNodes(graph, neighbor_exchange_v2_grad, all_to_all_v, split_nodes, split_num);
|
||||
return add;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_NEIGHBOR_EXCHANGE_V2_UNIFY_MINDIR_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_NEIGHBOR_EXCHANGE_V2_UNIFY_MINDIR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class NeighborExchangeV2UnifyMindIR : public PatternProcessPass {
|
||||
public:
|
||||
explicit NeighborExchangeV2UnifyMindIR(bool multigraph = true)
|
||||
: PatternProcessPass("neighbor_exchange_v2_unify_mindir", multigraph) {}
|
||||
~NeighborExchangeV2UnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
|
||||
class NeighborExchangeV2GradUnifyMindIR : public PatternProcessPass {
|
||||
public:
|
||||
explicit NeighborExchangeV2GradUnifyMindIR(bool multigraph = true)
|
||||
: PatternProcessPass("neighbor_exchange_v2_grad_unify_mindir", multigraph) {}
|
||||
~NeighborExchangeV2GradUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_NEIGHBOR_EXCHANGE_V2_UNIFY_MINDIR_H_
|
|
@ -235,6 +235,7 @@ constexpr char GET_TENSOR_SLICE[] = "_GetTensorSlice";
|
|||
constexpr char SPLIT[] = "Split";
|
||||
constexpr char ALL_TO_ALL[] = "AlltoAll";
|
||||
constexpr char NEIGHBOREXCHANGE[] = "NeighborExchange";
|
||||
constexpr char NEIGHBOREXCHANGEV2[] = "NeighborExchangeV2";
|
||||
constexpr char PERMUTE_BY_AXIS[] = "PermuteByAxis";
|
||||
constexpr char CONCAT_BY_AXIS[] = "ConcatByAxis";
|
||||
constexpr char SPLIT_BY_AXIS[] = "SplitByAxis";
|
||||
|
|
|
@ -437,6 +437,8 @@ constexpr auto kAttrConcatDim = "concat_dim";
|
|||
constexpr auto kAttrSplitCount = "split_count";
|
||||
constexpr auto kAttrSendRankIds = "send_rank_ids";
|
||||
constexpr auto kAttrRecvRankIds = "recv_rank_ids";
|
||||
constexpr auto kAttrSendLens = "send_lens";
|
||||
constexpr auto kAttrRecvLens = "recv_lens";
|
||||
constexpr auto kAttrRankSize = "rank_size";
|
||||
constexpr auto kAttrPadDimSize = "pad_dim_size";
|
||||
constexpr auto kAttrPaddings = "paddings";
|
||||
|
|
|
@ -437,6 +437,8 @@ inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send");
|
|||
inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("Receive");
|
||||
inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
|
||||
inline const PrimitivePtr kPrimNeighborExchange = std::make_shared<Primitive>("NeighborExchange");
|
||||
inline const PrimitivePtr kPrimNeighborExchangeV2 = std::make_shared<Primitive>("NeighborExchangeV2");
|
||||
inline const PrimitivePtr kPrimNeighborExchangeV2Grad = std::make_shared<Primitive>("NeighborExchangeV2Grad");
|
||||
inline const PrimitivePtr kPrimAllToAll = std::make_shared<Primitive>("AlltoAll");
|
||||
inline const PrimitivePtr kPrimAllToAllv = std::make_shared<Primitive>("AllToAllv");
|
||||
inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("AllSwap");
|
||||
|
|
|
@ -0,0 +1,156 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/neighborexchangev2.h"
|
||||
#include <string>
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr auto kSendRankIds = "send_rank_ids";
|
||||
constexpr auto kSendLens = "send_lens";
|
||||
constexpr auto kRecvRankIds = "recv_rank_ids";
|
||||
constexpr auto kRecvLens = "recv_lens";
|
||||
constexpr auto kDataFormat = "format";
|
||||
constexpr auto kGroup = "group";
|
||||
constexpr size_t kRankIdsSize = 8;
|
||||
constexpr size_t kLensSize = 4;
|
||||
|
||||
std::vector<int64_t> CheckAttrSize(const PrimitivePtr &primitive, const std::string &attr_name,
|
||||
const size_t attr_size) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
// size of send/recv_rank_ids equal to size of send/recv_shapes
|
||||
std::vector<int64_t> attr_value;
|
||||
try {
|
||||
auto attr = primitive->GetAttr(attr_name);
|
||||
if (attr->cast<ValueListPtr>() == nullptr) {
|
||||
MS_EXCEPTION(TypeError);
|
||||
}
|
||||
attr_value = GetValue<std::vector<int64_t>>(attr);
|
||||
} catch (const std::exception &) {
|
||||
MS_EXCEPTION(TypeError) << "Attr " << attr_name << " must be a list[int, int, ...].";
|
||||
}
|
||||
|
||||
if (attr_value.size() != attr_size) {
|
||||
MS_EXCEPTION(ValueError) << "Invalid " << primitive->name() << " attr " << attr_name << " size "
|
||||
<< attr_value.size() << " must be equal to size " << attr_size;
|
||||
}
|
||||
return attr_value;
|
||||
}
|
||||
|
||||
void CheckRecvCorner(std::vector<int64_t> recv_rank_ids, int64_t idx1, int64_t idx2, int64_t idx_corner) {
|
||||
if (recv_rank_ids[idx1] != -1 && recv_rank_ids[idx2] != -1 && recv_rank_ids[idx_corner] == -1) {
|
||||
MS_EXCEPTION(ValueError) << "Invalid recv_rank_ids, as recv_rank_ids[" << idx1 << "] = " << recv_rank_ids[idx1]
|
||||
<< ", recv_rank_ids[" << idx2 << "] = " << recv_rank_ids[idx2] << ", and recv_rank_ids["
|
||||
<< idx_corner << "] = " << recv_rank_ids[idx_corner] << ".";
|
||||
}
|
||||
if ((recv_rank_ids[idx1] == -1 || recv_rank_ids[idx2] == -1) && recv_rank_ids[idx_corner] != -1) {
|
||||
MS_EXCEPTION(ValueError) << "Invalid recv_rank_ids, as recv_rank_ids[" << idx1 << "] = " << recv_rank_ids[idx1]
|
||||
<< ", recv_rank_ids[" << idx2 << "] = " << recv_rank_ids[idx2] << ", and recv_rank_ids["
|
||||
<< idx_corner << "] = " << recv_rank_ids[idx_corner] << ".";
|
||||
}
|
||||
}
|
||||
|
||||
void Check(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t input_num = 1;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
|
||||
// check size of send_rank_ids, recv_rank_ids, send_lens, recv_lens
|
||||
(void)CheckAttrSize(primitive, kSendRankIds, kRankIdsSize);
|
||||
auto recv_rank_ids = CheckAttrSize(primitive, kRecvRankIds, kRankIdsSize);
|
||||
(void)CheckAttrSize(primitive, kSendLens, kLensSize);
|
||||
(void)CheckAttrSize(primitive, kRecvLens, kLensSize);
|
||||
|
||||
// check recv rankids invalid cond
|
||||
CheckRecvCorner(recv_rank_ids, 0, 2, 1);
|
||||
CheckRecvCorner(recv_rank_ids, 2, 4, 3);
|
||||
CheckRecvCorner(recv_rank_ids, 4, 6, 5);
|
||||
CheckRecvCorner(recv_rank_ids, 6, 0, 7);
|
||||
|
||||
// check data_format is NCHW
|
||||
auto format = GetValue<std::string>(primitive->GetAttr(kDataFormat));
|
||||
if (format != "NCHW") {
|
||||
MS_EXCEPTION(ValueError) << "Attr data_format only support NCHW now.";
|
||||
}
|
||||
|
||||
// check group
|
||||
auto group_attr = primitive->GetAttr(kGroup);
|
||||
try {
|
||||
MS_EXCEPTION_IF_NULL(group_attr);
|
||||
(void)GetValue<std::string>(group_attr);
|
||||
} catch (const std::exception &) {
|
||||
MS_EXCEPTION(TypeError) << "Attr " << kGroup << " should be a str.";
|
||||
}
|
||||
}
|
||||
|
||||
abstract::BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto recv_rank_ids = primitive->GetAttr(kRecvRankIds);
|
||||
MS_EXCEPTION_IF_NULL(recv_rank_ids);
|
||||
auto recv_rank_ids_value = recv_rank_ids->cast<ValueSequeuePtr>();
|
||||
MS_EXCEPTION_IF_NULL(recv_rank_ids_value);
|
||||
std::vector<int64_t> recv_rank_ids_v = GetValue<std::vector<int64_t>>(recv_rank_ids_value);
|
||||
auto recv_lens = primitive->GetAttr(kRecvLens);
|
||||
MS_EXCEPTION_IF_NULL(recv_lens);
|
||||
auto recv_lens_value = recv_lens->cast<ValueSequeuePtr>();
|
||||
MS_EXCEPTION_IF_NULL(recv_lens_value);
|
||||
std::vector<int64_t> recv_lens_v = GetValue<std::vector<int64_t>>(recv_lens_value);
|
||||
|
||||
std::vector<int64_t> input_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
if (recv_rank_ids_v[0] != -1) {
|
||||
input_shape[2] += recv_lens_v[0];
|
||||
}
|
||||
if (recv_rank_ids_v[4] != -1) {
|
||||
input_shape[2] += recv_lens_v[1];
|
||||
}
|
||||
if (recv_rank_ids_v[6] != -1) {
|
||||
input_shape[3] += recv_lens_v[2];
|
||||
}
|
||||
if (recv_rank_ids_v[2] != -1) {
|
||||
input_shape[3] += recv_lens_v[3];
|
||||
}
|
||||
BaseShapePtr output_shape = std::make_shared<abstract::Shape>(input_shape);
|
||||
if (input_shape.empty()) {
|
||||
return std::make_shared<abstract::Shape>();
|
||||
}
|
||||
return output_shape;
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
// recv type
|
||||
TypePtr recv_type = input_args[0]->BuildType();
|
||||
if (recv_type == nullptr) {
|
||||
return std::make_shared<TypeNone>();
|
||||
}
|
||||
return recv_type;
|
||||
}
|
||||
} // namespace
|
||||
AbstractBasePtr NeighborExchangeV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
Check(primitive, input_args);
|
||||
auto type = InferType(primitive, input_args);
|
||||
auto shape = InferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchangeV2, prim::kPrimNeighborExchangeV2, NeighborExchangeV2Infer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_NEIGHBOREXCHANGEV2_H_
|
||||
#define MINDSPORE_CORE_OPS_NEIGHBOREXCHANGEV2_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameNeighborExchangeV2 = "NeighborExchangeV2";
|
||||
class MS_CORE_API NeighborExchangeV2 : public PrimitiveC {
|
||||
public:
|
||||
NeighborExchangeV2() : PrimitiveC(kNameNeighborExchangeV2) {}
|
||||
~NeighborExchangeV2() = default;
|
||||
MS_DECLARE_PARENT(NeighborExchangeV2, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
using kPrimNeighborExchangeV2Ptr = std::shared_ptr<NeighborExchangeV2>;
|
||||
|
||||
AbstractBasePtr NeighborExchangeV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_NEIGHBOREXCHANGEV2_H_
|
|
@ -22,12 +22,13 @@ from mindspore.parallel._utils import _get_enable_parallel_optimizer
|
|||
from .. import operations as P
|
||||
from ...common.tensor import RowTensor
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, NeighborExchange, AlltoAll,
|
||||
from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, NeighborExchange, AlltoAll, NeighborExchangeV2,
|
||||
Broadcast, _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
|
||||
ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap,
|
||||
_VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator, _MicroStepAllGather)
|
||||
from .grad_base import bprop_getters
|
||||
from ..operations._inner_ops import Send, Receive
|
||||
from ..operations import _grad_ops as G
|
||||
|
||||
|
||||
@bprop_getters.register(AllReduce)
|
||||
|
@ -390,6 +391,24 @@ def get_bprop_all_to_all(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(NeighborExchangeV2)
|
||||
def get_bprop_neighborexchangev2(self):
|
||||
"""Generate bprop for NeighborExchangeV2."""
|
||||
group = self.group
|
||||
send_rank_ids = self.recv_rank_ids
|
||||
recv_rank_ids = self.send_rank_ids
|
||||
send_lens = self.recv_lens
|
||||
recv_lens = self.send_lens
|
||||
data_format = self.data_format
|
||||
neighborexchangev2_grad = G.NeighborExchangeV2Grad(send_rank_ids, send_lens, recv_rank_ids,
|
||||
recv_lens, data_format, group)
|
||||
|
||||
def bprop(x, out, dout):
|
||||
return (neighborexchangev2_grad(dout),)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(_MirrorOperator)
|
||||
def get_bprop_mirror_operator(self):
|
||||
"""
|
||||
|
|
|
@ -5,4 +5,4 @@ e
|
|||
bprop.8:x*
|
||||
bprop.8:out*
|
||||
bprop.8:dout2
|
||||
bprop.8:[CNode]:1:@13f1e8534ff98b1256889fc50eb483e05dc629e575bb06330ae1f2238e2372e8P
|
||||
bprop.8:[CNode]:1:@74787be4234cdeb03f214519cd8358a5f4ad2f5606dbeb494462cddc448eb4beP
|
|
@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
|
|||
BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence,
|
||||
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted,
|
||||
TensorScatterMax, TensorScatterMin, TensorScatterSub, ScatterElements, ExtractVolumePatches)
|
||||
from .comm_ops import (AllGather, AllReduce, NeighborExchange, AlltoAll, AllSwap, ReduceScatter, Broadcast,
|
||||
from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, AllSwap, ReduceScatter, Broadcast,
|
||||
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
|
||||
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad,
|
||||
_HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather)
|
||||
|
@ -486,6 +486,9 @@ __all__ = [
|
|||
"Trunc",
|
||||
"Complex",
|
||||
"ExtractVolumePatches",
|
||||
"NeighborExchangeV2",
|
||||
"NeighborExchange",
|
||||
"AlltoAll",
|
||||
]
|
||||
|
||||
__sponge__ = [
|
||||
|
|
|
@ -23,6 +23,7 @@ from ..._checkparam import Validator as validator, Rel
|
|||
from .._utils import get_concat_offset
|
||||
from ...common import dtype as mstype
|
||||
from ... import context
|
||||
from ...communication.management import GlobalComm
|
||||
|
||||
|
||||
class AbsGrad(PrimitiveWithInfer):
|
||||
|
@ -723,6 +724,42 @@ class BNTrainingUpdateGrad(PrimitiveWithInfer):
|
|||
def infer_dtype(self, grads, x, batch_mean, batch_variance):
|
||||
return (batch_mean, batch_variance)
|
||||
|
||||
class NeighborExchangeV2Grad(PrimitiveWithInfer):
|
||||
""""Gradients of NeighborExchangeV2 operation."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format,
|
||||
group=GlobalComm.WORLD_COMM_GROUP):
|
||||
self.init_prim_io_names(inputs=['dy'], outputs=['dx'])
|
||||
self.send_rank_ids = send_rank_ids
|
||||
self.recv_rank_ids = recv_rank_ids
|
||||
self.send_lens = send_lens
|
||||
self.recv_lens = recv_lens
|
||||
self.format = validator.check_string(data_format, ['NCHW'], 'format', self.name)
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
|
||||
def __infer__(self, dy):
|
||||
dy_shape = dy['shape']
|
||||
validator.check(f'dy_shape.size()', len(dy_shape), f'4', 4, Rel.EQ, self.name)
|
||||
if self.send_rank_ids[5] != -1 or self.send_rank_ids[6] != -1 or self.send_rank_ids[7] != -1:
|
||||
dy_shape[3] -= self.send_lens[2]
|
||||
|
||||
if self.send_rank_ids[1] != -1 or self.send_rank_ids[2] != -1 or self.send_rank_ids[3] != -1:
|
||||
dy_shape[3] -= self.send_lens[3]
|
||||
|
||||
if self.send_rank_ids[0] != -1 or self.send_rank_ids[1] != -1 or self.send_rank_ids[7] != -1:
|
||||
dy_shape[2] -= self.send_lens[0]
|
||||
|
||||
if self.send_rank_ids[3] != -1 or self.send_rank_ids[4] != -1 or self.send_rank_ids[5] != -1:
|
||||
dy_shape[2] -= self.send_lens[1]
|
||||
|
||||
return {'shape': dy_shape,
|
||||
'dtype': dy['dtype'],
|
||||
'value': None}
|
||||
|
||||
def __call__(self, tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class GeLUGrad(PrimitiveWithInfer):
|
||||
"""Gradients of GeLU operation."""
|
||||
|
|
|
@ -710,6 +710,39 @@ class AlltoAll(PrimitiveWithInfer):
|
|||
def __call__(self, tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
class NeighborExchangeV2(Primitive):
|
||||
"""
|
||||
NeighborExchange is a collective operation.
|
||||
|
||||
NeighborExchange sends data from the local rank to ranks in the send_rank_ids,
|
||||
as while receive data from recv_rank_ids.
|
||||
|
||||
Args:
|
||||
send_rank_ids (list(int)): Ranks which the data is sent to. 8 rank_ids represents 8 directions, if one
|
||||
direction is not send to , set it -1.
|
||||
recv_rank_ids (list(int)): Ranks which the data is received from. 8 rank_ids represents 8 directions,
|
||||
if one direction is not recv from , set it -1.
|
||||
send_lens (list(int)): Data lens which send to the send_rank_ids, 4 numbers represent the lens of
|
||||
[top, bottom, left, right].
|
||||
recv_lens (list(int)): Data lens which received from recv_rank_ids, 4 numbers represent the lens of
|
||||
[top, bottom, left, right].
|
||||
data_format (str): Data format, only support NCHW now.
|
||||
group (str): The communication group to work on. Default: "GlobalComm.WORLD_COMM_GROUP".
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format,
|
||||
group=GlobalComm.WORLD_COMM_GROUP):
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
self.send_rank_ids = send_rank_ids
|
||||
self.recv_rank_ids = recv_rank_ids
|
||||
self.send_lens = send_lens
|
||||
self.recv_lens = recv_lens
|
||||
self.format = data_format
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
|
||||
def __call__(self, tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
class _MirrorOperator(PrimitiveWithInfer):
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,542 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.nn import TrainOneStepCell, Momentum
|
||||
from mindspore.ops.operations.comm_ops import NeighborExchangeV2
|
||||
|
||||
_x1 = Tensor(np.ones([1, 1, 32, 16]), dtype=ms.float32)
|
||||
_x2 = Tensor(np.ones([1, 1, 33, 16]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net, x1, x2):
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_train()
|
||||
_cell_graph_executor.compile(train_net, x1, x2)
|
||||
|
||||
|
||||
def test_neighborexchangev2_single_input_success():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: one inputs and one outputs, with valid arguments
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.linear = nn.Dense(16, 16)
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0], data_format="NCHW")
|
||||
|
||||
def construct(self, x1, x2):
|
||||
y = self.linear(x1)
|
||||
y = self.neighborexchangev2(y)
|
||||
y = y + x2
|
||||
return y
|
||||
|
||||
net = Net()
|
||||
compile_net(net, _x1, _x2)
|
||||
|
||||
|
||||
def test_neighborexchangev2_empty_send_success():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: empty inputs, with valid arguments
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.linear = nn.Dense(16, 16)
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, -1, -1, -1, -1],
|
||||
send_lens=[1, 2, 3, 4],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x1, x2):
|
||||
y = self.linear(x1)
|
||||
y = self.neighborexchangev2(y)
|
||||
y = y + x2
|
||||
return y
|
||||
|
||||
net = Net()
|
||||
compile_net(net, _x1, _x2)
|
||||
|
||||
|
||||
def test_neighborexchangev2_empty_recv_success():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: empty outputs, with valid arguments
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.linear = nn.Dense(16, 16)
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, -1, -1, -1, -1],
|
||||
recv_lens=[1, 2, 3, 4],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x1, x2):
|
||||
y = self.linear(x1)
|
||||
y = self.neighborexchangev2(y)
|
||||
y = y + x2
|
||||
return y
|
||||
|
||||
net = Net()
|
||||
compile_net(net, _x1, _x1)
|
||||
|
||||
|
||||
def test_neighborexchangev2_empty_send_empty_recv_success():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: empty inputs and empty outputs, with valid arguments
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, -1, -1, -1, -1],
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, -1, -1, -1, -1],
|
||||
recv_lens=[1, 2, 3, 4],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x1):
|
||||
y = self.neighborexchangev2(x1)
|
||||
return y
|
||||
|
||||
net = Net()
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_invalid_dataformat_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: data_format should be NCHW, but gives NHWC
|
||||
Expectation: throw ValueError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0],
|
||||
data_format="NHWC")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_invalid_send_rank_ids_size_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: send_rank_ids size should be 8, but gives 5
|
||||
Expectation: throw ValueError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1],
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_invalid_recv_rank_ids_size_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: recv_rank_ids size should be 8, but gives 5
|
||||
Expectation: throw ValueError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1],
|
||||
recv_lens=[0, 1, 0, 0],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_invalid_send_lens_size_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: send_lens size should be 4, but gives 5
|
||||
Expectation: throw ValueError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, 1, 0, 0, 2],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_invalid_recv_lens_size_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: recv_lens size should be 4, but gives 5
|
||||
Expectation: throw ValueError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0, 2],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_invalid_input_size_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: input should be one tensor, but gives 2
|
||||
Expectation: throw ValueError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x1, x2):
|
||||
out = self.neighborexchangev2(x1, x2)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
_cell_graph_executor.compile(net, _x1, _x2)
|
||||
|
||||
def test_neighborexchangev2_recv_rank_ids_invalid_value_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: recv_rank_ids should can be concat, recv_rank_ids[3] and [4] is 1, [5] is -1 given
|
||||
Expectation: throw Exception
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, 1, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_attr_check_send_rank_ids_is_tuple_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: send_rank_ids should be list, but a tuple is given
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=(-1, -1, -1, -1, 1, -1, -1, -1),
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_attr_check_send_lens_is_tuple_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: send_lens should be list, but a tuple is given
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=(0, 1, 0, 0),
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_attr_check_recv_rank_ids_is_tuple_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: recv_rank_ids should be list, but a tuple is given
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=(-1, -1, -1, -1, 1, -1, -1, -1),
|
||||
recv_lens=[0, 1, 0, 0],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_attr_check_recv_lens_is_tuple_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: recv_lens should be list, but a tuple is given
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=(0, 1, 0, 0),
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_attr_check_send_rank_ids_is_float_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: send_rank_ids should be int, but float is given
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1.0, -1, -1, -1],
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_attr_check_send_lens_is_float_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: send_lens should be int, but float is given
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, 1.0, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_attr_check_recv_rank_ids_is_float_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: send_rank_ids should be int, but float is given
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1.0, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_attr_check_recv_lens_is_float_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: ids in send_rank_ids should be int, but float is given
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1.0, 0, 0],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_group_is_tuple_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: group should be a string, but tuple given
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0],
|
||||
data_format="NCHW", group=("str",))
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
Loading…
Reference in New Issue