diff --git a/mindspore/ccsrc/frontend/optimizer/clean.cc b/mindspore/ccsrc/frontend/optimizer/clean.cc index a495287d21d..ea196d34129 100644 --- a/mindspore/ccsrc/frontend/optimizer/clean.cc +++ b/mindspore/ccsrc/frontend/optimizer/clean.cc @@ -44,9 +44,11 @@ using mindspore::abstract::AbstractClass; using mindspore::abstract::AbstractCOOTensor; using mindspore::abstract::AbstractDictionary; using mindspore::abstract::AbstractList; +using mindspore::abstract::AbstractListPtr; using mindspore::abstract::AbstractRowTensor; using mindspore::abstract::AbstractScalar; using mindspore::abstract::AbstractTuple; +using mindspore::abstract::AbstractTuplePtr; namespace { void CheckInputsSize(const CNodePtr &cnode, size_t expect_size) { @@ -592,11 +594,12 @@ class CleanAfterOptARewriter : public BaseRewriter { return (this->*(iter->second))(cnode); } + static constexpr size_t kMaxListRecursiveDepth = 5; + // ValueList --> ValueTuple - static ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int64_t depth) { - constexpr int64_t max_depth = 5; - if (depth > max_depth) { - MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << max_depth << " levels."; + static ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, size_t depth) { + if (depth > kMaxListRecursiveDepth) { + MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxListRecursiveDepth << " levels."; } const auto &list_elements = value_list->value(); std::vector elements; @@ -619,11 +622,29 @@ class CleanAfterOptARewriter : public BaseRewriter { return nullptr; } + // AbstractList --> AbstractTuple + static AbstractTuplePtr ConvertAbstractListToAbstractTuple(const AbstractListPtr &abs_list, size_t depth) { + if (depth > kMaxListRecursiveDepth) { + MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxListRecursiveDepth << " levels."; + } + const auto &list_elements = abs_list->elements(); + std::vector elements; + elements.reserve(list_elements.size()); + for (const auto &element : list_elements) { + if (element->isa()) { + (void)elements.emplace_back(ConvertAbstractListToAbstractTuple(element->cast(), depth + 1)); + } else { + (void)elements.emplace_back(element); + } + } + return std::make_shared(std::move(elements)); + } + AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override { // AbstractList --> AbstractTuple. - auto abs_list = abs->cast(); + auto abs_list = abs->cast(); if (abs_list != nullptr) { - return std::make_shared(abs_list->elements()); + return ConvertAbstractListToAbstractTuple(abs_list, 0); } // AbstractCOOTensor --> AbstractTuple. auto abs_sparse = abs->cast(); diff --git a/tests/ut/python/pipeline/parse/test_sequence_assign.py b/tests/ut/python/pipeline/parse/test_sequence_assign.py index 255f40a4aa0..8fde2d6820b 100644 --- a/tests/ut/python/pipeline/parse/test_sequence_assign.py +++ b/tests/ut/python/pipeline/parse/test_sequence_assign.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,9 +27,6 @@ context.set_context(mode=context.GRAPH_MODE) def test_list_index_1D(): class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - def construct(self): list_ = [[1], [2, 2], [3, 3, 3]] list_[0] = [100] @@ -37,16 +34,13 @@ def test_list_index_1D(): net = Net() out = net() - assert out[0] == [100] - assert out[1] == [2, 2] - assert out[2] == [3, 3, 3] + assert list(out[0]) == [100] + assert list(out[1]) == [2, 2] + assert list(out[2]) == [3, 3, 3] def test_list_neg_index_1D(): class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - def construct(self): list_ = [[1], [2, 2], [3, 3, 3]] list_[-3] = [100] @@ -54,16 +48,13 @@ def test_list_neg_index_1D(): net = Net() out = net() - assert out[0] == [100] - assert out[1] == [2, 2] - assert out[2] == [3, 3, 3] + assert list(out[0]) == [100] + assert list(out[1]) == [2, 2] + assert list(out[2]) == [3, 3, 3] def test_list_index_2D(): class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - def construct(self): list_ = [[1], [2, 2], [3, 3, 3]] list_[1][0] = 200 @@ -72,16 +63,13 @@ def test_list_index_2D(): net = Net() out = net() - assert out[0] == [1] - assert out[1] == [200, 201] - assert out[2] == [3, 3, 3] + assert list(out[0]) == [1] + assert list(out[1]) == [200, 201] + assert list(out[2]) == [3, 3, 3] def test_list_neg_index_2D(): class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - def construct(self): list_ = [[1], [2, 2], [3, 3, 3]] list_[1][-2] = 200 @@ -90,16 +78,13 @@ def test_list_neg_index_2D(): net = Net() out = net() - assert out[0] == [1] - assert out[1] == [200, 201] - assert out[2] == [3, 3, 3] + assert list(out[0]) == [1] + assert list(out[1]) == [200, 201] + assert list(out[2]) == [3, 3, 3] def test_list_index_3D(): class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - def construct(self): list_ = [[1], [2, 2], [[3, 3, 3]]] list_[2][0][0] = 300 @@ -109,16 +94,13 @@ def test_list_index_3D(): net = Net() out = net() - assert out[0] == [1] - assert out[1] == [2, 2] - assert out[2] == [[300, 301, 302]] + assert list(out[0]) == [1] + assert list(out[1]) == [2, 2] + assert list(out[2][0]) == [300, 301, 302] def test_list_neg_index_3D(): class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - def construct(self): list_ = [[1], [2, 2], [[3, 3, 3]]] list_[2][0][-3] = 300 @@ -128,16 +110,13 @@ def test_list_neg_index_3D(): net = Net() out = net() - assert out[0] == [1] - assert out[1] == [2, 2] - assert out[2] == [[300, 301, 302]] + assert list(out[0]) == [1] + assert list(out[1]) == [2, 2] + assert list(out[2][0]) == [300, 301, 302] def test_list_index_1D_parameter(): class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - def construct(self, x): list_ = [x] list_[0] = 100 @@ -149,9 +128,6 @@ def test_list_index_1D_parameter(): def test_list_index_2D_parameter(): class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - def construct(self, x): list_ = [[x, x]] list_[0][0] = 100 @@ -163,9 +139,6 @@ def test_list_index_2D_parameter(): def test_list_index_3D_parameter(): class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - def construct(self, x): list_ = [[[x, x]]] list_[0][0][0] = 100