forked from mindspore-Ecosystem/mindspore
broaden element sacalr in tuple
This commit is contained in:
parent
84b83d6834
commit
59523e32f2
|
@ -343,15 +343,8 @@ void BroadenArgs(const AbstractBasePtrList &args_abs_list, AbstractBasePtrList *
|
|||
MS_EXCEPTION_IF_NULL(broaded_args);
|
||||
(void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(*broaded_args),
|
||||
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
auto arg_type = arg->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(arg_type);
|
||||
if (arg->isa<AbstractScalar>() && arg_type->isa<Number>()) {
|
||||
MS_LOG(DEBUG) << "Set variable for scalar arg:" << arg->ToString();
|
||||
arg->cast_ptr<AbstractScalar>()->set_is_variable(true);
|
||||
}
|
||||
if (arg->GetValueTrack() != kAnyValue) {
|
||||
return arg->Broaden();
|
||||
return AbstractBroaden(arg);
|
||||
}
|
||||
return arg;
|
||||
});
|
||||
|
|
|
@ -833,8 +833,11 @@ void AbstractSequence::set_dynamic_len_element_abs(const AbstractBasePtr &dynami
|
|||
if (dynamic_len_element_abs == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (dynamic_len_element_abs->isa<abstract::AbstractDictionary>()) {
|
||||
MS_EXCEPTION(TypeError) << "DynamicSequence does not support dictionary type as element type now.";
|
||||
}
|
||||
// dynamic_len_element_abs should ignore value.
|
||||
dynamic_len_element_abs_ = BroadenAllValues(dynamic_len_element_abs);
|
||||
dynamic_len_element_abs_ = AbstractBroaden(dynamic_len_element_abs);
|
||||
}
|
||||
|
||||
bool AbstractSequence::operator==(const AbstractBase &other) const {
|
||||
|
|
|
@ -384,12 +384,10 @@ void CheckMutableArgAbstract(const AbstractBasePtr &abs) {
|
|||
}
|
||||
return;
|
||||
}
|
||||
if (abs->isa<AbstractScalar>()) {
|
||||
MS_LOG(DEBUG) << "Set scalar as variable, scalar abstract:" << abs->ToString();
|
||||
abs->cast_ptr<AbstractScalar>()->set_is_variable(true);
|
||||
if (abs->isa<AbstractTensor>()) {
|
||||
return;
|
||||
}
|
||||
if (abs->isa<AbstractTensor>()) {
|
||||
if (abs->isa<AbstractScalar>()) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION(TypeError)
|
||||
|
@ -416,7 +414,7 @@ AbstractBasePtr InferImplMutable(const AnalysisEnginePtr &, const PrimitivePtr &
|
|||
}
|
||||
if (!variable_len) {
|
||||
CheckMutableArgAbstract(args_spec_list[0]);
|
||||
return args_spec_list[0]->Broaden();
|
||||
return AbstractBroaden(args_spec_list[0]);
|
||||
}
|
||||
auto ret = args_spec_list[0]->Clone();
|
||||
if (!ret->isa<abstract::AbstractSequence>()) {
|
||||
|
|
|
@ -49,38 +49,6 @@ TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2) {
|
|||
return kAnyType;
|
||||
}
|
||||
|
||||
// Broaden all values within the input abstract. Since for AbstractScalar, calling Broaden() can not set
|
||||
// the value directly to kAnyValue, this function can not be replaced by calling abs->Broaden().
|
||||
AbstractBasePtr BroadenAllValues(const AbstractBasePtr &abs) {
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
AbstractBasePtr ret = nullptr;
|
||||
if (abs->isa<abstract::AbstractDictionary>()) {
|
||||
MS_EXCEPTION(TypeError) << "BroadenAllValues does not support dictionary yet";
|
||||
}
|
||||
if (abs->isa<abstract::AbstractScalar>()) {
|
||||
ret = abs->Clone();
|
||||
ret->cast<abstract::AbstractScalarPtr>()->set_is_variable(true);
|
||||
return ret->Broaden();
|
||||
}
|
||||
if (abs->isa<abstract::AbstractSequence>()) {
|
||||
auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
|
||||
if (abs_seq->dynamic_len()) {
|
||||
// The value of elements for dynamic length sequence should all be kAnyValue.
|
||||
return abs->Clone();
|
||||
}
|
||||
AbstractBasePtrList elements = abs_seq->elements();
|
||||
AbstractBasePtrList new_elements;
|
||||
(void)std::transform(elements.begin(), elements.end(), std::back_inserter(new_elements), BroadenAllValues);
|
||||
if (abs->isa<abstract::AbstractList>()) {
|
||||
ret = std::make_shared<abstract::AbstractList>(new_elements, abs_seq->sequence_nodes());
|
||||
} else {
|
||||
ret = std::make_shared<abstract::AbstractTuple>(new_elements, abs_seq->sequence_nodes());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
return abs->Broaden();
|
||||
}
|
||||
|
||||
bool IsShapesDynamicRank(const std::vector<ShapeVector> &shapes) {
|
||||
return std::any_of(shapes.begin(), shapes.end(), [](const ShapeVector &shape) {
|
||||
return std::any_of(shape.begin(), shape.end(), [](int64_t dim) { return dim == Shape::kShapeRankAny; });
|
||||
|
@ -205,6 +173,40 @@ AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const Abstrac
|
|||
return joined_list;
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractBroaden(const AbstractBasePtr &abs) {
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
if (abs->isa<AbstractSequence>() && !abs->isa<AbstractSparseTensor>()) {
|
||||
auto sequence_abs = abs->cast<AbstractSequencePtr>();
|
||||
if (sequence_abs->dynamic_len()) {
|
||||
auto elem_abs = sequence_abs->dynamic_len_element_abs();
|
||||
if (elem_abs != nullptr && elem_abs->isa<AbstractScalar>()) {
|
||||
elem_abs->cast<AbstractScalarPtr>()->set_is_variable(true);
|
||||
}
|
||||
return abs->Broaden();
|
||||
}
|
||||
std::vector<AbstractBasePtr> new_elements;
|
||||
new_elements.reserve(sequence_abs->elements().size());
|
||||
(void)std::transform(sequence_abs->elements().cbegin(), sequence_abs->elements().cend(),
|
||||
std::back_inserter(new_elements), AbstractBroaden);
|
||||
if (sequence_abs->isa<AbstractTuple>()) {
|
||||
return std::make_shared<AbstractTuple>(new_elements, sequence_abs->sequence_nodes());
|
||||
}
|
||||
if (sequence_abs->isa<AbstractList>()) {
|
||||
return std::make_shared<AbstractList>(new_elements, sequence_abs->sequence_nodes());
|
||||
}
|
||||
MS_EXCEPTION(TypeError) << "Unknown AbstractSequence type:" << abs->ToString();
|
||||
}
|
||||
if (abs->isa<AbstractScalar>()) {
|
||||
auto arg_type = abs->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(arg_type);
|
||||
auto abs_scalar = abs->cast<AbstractScalarPtr>();
|
||||
if (arg_type->isa<Number>()) {
|
||||
abs_scalar->set_is_variable(true);
|
||||
}
|
||||
}
|
||||
return abs->Broaden();
|
||||
}
|
||||
|
||||
AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec) {
|
||||
auto f_spec = dyn_cast_ptr<AbstractFunction>(spec);
|
||||
if (f_spec != nullptr) {
|
||||
|
|
|
@ -35,10 +35,10 @@ namespace abstract {
|
|||
ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2);
|
||||
MS_CORE_API TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2);
|
||||
ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2);
|
||||
AbstractBasePtr BroadenAllValues(const AbstractBasePtr &abs);
|
||||
|
||||
MS_CORE_API AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list);
|
||||
MS_CORE_API AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const AbstractBasePtrList &spec2);
|
||||
MS_CORE_API AbstractBasePtr AbstractBroaden(const AbstractBasePtr &abs);
|
||||
|
||||
// Return an abstract value for the sensitivity of x.
|
||||
// The sensitivity of a function is an Env
|
||||
|
|
|
@ -50,8 +50,8 @@ if(ENABLE_MINDDATA)
|
|||
file(GLOB_RECURSE UT_SRCS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
./stub/*.cc
|
||||
./common/*.cc
|
||||
./core/abstract/*.cc
|
||||
./core/utils/*.cc
|
||||
./abstract/*.cc
|
||||
./base/*.cc
|
||||
./dataset/*.cc
|
||||
./ir/dtype/*.cc
|
||||
|
|
|
@ -74,6 +74,5 @@ TEST_F(TestDShape, Clone) {
|
|||
ASSERT_EQ(*shp_noshp_1.Clone(), shp_noshp_1);
|
||||
ASSERT_EQ(*shp_tuple_2.Clone(), shp_tuple_2);
|
||||
}
|
||||
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
|
@ -103,5 +103,81 @@ TEST_F(TestUtils, TestShapeJoin) {
|
|||
ShapeJoinCheck({-2}, {-2}, {-2});
|
||||
}
|
||||
|
||||
AbstractBasePtr MakeTensorAbstract(const ShapeVector &shape_vec, const TypePtr &elem_type) {
|
||||
auto shape = std::make_shared<abstract::Shape>(shape_vec);
|
||||
return std::make_shared<abstract::AbstractTensor>(elem_type, shape);
|
||||
}
|
||||
|
||||
// Feature: AbstractBroaden.
|
||||
// Description: Check function of AbstractBroaden in utils.cc
|
||||
// Expectation: Scalar can be successfully broadened.
|
||||
TEST_F(TestUtils, CheckScalarBroaden) {
|
||||
auto scalar_abs1 = std::make_shared<abstract::AbstractScalar>(1);
|
||||
scalar_abs1->set_is_variable(true);
|
||||
auto scalar_abs1_broaden = scalar_abs1->Broaden();
|
||||
|
||||
auto scalar_abs2 = std::make_shared<abstract::AbstractScalar>(2);
|
||||
auto scalar_abs2_broaden = abstract::AbstractBroaden(scalar_abs2);
|
||||
ASSERT_TRUE(*scalar_abs1_broaden == *scalar_abs2_broaden);
|
||||
}
|
||||
|
||||
// Feature: AbstractBroaden.
|
||||
// Description: Check function of AbstractBroaden in utils.cc
|
||||
// Expectation: Tensor in dynamic sequence can be successfully broadened.
|
||||
TEST_F(TestUtils, CheckDynSequenceBroaden) {
|
||||
// Test tensor as element abs
|
||||
auto sequence_abs = std::make_shared<abstract::AbstractTuple>(std::vector<AbstractBasePtr>({}));
|
||||
auto element_abs = MakeTensorAbstract({1, 2, 3}, kFloat32);
|
||||
auto element_broaden = element_abs->Broaden();
|
||||
sequence_abs->set_dynamic_len(true);
|
||||
sequence_abs->set_dynamic_len_element_abs(element_abs);
|
||||
auto broadened_sequence_abs = abstract::AbstractBroaden(sequence_abs)->cast<abstract::AbstractSequencePtr>();
|
||||
ASSERT_TRUE(broadened_sequence_abs != nullptr);
|
||||
auto equal = *element_broaden == *broadened_sequence_abs->dynamic_len_element_abs();
|
||||
ASSERT_TRUE(equal);
|
||||
// Test scalar as element abs
|
||||
sequence_abs = std::make_shared<abstract::AbstractTuple>(std::vector<AbstractBasePtr>({}));
|
||||
element_abs = std::make_shared<abstract::AbstractScalar>(1);
|
||||
auto scalar_abs = std::make_shared<abstract::AbstractScalar>(2);
|
||||
auto scalar_broaden = abstract::AbstractBroaden(scalar_abs);
|
||||
sequence_abs->set_dynamic_len(true);
|
||||
sequence_abs->set_dynamic_len_element_abs(element_abs);
|
||||
broadened_sequence_abs = abstract::AbstractBroaden(sequence_abs)->cast<abstract::AbstractSequencePtr>();
|
||||
ASSERT_TRUE(broadened_sequence_abs != nullptr);
|
||||
equal = *scalar_broaden == *broadened_sequence_abs->dynamic_len_element_abs();
|
||||
ASSERT_TRUE(equal);
|
||||
}
|
||||
|
||||
// Feature: AbstractBroaden.
|
||||
// Description: Check function of AbstractBroaden in utils.cc
|
||||
// Expectation: Scalar in tuple can be successfully broadened.
|
||||
TEST_F(TestUtils, CheckScalarInTupleBroaden) {
|
||||
auto element_abs = std::make_shared<abstract::AbstractScalar>(1);
|
||||
auto tuple_abs = std::make_shared<abstract::AbstractTuple>(std::vector<AbstractBasePtr>({element_abs}));
|
||||
auto broadened_tuple_abs = abstract::AbstractBroaden(tuple_abs)->cast<abstract::AbstractTuplePtr>();
|
||||
ASSERT_TRUE(broadened_tuple_abs != nullptr);
|
||||
ASSERT_TRUE(broadened_tuple_abs->size() == 1);
|
||||
|
||||
auto scalar_abs = std::make_shared<abstract::AbstractScalar>(2);
|
||||
auto broadened_scalar_abs = abstract::AbstractBroaden(scalar_abs);
|
||||
|
||||
ASSERT_TRUE(*(broadened_tuple_abs->elements()[0]) == *broadened_scalar_abs);
|
||||
}
|
||||
|
||||
// Feature: AbstractBroaden.
|
||||
// Description: Check function of AbstractBroaden in utils.cc
|
||||
// Expectation: Scalar in tuple can be successfully broadened.
|
||||
TEST_F(TestUtils, CheckScalarInListBroaden) {
|
||||
auto element_abs = std::make_shared<abstract::AbstractScalar>(1);
|
||||
auto list_abs = std::make_shared<abstract::AbstractList>(std::vector<AbstractBasePtr>({element_abs}));
|
||||
auto broadened_list_abs = abstract::AbstractBroaden(list_abs)->cast<abstract::AbstractListPtr>();
|
||||
ASSERT_TRUE(broadened_list_abs != nullptr);
|
||||
ASSERT_TRUE(broadened_list_abs->size() == 1);
|
||||
|
||||
auto scalar_abs = std::make_shared<abstract::AbstractScalar>(2);
|
||||
auto broadened_scalar_abs = abstract::AbstractBroaden(scalar_abs);
|
||||
|
||||
ASSERT_TRUE(*(broadened_list_abs->elements()[0]) == *broadened_scalar_abs);
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue