diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index a4e8db7218a..d9bbc4b2565 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -343,15 +343,8 @@ void BroadenArgs(const AbstractBasePtrList &args_abs_list, AbstractBasePtrList * MS_EXCEPTION_IF_NULL(broaded_args); (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(*broaded_args), [](const AbstractBasePtr &arg) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(arg); - auto arg_type = arg->BuildType(); - MS_EXCEPTION_IF_NULL(arg_type); - if (arg->isa() && arg_type->isa()) { - MS_LOG(DEBUG) << "Set variable for scalar arg:" << arg->ToString(); - arg->cast_ptr()->set_is_variable(true); - } if (arg->GetValueTrack() != kAnyValue) { - return arg->Broaden(); + return AbstractBroaden(arg); } return arg; }); diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index fea84685cd0..a94f89c25f1 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -833,8 +833,11 @@ void AbstractSequence::set_dynamic_len_element_abs(const AbstractBasePtr &dynami if (dynamic_len_element_abs == nullptr) { return; } + if (dynamic_len_element_abs->isa()) { + MS_EXCEPTION(TypeError) << "DynamicSequence does not support dictionary type as element type now."; + } // dynamic_len_element_abs should ignore value. - dynamic_len_element_abs_ = BroadenAllValues(dynamic_len_element_abs); + dynamic_len_element_abs_ = AbstractBroaden(dynamic_len_element_abs); } bool AbstractSequence::operator==(const AbstractBase &other) const { diff --git a/mindspore/core/abstract/ops/prim_structures.cc b/mindspore/core/abstract/ops/prim_structures.cc index 327daca75de..6bb0e7ff65e 100644 --- a/mindspore/core/abstract/ops/prim_structures.cc +++ b/mindspore/core/abstract/ops/prim_structures.cc @@ -384,12 +384,10 @@ void CheckMutableArgAbstract(const AbstractBasePtr &abs) { } return; } - if (abs->isa()) { - MS_LOG(DEBUG) << "Set scalar as variable, scalar abstract:" << abs->ToString(); - abs->cast_ptr()->set_is_variable(true); + if (abs->isa()) { return; } - if (abs->isa()) { + if (abs->isa()) { return; } MS_EXCEPTION(TypeError) @@ -416,7 +414,7 @@ AbstractBasePtr InferImplMutable(const AnalysisEnginePtr &, const PrimitivePtr & } if (!variable_len) { CheckMutableArgAbstract(args_spec_list[0]); - return args_spec_list[0]->Broaden(); + return AbstractBroaden(args_spec_list[0]); } auto ret = args_spec_list[0]->Clone(); if (!ret->isa()) { diff --git a/mindspore/core/abstract/utils.cc b/mindspore/core/abstract/utils.cc index 711b985488f..368572cd163 100644 --- a/mindspore/core/abstract/utils.cc +++ b/mindspore/core/abstract/utils.cc @@ -49,38 +49,6 @@ TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2) { return kAnyType; } -// Broaden all values within the input abstract. Since for AbstractScalar, calling Broaden() can not set -// the value directly to kAnyValue, this function can not be replaced by calling abs->Broaden(). -AbstractBasePtr BroadenAllValues(const AbstractBasePtr &abs) { - MS_EXCEPTION_IF_NULL(abs); - AbstractBasePtr ret = nullptr; - if (abs->isa()) { - MS_EXCEPTION(TypeError) << "BroadenAllValues does not support dictionary yet"; - } - if (abs->isa()) { - ret = abs->Clone(); - ret->cast()->set_is_variable(true); - return ret->Broaden(); - } - if (abs->isa()) { - auto abs_seq = abs->cast(); - if (abs_seq->dynamic_len()) { - // The value of elements for dynamic length sequence should all be kAnyValue. - return abs->Clone(); - } - AbstractBasePtrList elements = abs_seq->elements(); - AbstractBasePtrList new_elements; - (void)std::transform(elements.begin(), elements.end(), std::back_inserter(new_elements), BroadenAllValues); - if (abs->isa()) { - ret = std::make_shared(new_elements, abs_seq->sequence_nodes()); - } else { - ret = std::make_shared(new_elements, abs_seq->sequence_nodes()); - } - return ret; - } - return abs->Broaden(); -} - bool IsShapesDynamicRank(const std::vector &shapes) { return std::any_of(shapes.begin(), shapes.end(), [](const ShapeVector &shape) { return std::any_of(shape.begin(), shape.end(), [](int64_t dim) { return dim == Shape::kShapeRankAny; }); @@ -205,6 +173,40 @@ AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const Abstrac return joined_list; } +AbstractBasePtr AbstractBroaden(const AbstractBasePtr &abs) { + MS_EXCEPTION_IF_NULL(abs); + if (abs->isa() && !abs->isa()) { + auto sequence_abs = abs->cast(); + if (sequence_abs->dynamic_len()) { + auto elem_abs = sequence_abs->dynamic_len_element_abs(); + if (elem_abs != nullptr && elem_abs->isa()) { + elem_abs->cast()->set_is_variable(true); + } + return abs->Broaden(); + } + std::vector new_elements; + new_elements.reserve(sequence_abs->elements().size()); + (void)std::transform(sequence_abs->elements().cbegin(), sequence_abs->elements().cend(), + std::back_inserter(new_elements), AbstractBroaden); + if (sequence_abs->isa()) { + return std::make_shared(new_elements, sequence_abs->sequence_nodes()); + } + if (sequence_abs->isa()) { + return std::make_shared(new_elements, sequence_abs->sequence_nodes()); + } + MS_EXCEPTION(TypeError) << "Unknown AbstractSequence type:" << abs->ToString(); + } + if (abs->isa()) { + auto arg_type = abs->BuildType(); + MS_EXCEPTION_IF_NULL(arg_type); + auto abs_scalar = abs->cast(); + if (arg_type->isa()) { + abs_scalar->set_is_variable(true); + } + } + return abs->Broaden(); +} + AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec) { auto f_spec = dyn_cast_ptr(spec); if (f_spec != nullptr) { diff --git a/mindspore/core/abstract/utils.h b/mindspore/core/abstract/utils.h index 11d7eee16c1..4b9ce96c2f3 100644 --- a/mindspore/core/abstract/utils.h +++ b/mindspore/core/abstract/utils.h @@ -35,10 +35,10 @@ namespace abstract { ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2); MS_CORE_API TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2); ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2); -AbstractBasePtr BroadenAllValues(const AbstractBasePtr &abs); MS_CORE_API AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list); MS_CORE_API AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const AbstractBasePtrList &spec2); +MS_CORE_API AbstractBasePtr AbstractBroaden(const AbstractBasePtr &abs); // Return an abstract value for the sensitivity of x. // The sensitivity of a function is an Env diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index a553640eaa9..2dc2a5c6ddb 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -50,8 +50,8 @@ if(ENABLE_MINDDATA) file(GLOB_RECURSE UT_SRCS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ./stub/*.cc ./common/*.cc + ./core/abstract/*.cc ./core/utils/*.cc - ./abstract/*.cc ./base/*.cc ./dataset/*.cc ./ir/dtype/*.cc diff --git a/tests/ut/cpp/abstract/abstract_sequence_test.cc b/tests/ut/cpp/core/abstract/abstract_sequence_test.cc similarity index 100% rename from tests/ut/cpp/abstract/abstract_sequence_test.cc rename to tests/ut/cpp/core/abstract/abstract_sequence_test.cc diff --git a/tests/ut/cpp/abstract/dshape_test.cc b/tests/ut/cpp/core/abstract/dshape_test.cc similarity index 99% rename from tests/ut/cpp/abstract/dshape_test.cc rename to tests/ut/cpp/core/abstract/dshape_test.cc index da0e9ed3eef..c013ac8367c 100644 --- a/tests/ut/cpp/abstract/dshape_test.cc +++ b/tests/ut/cpp/core/abstract/dshape_test.cc @@ -74,6 +74,5 @@ TEST_F(TestDShape, Clone) { ASSERT_EQ(*shp_noshp_1.Clone(), shp_noshp_1); ASSERT_EQ(*shp_tuple_2.Clone(), shp_tuple_2); } - } // namespace abstract } // namespace mindspore diff --git a/tests/ut/cpp/abstract/utils_test.cc b/tests/ut/cpp/core/abstract/utils_test.cc similarity index 53% rename from tests/ut/cpp/abstract/utils_test.cc rename to tests/ut/cpp/core/abstract/utils_test.cc index e9ef5f5c58f..95bfb2078f2 100644 --- a/tests/ut/cpp/abstract/utils_test.cc +++ b/tests/ut/cpp/core/abstract/utils_test.cc @@ -103,5 +103,81 @@ TEST_F(TestUtils, TestShapeJoin) { ShapeJoinCheck({-2}, {-2}, {-2}); } +AbstractBasePtr MakeTensorAbstract(const ShapeVector &shape_vec, const TypePtr &elem_type) { + auto shape = std::make_shared(shape_vec); + return std::make_shared(elem_type, shape); +} + +// Feature: AbstractBroaden. +// Description: Check function of AbstractBroaden in utils.cc +// Expectation: Scalar can be successfully broadened. +TEST_F(TestUtils, CheckScalarBroaden) { + auto scalar_abs1 = std::make_shared(1); + scalar_abs1->set_is_variable(true); + auto scalar_abs1_broaden = scalar_abs1->Broaden(); + + auto scalar_abs2 = std::make_shared(2); + auto scalar_abs2_broaden = abstract::AbstractBroaden(scalar_abs2); + ASSERT_TRUE(*scalar_abs1_broaden == *scalar_abs2_broaden); +} + +// Feature: AbstractBroaden. +// Description: Check function of AbstractBroaden in utils.cc +// Expectation: Tensor in dynamic sequence can be successfully broadened. +TEST_F(TestUtils, CheckDynSequenceBroaden) { + // Test tensor as element abs + auto sequence_abs = std::make_shared(std::vector({})); + auto element_abs = MakeTensorAbstract({1, 2, 3}, kFloat32); + auto element_broaden = element_abs->Broaden(); + sequence_abs->set_dynamic_len(true); + sequence_abs->set_dynamic_len_element_abs(element_abs); + auto broadened_sequence_abs = abstract::AbstractBroaden(sequence_abs)->cast(); + ASSERT_TRUE(broadened_sequence_abs != nullptr); + auto equal = *element_broaden == *broadened_sequence_abs->dynamic_len_element_abs(); + ASSERT_TRUE(equal); + // Test scalar as element abs + sequence_abs = std::make_shared(std::vector({})); + element_abs = std::make_shared(1); + auto scalar_abs = std::make_shared(2); + auto scalar_broaden = abstract::AbstractBroaden(scalar_abs); + sequence_abs->set_dynamic_len(true); + sequence_abs->set_dynamic_len_element_abs(element_abs); + broadened_sequence_abs = abstract::AbstractBroaden(sequence_abs)->cast(); + ASSERT_TRUE(broadened_sequence_abs != nullptr); + equal = *scalar_broaden == *broadened_sequence_abs->dynamic_len_element_abs(); + ASSERT_TRUE(equal); +} + +// Feature: AbstractBroaden. +// Description: Check function of AbstractBroaden in utils.cc +// Expectation: Scalar in tuple can be successfully broadened. +TEST_F(TestUtils, CheckScalarInTupleBroaden) { + auto element_abs = std::make_shared(1); + auto tuple_abs = std::make_shared(std::vector({element_abs})); + auto broadened_tuple_abs = abstract::AbstractBroaden(tuple_abs)->cast(); + ASSERT_TRUE(broadened_tuple_abs != nullptr); + ASSERT_TRUE(broadened_tuple_abs->size() == 1); + + auto scalar_abs = std::make_shared(2); + auto broadened_scalar_abs = abstract::AbstractBroaden(scalar_abs); + + ASSERT_TRUE(*(broadened_tuple_abs->elements()[0]) == *broadened_scalar_abs); +} + +// Feature: AbstractBroaden. +// Description: Check function of AbstractBroaden in utils.cc +// Expectation: Scalar in tuple can be successfully broadened. +TEST_F(TestUtils, CheckScalarInListBroaden) { + auto element_abs = std::make_shared(1); + auto list_abs = std::make_shared(std::vector({element_abs})); + auto broadened_list_abs = abstract::AbstractBroaden(list_abs)->cast(); + ASSERT_TRUE(broadened_list_abs != nullptr); + ASSERT_TRUE(broadened_list_abs->size() == 1); + + auto scalar_abs = std::make_shared(2); + auto broadened_scalar_abs = abstract::AbstractBroaden(scalar_abs); + + ASSERT_TRUE(*(broadened_list_abs->elements()[0]) == *broadened_scalar_abs); +} } // namespace abstract } // namespace mindspore