Support user data in Tensor

This commit is contained in:
He Wei 2022-05-18 16:43:18 +08:00
parent 512a0afe66
commit 6313bc0cab
3 changed files with 58 additions and 3 deletions

View File

@ -631,7 +631,8 @@ Tensor::Tensor(const Tensor &tensor)
hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
padding_type_(tensor.padding_type()),
device_event_(tensor.device_event_),
lazy_callback_(tensor.lazy_callback_) {}
lazy_callback_(tensor.lazy_callback_),
user_data_(tensor.user_data_) {}
Tensor::Tensor(const Tensor &tensor, TypeId data_type)
: MetaTensor(data_type, tensor.shape_),

View File

@ -29,6 +29,7 @@
#include "ir/meta_tensor.h"
#include "utils/log_adapter.h"
#include "base/float16.h"
#include "base/user_data.h"
#include "utils/shape_utils.h"
#include "utils/ms_exception.h"
#include "ir/device_event.h"
@ -574,6 +575,32 @@ class MS_CORE_API Tensor final : public MetaTensor {
/// \return The memory chunk pointer and offset, nullptr and 0 if no memory chunk exists.
std::pair<void *, size_t> GetChunkOffset() const;
/// \brief Set user data.
///
/// \param[in] key The key of user data.
/// \param[in] value The value of user data, nullptr to erase the user data.
template <typename T>
void set_user_data(const std::string &key, const std::shared_ptr<T> &value) {
user_data_.set<T>(key, value);
}
/// \brief Get user data.
///
/// \param[in] key The key of user data.
///
/// \return Pointer to user data.
template <typename T>
std::shared_ptr<T> user_data(const std::string &key) const {
return user_data_.get<T>(key);
}
/// \brief Check whether there is corresponding user data by the given key.
///
/// \param[in] key The key of user data.
///
/// \return True if it exists, otherwise false.
bool has_user_data(const std::string &key) const { return user_data_.has(key); }
/// \brief Reset tensors data so that they are using contiguous memory chunks grouped by data type.
///
/// \param[in] tensors The tensors to be processed.
@ -619,6 +646,7 @@ class MS_CORE_API Tensor final : public MetaTensor {
TypePtr cast_dtype_{nullptr};
std::shared_ptr<DeviceEvent> device_event_{nullptr};
std::function<void(void)> lazy_callback_{nullptr};
UserData user_data_;
};
// CSRTensor entity class

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@
#include <iostream>
#include <memory>
#include <vector>
#include <string>
#include "common/common_test.h"
#include "common/py_func_graph_fetcher.h"
@ -28,7 +29,6 @@ using mindspore::tensor::TensorPy;
namespace mindspore {
namespace tensor {
class TestMetaTensor : public UT::Common {
public:
TestMetaTensor() {}
@ -374,5 +374,31 @@ TEST_F(TestTensor, TensorPyCast) {
ASSERT_EQ(shape, shape3);
}
/// Feature: Tensor
/// Description: Test user data for Tensor.
/// Expectation: user data works as expected.
TEST_F(TestTensor, TensorWithUserData) {
auto tensor = std::make_shared<Tensor>(3.14f);
auto mydata = std::make_shared<std::string>("mydata");
// Set user data.
tensor->set_user_data("mykey", mydata);
ASSERT_TRUE(tensor->has_user_data("mykey"));
ASSERT_EQ(tensor->user_data<std::string>("mykey"), mydata);
// Copy with user data.
auto tensor1 = std::make_shared<Tensor>(*tensor);
ASSERT_TRUE(tensor1->has_user_data("mykey"));
ASSERT_EQ(tensor1->user_data<std::string>("mykey"), mydata);
// Erase user data.
tensor->set_user_data<std::string>("mykey", nullptr);
ASSERT_FALSE(tensor->has_user_data("mykey"));
ASSERT_EQ(tensor->user_data<std::string>("mykey"), nullptr);
// user data in tensor1 is not removed.
ASSERT_TRUE(tensor1->has_user_data("mykey"));
ASSERT_EQ(tensor1->user_data<std::string>("mykey"), mydata);
}
} // namespace tensor
} // namespace mindspore