forked from mindspore-Ecosystem/mindspore
Support user data in Tensor
This commit is contained in:
parent
512a0afe66
commit
6313bc0cab
|
@ -631,7 +631,8 @@ Tensor::Tensor(const Tensor &tensor)
|
|||
hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
|
||||
padding_type_(tensor.padding_type()),
|
||||
device_event_(tensor.device_event_),
|
||||
lazy_callback_(tensor.lazy_callback_) {}
|
||||
lazy_callback_(tensor.lazy_callback_),
|
||||
user_data_(tensor.user_data_) {}
|
||||
|
||||
Tensor::Tensor(const Tensor &tensor, TypeId data_type)
|
||||
: MetaTensor(data_type, tensor.shape_),
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "ir/meta_tensor.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "base/float16.h"
|
||||
#include "base/user_data.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "utils/ms_exception.h"
|
||||
#include "ir/device_event.h"
|
||||
|
@ -574,6 +575,32 @@ class MS_CORE_API Tensor final : public MetaTensor {
|
|||
/// \return The memory chunk pointer and offset, nullptr and 0 if no memory chunk exists.
|
||||
std::pair<void *, size_t> GetChunkOffset() const;
|
||||
|
||||
/// \brief Set user data.
|
||||
///
|
||||
/// \param[in] key The key of user data.
|
||||
/// \param[in] value The value of user data, nullptr to erase the user data.
|
||||
template <typename T>
|
||||
void set_user_data(const std::string &key, const std::shared_ptr<T> &value) {
|
||||
user_data_.set<T>(key, value);
|
||||
}
|
||||
|
||||
/// \brief Get user data.
|
||||
///
|
||||
/// \param[in] key The key of user data.
|
||||
///
|
||||
/// \return Pointer to user data.
|
||||
template <typename T>
|
||||
std::shared_ptr<T> user_data(const std::string &key) const {
|
||||
return user_data_.get<T>(key);
|
||||
}
|
||||
|
||||
/// \brief Check whether there is corresponding user data by the given key.
|
||||
///
|
||||
/// \param[in] key The key of user data.
|
||||
///
|
||||
/// \return True if it exists, otherwise false.
|
||||
bool has_user_data(const std::string &key) const { return user_data_.has(key); }
|
||||
|
||||
/// \brief Reset tensors data so that they are using contiguous memory chunks grouped by data type.
|
||||
///
|
||||
/// \param[in] tensors The tensors to be processed.
|
||||
|
@ -619,6 +646,7 @@ class MS_CORE_API Tensor final : public MetaTensor {
|
|||
TypePtr cast_dtype_{nullptr};
|
||||
std::shared_ptr<DeviceEvent> device_event_{nullptr};
|
||||
std::function<void(void)> lazy_callback_{nullptr};
|
||||
UserData user_data_;
|
||||
};
|
||||
|
||||
// CSRTensor entity class
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,6 +16,7 @@
|
|||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "common/common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
|
@ -28,7 +29,6 @@ using mindspore::tensor::TensorPy;
|
|||
|
||||
namespace mindspore {
|
||||
namespace tensor {
|
||||
|
||||
class TestMetaTensor : public UT::Common {
|
||||
public:
|
||||
TestMetaTensor() {}
|
||||
|
@ -374,5 +374,31 @@ TEST_F(TestTensor, TensorPyCast) {
|
|||
ASSERT_EQ(shape, shape3);
|
||||
}
|
||||
|
||||
/// Feature: Tensor
|
||||
/// Description: Test user data for Tensor.
|
||||
/// Expectation: user data works as expected.
|
||||
TEST_F(TestTensor, TensorWithUserData) {
|
||||
auto tensor = std::make_shared<Tensor>(3.14f);
|
||||
auto mydata = std::make_shared<std::string>("mydata");
|
||||
|
||||
// Set user data.
|
||||
tensor->set_user_data("mykey", mydata);
|
||||
ASSERT_TRUE(tensor->has_user_data("mykey"));
|
||||
ASSERT_EQ(tensor->user_data<std::string>("mykey"), mydata);
|
||||
|
||||
// Copy with user data.
|
||||
auto tensor1 = std::make_shared<Tensor>(*tensor);
|
||||
ASSERT_TRUE(tensor1->has_user_data("mykey"));
|
||||
ASSERT_EQ(tensor1->user_data<std::string>("mykey"), mydata);
|
||||
|
||||
// Erase user data.
|
||||
tensor->set_user_data<std::string>("mykey", nullptr);
|
||||
ASSERT_FALSE(tensor->has_user_data("mykey"));
|
||||
ASSERT_EQ(tensor->user_data<std::string>("mykey"), nullptr);
|
||||
|
||||
// user data in tensor1 is not removed.
|
||||
ASSERT_TRUE(tensor1->has_user_data("mykey"));
|
||||
ASSERT_EQ(tensor1->user_data<std::string>("mykey"), mydata);
|
||||
}
|
||||
} // namespace tensor
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue