!30027 Fix tensor with dynamic shape problem

Merge pull request !30027 from hewei/fix_core
This commit is contained in:
i-robot 2022-02-15 06:26:51 +00:00 committed by Gitee
commit e57633ec7f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 42 additions and 1 deletions

View File

@ -44,7 +44,15 @@ static TypeId TypeIdOf(const TypePtr &data_type, TypeId defaultTypeId) {
} }
static size_t SizeOf(const ShapeVector &shape) { static size_t SizeOf(const ShapeVector &shape) {
return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies<size_t>()); int64_t data_size = 1;
for (auto dim : shape) {
if (dim < 0) {
// For dynamic shape which has negative dimensions, data size should be zero.
return 0;
}
data_size *= dim;
}
return static_cast<size_t>(data_size);
} }
static std::string ShapeToString(const ShapeVector &shape) { static std::string ShapeToString(const ShapeVector &shape) {

View File

@ -363,6 +363,39 @@ TEST_F(TestMindApi, test_tensor_api) {
ASSERT_EQ(tensor_type->cast<TensorTypePtr>()->element()->type_id(), kNumberTypeFloat32); ASSERT_EQ(tensor_type->cast<TensorTypePtr>()->element()->type_id(), kNumberTypeFloat32);
} }
/// Feature: MindAPI
/// Description: test Tensor with dynamic shape.
/// Expectation: Tensor API work as expected.
TEST_F(TestMindApi, test_tensor_with_dyn_shape) {
ShapeVector shape{1, 2, -1, -2};
auto tensor = MakeShared<Tensor>(kNumberTypeFloat32, shape);
ASSERT_EQ(tensor->data_type(), kNumberTypeFloat32);
ASSERT_EQ(tensor->shape(), shape);
ASSERT_EQ(tensor->DataSize(), 0);
ASSERT_EQ(tensor->Size(), 0);
ShapeVector shape2{2, 3};
tensor->set_data_type(kNumberTypeInt32);
tensor->set_shape(shape2);
ASSERT_EQ(tensor->data_type(), kNumberTypeInt32);
ASSERT_EQ(tensor->shape(), shape2);
ShapeVector shape3{1, -1, 3};
auto tensor2 = MakeShared<Tensor>(kNumberTypeFloat32, shape);
ASSERT_EQ(tensor2->data_type(), kNumberTypeFloat32);
ASSERT_EQ(tensor2->shape(), shape);
ASSERT_EQ(tensor2->DataSize(), 0);
ASSERT_EQ(tensor2->Size(), 0);
ShapeVector shape4{3, 4};
tensor2->set_data_type(kNumberTypeInt32);
tensor2->set_shape(shape4);
ASSERT_EQ(tensor2->data_type(), kNumberTypeInt32);
ASSERT_EQ(tensor2->shape(), shape4);
}
/// Feature: MindAPI /// Feature: MindAPI
/// Description: test utils API. /// Description: test utils API.
/// Expectation: Tensor API work as expected. /// Expectation: Tensor API work as expected.