From 2c24157ccbbc9035294259ede774d0793731ac1a Mon Sep 17 00:00:00 2001 From: He Wei Date: Mon, 13 Jun 2022 17:07:38 +0800 Subject: [PATCH] Add TensorData::has_sub_data() method --- mindspore/ccsrc/pybind_api/ir/tensor_py.cc | 2 + mindspore/core/ir/tensor.cc | 67 ++++++++++++++++++++-- mindspore/core/ir/tensor.h | 11 ++++ tests/ut/cpp/ir/anf_test.cc | 20 +++++++ 4 files changed, 96 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc index 17176440139..195a78feaf0 100644 --- a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc @@ -171,6 +171,8 @@ class TensorDataNumpy : public TensorData { bool is_sub_data() const override { return false; } + bool has_sub_data() const override { return false; } + /// To string. std::string ToString(const TypeId, const ShapeVector &, bool use_comma) const override { if (use_comma) { diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc index ad4e6aee7e3..9aa50247c67 100644 --- a/mindspore/core/ir/tensor.cc +++ b/mindspore/core/ir/tensor.cc @@ -386,7 +386,7 @@ template class TensorDataImpl : public TensorData { public: explicit TensorDataImpl(const ShapeVector &shape) : ndim_(shape.size()), data_size_(SizeOf(shape)) {} - ~TensorDataImpl() = default; + ~TensorDataImpl() override = default; TensorDataImpl(const ShapeVector &shape, void *data, size_t data_len) : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData(shape, data, data_len)) {} @@ -412,6 +412,8 @@ class TensorDataImpl : public TensorData { bool is_sub_data() const override { return false; } + bool has_sub_data() const override { return false; } + void *data() override { if (data_ == nullptr) { if (data_size_ > INT32_MAX) { @@ -456,6 +458,17 @@ class TensorDataImpl : public TensorData { std::unique_ptr data_; }; +// Tensor chunk data. +template +class TensorChunkData : public TensorDataImpl { + public: + explicit TensorChunkData(size_t size) : TensorDataImpl(ShapeVector{static_cast(size)}) {} + + ~TensorChunkData() override = default; + + bool has_sub_data() const override { return true; } +}; + // TensorSubData is the base class to provide tensor data as a segment from an owner tensor data. class TensorSubData : public TensorData { public: @@ -472,6 +485,8 @@ class TensorSubData : public TensorData { bool is_sub_data() const override { return true; } + bool has_sub_data() const override { return false; } + void *data() override { // Set data initialized if data() is called. data_initialized_ = true; @@ -562,6 +577,49 @@ TensorDataPtr MakeTensorData(TypeId data_type, const ShapeVector &shape, const A MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << "."; } +template +TensorDataPtr MakeChunkData(TypeId data_type, size_t size) { + switch (data_type) { + case kNumberTypeBool: + return std::make_shared>(size); + case kNumberTypeUInt8: + return std::make_shared>(size); + case kNumberTypeInt8: + return std::make_shared>(size); + case kNumberTypeInt16: + return std::make_shared>(size); + case kNumberTypeInt32: + return std::make_shared>(size); + case kNumberTypeInt64: + return std::make_shared>(size); + case kNumberTypeUInt16: + return std::make_shared>(size); + case kNumberTypeUInt32: + return std::make_shared>(size); + case kNumberTypeUInt64: + return std::make_shared>(size); + case kNumberTypeFloat16: + return std::make_shared>(size); + case kNumberTypeFloat: + return std::make_shared>(size); + case kNumberTypeFloat32: + return std::make_shared>(size); + case kNumberTypeFloat64: + return std::make_shared>(size); + case kNumberTypeComplex64: + return std::make_shared>>(size); + case kNumberTypeComplex128: + return std::make_shared>>(size); + case kObjectTypeString: + return std::make_shared>(size); + case kObjectTypeTensorType: + return std::make_shared>(size); + default: + break; + } + MS_LOG(EXCEPTION) << "Cannot construct chunk data because of unsupported data type: " << data_type << "."; +} + template TensorDataPtr MakeSubData(const TensorPtr &owner, size_t offset, const TensorDataPtr &data) { const size_t data_bytes = data->nbytes(); @@ -704,6 +762,9 @@ Tensor::Tensor(bool input, const TypePtr &data_type) data_(MakeTensorData(data_type_, {}, input)), id_(MakeId()) {} +Tensor::Tensor(TypeId data_type, size_t data_size) + : Tensor(data_type, ShapeVector{static_cast(data_size)}, MakeChunkData(data_type, data_size)) {} + bool Tensor::operator==(const Tensor &tensor) const { return (&tensor == this || (MetaTensor::operator==(tensor) && data_ == tensor.data_)); } @@ -901,11 +962,9 @@ TensorPtrList Tensor::FlattenTensors(const TensorPtrList &tensors, size_t fusion for (auto &type_group : group_info) { auto chunk_dtype = normalize_type(type_group.first); for (auto &chunk : type_group.second) { - // Chunk tensor is always 1 rank. - ShapeVector shape{static_cast(chunk.size)}; // Create chunk thensor as a lazy initialized tensor, the tensor data // will be allocated when we begin to copy small tensors data into it. - auto chunk_tensor = std::make_shared(chunk_dtype, shape); + auto chunk_tensor = std::make_shared(chunk_dtype, chunk.size); // Reset and copy tensors data. size_t offset = 0; for (auto &tensor : chunk.tensors) { diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index bbb600324f6..2bb3a33844b 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -90,6 +90,11 @@ class MS_CORE_API TensorData { /// \return Whether this tensor data is sub data. virtual bool is_sub_data() const = 0; + /// \brief Check whether this tensor data has sub data. + /// + /// \return True if this tensor data has sub data, otherwise false. + virtual bool has_sub_data() const = 0; + /// \brief Whether the data are equal. /// /// \param[in] other Another TensorData. @@ -238,6 +243,12 @@ class MS_CORE_API Tensor final : public MetaTensor { /// \param[in] data_type [TypeId] data type. explicit Tensor(bool input, const TypePtr &data_type = nullptr); + /// \brief Create a chunk tensor with the given data size. + /// + /// \param[in] data_type [TypeId] Data type of the tensor. + /// \param[in] data_size The tensor chunk data size in number of elements. + Tensor(TypeId data_type, size_t data_size); + /// Destructor of Tensor. ~Tensor() override = default; diff --git a/tests/ut/cpp/ir/anf_test.cc b/tests/ut/cpp/ir/anf_test.cc index abf2406e695..cef4eb62f5f 100644 --- a/tests/ut/cpp/ir/anf_test.cc +++ b/tests/ut/cpp/ir/anf_test.cc @@ -135,4 +135,24 @@ TEST_F(TestAnf, test_FlatParameterFinder) { assert(flat_param6 == nullptr); assert(offset6 == 0); } + +/// Feature: Flatten tensor +/// Description: Test is_sub_data() & has_sub_data() api +/// Expectation: API works as expected. +TEST_F(TestAnf, test_TensorWithSubData) { + auto t1 = std::make_shared(0.1f); + auto t2 = std::make_shared(0.2f); + auto t3 = std::make_shared(0.3f); + auto t4 = std::make_shared(0.4f); + assert(!t1->data().is_sub_data()); + assert(!t1->data().has_sub_data()); + auto flat_tensors = Tensor::FlattenTensors(TensorPtrList{t1, t2, t3}); + assert(flat_tensors.size() == 1); + assert(!flat_tensors[0]->data().is_sub_data()); + assert(flat_tensors[0]->data().has_sub_data()); + assert(t1->data().is_sub_data()); + assert(!t1->data().has_sub_data()); + assert(t2->data().is_sub_data()); + assert(!t2->data().has_sub_data()); +} } // namespace mindspore