Convert AbstractList to AbstractTuple recursively

This commit is contained in:
He Wei 2022-03-15 21:16:55 +08:00
parent 7284534fba
commit 51c147ab12
2 changed files with 46 additions and 52 deletions

View File

@ -44,9 +44,11 @@ using mindspore::abstract::AbstractClass;
using mindspore::abstract::AbstractCOOTensor;
using mindspore::abstract::AbstractDictionary;
using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractListPtr;
using mindspore::abstract::AbstractRowTensor;
using mindspore::abstract::AbstractScalar;
using mindspore::abstract::AbstractTuple;
using mindspore::abstract::AbstractTuplePtr;
namespace {
void CheckInputsSize(const CNodePtr &cnode, size_t expect_size) {
@ -592,11 +594,12 @@ class CleanAfterOptARewriter : public BaseRewriter {
return (this->*(iter->second))(cnode);
}
static constexpr size_t kMaxListRecursiveDepth = 5;
// ValueList --> ValueTuple
static ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int64_t depth) {
constexpr int64_t max_depth = 5;
if (depth > max_depth) {
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << max_depth << " levels.";
static ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, size_t depth) {
if (depth > kMaxListRecursiveDepth) {
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxListRecursiveDepth << " levels.";
}
const auto &list_elements = value_list->value();
std::vector<ValuePtr> elements;
@ -619,11 +622,29 @@ class CleanAfterOptARewriter : public BaseRewriter {
return nullptr;
}
// AbstractList --> AbstractTuple
static AbstractTuplePtr ConvertAbstractListToAbstractTuple(const AbstractListPtr &abs_list, size_t depth) {
if (depth > kMaxListRecursiveDepth) {
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxListRecursiveDepth << " levels.";
}
const auto &list_elements = abs_list->elements();
std::vector<AbstractBasePtr> elements;
elements.reserve(list_elements.size());
for (const auto &element : list_elements) {
if (element->isa<AbstractList>()) {
(void)elements.emplace_back(ConvertAbstractListToAbstractTuple(element->cast<AbstractListPtr>(), depth + 1));
} else {
(void)elements.emplace_back(element);
}
}
return std::make_shared<AbstractTuple>(std::move(elements));
}
AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override {
// AbstractList --> AbstractTuple.
auto abs_list = abs->cast<abstract::AbstractListPtr>();
auto abs_list = abs->cast<AbstractListPtr>();
if (abs_list != nullptr) {
return std::make_shared<AbstractTuple>(abs_list->elements());
return ConvertAbstractListToAbstractTuple(abs_list, 0);
}
// AbstractCOOTensor --> AbstractTuple.
auto abs_sparse = abs->cast<abstract::AbstractCOOTensorPtr>();

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -27,9 +27,6 @@ context.set_context(mode=context.GRAPH_MODE)
def test_list_index_1D():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self):
list_ = [[1], [2, 2], [3, 3, 3]]
list_[0] = [100]
@ -37,16 +34,13 @@ def test_list_index_1D():
net = Net()
out = net()
assert out[0] == [100]
assert out[1] == [2, 2]
assert out[2] == [3, 3, 3]
assert list(out[0]) == [100]
assert list(out[1]) == [2, 2]
assert list(out[2]) == [3, 3, 3]
def test_list_neg_index_1D():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self):
list_ = [[1], [2, 2], [3, 3, 3]]
list_[-3] = [100]
@ -54,16 +48,13 @@ def test_list_neg_index_1D():
net = Net()
out = net()
assert out[0] == [100]
assert out[1] == [2, 2]
assert out[2] == [3, 3, 3]
assert list(out[0]) == [100]
assert list(out[1]) == [2, 2]
assert list(out[2]) == [3, 3, 3]
def test_list_index_2D():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self):
list_ = [[1], [2, 2], [3, 3, 3]]
list_[1][0] = 200
@ -72,16 +63,13 @@ def test_list_index_2D():
net = Net()
out = net()
assert out[0] == [1]
assert out[1] == [200, 201]
assert out[2] == [3, 3, 3]
assert list(out[0]) == [1]
assert list(out[1]) == [200, 201]
assert list(out[2]) == [3, 3, 3]
def test_list_neg_index_2D():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self):
list_ = [[1], [2, 2], [3, 3, 3]]
list_[1][-2] = 200
@ -90,16 +78,13 @@ def test_list_neg_index_2D():
net = Net()
out = net()
assert out[0] == [1]
assert out[1] == [200, 201]
assert out[2] == [3, 3, 3]
assert list(out[0]) == [1]
assert list(out[1]) == [200, 201]
assert list(out[2]) == [3, 3, 3]
def test_list_index_3D():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self):
list_ = [[1], [2, 2], [[3, 3, 3]]]
list_[2][0][0] = 300
@ -109,16 +94,13 @@ def test_list_index_3D():
net = Net()
out = net()
assert out[0] == [1]
assert out[1] == [2, 2]
assert out[2] == [[300, 301, 302]]
assert list(out[0]) == [1]
assert list(out[1]) == [2, 2]
assert list(out[2][0]) == [300, 301, 302]
def test_list_neg_index_3D():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self):
list_ = [[1], [2, 2], [[3, 3, 3]]]
list_[2][0][-3] = 300
@ -128,16 +110,13 @@ def test_list_neg_index_3D():
net = Net()
out = net()
assert out[0] == [1]
assert out[1] == [2, 2]
assert out[2] == [[300, 301, 302]]
assert list(out[0]) == [1]
assert list(out[1]) == [2, 2]
assert list(out[2][0]) == [300, 301, 302]
def test_list_index_1D_parameter():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
list_ = [x]
list_[0] = 100
@ -149,9 +128,6 @@ def test_list_index_1D_parameter():
def test_list_index_2D_parameter():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
list_ = [[x, x]]
list_[0][0] = 100
@ -163,9 +139,6 @@ def test_list_index_2D_parameter():
def test_list_index_3D_parameter():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
list_ = [[[x, x]]]
list_[0][0][0] = 100