!23671 check neighbor attr type

Merge pull request !23671 from zhoufeng/xiu-ba-ge-3
This commit is contained in:
i-robot 2021-09-18 01:24:22 +00:00 committed by Gitee
commit 9a257bd487
3 changed files with 99 additions and 7 deletions

View File

@ -31,6 +31,20 @@
namespace mindspore {
namespace parallel {
namespace {
ValuePtr MakeListValue(const std::vector<int64_t> &v) {
std::vector<ValuePtr> list;
(void)std::transform(v.begin(), v.end(), std::back_inserter(list), [](int64_t ele) { return MakeValue(ele); });
return std::make_shared<ValueSequeue>(list);
}
ValuePtr MakeTupleListValue(const Shapes &v) {
std::vector<ValuePtr> tuple;
(void)std::transform(v.begin(), v.end(), std::back_inserter(tuple),
[](const std::vector<int64_t> &list) { return MakeListValue(list); });
return std::make_shared<ValueTuple>(tuple);
}
} // namespace
Status Conv2DInfo::GetAttrsBase() {
// out_channel
out_channel_ = GetIntAttr(OUT_CHANNEL);
@ -639,10 +653,10 @@ OperatorAttrs Conv2DInfo::CreateNeighborExchangeAttrs(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(tensor_type);
auto dtype = tensor_type->element();
MS_EXCEPTION_IF_NULL(dtype);
Attr send_ranks = {SEND_RNAK_IDS, MakeValue(send_rank_ids_)};
Attr recv_ranks = {RECV_RNAK_IDS, MakeValue(recv_rank_ids_)};
Attr send_shapes = {SEND_SHAPES, MakeValue(send_shapes_)};
Attr recv_shapes = {RECV_SHAPES, MakeValue(recv_shapes_)};
Attr send_ranks = {SEND_RNAK_IDS, MakeListValue(send_rank_ids_)};
Attr recv_ranks = {RECV_RNAK_IDS, MakeListValue(recv_rank_ids_)};
Attr send_shapes = {SEND_SHAPES, MakeTupleListValue(send_shapes_)};
Attr recv_shapes = {RECV_SHAPES, MakeTupleListValue(recv_shapes_)};
Attr recv_type = {RECV_TYPE, dtype};
OperatorAttrs attrs = {send_ranks, recv_ranks, recv_shapes, send_shapes, recv_type};
return attrs;

View File

@ -48,22 +48,28 @@ void CheckAttr(const PrimitivePtr &primitive, const std::string &shape_attr_name
ValuePtrList attr_shapes;
try {
auto attr = primitive->GetAttr(shape_attr_name);
if (attr->cast<ValueTuplePtr>() == nullptr) {
MS_EXCEPTION(TypeError);
}
attr_shapes = GetValue<ValuePtrList>(attr);
} catch (const std::exception &) {
MS_EXCEPTION(TypeError) << "Attr " << shape_attr_name << " should be a tuple(list, list, ...).";
MS_EXCEPTION(TypeError) << "Attr " << shape_attr_name << " must be a tuple(list, list, ...).";
}
if (!attr_shapes.empty()) {
auto ele = attr_shapes[0]->cast<ValueSequeuePtr>();
if (ele == nullptr) {
MS_EXCEPTION(TypeError) << "Attr " << shape_attr_name << " must be a tuple.";
MS_EXCEPTION(TypeError) << "Attr " << shape_attr_name << " must be a tuple(list, list, ...).";
}
}
std::vector<int64_t> attr_rank_ids;
try {
auto attr = primitive->GetAttr(rank_ids_attr_name);
if (attr->cast<ValueTuplePtr>() != nullptr) {
MS_EXCEPTION(TypeError);
}
attr_rank_ids = GetValue<std::vector<int64_t>>(attr);
} catch (const std::exception &) {
MS_EXCEPTION(TypeError) << "Attr " << rank_ids_attr_name << " should be a list[int, int, ...].";
MS_EXCEPTION(TypeError) << "Attr " << rank_ids_attr_name << " must be a list[int, int, ...].";
}
if (attr_shapes.size() != attr_rank_ids.size()) {
MS_EXCEPTION(ValueError) << "Invalid " << primitive->name() << " attr " << shape_attr_name << " size "

View File

@ -294,6 +294,30 @@ def test_NeighborExchange_attr_check_send_rank_ids_is_tuple_failed():
_cell_graph_executor.compile(net, _x1)
def test_NeighborExchange_attr_check_send_rank_ids_is_tuple_2_failed():
"""
Feature: NeighborExchange
Description: send_rank_ids should be list, but a tuple is given
Expectation: throw TypeError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.alltoallv = NeighborExchange(send_rank_ids=(0,), recv_rank_ids=[1, 2],
recv_shapes=([32, 32], [32, 64]),
send_shapes=([32, 16],), recv_type=ms.float32)
def construct(self, x1):
out = self.alltoallv((x1,))
return out[0]
net = Net()
with pytest.raises(TypeError):
_cell_graph_executor.compile(net, _x1)
def test_NeighborExchange_attr_check_send_rank_ids_is_float_failed():
"""
Feature: NeighborExchange
@ -342,6 +366,30 @@ def test_NeighborExchange_attr_check_recv_rank_ids_is_tuple_failed():
_cell_graph_executor.compile(net, _x1)
def test_NeighborExchange_attr_check_recv_rank_ids_is_tuple_2_failed():
"""
Feature: NeighborExchange
Description: recv_rank_ids should be list, but a tuple is given
Expectation: throw TypeError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=(1, 2,),
recv_shapes=([32, 32], [32, 64]),
send_shapes=([32, 16],), recv_type=ms.float32)
def construct(self, x1):
out = self.alltoallv((x1,))
return out[0]
net = Net()
with pytest.raises(TypeError):
_cell_graph_executor.compile(net, _x1)
def test_NeighborExchange_attr_check_recv_rank_ids_is_float_failed():
"""
Feature: NeighborExchange
@ -390,6 +438,30 @@ def test_NeighborExchange_attr_check_send_shape_not_tuple_failed():
_cell_graph_executor.compile(net, _x1)
def test_NeighborExchange_attr_check_send_shape_list_failed():
"""
Feature: NeighborExchange
Description: send_shapes should be tuple(list), but a list(list) is given
Expectation: throw TypeError
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.alltoallv = NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1, 2],
recv_shapes=([32, 32], [32, 64]),
send_shapes=[[32, 16]], recv_type=ms.float32)
def construct(self, x1):
out = self.alltoallv((x1,))
return out[0]
net = Net()
with pytest.raises(TypeError):
_cell_graph_executor.compile(net, _x1)
def test_NeighborExchange_attr_check_recv_type_numpy_failed():
"""
Feature: NeighborExchange