Fix tensor with dynamic shape problem

Set data size to zero for tensor with dynamic shape,
prevent unexpected memory allocation for tensor data.
This commit is contained in:
He Wei 2022-02-14 17:40:05 +08:00
parent c45e8b5c9b
commit 36ab9c3d4b
2 changed files with 42 additions and 1 deletions

View File

@ -44,7 +44,15 @@ static TypeId TypeIdOf(const TypePtr &data_type, TypeId defaultTypeId) {
}
static size_t SizeOf(const ShapeVector &shape) {
return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies<size_t>());
int64_t data_size = 1;
for (auto dim : shape) {
if (dim < 0) {
// For dynamic shape which has negative dimensions, data size should be zero.
return 0;
}
data_size *= dim;
}
return static_cast<size_t>(data_size);
}
static std::string ShapeToString(const ShapeVector &shape) {

View File

@ -347,6 +347,39 @@ TEST_F(TestMindApi, test_tensor_api) {
ASSERT_EQ(tensor_type->cast<TensorTypePtr>()->element()->type_id(), kNumberTypeFloat32);
}
/// Feature: MindAPI
/// Description: test Tensor with dynamic shape.
/// Expectation: Tensor API work as expected.
TEST_F(TestMindApi, test_tensor_with_dyn_shape) {
ShapeVector shape{1, 2, -1, -2};
auto tensor = MakeShared<Tensor>(kNumberTypeFloat32, shape);
ASSERT_EQ(tensor->data_type(), kNumberTypeFloat32);
ASSERT_EQ(tensor->shape(), shape);
ASSERT_EQ(tensor->DataSize(), 0);
ASSERT_EQ(tensor->Size(), 0);
ShapeVector shape2{2, 3};
tensor->set_data_type(kNumberTypeInt32);
tensor->set_shape(shape2);
ASSERT_EQ(tensor->data_type(), kNumberTypeInt32);
ASSERT_EQ(tensor->shape(), shape2);
ShapeVector shape3{1, -1, 3};
auto tensor2 = MakeShared<Tensor>(kNumberTypeFloat32, shape);
ASSERT_EQ(tensor2->data_type(), kNumberTypeFloat32);
ASSERT_EQ(tensor2->shape(), shape);
ASSERT_EQ(tensor2->DataSize(), 0);
ASSERT_EQ(tensor2->Size(), 0);
ShapeVector shape4{3, 4};
tensor2->set_data_type(kNumberTypeInt32);
tensor2->set_shape(shape4);
ASSERT_EQ(tensor2->data_type(), kNumberTypeInt32);
ASSERT_EQ(tensor2->shape(), shape4);
}
/// Feature: MindAPI
/// Description: test utils API.
/// Expectation: Tensor API work as expected.