!26279 neighborexchangev2 fix bug

Merge pull request !26279 from TuDouNi/neighborexchangev2_bug
This commit is contained in:
i-robot 2021-11-17 02:29:21 +00:00 committed by Gitee
commit 9f2e1edc00
4 changed files with 294 additions and 26 deletions

View File

@ -84,17 +84,19 @@ int64_t CalSplitAttrs(const std::vector<size_t> &base_shape, const bool is_first
}
if (is_last) {
// middle
++num_split;
split_middle_size -= last_size;
size_splits->push_back(split_middle_size);
shape_tmp[split_dim] = static_cast<size_t>(split_middle_size);
shapes->push_back(shape_tmp);
if (split_middle_size > 0) {
++num_split;
size_splits->push_back(split_middle_size);
shape_tmp[split_dim] = static_cast<size_t>(split_middle_size);
shapes->push_back(shape_tmp);
}
// last
++num_split;
size_splits->push_back(last_size);
shape_tmp[split_dim] = static_cast<size_t>(last_size);
shapes->push_back(shape_tmp);
} else {
} else if (split_middle_size > 0) {
++num_split;
size_splits->push_back(split_middle_size);
shape_tmp[split_dim] = static_cast<size_t>(split_middle_size);

View File

@ -15,6 +15,7 @@
*/
#include "ops/neighborexchangev2.h"
#include <set>
#include <string>
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
@ -30,6 +31,18 @@ constexpr auto kDataFormat = "format";
constexpr auto kGroup = "group";
constexpr size_t kRankIdsSize = 8;
constexpr size_t kLensSize = 4;
constexpr size_t kInputSize = 4;
constexpr size_t kHDim = 2;
constexpr size_t kWDim = 3;
constexpr int64_t kInvalidIds = -1;
constexpr size_t kIdx0 = 0;
constexpr size_t kIdx1 = 1;
constexpr size_t kIdx2 = 2;
constexpr size_t kIdx3 = 3;
constexpr size_t kIdx4 = 4;
constexpr size_t kIdx5 = 5;
constexpr size_t kIdx6 = 6;
constexpr size_t kIdx7 = 7;
std::vector<int64_t> CheckAttrSize(const PrimitivePtr &primitive, const std::string &attr_name,
const size_t attr_size) {
@ -50,22 +63,49 @@ std::vector<int64_t> CheckAttrSize(const PrimitivePtr &primitive, const std::str
MS_EXCEPTION(ValueError) << "Invalid " << primitive->name() << " attr " << attr_name << " size "
<< attr_value.size() << " must be equal to size " << attr_size;
}
return attr_value;
}
void CheckRecvCorner(std::vector<int64_t> recv_rank_ids, int64_t idx1, int64_t idx2, int64_t idx_corner) {
if (recv_rank_ids[idx1] != -1 && recv_rank_ids[idx2] != -1 && recv_rank_ids[idx_corner] == -1) {
if (recv_rank_ids[idx1] != kInvalidIds && recv_rank_ids[idx2] != kInvalidIds &&
recv_rank_ids[idx_corner] == kInvalidIds) {
MS_EXCEPTION(ValueError) << "Invalid recv_rank_ids, as recv_rank_ids[" << idx1 << "] = " << recv_rank_ids[idx1]
<< ", recv_rank_ids[" << idx2 << "] = " << recv_rank_ids[idx2] << ", and recv_rank_ids["
<< idx_corner << "] = " << recv_rank_ids[idx_corner] << ".";
}
if ((recv_rank_ids[idx1] == -1 || recv_rank_ids[idx2] == -1) && recv_rank_ids[idx_corner] != -1) {
if ((recv_rank_ids[idx1] == kInvalidIds || recv_rank_ids[idx2] == kInvalidIds) &&
recv_rank_ids[idx_corner] != kInvalidIds) {
MS_EXCEPTION(ValueError) << "Invalid recv_rank_ids, as recv_rank_ids[" << idx1 << "] = " << recv_rank_ids[idx1]
<< ", recv_rank_ids[" << idx2 << "] = " << recv_rank_ids[idx2] << ", and recv_rank_ids["
<< idx_corner << "] = " << recv_rank_ids[idx_corner] << ".";
}
}
void CheckIdsValue(std::vector<int64_t> rank_ids) {
// check repeat & invalid value
std::set<int64_t> ids_count;
for (auto id : rank_ids) {
if (id < 0 && id != kInvalidIds) {
MS_EXCEPTION(ValueError) << "Invalid send_rank_ids or recv_rank_ids: " << id
<< ", all the rank id should be >= 0 or -1.";
}
if (ids_count.find(id) != ids_count.end() && id != -1) {
MS_EXCEPTION(ValueError) << "Invalid send_rank_ids or recv_rank_ids: " << id << ", it repeated.";
}
ids_count.insert(id);
}
}
void CheckLensValue(std::vector<int64_t> lens) {
// check len <0
for (auto len : lens) {
if (len < 0) {
MS_EXCEPTION(ValueError) << "Invalid send_lens or recv_lens: " << len << ", the lens should be >=0.";
}
}
}
void Check(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
@ -73,23 +113,50 @@ void Check(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &in
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
// check size of send_rank_ids, recv_rank_ids, send_lens, recv_lens
(void)CheckAttrSize(primitive, kSendRankIds, kRankIdsSize);
auto send_rank_ids = CheckAttrSize(primitive, kSendRankIds, kRankIdsSize);
auto recv_rank_ids = CheckAttrSize(primitive, kRecvRankIds, kRankIdsSize);
(void)CheckAttrSize(primitive, kSendLens, kLensSize);
(void)CheckAttrSize(primitive, kRecvLens, kLensSize);
auto send_lens = CheckAttrSize(primitive, kSendLens, kLensSize);
auto recv_lens = CheckAttrSize(primitive, kRecvLens, kLensSize);
// check rank_ids value
CheckIdsValue(send_rank_ids);
CheckIdsValue(recv_rank_ids);
// check lens value
CheckLensValue(send_lens);
CheckLensValue(recv_lens);
// check recv rankids invalid cond
CheckRecvCorner(recv_rank_ids, 0, 2, 1);
CheckRecvCorner(recv_rank_ids, 2, 4, 3);
CheckRecvCorner(recv_rank_ids, 4, 6, 5);
CheckRecvCorner(recv_rank_ids, 6, 0, 7);
CheckRecvCorner(recv_rank_ids, kIdx0, kIdx2, kIdx1);
CheckRecvCorner(recv_rank_ids, kIdx2, kIdx4, kIdx3);
CheckRecvCorner(recv_rank_ids, kIdx4, kIdx6, kIdx5);
CheckRecvCorner(recv_rank_ids, kIdx6, kIdx0, kIdx7);
// check data_format is NCHW
auto format = GetValue<std::string>(primitive->GetAttr(kDataFormat));
auto format_attr = primitive->GetAttr(kDataFormat);
string format = "";
try {
MS_EXCEPTION_IF_NULL(format_attr);
format = GetValue<std::string>(format_attr);
} catch (const std::exception &) {
MS_EXCEPTION(TypeError) << "Attr " << kDataFormat << " should be a str.";
}
if (format != "NCHW") {
MS_EXCEPTION(ValueError) << "Attr data_format only support NCHW now.";
}
// check if send_lens > input_lens
std::vector<int64_t> input_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
if (input_shape.size() != kInputSize) {
MS_EXCEPTION(ValueError) << "Input size is not 4, only support NCHW now.";
}
if (send_lens[kIdx0] + send_lens[kIdx1] > input_shape[kHDim]) {
MS_EXCEPTION(ValueError) << "send_lens in H dim is larger than input size.";
}
if (send_lens[kIdx2] + send_lens[kIdx3] > input_shape[kWDim]) {
MS_EXCEPTION(ValueError) << "send_lens in W dim is larger than input size.";
}
// check group
auto group_attr = primitive->GetAttr(kGroup);
try {
@ -115,17 +182,17 @@ abstract::BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vect
std::vector<int64_t> input_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
if (recv_rank_ids_v[0] != -1) {
input_shape[2] += recv_lens_v[0];
if (recv_rank_ids_v[kIdx0] != kInvalidIds) {
input_shape[kIdx2] += recv_lens_v[kIdx0];
}
if (recv_rank_ids_v[4] != -1) {
input_shape[2] += recv_lens_v[1];
if (recv_rank_ids_v[kIdx4] != kInvalidIds) {
input_shape[kIdx2] += recv_lens_v[kIdx1];
}
if (recv_rank_ids_v[6] != -1) {
input_shape[3] += recv_lens_v[2];
if (recv_rank_ids_v[kIdx6] != kInvalidIds) {
input_shape[kIdx3] += recv_lens_v[kIdx2];
}
if (recv_rank_ids_v[2] != -1) {
input_shape[3] += recv_lens_v[3];
if (recv_rank_ids_v[kIdx2] != kInvalidIds) {
input_shape[kIdx3] += recv_lens_v[kIdx3];
}
BaseShapePtr output_shape = std::make_shared<abstract::Shape>(input_shape);
if (input_shape.empty()) {

View File

@ -713,9 +713,9 @@ class AlltoAll(PrimitiveWithInfer):
class NeighborExchangeV2(Primitive):
"""
NeighborExchange is a collective operation.
NeighborExchangeV2 is a collective operation.
NeighborExchange sends data from the local rank to ranks in the send_rank_ids,
NeighborExchangeV2 sends data from the local rank to ranks in the send_rank_ids,
as while receive data from recv_rank_ids.
Args:

View File

@ -25,7 +25,6 @@ from mindspore.ops.operations.comm_ops import NeighborExchangeV2
_x1 = Tensor(np.ones([1, 1, 32, 16]), dtype=ms.float32)
_x2 = Tensor(np.ones([1, 1, 33, 16]), dtype=ms.float32)
def compile_net(net, x1, x2):
context.set_context(mode=context.GRAPH_MODE)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
@ -60,6 +59,31 @@ def test_neighborexchangev2_single_input_success():
net = Net()
compile_net(net, _x1, _x2)
def test_neighborexchangev2_send_lens_equal_to_input_shape_success():
"""
Feature: NeighborExchangeV2
Description: send_lens is equal to input shape
Expectation: success
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.linear = nn.Dense(16, 16)
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, 32, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=[0, 1, 0, 0], data_format="NCHW")
def construct(self, x1, x2):
y = self.linear(x1)
y = self.neighborexchangev2(y)
y = y + x2
return y
net = Net()
compile_net(net, _x1, _x2)
def test_neighborexchangev2_empty_send_success():
"""
@ -540,3 +564,178 @@ def test_neighborexchangev2_group_is_tuple_failed():
net = Net()
with pytest.raises(TypeError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_send_lens_larger_than_input_shape_failed():
"""
Feature: NeighborExchangeV2
Description: send_lens should be <= input_shape, but a larger one given
Expectation: throw TypeError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, 35, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=[0, 1, 0, 0],
data_format="NCHW")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(ValueError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_send_rank_ids_value_invalid_failed():
"""
Feature: NeighborExchangeV2
Description: send_rank_ids should be >=0 or -1, but -3 is given
Expectation: throw TypeError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -3, 1, -1, -1, -1],
send_lens=[0, 1, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=[0, 1, 0, 0],
data_format="NCHW")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(ValueError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_recv_rank_ids_value_invalid_failed():
"""
Feature: NeighborExchangeV2
Description: recv_rank_ids should be >=0 or -1, but -3 is given
Expectation: throw TypeError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, 1, 0, 0],
recv_rank_ids=[-1, -1, -1, -3, 1, -1, -1, -1],
recv_lens=[0, 1, 0, 0],
data_format="NCHW")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(ValueError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_send_lens_value_invalid_failed():
"""
Feature: NeighborExchangeV2
Description: send_lens should be >=0, but -3 is given
Expectation: throw TypeError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, -3, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=[0, 1, 0, 0],
data_format="NCHW")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(ValueError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_recv_lens_value_invalid_failed():
"""
Feature: NeighborExchangeV2
Description: recv_lens should be >=0, but -3 is given
Expectation: throw TypeError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, 1, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=[0, -3, 0, 0],
data_format="NCHW")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(ValueError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_send_rank_ids_repeat_failed():
"""
Feature: NeighborExchangeV2
Description: send_rank_ids cannot be repeated, but two 1 is given
Expectation: throw TypeError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, 1, 0, 0],
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=[0, 1, 0, 0],
data_format="NCHW")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(ValueError):
_cell_graph_executor.compile(net, _x1)
def test_neighborexchangev2_recv_rank_ids_repeat_failed():
"""
Feature: NeighborExchangeV2
Description: recv_rank_ids cannot be repeated, but two 1 is given
Expectation: throw TypeError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
send_lens=[0, 1, 0, 0],
recv_rank_ids=[1, -1, -1, -1, 1, -1, -1, -1],
recv_lens=[0, 1, 0, 0],
data_format="NCHW")
def construct(self, x):
out = self.neighborexchangev2(x)
return out[0]
net = Net()
with pytest.raises(ValueError):
_cell_graph_executor.compile(net, _x1)