broaden element sacalr in tuple

This commit is contained in:
chenfei 2022-12-20 14:55:30 +08:00
parent 84b83d6834
commit 59523e32f2
9 changed files with 120 additions and 49 deletions

View File

@ -343,15 +343,8 @@ void BroadenArgs(const AbstractBasePtrList &args_abs_list, AbstractBasePtrList *
MS_EXCEPTION_IF_NULL(broaded_args);
(void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(*broaded_args),
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(arg);
auto arg_type = arg->BuildType();
MS_EXCEPTION_IF_NULL(arg_type);
if (arg->isa<AbstractScalar>() && arg_type->isa<Number>()) {
MS_LOG(DEBUG) << "Set variable for scalar arg:" << arg->ToString();
arg->cast_ptr<AbstractScalar>()->set_is_variable(true);
}
if (arg->GetValueTrack() != kAnyValue) {
return arg->Broaden();
return AbstractBroaden(arg);
}
return arg;
});

View File

@ -833,8 +833,11 @@ void AbstractSequence::set_dynamic_len_element_abs(const AbstractBasePtr &dynami
if (dynamic_len_element_abs == nullptr) {
return;
}
if (dynamic_len_element_abs->isa<abstract::AbstractDictionary>()) {
MS_EXCEPTION(TypeError) << "DynamicSequence does not support dictionary type as element type now.";
}
// dynamic_len_element_abs should ignore value.
dynamic_len_element_abs_ = BroadenAllValues(dynamic_len_element_abs);
dynamic_len_element_abs_ = AbstractBroaden(dynamic_len_element_abs);
}
bool AbstractSequence::operator==(const AbstractBase &other) const {

View File

@ -384,12 +384,10 @@ void CheckMutableArgAbstract(const AbstractBasePtr &abs) {
}
return;
}
if (abs->isa<AbstractScalar>()) {
MS_LOG(DEBUG) << "Set scalar as variable, scalar abstract:" << abs->ToString();
abs->cast_ptr<AbstractScalar>()->set_is_variable(true);
if (abs->isa<AbstractTensor>()) {
return;
}
if (abs->isa<AbstractTensor>()) {
if (abs->isa<AbstractScalar>()) {
return;
}
MS_EXCEPTION(TypeError)
@ -416,7 +414,7 @@ AbstractBasePtr InferImplMutable(const AnalysisEnginePtr &, const PrimitivePtr &
}
if (!variable_len) {
CheckMutableArgAbstract(args_spec_list[0]);
return args_spec_list[0]->Broaden();
return AbstractBroaden(args_spec_list[0]);
}
auto ret = args_spec_list[0]->Clone();
if (!ret->isa<abstract::AbstractSequence>()) {

View File

@ -49,38 +49,6 @@ TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2) {
return kAnyType;
}
// Broaden all values within the input abstract. Since for AbstractScalar, calling Broaden() can not set
// the value directly to kAnyValue, this function can not be replaced by calling abs->Broaden().
AbstractBasePtr BroadenAllValues(const AbstractBasePtr &abs) {
MS_EXCEPTION_IF_NULL(abs);
AbstractBasePtr ret = nullptr;
if (abs->isa<abstract::AbstractDictionary>()) {
MS_EXCEPTION(TypeError) << "BroadenAllValues does not support dictionary yet";
}
if (abs->isa<abstract::AbstractScalar>()) {
ret = abs->Clone();
ret->cast<abstract::AbstractScalarPtr>()->set_is_variable(true);
return ret->Broaden();
}
if (abs->isa<abstract::AbstractSequence>()) {
auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
if (abs_seq->dynamic_len()) {
// The value of elements for dynamic length sequence should all be kAnyValue.
return abs->Clone();
}
AbstractBasePtrList elements = abs_seq->elements();
AbstractBasePtrList new_elements;
(void)std::transform(elements.begin(), elements.end(), std::back_inserter(new_elements), BroadenAllValues);
if (abs->isa<abstract::AbstractList>()) {
ret = std::make_shared<abstract::AbstractList>(new_elements, abs_seq->sequence_nodes());
} else {
ret = std::make_shared<abstract::AbstractTuple>(new_elements, abs_seq->sequence_nodes());
}
return ret;
}
return abs->Broaden();
}
bool IsShapesDynamicRank(const std::vector<ShapeVector> &shapes) {
return std::any_of(shapes.begin(), shapes.end(), [](const ShapeVector &shape) {
return std::any_of(shape.begin(), shape.end(), [](int64_t dim) { return dim == Shape::kShapeRankAny; });
@ -205,6 +173,40 @@ AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const Abstrac
return joined_list;
}
AbstractBasePtr AbstractBroaden(const AbstractBasePtr &abs) {
MS_EXCEPTION_IF_NULL(abs);
if (abs->isa<AbstractSequence>() && !abs->isa<AbstractSparseTensor>()) {
auto sequence_abs = abs->cast<AbstractSequencePtr>();
if (sequence_abs->dynamic_len()) {
auto elem_abs = sequence_abs->dynamic_len_element_abs();
if (elem_abs != nullptr && elem_abs->isa<AbstractScalar>()) {
elem_abs->cast<AbstractScalarPtr>()->set_is_variable(true);
}
return abs->Broaden();
}
std::vector<AbstractBasePtr> new_elements;
new_elements.reserve(sequence_abs->elements().size());
(void)std::transform(sequence_abs->elements().cbegin(), sequence_abs->elements().cend(),
std::back_inserter(new_elements), AbstractBroaden);
if (sequence_abs->isa<AbstractTuple>()) {
return std::make_shared<AbstractTuple>(new_elements, sequence_abs->sequence_nodes());
}
if (sequence_abs->isa<AbstractList>()) {
return std::make_shared<AbstractList>(new_elements, sequence_abs->sequence_nodes());
}
MS_EXCEPTION(TypeError) << "Unknown AbstractSequence type:" << abs->ToString();
}
if (abs->isa<AbstractScalar>()) {
auto arg_type = abs->BuildType();
MS_EXCEPTION_IF_NULL(arg_type);
auto abs_scalar = abs->cast<AbstractScalarPtr>();
if (arg_type->isa<Number>()) {
abs_scalar->set_is_variable(true);
}
}
return abs->Broaden();
}
AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec) {
auto f_spec = dyn_cast_ptr<AbstractFunction>(spec);
if (f_spec != nullptr) {

View File

@ -35,10 +35,10 @@ namespace abstract {
ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2);
MS_CORE_API TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2);
ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2);
AbstractBasePtr BroadenAllValues(const AbstractBasePtr &abs);
MS_CORE_API AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list);
MS_CORE_API AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const AbstractBasePtrList &spec2);
MS_CORE_API AbstractBasePtr AbstractBroaden(const AbstractBasePtr &abs);
// Return an abstract value for the sensitivity of x.
// The sensitivity of a function is an Env

View File

@ -50,8 +50,8 @@ if(ENABLE_MINDDATA)
file(GLOB_RECURSE UT_SRCS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
./stub/*.cc
./common/*.cc
./core/abstract/*.cc
./core/utils/*.cc
./abstract/*.cc
./base/*.cc
./dataset/*.cc
./ir/dtype/*.cc

View File

@ -74,6 +74,5 @@ TEST_F(TestDShape, Clone) {
ASSERT_EQ(*shp_noshp_1.Clone(), shp_noshp_1);
ASSERT_EQ(*shp_tuple_2.Clone(), shp_tuple_2);
}
} // namespace abstract
} // namespace mindspore

View File

@ -103,5 +103,81 @@ TEST_F(TestUtils, TestShapeJoin) {
ShapeJoinCheck({-2}, {-2}, {-2});
}
AbstractBasePtr MakeTensorAbstract(const ShapeVector &shape_vec, const TypePtr &elem_type) {
auto shape = std::make_shared<abstract::Shape>(shape_vec);
return std::make_shared<abstract::AbstractTensor>(elem_type, shape);
}
// Feature: AbstractBroaden.
// Description: Check function of AbstractBroaden in utils.cc
// Expectation: Scalar can be successfully broadened.
TEST_F(TestUtils, CheckScalarBroaden) {
auto scalar_abs1 = std::make_shared<abstract::AbstractScalar>(1);
scalar_abs1->set_is_variable(true);
auto scalar_abs1_broaden = scalar_abs1->Broaden();
auto scalar_abs2 = std::make_shared<abstract::AbstractScalar>(2);
auto scalar_abs2_broaden = abstract::AbstractBroaden(scalar_abs2);
ASSERT_TRUE(*scalar_abs1_broaden == *scalar_abs2_broaden);
}
// Feature: AbstractBroaden.
// Description: Check function of AbstractBroaden in utils.cc
// Expectation: Tensor in dynamic sequence can be successfully broadened.
TEST_F(TestUtils, CheckDynSequenceBroaden) {
// Test tensor as element abs
auto sequence_abs = std::make_shared<abstract::AbstractTuple>(std::vector<AbstractBasePtr>({}));
auto element_abs = MakeTensorAbstract({1, 2, 3}, kFloat32);
auto element_broaden = element_abs->Broaden();
sequence_abs->set_dynamic_len(true);
sequence_abs->set_dynamic_len_element_abs(element_abs);
auto broadened_sequence_abs = abstract::AbstractBroaden(sequence_abs)->cast<abstract::AbstractSequencePtr>();
ASSERT_TRUE(broadened_sequence_abs != nullptr);
auto equal = *element_broaden == *broadened_sequence_abs->dynamic_len_element_abs();
ASSERT_TRUE(equal);
// Test scalar as element abs
sequence_abs = std::make_shared<abstract::AbstractTuple>(std::vector<AbstractBasePtr>({}));
element_abs = std::make_shared<abstract::AbstractScalar>(1);
auto scalar_abs = std::make_shared<abstract::AbstractScalar>(2);
auto scalar_broaden = abstract::AbstractBroaden(scalar_abs);
sequence_abs->set_dynamic_len(true);
sequence_abs->set_dynamic_len_element_abs(element_abs);
broadened_sequence_abs = abstract::AbstractBroaden(sequence_abs)->cast<abstract::AbstractSequencePtr>();
ASSERT_TRUE(broadened_sequence_abs != nullptr);
equal = *scalar_broaden == *broadened_sequence_abs->dynamic_len_element_abs();
ASSERT_TRUE(equal);
}
// Feature: AbstractBroaden.
// Description: Check function of AbstractBroaden in utils.cc
// Expectation: Scalar in tuple can be successfully broadened.
TEST_F(TestUtils, CheckScalarInTupleBroaden) {
auto element_abs = std::make_shared<abstract::AbstractScalar>(1);
auto tuple_abs = std::make_shared<abstract::AbstractTuple>(std::vector<AbstractBasePtr>({element_abs}));
auto broadened_tuple_abs = abstract::AbstractBroaden(tuple_abs)->cast<abstract::AbstractTuplePtr>();
ASSERT_TRUE(broadened_tuple_abs != nullptr);
ASSERT_TRUE(broadened_tuple_abs->size() == 1);
auto scalar_abs = std::make_shared<abstract::AbstractScalar>(2);
auto broadened_scalar_abs = abstract::AbstractBroaden(scalar_abs);
ASSERT_TRUE(*(broadened_tuple_abs->elements()[0]) == *broadened_scalar_abs);
}
// Feature: AbstractBroaden.
// Description: Check function of AbstractBroaden in utils.cc
// Expectation: Scalar in tuple can be successfully broadened.
TEST_F(TestUtils, CheckScalarInListBroaden) {
auto element_abs = std::make_shared<abstract::AbstractScalar>(1);
auto list_abs = std::make_shared<abstract::AbstractList>(std::vector<AbstractBasePtr>({element_abs}));
auto broadened_list_abs = abstract::AbstractBroaden(list_abs)->cast<abstract::AbstractListPtr>();
ASSERT_TRUE(broadened_list_abs != nullptr);
ASSERT_TRUE(broadened_list_abs->size() == 1);
auto scalar_abs = std::make_shared<abstract::AbstractScalar>(2);
auto broadened_scalar_abs = abstract::AbstractBroaden(scalar_abs);
ASSERT_TRUE(*(broadened_list_abs->elements()[0]) == *broadened_scalar_abs);
}
} // namespace abstract
} // namespace mindspore