From 6313bc0cabf51feab309965e964e29fc5dcfcc7d Mon Sep 17 00:00:00 2001 From: He Wei Date: Wed, 18 May 2022 16:43:18 +0800 Subject: [PATCH] Support user data in Tensor --- mindspore/core/ir/tensor.cc | 3 ++- mindspore/core/ir/tensor.h | 28 +++++++++++++++++++++++++++ tests/ut/cpp/ir/meta_tensor_test.cc | 30 +++++++++++++++++++++++++++-- 3 files changed, 58 insertions(+), 3 deletions(-) diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc index f3e95eedb51..cda263f9336 100644 --- a/mindspore/core/ir/tensor.cc +++ b/mindspore/core/ir/tensor.cc @@ -631,7 +631,8 @@ Tensor::Tensor(const Tensor &tensor) hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_), padding_type_(tensor.padding_type()), device_event_(tensor.device_event_), - lazy_callback_(tensor.lazy_callback_) {} + lazy_callback_(tensor.lazy_callback_), + user_data_(tensor.user_data_) {} Tensor::Tensor(const Tensor &tensor, TypeId data_type) : MetaTensor(data_type, tensor.shape_), diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index e71edd9fea7..8cc5c8e29de 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -29,6 +29,7 @@ #include "ir/meta_tensor.h" #include "utils/log_adapter.h" #include "base/float16.h" +#include "base/user_data.h" #include "utils/shape_utils.h" #include "utils/ms_exception.h" #include "ir/device_event.h" @@ -574,6 +575,32 @@ class MS_CORE_API Tensor final : public MetaTensor { /// \return The memory chunk pointer and offset, nullptr and 0 if no memory chunk exists. std::pair GetChunkOffset() const; + /// \brief Set user data. + /// + /// \param[in] key The key of user data. + /// \param[in] value The value of user data, nullptr to erase the user data. + template + void set_user_data(const std::string &key, const std::shared_ptr &value) { + user_data_.set(key, value); + } + + /// \brief Get user data. + /// + /// \param[in] key The key of user data. + /// + /// \return Pointer to user data. + template + std::shared_ptr user_data(const std::string &key) const { + return user_data_.get(key); + } + + /// \brief Check whether there is corresponding user data by the given key. + /// + /// \param[in] key The key of user data. + /// + /// \return True if it exists, otherwise false. + bool has_user_data(const std::string &key) const { return user_data_.has(key); } + /// \brief Reset tensors data so that they are using contiguous memory chunks grouped by data type. /// /// \param[in] tensors The tensors to be processed. @@ -619,6 +646,7 @@ class MS_CORE_API Tensor final : public MetaTensor { TypePtr cast_dtype_{nullptr}; std::shared_ptr device_event_{nullptr}; std::function lazy_callback_{nullptr}; + UserData user_data_; }; // CSRTensor entity class diff --git a/tests/ut/cpp/ir/meta_tensor_test.cc b/tests/ut/cpp/ir/meta_tensor_test.cc index cd0d8e267c3..d2bba672929 100644 --- a/tests/ut/cpp/ir/meta_tensor_test.cc +++ b/tests/ut/cpp/ir/meta_tensor_test.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #include #include #include +#include #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" @@ -28,7 +29,6 @@ using mindspore::tensor::TensorPy; namespace mindspore { namespace tensor { - class TestMetaTensor : public UT::Common { public: TestMetaTensor() {} @@ -374,5 +374,31 @@ TEST_F(TestTensor, TensorPyCast) { ASSERT_EQ(shape, shape3); } +/// Feature: Tensor +/// Description: Test user data for Tensor. +/// Expectation: user data works as expected. +TEST_F(TestTensor, TensorWithUserData) { + auto tensor = std::make_shared(3.14f); + auto mydata = std::make_shared("mydata"); + + // Set user data. + tensor->set_user_data("mykey", mydata); + ASSERT_TRUE(tensor->has_user_data("mykey")); + ASSERT_EQ(tensor->user_data("mykey"), mydata); + + // Copy with user data. + auto tensor1 = std::make_shared(*tensor); + ASSERT_TRUE(tensor1->has_user_data("mykey")); + ASSERT_EQ(tensor1->user_data("mykey"), mydata); + + // Erase user data. + tensor->set_user_data("mykey", nullptr); + ASSERT_FALSE(tensor->has_user_data("mykey")); + ASSERT_EQ(tensor->user_data("mykey"), nullptr); + + // user data in tensor1 is not removed. + ASSERT_TRUE(tensor1->has_user_data("mykey")); + ASSERT_EQ(tensor1->user_data("mykey"), mydata); +} } // namespace tensor } // namespace mindspore