forked from mindspore-Ecosystem/mindspore
!26279 neighborexchangev2 fix bug
Merge pull request !26279 from TuDouNi/neighborexchangev2_bug
This commit is contained in:
commit
9f2e1edc00
|
@ -84,17 +84,19 @@ int64_t CalSplitAttrs(const std::vector<size_t> &base_shape, const bool is_first
|
|||
}
|
||||
if (is_last) {
|
||||
// middle
|
||||
++num_split;
|
||||
split_middle_size -= last_size;
|
||||
size_splits->push_back(split_middle_size);
|
||||
shape_tmp[split_dim] = static_cast<size_t>(split_middle_size);
|
||||
shapes->push_back(shape_tmp);
|
||||
if (split_middle_size > 0) {
|
||||
++num_split;
|
||||
size_splits->push_back(split_middle_size);
|
||||
shape_tmp[split_dim] = static_cast<size_t>(split_middle_size);
|
||||
shapes->push_back(shape_tmp);
|
||||
}
|
||||
// last
|
||||
++num_split;
|
||||
size_splits->push_back(last_size);
|
||||
shape_tmp[split_dim] = static_cast<size_t>(last_size);
|
||||
shapes->push_back(shape_tmp);
|
||||
} else {
|
||||
} else if (split_middle_size > 0) {
|
||||
++num_split;
|
||||
size_splits->push_back(split_middle_size);
|
||||
shape_tmp[split_dim] = static_cast<size_t>(split_middle_size);
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "ops/neighborexchangev2.h"
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
@ -30,6 +31,18 @@ constexpr auto kDataFormat = "format";
|
|||
constexpr auto kGroup = "group";
|
||||
constexpr size_t kRankIdsSize = 8;
|
||||
constexpr size_t kLensSize = 4;
|
||||
constexpr size_t kInputSize = 4;
|
||||
constexpr size_t kHDim = 2;
|
||||
constexpr size_t kWDim = 3;
|
||||
constexpr int64_t kInvalidIds = -1;
|
||||
constexpr size_t kIdx0 = 0;
|
||||
constexpr size_t kIdx1 = 1;
|
||||
constexpr size_t kIdx2 = 2;
|
||||
constexpr size_t kIdx3 = 3;
|
||||
constexpr size_t kIdx4 = 4;
|
||||
constexpr size_t kIdx5 = 5;
|
||||
constexpr size_t kIdx6 = 6;
|
||||
constexpr size_t kIdx7 = 7;
|
||||
|
||||
std::vector<int64_t> CheckAttrSize(const PrimitivePtr &primitive, const std::string &attr_name,
|
||||
const size_t attr_size) {
|
||||
|
@ -50,22 +63,49 @@ std::vector<int64_t> CheckAttrSize(const PrimitivePtr &primitive, const std::str
|
|||
MS_EXCEPTION(ValueError) << "Invalid " << primitive->name() << " attr " << attr_name << " size "
|
||||
<< attr_value.size() << " must be equal to size " << attr_size;
|
||||
}
|
||||
|
||||
return attr_value;
|
||||
}
|
||||
|
||||
void CheckRecvCorner(std::vector<int64_t> recv_rank_ids, int64_t idx1, int64_t idx2, int64_t idx_corner) {
|
||||
if (recv_rank_ids[idx1] != -1 && recv_rank_ids[idx2] != -1 && recv_rank_ids[idx_corner] == -1) {
|
||||
if (recv_rank_ids[idx1] != kInvalidIds && recv_rank_ids[idx2] != kInvalidIds &&
|
||||
recv_rank_ids[idx_corner] == kInvalidIds) {
|
||||
MS_EXCEPTION(ValueError) << "Invalid recv_rank_ids, as recv_rank_ids[" << idx1 << "] = " << recv_rank_ids[idx1]
|
||||
<< ", recv_rank_ids[" << idx2 << "] = " << recv_rank_ids[idx2] << ", and recv_rank_ids["
|
||||
<< idx_corner << "] = " << recv_rank_ids[idx_corner] << ".";
|
||||
}
|
||||
if ((recv_rank_ids[idx1] == -1 || recv_rank_ids[idx2] == -1) && recv_rank_ids[idx_corner] != -1) {
|
||||
if ((recv_rank_ids[idx1] == kInvalidIds || recv_rank_ids[idx2] == kInvalidIds) &&
|
||||
recv_rank_ids[idx_corner] != kInvalidIds) {
|
||||
MS_EXCEPTION(ValueError) << "Invalid recv_rank_ids, as recv_rank_ids[" << idx1 << "] = " << recv_rank_ids[idx1]
|
||||
<< ", recv_rank_ids[" << idx2 << "] = " << recv_rank_ids[idx2] << ", and recv_rank_ids["
|
||||
<< idx_corner << "] = " << recv_rank_ids[idx_corner] << ".";
|
||||
}
|
||||
}
|
||||
|
||||
void CheckIdsValue(std::vector<int64_t> rank_ids) {
|
||||
// check repeat & invalid value
|
||||
std::set<int64_t> ids_count;
|
||||
for (auto id : rank_ids) {
|
||||
if (id < 0 && id != kInvalidIds) {
|
||||
MS_EXCEPTION(ValueError) << "Invalid send_rank_ids or recv_rank_ids: " << id
|
||||
<< ", all the rank id should be >= 0 or -1.";
|
||||
}
|
||||
if (ids_count.find(id) != ids_count.end() && id != -1) {
|
||||
MS_EXCEPTION(ValueError) << "Invalid send_rank_ids or recv_rank_ids: " << id << ", it repeated.";
|
||||
}
|
||||
ids_count.insert(id);
|
||||
}
|
||||
}
|
||||
|
||||
void CheckLensValue(std::vector<int64_t> lens) {
|
||||
// check len <0
|
||||
for (auto len : lens) {
|
||||
if (len < 0) {
|
||||
MS_EXCEPTION(ValueError) << "Invalid send_lens or recv_lens: " << len << ", the lens should be >=0.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Check(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
|
@ -73,23 +113,50 @@ void Check(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &in
|
|||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
|
||||
// check size of send_rank_ids, recv_rank_ids, send_lens, recv_lens
|
||||
(void)CheckAttrSize(primitive, kSendRankIds, kRankIdsSize);
|
||||
auto send_rank_ids = CheckAttrSize(primitive, kSendRankIds, kRankIdsSize);
|
||||
auto recv_rank_ids = CheckAttrSize(primitive, kRecvRankIds, kRankIdsSize);
|
||||
(void)CheckAttrSize(primitive, kSendLens, kLensSize);
|
||||
(void)CheckAttrSize(primitive, kRecvLens, kLensSize);
|
||||
auto send_lens = CheckAttrSize(primitive, kSendLens, kLensSize);
|
||||
auto recv_lens = CheckAttrSize(primitive, kRecvLens, kLensSize);
|
||||
|
||||
// check rank_ids value
|
||||
CheckIdsValue(send_rank_ids);
|
||||
CheckIdsValue(recv_rank_ids);
|
||||
// check lens value
|
||||
CheckLensValue(send_lens);
|
||||
CheckLensValue(recv_lens);
|
||||
|
||||
// check recv rankids invalid cond
|
||||
CheckRecvCorner(recv_rank_ids, 0, 2, 1);
|
||||
CheckRecvCorner(recv_rank_ids, 2, 4, 3);
|
||||
CheckRecvCorner(recv_rank_ids, 4, 6, 5);
|
||||
CheckRecvCorner(recv_rank_ids, 6, 0, 7);
|
||||
CheckRecvCorner(recv_rank_ids, kIdx0, kIdx2, kIdx1);
|
||||
CheckRecvCorner(recv_rank_ids, kIdx2, kIdx4, kIdx3);
|
||||
CheckRecvCorner(recv_rank_ids, kIdx4, kIdx6, kIdx5);
|
||||
CheckRecvCorner(recv_rank_ids, kIdx6, kIdx0, kIdx7);
|
||||
|
||||
// check data_format is NCHW
|
||||
auto format = GetValue<std::string>(primitive->GetAttr(kDataFormat));
|
||||
auto format_attr = primitive->GetAttr(kDataFormat);
|
||||
string format = "";
|
||||
try {
|
||||
MS_EXCEPTION_IF_NULL(format_attr);
|
||||
format = GetValue<std::string>(format_attr);
|
||||
} catch (const std::exception &) {
|
||||
MS_EXCEPTION(TypeError) << "Attr " << kDataFormat << " should be a str.";
|
||||
}
|
||||
if (format != "NCHW") {
|
||||
MS_EXCEPTION(ValueError) << "Attr data_format only support NCHW now.";
|
||||
}
|
||||
|
||||
// check if send_lens > input_lens
|
||||
std::vector<int64_t> input_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
if (input_shape.size() != kInputSize) {
|
||||
MS_EXCEPTION(ValueError) << "Input size is not 4, only support NCHW now.";
|
||||
}
|
||||
if (send_lens[kIdx0] + send_lens[kIdx1] > input_shape[kHDim]) {
|
||||
MS_EXCEPTION(ValueError) << "send_lens in H dim is larger than input size.";
|
||||
}
|
||||
if (send_lens[kIdx2] + send_lens[kIdx3] > input_shape[kWDim]) {
|
||||
MS_EXCEPTION(ValueError) << "send_lens in W dim is larger than input size.";
|
||||
}
|
||||
|
||||
// check group
|
||||
auto group_attr = primitive->GetAttr(kGroup);
|
||||
try {
|
||||
|
@ -115,17 +182,17 @@ abstract::BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vect
|
|||
|
||||
std::vector<int64_t> input_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
if (recv_rank_ids_v[0] != -1) {
|
||||
input_shape[2] += recv_lens_v[0];
|
||||
if (recv_rank_ids_v[kIdx0] != kInvalidIds) {
|
||||
input_shape[kIdx2] += recv_lens_v[kIdx0];
|
||||
}
|
||||
if (recv_rank_ids_v[4] != -1) {
|
||||
input_shape[2] += recv_lens_v[1];
|
||||
if (recv_rank_ids_v[kIdx4] != kInvalidIds) {
|
||||
input_shape[kIdx2] += recv_lens_v[kIdx1];
|
||||
}
|
||||
if (recv_rank_ids_v[6] != -1) {
|
||||
input_shape[3] += recv_lens_v[2];
|
||||
if (recv_rank_ids_v[kIdx6] != kInvalidIds) {
|
||||
input_shape[kIdx3] += recv_lens_v[kIdx2];
|
||||
}
|
||||
if (recv_rank_ids_v[2] != -1) {
|
||||
input_shape[3] += recv_lens_v[3];
|
||||
if (recv_rank_ids_v[kIdx2] != kInvalidIds) {
|
||||
input_shape[kIdx3] += recv_lens_v[kIdx3];
|
||||
}
|
||||
BaseShapePtr output_shape = std::make_shared<abstract::Shape>(input_shape);
|
||||
if (input_shape.empty()) {
|
||||
|
|
|
@ -713,9 +713,9 @@ class AlltoAll(PrimitiveWithInfer):
|
|||
|
||||
class NeighborExchangeV2(Primitive):
|
||||
"""
|
||||
NeighborExchange is a collective operation.
|
||||
NeighborExchangeV2 is a collective operation.
|
||||
|
||||
NeighborExchange sends data from the local rank to ranks in the send_rank_ids,
|
||||
NeighborExchangeV2 sends data from the local rank to ranks in the send_rank_ids,
|
||||
as while receive data from recv_rank_ids.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -25,7 +25,6 @@ from mindspore.ops.operations.comm_ops import NeighborExchangeV2
|
|||
_x1 = Tensor(np.ones([1, 1, 32, 16]), dtype=ms.float32)
|
||||
_x2 = Tensor(np.ones([1, 1, 33, 16]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net, x1, x2):
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
|
@ -60,6 +59,31 @@ def test_neighborexchangev2_single_input_success():
|
|||
net = Net()
|
||||
compile_net(net, _x1, _x2)
|
||||
|
||||
def test_neighborexchangev2_send_lens_equal_to_input_shape_success():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: send_lens is equal to input shape
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.linear = nn.Dense(16, 16)
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, 32, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0], data_format="NCHW")
|
||||
|
||||
def construct(self, x1, x2):
|
||||
y = self.linear(x1)
|
||||
y = self.neighborexchangev2(y)
|
||||
y = y + x2
|
||||
return y
|
||||
|
||||
net = Net()
|
||||
compile_net(net, _x1, _x2)
|
||||
|
||||
def test_neighborexchangev2_empty_send_success():
|
||||
"""
|
||||
|
@ -540,3 +564,178 @@ def test_neighborexchangev2_group_is_tuple_failed():
|
|||
net = Net()
|
||||
with pytest.raises(TypeError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_send_lens_larger_than_input_shape_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: send_lens should be <= input_shape, but a larger one given
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, 35, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_send_rank_ids_value_invalid_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: send_rank_ids should be >=0 or -1, but -3 is given
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -3, 1, -1, -1, -1],
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_recv_rank_ids_value_invalid_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: recv_rank_ids should be >=0 or -1, but -3 is given
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -3, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_send_lens_value_invalid_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: send_lens should be >=0, but -3 is given
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, -3, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_recv_lens_value_invalid_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: recv_lens should be >=0, but -3 is given
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=[0, -3, 0, 0],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_send_rank_ids_repeat_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: send_rank_ids cannot be repeated, but two 1 is given
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
def test_neighborexchangev2_recv_rank_ids_repeat_failed():
|
||||
"""
|
||||
Feature: NeighborExchangeV2
|
||||
Description: recv_rank_ids cannot be repeated, but two 1 is given
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
send_lens=[0, 1, 0, 0],
|
||||
recv_rank_ids=[1, -1, -1, -1, 1, -1, -1, -1],
|
||||
recv_lens=[0, 1, 0, 0],
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.neighborexchangev2(x)
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
|
Loading…
Reference in New Issue