forked from mindspore-Ecosystem/mindspore
Add TensorData::has_sub_data() method
This commit is contained in:
parent
32a19e6623
commit
2c24157ccb
|
@ -171,6 +171,8 @@ class TensorDataNumpy : public TensorData {
|
|||
|
||||
bool is_sub_data() const override { return false; }
|
||||
|
||||
bool has_sub_data() const override { return false; }
|
||||
|
||||
/// To string.
|
||||
std::string ToString(const TypeId, const ShapeVector &, bool use_comma) const override {
|
||||
if (use_comma) {
|
||||
|
|
|
@ -386,7 +386,7 @@ template <typename T>
|
|||
class TensorDataImpl : public TensorData {
|
||||
public:
|
||||
explicit TensorDataImpl(const ShapeVector &shape) : ndim_(shape.size()), data_size_(SizeOf(shape)) {}
|
||||
~TensorDataImpl() = default;
|
||||
~TensorDataImpl() override = default;
|
||||
|
||||
TensorDataImpl(const ShapeVector &shape, void *data, size_t data_len)
|
||||
: ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData<T>(shape, data, data_len)) {}
|
||||
|
@ -412,6 +412,8 @@ class TensorDataImpl : public TensorData {
|
|||
|
||||
bool is_sub_data() const override { return false; }
|
||||
|
||||
bool has_sub_data() const override { return false; }
|
||||
|
||||
void *data() override {
|
||||
if (data_ == nullptr) {
|
||||
if (data_size_ > INT32_MAX) {
|
||||
|
@ -456,6 +458,17 @@ class TensorDataImpl : public TensorData {
|
|||
std::unique_ptr<T[]> data_;
|
||||
};
|
||||
|
||||
// Tensor chunk data.
|
||||
template <typename T>
|
||||
class TensorChunkData : public TensorDataImpl<T> {
|
||||
public:
|
||||
explicit TensorChunkData(size_t size) : TensorDataImpl<T>(ShapeVector{static_cast<int64_t>(size)}) {}
|
||||
|
||||
~TensorChunkData() override = default;
|
||||
|
||||
bool has_sub_data() const override { return true; }
|
||||
};
|
||||
|
||||
// TensorSubData is the base class to provide tensor data as a segment from an owner tensor data.
|
||||
class TensorSubData : public TensorData {
|
||||
public:
|
||||
|
@ -472,6 +485,8 @@ class TensorSubData : public TensorData {
|
|||
|
||||
bool is_sub_data() const override { return true; }
|
||||
|
||||
bool has_sub_data() const override { return false; }
|
||||
|
||||
void *data() override {
|
||||
// Set data initialized if data() is called.
|
||||
data_initialized_ = true;
|
||||
|
@ -562,6 +577,49 @@ TensorDataPtr MakeTensorData(TypeId data_type, const ShapeVector &shape, const A
|
|||
MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << ".";
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
TensorDataPtr MakeChunkData(TypeId data_type, size_t size) {
|
||||
switch (data_type) {
|
||||
case kNumberTypeBool:
|
||||
return std::make_shared<TensorChunkData<bool>>(size);
|
||||
case kNumberTypeUInt8:
|
||||
return std::make_shared<TensorChunkData<uint8_t>>(size);
|
||||
case kNumberTypeInt8:
|
||||
return std::make_shared<TensorChunkData<int8_t>>(size);
|
||||
case kNumberTypeInt16:
|
||||
return std::make_shared<TensorChunkData<int16_t>>(size);
|
||||
case kNumberTypeInt32:
|
||||
return std::make_shared<TensorChunkData<int32_t>>(size);
|
||||
case kNumberTypeInt64:
|
||||
return std::make_shared<TensorChunkData<int64_t>>(size);
|
||||
case kNumberTypeUInt16:
|
||||
return std::make_shared<TensorChunkData<uint16_t>>(size);
|
||||
case kNumberTypeUInt32:
|
||||
return std::make_shared<TensorChunkData<uint32_t>>(size);
|
||||
case kNumberTypeUInt64:
|
||||
return std::make_shared<TensorChunkData<uint64_t>>(size);
|
||||
case kNumberTypeFloat16:
|
||||
return std::make_shared<TensorChunkData<float16>>(size);
|
||||
case kNumberTypeFloat:
|
||||
return std::make_shared<TensorChunkData<float>>(size);
|
||||
case kNumberTypeFloat32:
|
||||
return std::make_shared<TensorChunkData<float>>(size);
|
||||
case kNumberTypeFloat64:
|
||||
return std::make_shared<TensorChunkData<double>>(size);
|
||||
case kNumberTypeComplex64:
|
||||
return std::make_shared<TensorChunkData<ComplexStorage<float>>>(size);
|
||||
case kNumberTypeComplex128:
|
||||
return std::make_shared<TensorChunkData<ComplexStorage<double>>>(size);
|
||||
case kObjectTypeString:
|
||||
return std::make_shared<TensorChunkData<uint8_t>>(size);
|
||||
case kObjectTypeTensorType:
|
||||
return std::make_shared<TensorChunkData<int>>(size);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Cannot construct chunk data because of unsupported data type: " << data_type << ".";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
TensorDataPtr MakeSubData(const TensorPtr &owner, size_t offset, const TensorDataPtr &data) {
|
||||
const size_t data_bytes = data->nbytes();
|
||||
|
@ -704,6 +762,9 @@ Tensor::Tensor(bool input, const TypePtr &data_type)
|
|||
data_(MakeTensorData(data_type_, {}, input)),
|
||||
id_(MakeId()) {}
|
||||
|
||||
Tensor::Tensor(TypeId data_type, size_t data_size)
|
||||
: Tensor(data_type, ShapeVector{static_cast<int64_t>(data_size)}, MakeChunkData(data_type, data_size)) {}
|
||||
|
||||
bool Tensor::operator==(const Tensor &tensor) const {
|
||||
return (&tensor == this || (MetaTensor::operator==(tensor) && data_ == tensor.data_));
|
||||
}
|
||||
|
@ -901,11 +962,9 @@ TensorPtrList Tensor::FlattenTensors(const TensorPtrList &tensors, size_t fusion
|
|||
for (auto &type_group : group_info) {
|
||||
auto chunk_dtype = normalize_type(type_group.first);
|
||||
for (auto &chunk : type_group.second) {
|
||||
// Chunk tensor is always 1 rank.
|
||||
ShapeVector shape{static_cast<int64_t>(chunk.size)};
|
||||
// Create chunk thensor as a lazy initialized tensor, the tensor data
|
||||
// will be allocated when we begin to copy small tensors data into it.
|
||||
auto chunk_tensor = std::make_shared<Tensor>(chunk_dtype, shape);
|
||||
auto chunk_tensor = std::make_shared<Tensor>(chunk_dtype, chunk.size);
|
||||
// Reset and copy tensors data.
|
||||
size_t offset = 0;
|
||||
for (auto &tensor : chunk.tensors) {
|
||||
|
|
|
@ -90,6 +90,11 @@ class MS_CORE_API TensorData {
|
|||
/// \return Whether this tensor data is sub data.
|
||||
virtual bool is_sub_data() const = 0;
|
||||
|
||||
/// \brief Check whether this tensor data has sub data.
|
||||
///
|
||||
/// \return True if this tensor data has sub data, otherwise false.
|
||||
virtual bool has_sub_data() const = 0;
|
||||
|
||||
/// \brief Whether the data are equal.
|
||||
///
|
||||
/// \param[in] other Another TensorData.
|
||||
|
@ -238,6 +243,12 @@ class MS_CORE_API Tensor final : public MetaTensor {
|
|||
/// \param[in] data_type [TypeId] data type.
|
||||
explicit Tensor(bool input, const TypePtr &data_type = nullptr);
|
||||
|
||||
/// \brief Create a chunk tensor with the given data size.
|
||||
///
|
||||
/// \param[in] data_type [TypeId] Data type of the tensor.
|
||||
/// \param[in] data_size The tensor chunk data size in number of elements.
|
||||
Tensor(TypeId data_type, size_t data_size);
|
||||
|
||||
/// Destructor of Tensor.
|
||||
~Tensor() override = default;
|
||||
|
||||
|
|
|
@ -135,4 +135,24 @@ TEST_F(TestAnf, test_FlatParameterFinder) {
|
|||
assert(flat_param6 == nullptr);
|
||||
assert(offset6 == 0);
|
||||
}
|
||||
|
||||
/// Feature: Flatten tensor
|
||||
/// Description: Test is_sub_data() & has_sub_data() api
|
||||
/// Expectation: API works as expected.
|
||||
TEST_F(TestAnf, test_TensorWithSubData) {
|
||||
auto t1 = std::make_shared<Tensor>(0.1f);
|
||||
auto t2 = std::make_shared<Tensor>(0.2f);
|
||||
auto t3 = std::make_shared<Tensor>(0.3f);
|
||||
auto t4 = std::make_shared<Tensor>(0.4f);
|
||||
assert(!t1->data().is_sub_data());
|
||||
assert(!t1->data().has_sub_data());
|
||||
auto flat_tensors = Tensor::FlattenTensors(TensorPtrList{t1, t2, t3});
|
||||
assert(flat_tensors.size() == 1);
|
||||
assert(!flat_tensors[0]->data().is_sub_data());
|
||||
assert(flat_tensors[0]->data().has_sub_data());
|
||||
assert(t1->data().is_sub_data());
|
||||
assert(!t1->data().has_sub_data());
|
||||
assert(t2->data().is_sub_data());
|
||||
assert(!t2->data().has_sub_data());
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue