From 33ac1de06230c6e3fbe9aad0b9dd265f91c1d58b Mon Sep 17 00:00:00 2001 From: ttudu Date: Mon, 15 Nov 2021 10:52:04 +0800 Subject: [PATCH] fix bug --- .../neighbor_exchange_v2_unify_mindir.cc | 12 +- mindspore/core/ops/neighborexchangev2.cc | 103 +++++++-- mindspore/ops/operations/comm_ops.py | 4 +- .../parallel/test_neighborexchangev2.py | 201 +++++++++++++++++- 4 files changed, 294 insertions(+), 26 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.cc index 3afe315d3a1..e21792ffd81 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.cc @@ -84,17 +84,19 @@ int64_t CalSplitAttrs(const std::vector &base_shape, const bool is_first } if (is_last) { // middle - ++num_split; split_middle_size -= last_size; - size_splits->push_back(split_middle_size); - shape_tmp[split_dim] = static_cast(split_middle_size); - shapes->push_back(shape_tmp); + if (split_middle_size > 0) { + ++num_split; + size_splits->push_back(split_middle_size); + shape_tmp[split_dim] = static_cast(split_middle_size); + shapes->push_back(shape_tmp); + } // last ++num_split; size_splits->push_back(last_size); shape_tmp[split_dim] = static_cast(last_size); shapes->push_back(shape_tmp); - } else { + } else if (split_middle_size > 0) { ++num_split; size_splits->push_back(split_middle_size); shape_tmp[split_dim] = static_cast(split_middle_size); diff --git a/mindspore/core/ops/neighborexchangev2.cc b/mindspore/core/ops/neighborexchangev2.cc index 8336afe0362..e1fb36b779e 100644 --- a/mindspore/core/ops/neighborexchangev2.cc +++ b/mindspore/core/ops/neighborexchangev2.cc @@ -15,6 +15,7 @@ */ #include "ops/neighborexchangev2.h" +#include #include #include "utils/check_convert_utils.h" #include "abstract/primitive_infer_map.h" @@ -30,6 +31,18 @@ constexpr auto kDataFormat = "format"; constexpr auto kGroup = "group"; constexpr size_t kRankIdsSize = 8; constexpr size_t kLensSize = 4; +constexpr size_t kInputSize = 4; +constexpr size_t kHDim = 2; +constexpr size_t kWDim = 3; +constexpr int64_t kInvalidIds = -1; +constexpr size_t kIdx0 = 0; +constexpr size_t kIdx1 = 1; +constexpr size_t kIdx2 = 2; +constexpr size_t kIdx3 = 3; +constexpr size_t kIdx4 = 4; +constexpr size_t kIdx5 = 5; +constexpr size_t kIdx6 = 6; +constexpr size_t kIdx7 = 7; std::vector CheckAttrSize(const PrimitivePtr &primitive, const std::string &attr_name, const size_t attr_size) { @@ -50,22 +63,49 @@ std::vector CheckAttrSize(const PrimitivePtr &primitive, const std::str MS_EXCEPTION(ValueError) << "Invalid " << primitive->name() << " attr " << attr_name << " size " << attr_value.size() << " must be equal to size " << attr_size; } + return attr_value; } void CheckRecvCorner(std::vector recv_rank_ids, int64_t idx1, int64_t idx2, int64_t idx_corner) { - if (recv_rank_ids[idx1] != -1 && recv_rank_ids[idx2] != -1 && recv_rank_ids[idx_corner] == -1) { + if (recv_rank_ids[idx1] != kInvalidIds && recv_rank_ids[idx2] != kInvalidIds && + recv_rank_ids[idx_corner] == kInvalidIds) { MS_EXCEPTION(ValueError) << "Invalid recv_rank_ids, as recv_rank_ids[" << idx1 << "] = " << recv_rank_ids[idx1] << ", recv_rank_ids[" << idx2 << "] = " << recv_rank_ids[idx2] << ", and recv_rank_ids[" << idx_corner << "] = " << recv_rank_ids[idx_corner] << "."; } - if ((recv_rank_ids[idx1] == -1 || recv_rank_ids[idx2] == -1) && recv_rank_ids[idx_corner] != -1) { + if ((recv_rank_ids[idx1] == kInvalidIds || recv_rank_ids[idx2] == kInvalidIds) && + recv_rank_ids[idx_corner] != kInvalidIds) { MS_EXCEPTION(ValueError) << "Invalid recv_rank_ids, as recv_rank_ids[" << idx1 << "] = " << recv_rank_ids[idx1] << ", recv_rank_ids[" << idx2 << "] = " << recv_rank_ids[idx2] << ", and recv_rank_ids[" << idx_corner << "] = " << recv_rank_ids[idx_corner] << "."; } } +void CheckIdsValue(std::vector rank_ids) { + // check repeat & invalid value + std::set ids_count; + for (auto id : rank_ids) { + if (id < 0 && id != kInvalidIds) { + MS_EXCEPTION(ValueError) << "Invalid send_rank_ids or recv_rank_ids: " << id + << ", all the rank id should be >= 0 or -1."; + } + if (ids_count.find(id) != ids_count.end() && id != -1) { + MS_EXCEPTION(ValueError) << "Invalid send_rank_ids or recv_rank_ids: " << id << ", it repeated."; + } + ids_count.insert(id); + } +} + +void CheckLensValue(std::vector lens) { + // check len <0 + for (auto len : lens) { + if (len < 0) { + MS_EXCEPTION(ValueError) << "Invalid send_lens or recv_lens: " << len << ", the lens should be >=0."; + } + } +} + void Check(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); @@ -73,23 +113,50 @@ void Check(const PrimitivePtr &primitive, const std::vector &in (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); // check size of send_rank_ids, recv_rank_ids, send_lens, recv_lens - (void)CheckAttrSize(primitive, kSendRankIds, kRankIdsSize); + auto send_rank_ids = CheckAttrSize(primitive, kSendRankIds, kRankIdsSize); auto recv_rank_ids = CheckAttrSize(primitive, kRecvRankIds, kRankIdsSize); - (void)CheckAttrSize(primitive, kSendLens, kLensSize); - (void)CheckAttrSize(primitive, kRecvLens, kLensSize); + auto send_lens = CheckAttrSize(primitive, kSendLens, kLensSize); + auto recv_lens = CheckAttrSize(primitive, kRecvLens, kLensSize); + + // check rank_ids value + CheckIdsValue(send_rank_ids); + CheckIdsValue(recv_rank_ids); + // check lens value + CheckLensValue(send_lens); + CheckLensValue(recv_lens); // check recv rankids invalid cond - CheckRecvCorner(recv_rank_ids, 0, 2, 1); - CheckRecvCorner(recv_rank_ids, 2, 4, 3); - CheckRecvCorner(recv_rank_ids, 4, 6, 5); - CheckRecvCorner(recv_rank_ids, 6, 0, 7); + CheckRecvCorner(recv_rank_ids, kIdx0, kIdx2, kIdx1); + CheckRecvCorner(recv_rank_ids, kIdx2, kIdx4, kIdx3); + CheckRecvCorner(recv_rank_ids, kIdx4, kIdx6, kIdx5); + CheckRecvCorner(recv_rank_ids, kIdx6, kIdx0, kIdx7); // check data_format is NCHW - auto format = GetValue(primitive->GetAttr(kDataFormat)); + auto format_attr = primitive->GetAttr(kDataFormat); + string format = ""; + try { + MS_EXCEPTION_IF_NULL(format_attr); + format = GetValue(format_attr); + } catch (const std::exception &) { + MS_EXCEPTION(TypeError) << "Attr " << kDataFormat << " should be a str."; + } if (format != "NCHW") { MS_EXCEPTION(ValueError) << "Attr data_format only support NCHW now."; } + // check if send_lens > input_lens + std::vector input_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + if (input_shape.size() != kInputSize) { + MS_EXCEPTION(ValueError) << "Input size is not 4, only support NCHW now."; + } + if (send_lens[kIdx0] + send_lens[kIdx1] > input_shape[kHDim]) { + MS_EXCEPTION(ValueError) << "send_lens in H dim is larger than input size."; + } + if (send_lens[kIdx2] + send_lens[kIdx3] > input_shape[kWDim]) { + MS_EXCEPTION(ValueError) << "send_lens in W dim is larger than input size."; + } + // check group auto group_attr = primitive->GetAttr(kGroup); try { @@ -115,17 +182,17 @@ abstract::BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vect std::vector input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - if (recv_rank_ids_v[0] != -1) { - input_shape[2] += recv_lens_v[0]; + if (recv_rank_ids_v[kIdx0] != kInvalidIds) { + input_shape[kIdx2] += recv_lens_v[kIdx0]; } - if (recv_rank_ids_v[4] != -1) { - input_shape[2] += recv_lens_v[1]; + if (recv_rank_ids_v[kIdx4] != kInvalidIds) { + input_shape[kIdx2] += recv_lens_v[kIdx1]; } - if (recv_rank_ids_v[6] != -1) { - input_shape[3] += recv_lens_v[2]; + if (recv_rank_ids_v[kIdx6] != kInvalidIds) { + input_shape[kIdx3] += recv_lens_v[kIdx2]; } - if (recv_rank_ids_v[2] != -1) { - input_shape[3] += recv_lens_v[3]; + if (recv_rank_ids_v[kIdx2] != kInvalidIds) { + input_shape[kIdx3] += recv_lens_v[kIdx3]; } BaseShapePtr output_shape = std::make_shared(input_shape); if (input_shape.empty()) { diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index e128330770b..73233a69431 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -713,9 +713,9 @@ class AlltoAll(PrimitiveWithInfer): class NeighborExchangeV2(Primitive): """ - NeighborExchange is a collective operation. + NeighborExchangeV2 is a collective operation. - NeighborExchange sends data from the local rank to ranks in the send_rank_ids, + NeighborExchangeV2 sends data from the local rank to ranks in the send_rank_ids, as while receive data from recv_rank_ids. Args: diff --git a/tests/ut/python/parallel/test_neighborexchangev2.py b/tests/ut/python/parallel/test_neighborexchangev2.py index 86c0bf8cf65..b560522dce4 100644 --- a/tests/ut/python/parallel/test_neighborexchangev2.py +++ b/tests/ut/python/parallel/test_neighborexchangev2.py @@ -25,7 +25,6 @@ from mindspore.ops.operations.comm_ops import NeighborExchangeV2 _x1 = Tensor(np.ones([1, 1, 32, 16]), dtype=ms.float32) _x2 = Tensor(np.ones([1, 1, 33, 16]), dtype=ms.float32) - def compile_net(net, x1, x2): context.set_context(mode=context.GRAPH_MODE) optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) @@ -60,6 +59,31 @@ def test_neighborexchangev2_single_input_success(): net = Net() compile_net(net, _x1, _x2) +def test_neighborexchangev2_send_lens_equal_to_input_shape_success(): + """ + Feature: NeighborExchangeV2 + Description: send_lens is equal to input shape + Expectation: success + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.linear = nn.Dense(16, 16) + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, 32, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=[0, 1, 0, 0], data_format="NCHW") + + def construct(self, x1, x2): + y = self.linear(x1) + y = self.neighborexchangev2(y) + y = y + x2 + return y + + net = Net() + compile_net(net, _x1, _x2) def test_neighborexchangev2_empty_send_success(): """ @@ -540,3 +564,178 @@ def test_neighborexchangev2_group_is_tuple_failed(): net = Net() with pytest.raises(TypeError): _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_send_lens_larger_than_input_shape_failed(): + """ + Feature: NeighborExchangeV2 + Description: send_lens should be <= input_shape, but a larger one given + Expectation: throw TypeError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, 35, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=[0, 1, 0, 0], + data_format="NCHW") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(ValueError): + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_send_rank_ids_value_invalid_failed(): + """ + Feature: NeighborExchangeV2 + Description: send_rank_ids should be >=0 or -1, but -3 is given + Expectation: throw TypeError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -3, 1, -1, -1, -1], + send_lens=[0, 1, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=[0, 1, 0, 0], + data_format="NCHW") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(ValueError): + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_recv_rank_ids_value_invalid_failed(): + """ + Feature: NeighborExchangeV2 + Description: recv_rank_ids should be >=0 or -1, but -3 is given + Expectation: throw TypeError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, 1, 0, 0], + recv_rank_ids=[-1, -1, -1, -3, 1, -1, -1, -1], + recv_lens=[0, 1, 0, 0], + data_format="NCHW") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(ValueError): + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_send_lens_value_invalid_failed(): + """ + Feature: NeighborExchangeV2 + Description: send_lens should be >=0, but -3 is given + Expectation: throw TypeError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, -3, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=[0, 1, 0, 0], + data_format="NCHW") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(ValueError): + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_recv_lens_value_invalid_failed(): + """ + Feature: NeighborExchangeV2 + Description: recv_lens should be >=0, but -3 is given + Expectation: throw TypeError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, 1, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=[0, -3, 0, 0], + data_format="NCHW") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(ValueError): + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_send_rank_ids_repeat_failed(): + """ + Feature: NeighborExchangeV2 + Description: send_rank_ids cannot be repeated, but two 1 is given + Expectation: throw TypeError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, 1, 0, 0], + recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=[0, 1, 0, 0], + data_format="NCHW") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(ValueError): + _cell_graph_executor.compile(net, _x1) + +def test_neighborexchangev2_recv_rank_ids_repeat_failed(): + """ + Feature: NeighborExchangeV2 + Description: recv_rank_ids cannot be repeated, but two 1 is given + Expectation: throw TypeError + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + send_lens=[0, 1, 0, 0], + recv_rank_ids=[1, -1, -1, -1, 1, -1, -1, -1], + recv_lens=[0, 1, 0, 0], + data_format="NCHW") + + def construct(self, x): + out = self.neighborexchangev2(x) + return out[0] + + net = Net() + with pytest.raises(ValueError): + _cell_graph_executor.compile(net, _x1)