forked from mindspore-Ecosystem/mindspore
!23671 check neighbor attr type
Merge pull request !23671 from zhoufeng/xiu-ba-ge-3
This commit is contained in:
commit
9a257bd487
|
@ -31,6 +31,20 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
namespace {
|
||||
ValuePtr MakeListValue(const std::vector<int64_t> &v) {
|
||||
std::vector<ValuePtr> list;
|
||||
(void)std::transform(v.begin(), v.end(), std::back_inserter(list), [](int64_t ele) { return MakeValue(ele); });
|
||||
return std::make_shared<ValueSequeue>(list);
|
||||
}
|
||||
|
||||
ValuePtr MakeTupleListValue(const Shapes &v) {
|
||||
std::vector<ValuePtr> tuple;
|
||||
(void)std::transform(v.begin(), v.end(), std::back_inserter(tuple),
|
||||
[](const std::vector<int64_t> &list) { return MakeListValue(list); });
|
||||
return std::make_shared<ValueTuple>(tuple);
|
||||
}
|
||||
} // namespace
|
||||
Status Conv2DInfo::GetAttrsBase() {
|
||||
// out_channel
|
||||
out_channel_ = GetIntAttr(OUT_CHANNEL);
|
||||
|
@ -639,10 +653,10 @@ OperatorAttrs Conv2DInfo::CreateNeighborExchangeAttrs(const CNodePtr &cnode) {
|
|||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto dtype = tensor_type->element();
|
||||
MS_EXCEPTION_IF_NULL(dtype);
|
||||
Attr send_ranks = {SEND_RNAK_IDS, MakeValue(send_rank_ids_)};
|
||||
Attr recv_ranks = {RECV_RNAK_IDS, MakeValue(recv_rank_ids_)};
|
||||
Attr send_shapes = {SEND_SHAPES, MakeValue(send_shapes_)};
|
||||
Attr recv_shapes = {RECV_SHAPES, MakeValue(recv_shapes_)};
|
||||
Attr send_ranks = {SEND_RNAK_IDS, MakeListValue(send_rank_ids_)};
|
||||
Attr recv_ranks = {RECV_RNAK_IDS, MakeListValue(recv_rank_ids_)};
|
||||
Attr send_shapes = {SEND_SHAPES, MakeTupleListValue(send_shapes_)};
|
||||
Attr recv_shapes = {RECV_SHAPES, MakeTupleListValue(recv_shapes_)};
|
||||
Attr recv_type = {RECV_TYPE, dtype};
|
||||
OperatorAttrs attrs = {send_ranks, recv_ranks, recv_shapes, send_shapes, recv_type};
|
||||
return attrs;
|
||||
|
|
|
@ -48,22 +48,28 @@ void CheckAttr(const PrimitivePtr &primitive, const std::string &shape_attr_name
|
|||
ValuePtrList attr_shapes;
|
||||
try {
|
||||
auto attr = primitive->GetAttr(shape_attr_name);
|
||||
if (attr->cast<ValueTuplePtr>() == nullptr) {
|
||||
MS_EXCEPTION(TypeError);
|
||||
}
|
||||
attr_shapes = GetValue<ValuePtrList>(attr);
|
||||
} catch (const std::exception &) {
|
||||
MS_EXCEPTION(TypeError) << "Attr " << shape_attr_name << " should be a tuple(list, list, ...).";
|
||||
MS_EXCEPTION(TypeError) << "Attr " << shape_attr_name << " must be a tuple(list, list, ...).";
|
||||
}
|
||||
if (!attr_shapes.empty()) {
|
||||
auto ele = attr_shapes[0]->cast<ValueSequeuePtr>();
|
||||
if (ele == nullptr) {
|
||||
MS_EXCEPTION(TypeError) << "Attr " << shape_attr_name << " must be a tuple.";
|
||||
MS_EXCEPTION(TypeError) << "Attr " << shape_attr_name << " must be a tuple(list, list, ...).";
|
||||
}
|
||||
}
|
||||
std::vector<int64_t> attr_rank_ids;
|
||||
try {
|
||||
auto attr = primitive->GetAttr(rank_ids_attr_name);
|
||||
if (attr->cast<ValueTuplePtr>() != nullptr) {
|
||||
MS_EXCEPTION(TypeError);
|
||||
}
|
||||
attr_rank_ids = GetValue<std::vector<int64_t>>(attr);
|
||||
} catch (const std::exception &) {
|
||||
MS_EXCEPTION(TypeError) << "Attr " << rank_ids_attr_name << " should be a list[int, int, ...].";
|
||||
MS_EXCEPTION(TypeError) << "Attr " << rank_ids_attr_name << " must be a list[int, int, ...].";
|
||||
}
|
||||
if (attr_shapes.size() != attr_rank_ids.size()) {
|
||||
MS_EXCEPTION(ValueError) << "Invalid " << primitive->name() << " attr " << shape_attr_name << " size "
|
||||
|
|
|
@ -294,6 +294,30 @@ def test_NeighborExchange_attr_check_send_rank_ids_is_tuple_failed():
|
|||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
|
||||
def test_NeighborExchange_attr_check_send_rank_ids_is_tuple_2_failed():
|
||||
"""
|
||||
Feature: NeighborExchange
|
||||
Description: send_rank_ids should be list, but a tuple is given
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.alltoallv = NeighborExchange(send_rank_ids=(0,), recv_rank_ids=[1, 2],
|
||||
recv_shapes=([32, 32], [32, 64]),
|
||||
send_shapes=([32, 16],), recv_type=ms.float32)
|
||||
|
||||
def construct(self, x1):
|
||||
out = self.alltoallv((x1,))
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
|
||||
def test_NeighborExchange_attr_check_send_rank_ids_is_float_failed():
|
||||
"""
|
||||
Feature: NeighborExchange
|
||||
|
@ -342,6 +366,30 @@ def test_NeighborExchange_attr_check_recv_rank_ids_is_tuple_failed():
|
|||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
|
||||
def test_NeighborExchange_attr_check_recv_rank_ids_is_tuple_2_failed():
|
||||
"""
|
||||
Feature: NeighborExchange
|
||||
Description: recv_rank_ids should be list, but a tuple is given
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=(1, 2,),
|
||||
recv_shapes=([32, 32], [32, 64]),
|
||||
send_shapes=([32, 16],), recv_type=ms.float32)
|
||||
|
||||
def construct(self, x1):
|
||||
out = self.alltoallv((x1,))
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
|
||||
def test_NeighborExchange_attr_check_recv_rank_ids_is_float_failed():
|
||||
"""
|
||||
Feature: NeighborExchange
|
||||
|
@ -390,6 +438,30 @@ def test_NeighborExchange_attr_check_send_shape_not_tuple_failed():
|
|||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
|
||||
def test_NeighborExchange_attr_check_send_shape_list_failed():
|
||||
"""
|
||||
Feature: NeighborExchange
|
||||
Description: send_shapes should be tuple(list), but a list(list) is given
|
||||
Expectation: throw TypeError
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.alltoallv = NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1, 2],
|
||||
recv_shapes=([32, 32], [32, 64]),
|
||||
send_shapes=[[32, 16]], recv_type=ms.float32)
|
||||
|
||||
def construct(self, x1):
|
||||
out = self.alltoallv((x1,))
|
||||
return out[0]
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError):
|
||||
_cell_graph_executor.compile(net, _x1)
|
||||
|
||||
|
||||
def test_NeighborExchange_attr_check_recv_type_numpy_failed():
|
||||
"""
|
||||
Feature: NeighborExchange
|
||||
|
|
Loading…
Reference in New Issue