From 978d7512e71854765b215a4a85018f5f97c05ea2 Mon Sep 17 00:00:00 2001
From: zhupuxu <zhupuxu@huawei.com>
Date: Wed, 31 Mar 2021 16:19:43 +0800
Subject: [PATCH] device tensor

Signed-off-by: zhupuxu <zhupuxu@huawei.com>
---
 include/api/types.h                  |  9 ++++
 mindspore/ccsrc/cxx_api/types.cc     | 30 ++++++++---
 tests/st/cpp/model/test_zero_copy.cc | 78 ++++++++++++++++++++++++++++
 3 files changed, 111 insertions(+), 6 deletions(-)

diff --git a/include/api/types.h b/include/api/types.h
index f2d919f802e..00f410c885d 100644
--- a/include/api/types.h
+++ b/include/api/types.h
@@ -47,6 +47,8 @@ class MS_API MSTensor {
                                        const void *data, size_t data_len) noexcept;
   static inline MSTensor *CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
                                           const void *data, size_t data_len) noexcept;
+  static inline MSTensor *CreateDevTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
+                                          const void *data, size_t data_len) noexcept;
   static inline MSTensor *StringsToTensor(const std::string &name, const std::vector<std::string> &str);
   static inline std::vector<std::string> TensorToStrings(const MSTensor &tensor);
   static void DestroyTensorPtr(MSTensor *tensor) noexcept;
@@ -79,6 +81,8 @@ class MS_API MSTensor {
                                 const void *data, size_t data_len) noexcept;
   static MSTensor *CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
                                    const void *data, size_t data_len) noexcept;
+  static MSTensor *CreateDevTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
+                                   const void *data, size_t data_len) noexcept;
   static MSTensor *CharStringsToTensor(const std::vector<char> &name, const std::vector<std::vector<char>> &str);
   static std::vector<std::vector<char>> TensorToStringChars(const MSTensor &tensor);
 
@@ -120,6 +124,11 @@ MSTensor *MSTensor::CreateRefTensor(const std::string &name, enum DataType type,
   return CreateRefTensor(StringToChar(name), type, shape, data, data_len);
 }
 
+MSTensor *MSTensor::CreateDevTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
+                                    const void *data, size_t data_len) noexcept {
+  return CreateDevTensor(StringToChar(name), type, shape, data, data_len);
+}
+
 MSTensor *MSTensor::StringsToTensor(const std::string &name, const std::vector<std::string> &str) {
   return CharStringsToTensor(StringToChar(name), VectorStringToChar(str));
 }
diff --git a/mindspore/ccsrc/cxx_api/types.cc b/mindspore/ccsrc/cxx_api/types.cc
index ba002486fdd..ef4f32c70de 100644
--- a/mindspore/ccsrc/cxx_api/types.cc
+++ b/mindspore/ccsrc/cxx_api/types.cc
@@ -103,11 +103,12 @@ class TensorDefaultImpl : public MSTensor::Impl {
 
 class TensorReferenceImpl : public MSTensor::Impl {
  public:
-  TensorReferenceImpl() : data_(nullptr), data_size_(0), name_(), type_(DataType::kTypeUnknown), shape_() {}
+  TensorReferenceImpl()
+      : data_(nullptr), data_size_(0), name_(), type_(DataType::kTypeUnknown), shape_(), is_device_(false) {}
   ~TensorReferenceImpl() override = default;
   TensorReferenceImpl(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
-                      size_t data_len)
-      : data_(data), data_size_(data_len), name_(name), type_(type), shape_(shape) {}
+                      size_t data_len, bool is_device)
+      : data_(data), data_size_(data_len), name_(name), type_(type), shape_(shape), is_device_(is_device) {}
 
   const std::string &Name() const override { return name_; }
   enum DataType DataType() const override { return type_; }
@@ -120,10 +121,10 @@ class TensorReferenceImpl : public MSTensor::Impl {
   void *MutableData() override { return const_cast<void *>(data_); }
   size_t DataSize() const override { return data_size_; }
 
-  bool IsDevice() const override { return false; }
+  bool IsDevice() const override { return is_device_; }
 
   std::shared_ptr<Impl> Clone() const override {
-    return std::make_shared<TensorReferenceImpl>(name_, type_, shape_, data_, data_size_);
+    return std::make_shared<TensorReferenceImpl>(name_, type_, shape_, data_, data_size_, is_device_);
   }
 
  protected:
@@ -132,6 +133,7 @@ class TensorReferenceImpl : public MSTensor::Impl {
   std::string name_;
   enum DataType type_;
   std::vector<int64_t> shape_;
+  bool is_device_;
 };
 
 MSTensor *MSTensor::CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
@@ -154,7 +156,23 @@ MSTensor *MSTensor::CreateRefTensor(const std::vector<char> &name, enum DataType
                                     const std::vector<int64_t> &shape, const void *data, size_t data_len) noexcept {
   std::string name_str = CharToString(name);
   try {
-    std::shared_ptr<Impl> impl = std::make_shared<TensorReferenceImpl>(name_str, type, shape, data, data_len);
+    std::shared_ptr<Impl> impl = std::make_shared<TensorReferenceImpl>(name_str, type, shape, data, data_len, false);
+    MSTensor *ret = new MSTensor(impl);
+    return ret;
+  } catch (const std::bad_alloc &) {
+    MS_LOG(ERROR) << "Malloc memory failed.";
+    return nullptr;
+  } catch (...) {
+    MS_LOG(ERROR) << "Unknown error occurred.";
+    return nullptr;
+  }
+}
+
+MSTensor *MSTensor::CreateDevTensor(const std::vector<char> &name, enum DataType type,
+                                    const std::vector<int64_t> &shape, const void *data, size_t data_len) noexcept {
+  std::string name_str = CharToString(name);
+  try {
+    std::shared_ptr<Impl> impl = std::make_shared<TensorReferenceImpl>(name_str, type, shape, data, data_len, true);
     MSTensor *ret = new MSTensor(impl);
     return ret;
   } catch (const std::bad_alloc &) {
diff --git a/tests/st/cpp/model/test_zero_copy.cc b/tests/st/cpp/model/test_zero_copy.cc
index ff4d82ab1e8..892412a2ce2 100644
--- a/tests/st/cpp/model/test_zero_copy.cc
+++ b/tests/st/cpp/model/test_zero_copy.cc
@@ -18,6 +18,7 @@
 #include <vector>
 #include <fstream>
 #include <iostream>
+#include <sys/time.h>
 #include "common/common_test.h"
 #include "include/api/types.h"
 #include "minddata/dataset/include/execute.h"
@@ -40,9 +41,12 @@ class TestZeroCopy : public ST::Common {
   TestZeroCopy() {}
 };
 
+typedef timeval TimeValue;
 constexpr auto resnet_file = "/home/workspace/mindspore_dataset/mindir/resnet50/resnet50_imagenet.mindir";
 constexpr auto image_path = "/home/workspace/mindspore_dataset/imagenet/imagenet_original/val/n01440764/";
 constexpr auto aipp_path = "./data/dataset/aipp_resnet50.cfg";
+constexpr uint64_t kUSecondInSecond = 1000000;
+constexpr uint64_t run_nums = 10;
 
 size_t GetMax(mindspore::MSTensor data);
 std::string RealPath(std::string_view path);
@@ -97,6 +101,80 @@ TEST_F(TestZeroCopy, TestMindIR) {
 #endif
 }
 
+TEST_F(TestZeroCopy, TestDeviceTensor) {
+#ifdef ENABLE_ACL
+// Set context
+  auto context = ContextAutoSet();
+  ASSERT_TRUE(context != nullptr);
+  ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
+  auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
+  ASSERT_TRUE(ascend310_info != nullptr);
+  ascend310_info->SetInsertOpConfigPath(aipp_path);
+  auto device_id = ascend310_info->GetDeviceID();
+  // Define model
+  Graph graph;
+  ASSERT_TRUE(Serialization::Load(resnet_file, ModelType::kMindIR, &graph) == kSuccess);
+  Model resnet50;
+  ASSERT_TRUE(resnet50.Build(GraphCell(graph), context) == kSuccess);
+  // Get model info
+  std::vector<mindspore::MSTensor> model_inputs = resnet50.GetInputs();
+  ASSERT_EQ(model_inputs.size(), 1);
+  // Define transform operations
+  std::shared_ptr<TensorTransform> decode(new vision::Decode());
+  std::shared_ptr<TensorTransform> resize(new vision::Resize({256}));
+  std::shared_ptr<TensorTransform> center_crop(new vision::CenterCrop({224, 224}));
+  mindspore::dataset::Execute Transform({decode, resize, center_crop}, MapTargetDevice::kAscend310, device_id);
+  // Read images
+  std::vector<std::string> images = GetAllFiles(image_path);
+  uint64_t cost = 0, device_cost = 0;
+  for (const auto &image_file : images) {
+    // prepare input
+    std::vector<mindspore::MSTensor> inputs;
+    std::vector<mindspore::MSTensor> outputs;
+    std::shared_ptr<mindspore::dataset::Tensor> de_tensor;
+    mindspore::dataset::Tensor::CreateFromFile(image_file, &de_tensor);
+    auto image = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_tensor));
+    // Apply transform on images
+    Status rc = Transform(image, &image);
+    ASSERT_TRUE(rc == kSuccess);
+    MSTensor *device_tensor =
+      MSTensor::CreateDevTensor(image.Name(), image.DataType(), image.Shape(),
+                                image.MutableData(), image.DataSize());
+    MSTensor *tensor =
+      MSTensor::CreateTensor(image.Name(), image.DataType(), image.Shape(),
+                             image.Data().get(), image.DataSize());
+    inputs.push_back(*tensor);
+    // infer
+    TimeValue start_time, end_time;
+    (void)gettimeofday(&start_time, nullptr);
+    for (size_t i = 0; i < run_nums; ++i) {
+      ASSERT_TRUE(resnet50.Predict(inputs, &outputs) == kSuccess);
+    }
+    (void)gettimeofday(&end_time, nullptr);
+    cost +=
+      (kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec) + static_cast<uint64_t>(end_time.tv_usec)) -
+      (kUSecondInSecond * static_cast<uint64_t>(start_time.tv_sec) + static_cast<uint64_t>(start_time.tv_usec));
+    // clear inputs
+    inputs.clear();
+    start_time = (TimeValue){0};
+    end_time = (TimeValue){0};
+    inputs.push_back(*device_tensor);
+
+    // infer with device tensor
+    (void)gettimeofday(&start_time, nullptr);
+    for (size_t i = 0; i < run_nums; ++i) {
+      ASSERT_TRUE(resnet50.Predict(inputs, &outputs) == kSuccess);
+    }
+    (void)gettimeofday(&end_time, nullptr);
+    device_cost +=
+      (kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec) + static_cast<uint64_t>(end_time.tv_usec)) -
+      (kUSecondInSecond * static_cast<uint64_t>(start_time.tv_sec) + static_cast<uint64_t>(start_time.tv_usec));
+    Transform.DeviceMemoryRelease();
+  }
+  ASSERT_GE(cost, device_cost);
+#endif
+}
+
 size_t GetMax(mindspore::MSTensor data) {
   float max_value = -1;
   size_t max_idx = 0;