Add TensorData::has_sub_data() method

This commit is contained in:
He Wei 2022-06-13 17:07:38 +08:00
parent 32a19e6623
commit 2c24157ccb
4 changed files with 96 additions and 4 deletions

View File

@ -171,6 +171,8 @@ class TensorDataNumpy : public TensorData {
bool is_sub_data() const override { return false; }
bool has_sub_data() const override { return false; }
/// To string.
std::string ToString(const TypeId, const ShapeVector &, bool use_comma) const override {
if (use_comma) {

View File

@ -386,7 +386,7 @@ template <typename T>
class TensorDataImpl : public TensorData {
public:
explicit TensorDataImpl(const ShapeVector &shape) : ndim_(shape.size()), data_size_(SizeOf(shape)) {}
~TensorDataImpl() = default;
~TensorDataImpl() override = default;
TensorDataImpl(const ShapeVector &shape, void *data, size_t data_len)
: ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData<T>(shape, data, data_len)) {}
@ -412,6 +412,8 @@ class TensorDataImpl : public TensorData {
bool is_sub_data() const override { return false; }
bool has_sub_data() const override { return false; }
void *data() override {
if (data_ == nullptr) {
if (data_size_ > INT32_MAX) {
@ -456,6 +458,17 @@ class TensorDataImpl : public TensorData {
std::unique_ptr<T[]> data_;
};
// Tensor chunk data.
template <typename T>
class TensorChunkData : public TensorDataImpl<T> {
public:
explicit TensorChunkData(size_t size) : TensorDataImpl<T>(ShapeVector{static_cast<int64_t>(size)}) {}
~TensorChunkData() override = default;
bool has_sub_data() const override { return true; }
};
// TensorSubData is the base class to provide tensor data as a segment from an owner tensor data.
class TensorSubData : public TensorData {
public:
@ -472,6 +485,8 @@ class TensorSubData : public TensorData {
bool is_sub_data() const override { return true; }
bool has_sub_data() const override { return false; }
void *data() override {
// Set data initialized if data() is called.
data_initialized_ = true;
@ -562,6 +577,49 @@ TensorDataPtr MakeTensorData(TypeId data_type, const ShapeVector &shape, const A
MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << ".";
}
template <typename... Args>
TensorDataPtr MakeChunkData(TypeId data_type, size_t size) {
switch (data_type) {
case kNumberTypeBool:
return std::make_shared<TensorChunkData<bool>>(size);
case kNumberTypeUInt8:
return std::make_shared<TensorChunkData<uint8_t>>(size);
case kNumberTypeInt8:
return std::make_shared<TensorChunkData<int8_t>>(size);
case kNumberTypeInt16:
return std::make_shared<TensorChunkData<int16_t>>(size);
case kNumberTypeInt32:
return std::make_shared<TensorChunkData<int32_t>>(size);
case kNumberTypeInt64:
return std::make_shared<TensorChunkData<int64_t>>(size);
case kNumberTypeUInt16:
return std::make_shared<TensorChunkData<uint16_t>>(size);
case kNumberTypeUInt32:
return std::make_shared<TensorChunkData<uint32_t>>(size);
case kNumberTypeUInt64:
return std::make_shared<TensorChunkData<uint64_t>>(size);
case kNumberTypeFloat16:
return std::make_shared<TensorChunkData<float16>>(size);
case kNumberTypeFloat:
return std::make_shared<TensorChunkData<float>>(size);
case kNumberTypeFloat32:
return std::make_shared<TensorChunkData<float>>(size);
case kNumberTypeFloat64:
return std::make_shared<TensorChunkData<double>>(size);
case kNumberTypeComplex64:
return std::make_shared<TensorChunkData<ComplexStorage<float>>>(size);
case kNumberTypeComplex128:
return std::make_shared<TensorChunkData<ComplexStorage<double>>>(size);
case kObjectTypeString:
return std::make_shared<TensorChunkData<uint8_t>>(size);
case kObjectTypeTensorType:
return std::make_shared<TensorChunkData<int>>(size);
default:
break;
}
MS_LOG(EXCEPTION) << "Cannot construct chunk data because of unsupported data type: " << data_type << ".";
}
template <typename T>
TensorDataPtr MakeSubData(const TensorPtr &owner, size_t offset, const TensorDataPtr &data) {
const size_t data_bytes = data->nbytes();
@ -704,6 +762,9 @@ Tensor::Tensor(bool input, const TypePtr &data_type)
data_(MakeTensorData(data_type_, {}, input)),
id_(MakeId()) {}
Tensor::Tensor(TypeId data_type, size_t data_size)
: Tensor(data_type, ShapeVector{static_cast<int64_t>(data_size)}, MakeChunkData(data_type, data_size)) {}
bool Tensor::operator==(const Tensor &tensor) const {
return (&tensor == this || (MetaTensor::operator==(tensor) && data_ == tensor.data_));
}
@ -901,11 +962,9 @@ TensorPtrList Tensor::FlattenTensors(const TensorPtrList &tensors, size_t fusion
for (auto &type_group : group_info) {
auto chunk_dtype = normalize_type(type_group.first);
for (auto &chunk : type_group.second) {
// Chunk tensor is always 1 rank.
ShapeVector shape{static_cast<int64_t>(chunk.size)};
// Create chunk thensor as a lazy initialized tensor, the tensor data
// will be allocated when we begin to copy small tensors data into it.
auto chunk_tensor = std::make_shared<Tensor>(chunk_dtype, shape);
auto chunk_tensor = std::make_shared<Tensor>(chunk_dtype, chunk.size);
// Reset and copy tensors data.
size_t offset = 0;
for (auto &tensor : chunk.tensors) {

View File

@ -90,6 +90,11 @@ class MS_CORE_API TensorData {
/// \return Whether this tensor data is sub data.
virtual bool is_sub_data() const = 0;
/// \brief Check whether this tensor data has sub data.
///
/// \return True if this tensor data has sub data, otherwise false.
virtual bool has_sub_data() const = 0;
/// \brief Whether the data are equal.
///
/// \param[in] other Another TensorData.
@ -238,6 +243,12 @@ class MS_CORE_API Tensor final : public MetaTensor {
/// \param[in] data_type [TypeId] data type.
explicit Tensor(bool input, const TypePtr &data_type = nullptr);
/// \brief Create a chunk tensor with the given data size.
///
/// \param[in] data_type [TypeId] Data type of the tensor.
/// \param[in] data_size The tensor chunk data size in number of elements.
Tensor(TypeId data_type, size_t data_size);
/// Destructor of Tensor.
~Tensor() override = default;

View File

@ -135,4 +135,24 @@ TEST_F(TestAnf, test_FlatParameterFinder) {
assert(flat_param6 == nullptr);
assert(offset6 == 0);
}
/// Feature: Flatten tensor
/// Description: Test is_sub_data() & has_sub_data() api
/// Expectation: API works as expected.
TEST_F(TestAnf, test_TensorWithSubData) {
auto t1 = std::make_shared<Tensor>(0.1f);
auto t2 = std::make_shared<Tensor>(0.2f);
auto t3 = std::make_shared<Tensor>(0.3f);
auto t4 = std::make_shared<Tensor>(0.4f);
assert(!t1->data().is_sub_data());
assert(!t1->data().has_sub_data());
auto flat_tensors = Tensor::FlattenTensors(TensorPtrList{t1, t2, t3});
assert(flat_tensors.size() == 1);
assert(!flat_tensors[0]->data().is_sub_data());
assert(flat_tensors[0]->data().has_sub_data());
assert(t1->data().is_sub_data());
assert(!t1->data().has_sub_data());
assert(t2->data().is_sub_data());
assert(!t2->data().has_sub_data());
}
} // namespace mindspore