forked from mindspore-Ecosystem/mindspore
Convert AbstractList to AbstractTuple recursively
This commit is contained in:
parent
7284534fba
commit
51c147ab12
|
@ -44,9 +44,11 @@ using mindspore::abstract::AbstractClass;
|
|||
using mindspore::abstract::AbstractCOOTensor;
|
||||
using mindspore::abstract::AbstractDictionary;
|
||||
using mindspore::abstract::AbstractList;
|
||||
using mindspore::abstract::AbstractListPtr;
|
||||
using mindspore::abstract::AbstractRowTensor;
|
||||
using mindspore::abstract::AbstractScalar;
|
||||
using mindspore::abstract::AbstractTuple;
|
||||
using mindspore::abstract::AbstractTuplePtr;
|
||||
|
||||
namespace {
|
||||
void CheckInputsSize(const CNodePtr &cnode, size_t expect_size) {
|
||||
|
@ -592,11 +594,12 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
|||
return (this->*(iter->second))(cnode);
|
||||
}
|
||||
|
||||
static constexpr size_t kMaxListRecursiveDepth = 5;
|
||||
|
||||
// ValueList --> ValueTuple
|
||||
static ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int64_t depth) {
|
||||
constexpr int64_t max_depth = 5;
|
||||
if (depth > max_depth) {
|
||||
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << max_depth << " levels.";
|
||||
static ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, size_t depth) {
|
||||
if (depth > kMaxListRecursiveDepth) {
|
||||
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxListRecursiveDepth << " levels.";
|
||||
}
|
||||
const auto &list_elements = value_list->value();
|
||||
std::vector<ValuePtr> elements;
|
||||
|
@ -619,11 +622,29 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// AbstractList --> AbstractTuple
|
||||
static AbstractTuplePtr ConvertAbstractListToAbstractTuple(const AbstractListPtr &abs_list, size_t depth) {
|
||||
if (depth > kMaxListRecursiveDepth) {
|
||||
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxListRecursiveDepth << " levels.";
|
||||
}
|
||||
const auto &list_elements = abs_list->elements();
|
||||
std::vector<AbstractBasePtr> elements;
|
||||
elements.reserve(list_elements.size());
|
||||
for (const auto &element : list_elements) {
|
||||
if (element->isa<AbstractList>()) {
|
||||
(void)elements.emplace_back(ConvertAbstractListToAbstractTuple(element->cast<AbstractListPtr>(), depth + 1));
|
||||
} else {
|
||||
(void)elements.emplace_back(element);
|
||||
}
|
||||
}
|
||||
return std::make_shared<AbstractTuple>(std::move(elements));
|
||||
}
|
||||
|
||||
AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override {
|
||||
// AbstractList --> AbstractTuple.
|
||||
auto abs_list = abs->cast<abstract::AbstractListPtr>();
|
||||
auto abs_list = abs->cast<AbstractListPtr>();
|
||||
if (abs_list != nullptr) {
|
||||
return std::make_shared<AbstractTuple>(abs_list->elements());
|
||||
return ConvertAbstractListToAbstractTuple(abs_list, 0);
|
||||
}
|
||||
// AbstractCOOTensor --> AbstractTuple.
|
||||
auto abs_sparse = abs->cast<abstract::AbstractCOOTensorPtr>();
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -27,9 +27,6 @@ context.set_context(mode=context.GRAPH_MODE)
|
|||
|
||||
def test_list_index_1D():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self):
|
||||
list_ = [[1], [2, 2], [3, 3, 3]]
|
||||
list_[0] = [100]
|
||||
|
@ -37,16 +34,13 @@ def test_list_index_1D():
|
|||
|
||||
net = Net()
|
||||
out = net()
|
||||
assert out[0] == [100]
|
||||
assert out[1] == [2, 2]
|
||||
assert out[2] == [3, 3, 3]
|
||||
assert list(out[0]) == [100]
|
||||
assert list(out[1]) == [2, 2]
|
||||
assert list(out[2]) == [3, 3, 3]
|
||||
|
||||
|
||||
def test_list_neg_index_1D():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self):
|
||||
list_ = [[1], [2, 2], [3, 3, 3]]
|
||||
list_[-3] = [100]
|
||||
|
@ -54,16 +48,13 @@ def test_list_neg_index_1D():
|
|||
|
||||
net = Net()
|
||||
out = net()
|
||||
assert out[0] == [100]
|
||||
assert out[1] == [2, 2]
|
||||
assert out[2] == [3, 3, 3]
|
||||
assert list(out[0]) == [100]
|
||||
assert list(out[1]) == [2, 2]
|
||||
assert list(out[2]) == [3, 3, 3]
|
||||
|
||||
|
||||
def test_list_index_2D():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self):
|
||||
list_ = [[1], [2, 2], [3, 3, 3]]
|
||||
list_[1][0] = 200
|
||||
|
@ -72,16 +63,13 @@ def test_list_index_2D():
|
|||
|
||||
net = Net()
|
||||
out = net()
|
||||
assert out[0] == [1]
|
||||
assert out[1] == [200, 201]
|
||||
assert out[2] == [3, 3, 3]
|
||||
assert list(out[0]) == [1]
|
||||
assert list(out[1]) == [200, 201]
|
||||
assert list(out[2]) == [3, 3, 3]
|
||||
|
||||
|
||||
def test_list_neg_index_2D():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self):
|
||||
list_ = [[1], [2, 2], [3, 3, 3]]
|
||||
list_[1][-2] = 200
|
||||
|
@ -90,16 +78,13 @@ def test_list_neg_index_2D():
|
|||
|
||||
net = Net()
|
||||
out = net()
|
||||
assert out[0] == [1]
|
||||
assert out[1] == [200, 201]
|
||||
assert out[2] == [3, 3, 3]
|
||||
assert list(out[0]) == [1]
|
||||
assert list(out[1]) == [200, 201]
|
||||
assert list(out[2]) == [3, 3, 3]
|
||||
|
||||
|
||||
def test_list_index_3D():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self):
|
||||
list_ = [[1], [2, 2], [[3, 3, 3]]]
|
||||
list_[2][0][0] = 300
|
||||
|
@ -109,16 +94,13 @@ def test_list_index_3D():
|
|||
|
||||
net = Net()
|
||||
out = net()
|
||||
assert out[0] == [1]
|
||||
assert out[1] == [2, 2]
|
||||
assert out[2] == [[300, 301, 302]]
|
||||
assert list(out[0]) == [1]
|
||||
assert list(out[1]) == [2, 2]
|
||||
assert list(out[2][0]) == [300, 301, 302]
|
||||
|
||||
|
||||
def test_list_neg_index_3D():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self):
|
||||
list_ = [[1], [2, 2], [[3, 3, 3]]]
|
||||
list_[2][0][-3] = 300
|
||||
|
@ -128,16 +110,13 @@ def test_list_neg_index_3D():
|
|||
|
||||
net = Net()
|
||||
out = net()
|
||||
assert out[0] == [1]
|
||||
assert out[1] == [2, 2]
|
||||
assert out[2] == [[300, 301, 302]]
|
||||
assert list(out[0]) == [1]
|
||||
assert list(out[1]) == [2, 2]
|
||||
assert list(out[2][0]) == [300, 301, 302]
|
||||
|
||||
|
||||
def test_list_index_1D_parameter():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
list_ = [x]
|
||||
list_[0] = 100
|
||||
|
@ -149,9 +128,6 @@ def test_list_index_1D_parameter():
|
|||
|
||||
def test_list_index_2D_parameter():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
list_ = [[x, x]]
|
||||
list_[0][0] = 100
|
||||
|
@ -163,9 +139,6 @@ def test_list_index_2D_parameter():
|
|||
|
||||
def test_list_index_3D_parameter():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
list_ = [[[x, x]]]
|
||||
list_[0][0][0] = 100
|
||||
|
|
Loading…
Reference in New Issue