From adfbc891d38900a4b3039c803bb46da13f65fb56 Mon Sep 17 00:00:00 2001 From: hesham Date: Thu, 9 Jul 2020 14:58:12 -0400 Subject: [PATCH] - Add checks and testing for empty tensors - cleanup work on createTensor and Tensor's constructors --- .../ccsrc/minddata/dataset/api/de_pipeline.cc | 13 +- .../minddata/dataset/api/python_bindings.cc | 15 +- .../ccsrc/minddata/dataset/core/cv_tensor.cc | 42 +- .../ccsrc/minddata/dataset/core/cv_tensor.h | 107 ++-- .../ccsrc/minddata/dataset/core/data_type.h | 5 + .../ccsrc/minddata/dataset/core/tensor.cc | 377 +++++------ .../ccsrc/minddata/dataset/core/tensor.h | 584 ++++++++++-------- .../dataset/engine/cache/cache_request.cc | 5 +- .../dataset/engine/datasetops/batch_op.cc | 11 +- .../engine/datasetops/cache_merge_op.cc | 3 +- .../engine/datasetops/device_queue_op.cc | 1 + .../engine/datasetops/source/celeba_op.cc | 7 +- .../engine/datasetops/source/cifar_op.cc | 18 +- .../engine/datasetops/source/clue_op.cc | 19 +- .../engine/datasetops/source/coco_op.cc | 37 +- .../engine/datasetops/source/csv_op.cc | 11 +- .../engine/datasetops/source/generator_op.cc | 2 +- .../datasetops/source/image_folder_op.cc | 6 +- .../engine/datasetops/source/manifest_op.cc | 11 +- .../engine/datasetops/source/mindrecord_op.cc | 6 +- .../engine/datasetops/source/mnist_op.cc | 12 +- .../engine/datasetops/source/mnist_op.h | 2 +- .../datasetops/source/random_data_op.cc | 3 +- .../source/sampler/python_sampler.cc | 2 +- .../datasetops/source/sampler/sampler.cc | 4 +- .../engine/datasetops/source/text_file_op.cc | 2 +- .../engine/datasetops/source/tf_reader_op.cc | 12 +- .../engine/datasetops/source/voc_op.cc | 18 +- .../minddata/dataset/engine/gnn/graph.cc | 9 +- .../dataset/engine/gnn/graph_loader.cc | 10 +- .../ccsrc/minddata/dataset/include/tensor.h | 584 ++++++++++-------- .../dataset/kernels/data/data_utils.cc | 24 +- .../dataset/kernels/data/duplicate_op.cc | 2 +- .../dataset/kernels/image/image_utils.cc | 77 +-- .../dataset/kernels/image/invert_op.cc | 4 +- .../dataset/kernels/image/normalize_op.cc | 26 +- .../dataset/kernels/image/normalize_op.h | 4 +- .../minddata/dataset/kernels/py_func_op.cc | 4 +- .../text/kernels/basic_tokenizer_op.cc | 3 +- .../dataset/text/kernels/case_fold_op.cc | 3 +- .../dataset/text/kernels/data_utils.cc | 7 +- .../text/kernels/jieba_tokenizer_op.cc | 11 +- .../dataset/text/kernels/lookup_op.cc | 4 +- .../minddata/dataset/text/kernels/ngram_op.cc | 2 +- .../dataset/text/kernels/normalize_utf8_op.cc | 3 +- .../dataset/text/kernels/regex_replace_op.cc | 3 +- .../text/kernels/regex_tokenizer_op.cc | 10 +- .../kernels/sentence_piece_tokenizer_op.cc | 4 +- .../dataset/text/kernels/to_number_op.cc | 10 +- .../text/kernels/unicode_char_tokenizer_op.cc | 12 +- .../kernels/unicode_script_tokenizer_op.cc | 11 +- .../text/kernels/whitespace_tokenizer_op.cc | 11 +- .../text/kernels/wordpiece_tokenizer_op.cc | 311 +++++----- tests/ut/cpp/dataset/batch_op_test.cc | 54 +- tests/ut/cpp/dataset/cache_op_test.cc | 12 +- tests/ut/cpp/dataset/channel_swap_test.cc | 2 +- tests/ut/cpp/dataset/common/bboxop_common.cc | 7 +- tests/ut/cpp/dataset/common/cvop_common.cc | 6 +- tests/ut/cpp/dataset/concatenate_op_test.cc | 19 +- tests/ut/cpp/dataset/duplicate_op_test.cc | 4 +- tests/ut/cpp/dataset/fill_op_test.cc | 86 ++- tests/ut/cpp/dataset/image_folder_op_test.cc | 3 +- .../ut/cpp/dataset/jieba_tokenizer_op_test.cc | 9 +- tests/ut/cpp/dataset/manifest_op_test.cc | 22 +- tests/ut/cpp/dataset/mask_test.cc | 4 +- tests/ut/cpp/dataset/one_hot_op_test.cc | 16 +- tests/ut/cpp/dataset/pad_end_op_test.cc | 38 +- .../dataset/sentence_piece_vocab_op_test.cc | 8 +- .../ut/cpp/dataset/sliding_window_op_test.cc | 13 +- .../cpp/dataset/stand_alone_samplers_test.cc | 6 +- tests/ut/cpp/dataset/tensor_string_test.cc | 26 +- tests/ut/cpp/dataset/tensor_test.cc | 195 ++++-- tests/ut/cpp/dataset/tokenizer_op_test.cc | 126 ++-- tests/ut/cpp/dataset/trucate_pair_test.cc | 8 +- tests/ut/cpp/dataset/type_cast_op_test.cc | 7 +- tests/ut/python/dataset/test_pair_truncate.py | 5 +- tests/ut/python/dataset/test_slice_op.py | 38 +- tests/ut/python/dataset/test_tensor_empty.py | 72 +++ 78 files changed, 1730 insertions(+), 1540 deletions(-) create mode 100644 tests/ut/python/dataset/test_tensor_empty.py diff --git a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc index b378d1ee3b2..35c6dae4f98 100644 --- a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc @@ -511,8 +511,9 @@ Status DEPipeline::FetchDataFromTensorRow(const TensorRow &row, RETURN_IF_NOT_OK(s); if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); } else if (column_type == DataType::DE_STRING) { - auto buffer = tensor->GetStringsBuffer(); - std::string ss(reinterpret_cast(buffer)); // assume scalar string tensor + std::string_view sv; + RETURN_IF_NOT_OK(tensor->GetItemAt(&sv, {0})); // assume scalar string tensor + std::string ss(sv); (*row_raw_data)[column_name] = std::move(ss); continue; } else { @@ -1678,13 +1679,13 @@ Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) { if (py::isinstance(tp[1])) { std::string pad_val_string = tp[1].is_none() ? "" : ToString(tp[1]); CHECK_FAIL_RETURN_UNEXPECTED( - Tensor::CreateTensor(&pad_val, std::vector{pad_val_string}, TensorShape::CreateScalar()), + Tensor::CreateFromVector(std::vector{pad_val_string}, TensorShape::CreateScalar(), &pad_val), "Cannot create pad_value Tensor"); } else { float pad_val_float = tp[1].is_none() ? 0 : ToFloat(tp[1]); - CHECK_FAIL_RETURN_UNEXPECTED(Tensor::CreateTensor(&pad_val, TensorImpl::kFlexible, TensorShape::CreateScalar(), - DataType(DataType::DE_FLOAT32)), - "Cannot create pad_value Tensor"); + CHECK_FAIL_RETURN_UNEXPECTED( + Tensor::CreateEmpty(TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32), &pad_val), + "Cannot create pad_value Tensor"); pad_val->SetItemAt({}, pad_val_float); } (void)pad_info->insert({ToString(p.first), {shape, pad_val}}); diff --git a/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc index 16bdd613c92..c6a9eb0aea6 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc @@ -340,7 +340,7 @@ void bindTensor(py::module *m) { (void)py::class_>(*m, "Tensor", py::buffer_protocol()) .def(py::init([](py::array arr) { std::shared_ptr out; - THROW_IF_ERROR(Tensor::CreateTensor(&out, arr)); + THROW_IF_ERROR(Tensor::CreateFromNpArray(arr, &out)); return out; })) .def_buffer([](Tensor &tensor) { @@ -364,7 +364,18 @@ void bindTensor(py::module *m) { }); (void)py::class_(*m, "TensorShape") - .def(py::init()) + .def(py::init([](const py::list &list) { + std::vector list_c; + for (auto &i : list) { + if (!i.is_none()) { + list_c.push_back(i.cast()); + } else { + list_c.push_back(TensorShape::kDimUnknown); + } + } + TensorShape out(list_c); + return out; + })) .def("__str__", &TensorShape::ToString) .def("as_list", &TensorShape::AsPyList) .def("is_known", &TensorShape::known); diff --git a/mindspore/ccsrc/minddata/dataset/core/cv_tensor.cc b/mindspore/ccsrc/minddata/dataset/core/cv_tensor.cc index 5af748b5de4..79c27d45cb2 100644 --- a/mindspore/ccsrc/minddata/dataset/core/cv_tensor.cc +++ b/mindspore/ccsrc/minddata/dataset/core/cv_tensor.cc @@ -23,18 +23,35 @@ namespace mindspore { namespace dataset { -CVTensor::CVTensor(const TensorShape &shape, const DataType &type) : Tensor(shape, type) { - (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); -} - -CVTensor::CVTensor(const TensorShape &shape, const DataType &type, const uchar *data) : Tensor(shape, type, data) { - (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); -} CVTensor::CVTensor(std::shared_ptr tensor) : Tensor(std::move(*tensor)) { (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); } +Status CVTensor::CreateEmpty(const TensorShape &shape, DataType type, CVTensorPtr *out) { + const CVTensorAlloc *alloc = GlobalContext::Instance()->cv_tensor_allocator(); + *out = std::allocate_shared(*alloc, shape, type); + int64_t byte_size = (*out)->SizeInBytes(); + // Don't allocate if we have a tensor with no elements. + if (byte_size != 0) { + RETURN_IF_NOT_OK((*out)->AllocateBuffer(byte_size)); + } + + return (*out)->MatInit((*out)->GetMutableBuffer(), (*out)->shape_, (*out)->type_, &(*out)->mat_); +} + +Status CVTensor::CreateFromMat(const cv::Mat &mat, CVTensorPtr *out) { + TensorPtr out_tensor; + cv::Mat mat_local = mat; + // if the input Mat's memory is not continuous, copy it to one block of memory + if (!mat.isContinuous()) mat_local = mat.clone(); + TensorShape shape(mat.size, mat_local.type()); + DataType type = DataType::FromCVType(mat_local.type()); + RETURN_IF_NOT_OK(CreateFromMemory(shape, type, mat_local.data, &out_tensor)); + *out = AsCVTensor(out_tensor); + return Status::OK(); +} + std::pair, int> CVTensor::IsValidImage(const TensorShape &shape, const DataType &type) { std::array size = {1, 1}; if (shape.Rank() <= 2 || (shape.Rank() == 3 && shape[2] <= CV_CN_MAX)) { @@ -57,7 +74,8 @@ std::shared_ptr CVTensor::AsCVTensor(std::shared_ptr t) { if (cv_t != nullptr) { return cv_t; } else { - return std::make_shared(t); + const CVTensorAlloc *alloc = GlobalContext::Instance()->cv_tensor_allocator(); + return std::allocate_shared(*alloc, t); } } @@ -97,5 +115,13 @@ void CVTensor::Squeeze() { Tensor::Squeeze(); (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); } + +Status CVTensor::MatAtIndex(const std::vector &index, cv::Mat *mat) { + uchar *start = nullptr; + TensorShape remaining({-1}); + RETURN_IF_NOT_OK(this->StartAddrOfIndex(index, &start, &remaining)); + RETURN_IF_NOT_OK(this->MatInit(start, remaining, type_, mat)); + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/cv_tensor.h b/mindspore/ccsrc/minddata/dataset/core/cv_tensor.h index ac0ab72ea44..f32d4226729 100644 --- a/mindspore/ccsrc/minddata/dataset/core/cv_tensor.h +++ b/mindspore/ccsrc/minddata/dataset/core/cv_tensor.h @@ -30,56 +30,60 @@ namespace mindspore { namespace dataset { +using CVTensorPtr = std::shared_ptr; class CVTensor : public Tensor { public: - // Create an empty CVTensor of shape `shape` and type `type`. - // @note The shape and type information should be known and valid. - // @param shape TensorShape - // @param type DataType - CVTensor(const TensorShape &shape, const DataType &type); + // Inherit Tensor's constructors + using Tensor::Tensor; - // Create a CVTensor from a given buffer, shape and type. - // @note This constructor allocates a new space in the memory and copies the buffer into it. - // @note The buffer should be valid and the shape and type information should be known and valid. - // @param shape TensorShape - // @param type DataType - // @param data unsigned char*, pointer to the data. - CVTensor(const TensorShape &shape, const DataType &type, const uchar *data); - - // Create a CVTensor from a given CV::Mat. - // @note This constructor allocates a new space in the memory and copies the CV::Mat buffer into it. - // @param mat CV::Mat - explicit CVTensor(const cv::Mat &mat) - : CVTensor(TensorShape(mat.size, mat.type()), DataType::FromCVType(mat.type()), mat.data) {} - - ~CVTensor() = default; - - // Static function to cast a given Tensor as CVTensor. If the input tensor is already of type CVTensor, - // this function would be treated as a no-op. Fot other tensor types, a new CVTensor is created based on the data - // provided. The Passed Tensor will be invalidated. - // @note there is no memory copying here, the buffer will be assigned to the constructed tensor. - // @param tensor - // @return CVTensor - static std::shared_ptr AsCVTensor(std::shared_ptr tensor); - - // Create a CVTensor from a given tensor. The input tensor will be invalidated (i.e., the shape and type will be - // set to unknown and the data buffer will point to null. - // @note there is no memory copying here, the buffer will be assigned to the constructed tensor. - // @param tensor + /// Create a CVTensor from a given tensor. This constructor should not be used directly, use Create* instead. + /// The input tensor will be invalidated (i.e., the shape and type will be + /// set to unknown and the data buffer will point to null. + /// \note there is no memory copying here, the buffer will be assigned to the constructed tensor. + /// \param tensor explicit CVTensor(std::shared_ptr tensor); - // Getter function for the CV::Mat - // @return + /// Create CV tensor with type and shape. Items of the tensor would be uninitialized. + /// \param shape [in] shape of the output tensor + /// \param type [in] type of the output tensor + /// \param out [out] Generated tensor + /// \return Status code + static Status CreateEmpty(const TensorShape &shape, DataType type, CVTensorPtr *out); + + /// Create CV tensor from cv::Mat + /// \note This constructor allocates a new space in the memory and copies the CV::Mat buffer into it. + /// \param mat [in] cv::Mat to be copied into the new tensor. + /// \param out [out] Generated tensor + /// \return Status code + static Status CreateFromMat(const cv::Mat &mat, CVTensorPtr *out); + + ~CVTensor() override = default; + + /// Static function to cast a given Tensor as CVTensor. If the input tensor is already of type CVTensor, + /// this function would be treated as a no-op. Fot other tensor types, a new CVTensor is created based on the data + /// provided. The Passed Tensor will be invalidated. + /// \note the input tensor will be invalidated. + /// \note there is no memory copying here, the buffer will be assigned to the constructed tensor. + /// \param tensor [in] + /// \return CVTensor + static std::shared_ptr AsCVTensor(std::shared_ptr tensor); + + /// Get a reference to the CV::Mat + /// \return a reference to the internal CV::Mat cv::Mat mat() const { return mat_; } - // Static function to check if the passed information (shape and type) can be treated as a valid description - // of an image in OpenCV. Moreover, it returns OpenCV shape and type - // For example, if the shape is <512,512,3> and type is DE_UINT8, the output would be [512,512] and CV_8UC3. - // In case of invalid shape or type, the function will return pair - // @param shape TensorShape - // @param type DataType - // @return std::pair of OpenCV shape and type - std::pair, int> IsValidImage(const TensorShape &shape, const DataType &type); + /// Get a copy of the CV::Mat + /// \return a copy of internal CV::Mat + cv::Mat matCopy() const { return mat_.clone(); } + + /// Static function to check if the passed information (shape and type) can be treated as a valid description + /// of an image in OpenCV. Moreover, it returns OpenCV shape and type + /// For example, if the shape is <512,512,3> and type is DE_UINT8, the output would be [512,512] and CV_8UC3. + /// In case of invalid shape or type, the function will return pair + /// \param shape [in] TensorShape + /// \param type [in] DataType + /// \return std::pair of OpenCV shape and type + static std::pair, int> IsValidImage(const TensorShape &shape, const DataType &type); Status Reshape(const TensorShape &shape) override; @@ -87,18 +91,19 @@ class CVTensor : public Tensor { void Squeeze() override; - Status Mat(const std::vector &index, cv::Mat *mat) { - uchar *start = nullptr; - TensorShape remaining({-1}); - RETURN_IF_NOT_OK(this->StartAddrOfIndex(index, &start, &remaining)); - RETURN_IF_NOT_OK(this->MatInit(start, remaining, type_, mat)); - return Status::OK(); - } + Status MatAtIndex(const std::vector &index, cv::Mat *mat); private: + /// Opencv Mat object wrapping the raw data of the tensor. + /// Modifying the content of the matrix, modifies the tensor. cv::Mat mat_; - // Initialize CV::Mat with the data_, shape_ and type_ + /// Create cv::Mat from data, TensorShape and DataType + /// \param data [in] Pointer to the data in memory. + /// \param shape [in] Shape of the tensor. + /// \param type [in] Type of the tensor. + /// \param mat [out] cv::Mat initialized with the provided data. + /// \return Status code Status MatInit(uchar *data, const TensorShape &shape, const DataType &type, cv::Mat *mat); }; } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/core/data_type.h b/mindspore/ccsrc/minddata/dataset/core/data_type.h index efcbe39794b..ab48c3fc781 100644 --- a/mindspore/ccsrc/minddata/dataset/core/data_type.h +++ b/mindspore/ccsrc/minddata/dataset/core/data_type.h @@ -284,6 +284,11 @@ inline DataType DataType::FromCType() { return DataType(DataType::DE_STRING); } +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_STRING); +} + template <> inline bool DataType::IsLooselyCompatible() const { return type_ == DataType::DE_BOOL; diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor.cc b/mindspore/ccsrc/minddata/dataset/core/tensor.cc index 842615f9e18..b8717c26fa0 100644 --- a/mindspore/ccsrc/minddata/dataset/core/tensor.cc +++ b/mindspore/ccsrc/minddata/dataset/core/tensor.cc @@ -59,49 +59,11 @@ Tensor::Tensor(const TensorShape &shape, const DataType &type) : shape_(shape), data_allocator_ = std::make_unique>(global_pool); } -Tensor::Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data) : Tensor(shape, type) { - if (type.IsNumeric()) { - // If the data pointer was given, then we can also populate the tensor with data - if (data != nullptr) { - // Given the shape/type of this tensor, compute the data size and copy in the input bytes. - int64_t byte_size = this->SizeInBytes(); - Status s = this->AllocateBuffer(byte_size); // Allocates data_ inside itself - if (s.IsOk() && data_ != nullptr) { - int ret_code = memcpy_s(data_, byte_size, data, byte_size); - if (ret_code != 0) { - MS_LOG(ERROR) << "Failed to copy data into Tensor!"; - } - } else { - MS_LOG(ERROR) << "Failed to create memory for Tensor!"; - } - } - } else { - MS_LOG(ERROR) << "Type should be numeric to use this constructor."; - } -} - -Tensor::Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data, const dsize_t &length) - : Tensor(shape, type) { - // If the data pointer was given, then we can also populate the tensor with data - if (data != nullptr) { - // Allocates data_ inside itself - Status s = AllocateBuffer(length); - if (s.IsError()) { - MS_LOG(ERROR) << "Failed to create memory for Tensor!"; - } - if (data_ != nullptr) { - int ret_code = memcpy_s(data_, length, data, length); - if (ret_code != 0) { - MS_LOG(ERROR) << "Failed to copy data into Tensor!"; - } - } - } -} - Tensor::Tensor(Tensor &&other) noexcept : shape_(other.shape()), type_(other.type()), data_(other.GetMutableBuffer()), + data_end_(other.data_end_), data_allocator_(std::move(other.data_allocator_)) { other.Invalidate(); } @@ -117,118 +79,61 @@ Tensor &Tensor::operator=(Tensor &&other) noexcept { } return *this; } - -Tensor::Tensor(const std::vector &strings, const TensorShape &shape) - : Tensor(TensorShape({static_cast(strings.size())}), DataType(DataType::DE_STRING)) { - auto length_sum = [](dsize_t sum, const std::string &s) { return s.length() + sum; }; - dsize_t total_length = std::accumulate(strings.begin(), strings.end(), 0, length_sum); - - // total bytes needed = offset array + strings - // offset array needs to store one offset var per element + 1 extra to get the length of the last string. - // strings will be null-terminated --> need 1 extra byte per element - dsize_t num_bytes = (kOffsetSize + 1) * shape_.NumOfElements() + kOffsetSize + total_length; - - data_ = data_allocator_->allocate(num_bytes); - - auto offset_arr = reinterpret_cast(data_); - uchar *buf = GetStringsBuffer(); - - offset_t offset = buf - data_; // the first string will start here - uint32_t i = 0; - for (const auto &str : strings) { - // insert the start index of the string. - offset_arr[i++] = offset; - // total bytes are reduced by kOffsetSize - num_bytes -= kOffsetSize; - // insert actual string - int ret_code = memcpy_s(data_ + offset, num_bytes, common::SafeCStr(str), str.length() + 1); - if (ret_code != 0) MS_LOG(ERROR) << "Cannot copy string into Tensor"; - // next string will be stored right after the current one. - offset = offset + str.length() + 1; - // total bytes are reduced by the length of the string - num_bytes -= str.length() + 1; +Status Tensor::CreateEmpty(const TensorShape &shape, const DataType &type, TensorPtr *out) { + CHECK_FAIL_RETURN_UNEXPECTED(shape.known(), "Invalid shape."); + CHECK_FAIL_RETURN_UNEXPECTED(type != DataType::DE_UNKNOWN, "Invalid data type."); + const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); + *out = std::allocate_shared(*alloc, shape, type); + // if it's a string tensor and it has no elements, Just initialize the shape and type. + if (!type.IsNumeric() && shape.NumOfElements() == 0) { + return Status::OK(); } - // store one more offset value so we can get the length of the last string - // length[last_element] = offset_arr[last_element + 1] - offset_arr[last_element] - offset_arr[i] = offset; - this->data_end_ = data_ + offset_arr[i]; + CHECK_FAIL_RETURN_UNEXPECTED(type.IsNumeric(), "Number of elements is not 0. The type should be numeric."); - MS_ASSERT(num_bytes == 0); - if (shape.known()) Tensor::Reshape(shape); + int64_t byte_size = (*out)->SizeInBytes(); + // Don't allocate if we have a tensor with no elements. + if (byte_size != 0) { + RETURN_IF_NOT_OK((*out)->AllocateBuffer(byte_size)); + } + + return Status::OK(); +} +Status Tensor::CreateFromMemory(const TensorShape &shape, const DataType &type, const uchar *src, TensorPtr *out) { + RETURN_IF_NOT_OK(CreateEmpty(shape, type, out)); + if (src != nullptr) { + // Given the shape/type of this tensor, compute the data size and copy in the input bytes. + int64_t byte_size = (*out)->SizeInBytes(); + int ret_code = memcpy_s((*out)->data_, byte_size, src, byte_size); + CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy data into tensor."); + } + return Status::OK(); } -Tensor::Tensor(const dataengine::BytesList &bytes_list, const TensorShape &shape) - : Tensor(TensorShape({static_cast(bytes_list.value_size())}), DataType(DataType::DE_STRING)) { - // total bytes needed = offset array + strings - // offset array needs to store one offset var per element + 1 extra to get the length of the last string. - // strings will be null-terminated --> need 1 extra byte per element - dsize_t num_bytes = (kOffsetSize)*shape_.NumOfElements() + kOffsetSize + bytes_list.ByteSizeLong(); - - data_ = data_allocator_->allocate(num_bytes); - - auto offset_arr = reinterpret_cast(data_); - uchar *buf = GetStringsBuffer(); - - offset_t offset = buf - data_; // the first string will start here - uint32_t i = 0; - for (; i < bytes_list.value_size(); i++) { - const std::string &str = bytes_list.value(i); - // insert the start index of the string. - offset_arr[i] = offset; - // total bytes are reduced by kOffsetSize - num_bytes -= kOffsetSize; - // insert actual string - int ret_code = memcpy_s(data_ + offset, num_bytes, common::SafeCStr(str), str.length() + 1); - if (ret_code != 0) { - MS_LOG(ERROR) << "Cannot copy string into Tensor"; - } - // next string will be stored right after the current one. - offset = offset + str.length() + 1; - // total bytes are reduced by the length of the string - num_bytes -= str.length() + 1; - } - // store one more offset value so we can get the length of the last string - // length[last_element] = offset_arr[last_element + 1] - offset_arr[last_element] - offset_arr[i] = offset; - - data_end_ = data_ + offset_arr[i]; - - MS_ASSERT(num_bytes == 0); - if (shape.known()) Tensor::Reshape(shape); -} - -Status Tensor::CreateTensor(std::shared_ptr *ptr, TensorImpl tensor_impl, const TensorShape &shape, - DataType type, const unsigned char *data) { - if (!shape.known()) { - RETURN_STATUS_UNEXPECTED("Invalid shape."); - } - if (type == DataType::DE_UNKNOWN) { - RETURN_STATUS_UNEXPECTED("Invalid data type."); +Status Tensor::CreateFromMemory(const TensorShape &shape, const DataType &type, const unsigned char *src, + const dsize_t &length, TensorPtr *out) { + CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr, "Pointer to source data is null."); + const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); + *out = std::allocate_shared(*alloc, shape, type); + if (type.IsNumeric()) { + dsize_t calculated_length = (*out)->SizeInBytes(); + CHECK_FAIL_RETURN_UNEXPECTED(calculated_length == length, "Length of source data does not match the shape."); + } else { + // min_length is the length of a tensor with empty strings + // min_length = the number of bytes needed to store the offsets + 1 byte for each element + dsize_t min_length = (shape.NumOfElements() + 1) * kOffsetSize + shape.NumOfElements(); + CHECK_FAIL_RETURN_UNEXPECTED(min_length <= length, "Length of source data does not match the shape."); } - switch (tensor_impl) { - case TensorImpl::kFlexible: { - // The flex tensor is really just the base class tensor implementation - const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *ptr = std::allocate_shared(*alloc, shape, type, data); - break; - } - case TensorImpl::kCv: { - const CVTensorAlloc *alloc = GlobalContext::Instance()->cv_tensor_allocator(); - *ptr = std::allocate_shared(*alloc, shape, type, data); - break; - } - default: { - std::string err_msg("Invalid tensor implementation type."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - return Status::OK(); // returns base-class shared_ptr + RETURN_IF_NOT_OK((*out)->AllocateBuffer(length)); + int ret_code = memcpy_s((*out)->data_, length, src, length); + CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy data into tensor."); + + return Status::OK(); } #ifdef ENABLE_PYTHON -Status Tensor::CreateTensorFromNumpyString(std::shared_ptr *ptr, py::array arr) { +Status Tensor::CreateFromNpString(py::array arr, std::shared_ptr *out) { std::vector shape; for (dsize_t i = 0; i < arr.ndim(); i++) { shape.push_back(static_cast(arr.shape()[i])); @@ -244,34 +149,38 @@ Status Tensor::CreateTensorFromNumpyString(std::shared_ptr *ptr, py::arr arr.resize(shape); // resize arr back to the original shape - return CreateTensor(ptr, strings, TensorShape{shape}); + return CreateFromVector(strings, TensorShape{shape}, out); } -Status Tensor::CreateTensor(std::shared_ptr *ptr, py::array arr) { +Status Tensor::CreateFromNpArray(const py::array &arr, std::shared_ptr *out) { if (DataType::FromNpArray(arr) == DataType::DE_STRING) { - return CreateTensorFromNumpyString(ptr, arr); + return CreateFromNpString(arr, out); } const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *ptr = std::allocate_shared(*alloc, TensorShape({}), DataType(DataType::DE_UNKNOWN)); + *out = std::allocate_shared(*alloc, TensorShape::CreateScalar(), DataType(DataType::DE_UNKNOWN)); std::vector shape; for (dsize_t i = 0; i < arr.ndim(); i++) { shape.push_back(static_cast(arr.shape()[i])); } - (*ptr)->shape_ = TensorShape(shape); - (*ptr)->type_ = DataType::FromNpArray(arr); - if (!(*ptr)->shape_.known()) RETURN_STATUS_UNEXPECTED("Invalid shape."); + (*out)->shape_ = TensorShape(shape); + (*out)->type_ = DataType::FromNpArray(arr); + if (!(*out)->shape_.known()) RETURN_STATUS_UNEXPECTED("Invalid shape."); - if ((*ptr)->type_ == DataType::DE_UNKNOWN) RETURN_STATUS_UNEXPECTED("Invalid data type."); + if ((*out)->type_ == DataType::DE_UNKNOWN) RETURN_STATUS_UNEXPECTED("Invalid data type."); std::shared_ptr global_pool = GlobalContext::Instance()->mem_pool(); - (*ptr)->data_allocator_ = std::make_unique>(global_pool); - int64_t byte_size = (*ptr)->SizeInBytes(); - RETURN_IF_NOT_OK((*ptr)->AllocateBuffer(byte_size)); + (*out)->data_allocator_ = std::make_unique>(global_pool); + int64_t byte_size = (*out)->SizeInBytes(); + if (byte_size == 0) { + return Status::OK(); + } + + RETURN_IF_NOT_OK((*out)->AllocateBuffer(byte_size)); unsigned char *data = static_cast(arr.request().ptr); - if ((*ptr)->data_ == nullptr) { + if ((*out)->data_ == nullptr) { RETURN_STATUS_UNEXPECTED("Failed to create memory for Tensor."); } @@ -282,61 +191,89 @@ Status Tensor::CreateTensor(std::shared_ptr *ptr, py::array arr) { // check if strides are contiguous bool is_strided = false; - dsize_t count = (*ptr)->shape_.NumOfElements(); + dsize_t count = (*out)->shape_.NumOfElements(); for (size_t i = 0; i < shape.size(); i++) { count /= shape[i]; - if (strides[i] != (*ptr)->type_.SizeInBytes() * count) { + if (strides[i] != (*out)->type_.SizeInBytes() * count) { is_strided = true; break; } } if (is_strided) { - RETURN_IF_NOT_OK(CopyStridedArray((*ptr)->data_, data, shape, strides, (*ptr)->type_.SizeInBytes())); + RETURN_IF_NOT_OK(CopyStridedArray((*out)->data_, data, shape, strides, (*out)->type_.SizeInBytes())); } else { - int ret_code = memcpy_s((*ptr)->data_, byte_size, data, byte_size); + int ret_code = memcpy_s((*out)->data_, byte_size, data, byte_size); if (ret_code != 0) { RETURN_STATUS_UNEXPECTED("Failed to copy data into Tensor."); } } - return Status::OK(); // returns base-class shared_ptr + return Status::OK(); } #endif -Status Tensor::CreateTensor(std::shared_ptr *ptr, const std::vector &strings, - const TensorShape &shape) { +Status Tensor::CreateFromByteList(const dataengine::BytesList &bytes_list, const TensorShape &shape, TensorPtr *out) { const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *ptr = std::allocate_shared(*alloc, strings, shape); + *out = std::allocate_shared(*alloc, TensorShape({static_cast(bytes_list.value_size())}), + DataType(DataType::DE_STRING)); + // total bytes needed = offset array + strings + // offset array needs to store one offset var per element + 1 extra to get the length of the last string. + // strings will be null-terminated --> need 1 extra byte per element + dsize_t num_bytes = (kOffsetSize) * (*out)->shape_.NumOfElements() + kOffsetSize + bytes_list.ByteSizeLong(); + + (*out)->data_ = (*out)->data_allocator_->allocate(num_bytes); + + auto offset_arr = reinterpret_cast((*out)->data_); + uchar *buf = (*out)->GetStringsBuffer(); + + offset_t offset = buf - (*out)->data_; // the first string will start here + uint32_t i = 0; + for (; i < bytes_list.value_size(); i++) { + const std::string &str = bytes_list.value(i); + // insert the start index of the string. + offset_arr[i] = offset; + // total bytes are reduced by kOffsetSize + num_bytes -= kOffsetSize; + // insert actual string + int ret_code = memcpy_s((*out)->data_ + offset, num_bytes, common::SafeCStr(str), str.length() + 1); + if (ret_code != 0) { + MS_LOG(ERROR) << "Cannot copy string into Tensor"; + } + // next string will be stored right after the current one. + offset = offset + str.length() + 1; + // total bytes are reduced by the length of the string + num_bytes -= str.length() + 1; + } + // store one more offset value so we can get the length of the last string + // length[last_element] = offset_arr[last_element + 1] - offset_arr[last_element] + offset_arr[i] = offset; + + (*out)->data_end_ = (*out)->data_ + offset_arr[i]; + + MS_ASSERT(num_bytes == 0); + (*out)->Reshape(shape); return Status::OK(); } -Status Tensor::CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, - const TensorShape &shape) { - const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *ptr = std::allocate_shared(*alloc, bytes_list, shape); - return Status::OK(); -} - -Status Tensor::CreateTensor(std::shared_ptr *ptr, const std::string &file_path) { +Status Tensor::CreateFromFile(const std::string &path, std::shared_ptr *out) { std::ifstream fs; - fs.open(file_path, std::ios::binary | std::ios::in); - CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Fail to open file: " + file_path); + fs.open(path, std::ios::binary | std::ios::in); + CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Fail to open file: " + path); int64_t num_bytes = fs.seekg(0, std::ios::end).tellg(); CHECK_FAIL_RETURN_UNEXPECTED(fs.seekg(0, std::ios::beg).good(), "Fail to find size of file"); - RETURN_IF_NOT_OK( - Tensor::CreateTensor(ptr, TensorImpl::kFlexible, TensorShape{num_bytes}, DataType(DataType::DE_UINT8))); - int64_t written_bytes = fs.read(reinterpret_cast((*ptr)->GetMutableBuffer()), num_bytes).gcount(); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape{num_bytes}, DataType(DataType::DE_UINT8), out)); + int64_t written_bytes = fs.read(reinterpret_cast((*out)->GetMutableBuffer()), num_bytes).gcount(); CHECK_FAIL_RETURN_UNEXPECTED(written_bytes == num_bytes && fs.good(), "Error in writing to tensor"); fs.close(); return Status::OK(); } -Status Tensor::CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, - const TensorShape &shape, const DataType &type, dsize_t pad_size) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(ptr, TensorImpl::kFlexible, shape, type)); +Status Tensor::CreateFromByteList(const dataengine::BytesList &bytes_list, const TensorShape &shape, + const DataType &type, dsize_t pad_size, TensorPtr *out) { + RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, type, out)); - unsigned char *current_tensor_addr = (*ptr)->GetMutableBuffer(); + unsigned char *current_tensor_addr = (*out)->GetMutableBuffer(); int64_t tensor_bytes_remaining = bytes_list.value_size() * pad_size; for (int i = 0; i < bytes_list.value_size(); i++) { @@ -368,7 +305,7 @@ Status Tensor::CreateTensor(std::shared_ptr *ptr, const dataengine::Byte // Here we convert array C to array A, by memcpy index by index (Note that not all elements in C is copied) Status Tensor::CopyStridedArray(unsigned char *dst, unsigned char *src, std::vector shape, std::vector strides, uint8_t type_size) { - dsize_t size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + dsize_t size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>()); for (dsize_t i = 0; i < size; ++i) { dsize_t offset = 0; dsize_t count = i; @@ -429,29 +366,29 @@ void Tensor::PrintItemAt(const std::vector &index, std::ostream &out) c MS_ASSERT(data_); switch (type_.value()) { - CASE_PRINT_HEX(DataType::DE_BOOL, bool); + CASE_PRINT_HEX(DataType::DE_BOOL, bool) - CASE_PRINT_HEX(DataType::DE_INT8, int8_t); + CASE_PRINT_HEX(DataType::DE_INT8, int8_t) - CASE_PRINT_HEX(DataType::DE_UINT8, uint8_t); + CASE_PRINT_HEX(DataType::DE_UINT8, uint8_t) - CASE_PRINT(DataType::DE_INT16, int16_t); + CASE_PRINT(DataType::DE_INT16, int16_t) - CASE_PRINT(DataType::DE_UINT16, uint16_t); + CASE_PRINT(DataType::DE_UINT16, uint16_t) - CASE_PRINT(DataType::DE_INT32, int32_t); + CASE_PRINT(DataType::DE_INT32, int32_t) - CASE_PRINT(DataType::DE_UINT32, uint32_t); + CASE_PRINT(DataType::DE_UINT32, uint32_t) - CASE_PRINT(DataType::DE_INT64, int64_t); + CASE_PRINT(DataType::DE_INT64, int64_t) - CASE_PRINT(DataType::DE_UINT64, uint64_t); + CASE_PRINT(DataType::DE_UINT64, uint64_t) - CASE_PRINT(DataType::DE_FLOAT16, float16); + CASE_PRINT(DataType::DE_FLOAT16, float16) - CASE_PRINT(DataType::DE_FLOAT32, float); + CASE_PRINT(DataType::DE_FLOAT32, float) - CASE_PRINT(DataType::DE_FLOAT64, double); + CASE_PRINT(DataType::DE_FLOAT64, double) case DataType::DE_STRING: { std::string_view o{""}; @@ -501,50 +438,14 @@ void Tensor::Print(std::ostream &out) const { } } Status Tensor::AllocateBuffer(const dsize_t &length) { + RETURN_UNEXPECTED_IF_NULL(data_allocator_); if (data_ == nullptr) { - if (data_allocator_ != nullptr) { - data_ = data_allocator_->allocate(length); - RETURN_UNEXPECTED_IF_NULL(data_); - data_end_ = data_ + length; - } else { - data_ = static_cast(malloc(length)); - data_end_ = data_ + length; - RETURN_UNEXPECTED_IF_NULL(data_); - } + data_ = data_allocator_->allocate(length); + CHECK_FAIL_RETURN_UNEXPECTED(data_ != nullptr, "Failed to allocate memory for tensor."); + data_end_ = data_ + length; } return Status::OK(); } -const unsigned char *Tensor::GetBuffer() const { - // This version cannot modify anything. data_ could possibly be null. - return data_; -} - -// check for empty -bool Tensor::HasData() const { - if (data_ == nullptr) { - return true; - } else { - return false; - } -} - -unsigned char *Tensor::GetMutableBuffer() { - if (!shape_.known() || type_ == DataType::DE_UNKNOWN) { - return nullptr; - } - // If the data area is already created, return the pointer to it - if (data_ != nullptr) { - return data_; - } else { - // If the data area is not created, then identify the memory size based - // on the shape and type and allocate it. - if (this->AllocateBuffer(this->SizeInBytes()).IsOk()) { - return data_; - } else { - return nullptr; - } - } -} Status Tensor::Reshape(const TensorShape &shape) { if (shape.NumOfElements() == shape_.NumOfElements()) { @@ -628,7 +529,7 @@ Status Tensor::InsertTensor(const std::vector &ind, const std::shared_p err_msg += (ind.size() + tensor->Rank() != this->Rank()) ? "[Tensor] incorrect index\n" : ""; err_msg += tensor->type().SizeInBytes() != this->type().SizeInBytes() ? "[Tensor] incorrect datatype\n" : ""; uchar *start_addr_of_ind = nullptr; - TensorShape remaining_shape({-1}); + TensorShape remaining_shape = TensorShape::CreateUnknownRankShape(); err_msg += (!StartAddrOfIndex(ind, &start_addr_of_ind, &remaining_shape).IsOk()) ? "[Tensor] incorrect index\n" : ""; err_msg += !(remaining_shape == tensor->shape()) ? "[Tensor] memory error\n" : ""; if (!err_msg.empty()) { @@ -697,7 +598,7 @@ Status Tensor::ExpandDim(const dsize_t &axis) { return Status::OK(); } -std::vector Tensor::Strides() { +std::vector Tensor::Strides() const { std::vector strides = shape_.Strides(); uint8_t size = type_.SizeInBytes(); std::transform(strides.begin(), strides.end(), strides.begin(), [&size](const auto &c) { return c * size; }); @@ -765,7 +666,6 @@ Status Tensor::GetItemAt(std::string_view *o, const std::vector &index) #ifdef ENABLE_PYTHON // return data as numpy, should return status Status Tensor::GetDataAsNumpy(py::array *data) { - RETURN_UNEXPECTED_IF_NULL(data_); RETURN_UNEXPECTED_IF_NULL(data); if (type_ == DataType::DE_BOOL) { *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); @@ -974,7 +874,9 @@ Status Tensor::CopyLastDimAt(const std::shared_ptr &src, const std::vect } Status Tensor::Slice(std::shared_ptr *out, const std::vector &indices) { CHECK_FAIL_RETURN_UNEXPECTED(shape_.Rank() == 1, "Currently Slice work with rank 1 tensors only."); - CHECK_FAIL_RETURN_UNEXPECTED(!indices.empty(), "Indices are empty, generated tensor would be empty."); + if (indices.empty()) { + return CreateEmpty(TensorShape({0}), type_, out); + } if (type_.IsNumeric()) { return SliceNumeric(out, indices); } else { @@ -982,8 +884,7 @@ Status Tensor::Slice(std::shared_ptr *out, const std::vector &i } } Status Tensor::SliceNumeric(std::shared_ptr *out, const std::vector &indices) { - RETURN_IF_NOT_OK( - CreateTensor(out, TensorImpl::kFlexible, TensorShape({static_cast(indices.size())}), type_)); + RETURN_IF_NOT_OK(CreateEmpty(TensorShape({static_cast(indices.size())}), type_, out)); (*out)->GetMutableBuffer(); dsize_t out_index = 0; dsize_t dim_length = shape_[0]; @@ -1027,7 +928,7 @@ Status Tensor::SliceString(std::shared_ptr *out, const std::vector(strings.size())}), out); } } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor.h b/mindspore/ccsrc/minddata/dataset/core/tensor.h index 9996e8bfcf2..888c542cf99 100644 --- a/mindspore/ccsrc/minddata/dataset/core/tensor.h +++ b/mindspore/ccsrc/minddata/dataset/core/tensor.h @@ -33,6 +33,7 @@ #include "pybind11/stl.h" #endif +#include "common/utils.h" #include "minddata/dataset/core/constants.h" #include "minddata/dataset/core/data_type.h" #include "minddata/dataset/core/tensor_shape.h" @@ -50,170 +51,155 @@ class Allocator; using CharAllocPtr = std::unique_ptr>; using TensorAllocPtr = std::shared_ptr>; // An allocator shared_ptr for Tensors +using offset_t = uint32_t; // type of offset values to store strings locations +using TensorPtr = std::shared_ptr; class Tensor { public: Tensor() = delete; - - // Create a new tensor, does not internally allocate storage. This constructor is protected, use CreateTensor. - // @note The shape and type information should be known and valid. - // @param shape TensorShape - // @param type DataType - Tensor(const TensorShape &shape, const DataType &type); - - // Create a new tensor, allocates storage and copies in data. This constructor is protected, use CreateTensor. - // @note The buffer should be valid and the shape and type information should be known and valid. - // @param shape TensorShape - // @param type DataType - // @param data unsigned char*, pointer to the data. - Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data); - - Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data, const dsize_t &length); - Tensor(const Tensor &other) = delete; - Tensor &operator=(const Tensor &other) = delete; + /// Create a tensor using shape and type. This constructor should not be used directly, use CreateFromTensor instead + /// \note The shape and type information should be known and valid + /// \note The constructor does not allocate data + /// \param shape TensorShape + /// \param type DataType + Tensor(const TensorShape &shape, const DataType &type); + + /// Move constructor + /// \param other Tensor to be moved Tensor(Tensor &&other) noexcept; + /// Move assigment operator + /// \param other Tensor to be moved Tensor &operator=(Tensor &&other) noexcept; - Status AllocateBuffer(const dsize_t &length); + /// Create a numeric tensor with type and shape. Items of the tensor would be uninitialized. + /// \param[in] shape shape of the output tensor + /// \param[in] type type of the output tensor + /// \param[out] out Generated tensor + /// \return Status code + static Status CreateEmpty(const TensorShape &shape, const DataType &type, TensorPtr *out); - // type of offest values to store strings information - using offset_t = uint32_t; - // const of the size of the offset variable - static constexpr uint8_t kOffsetSize = sizeof(offset_t); - // Tensor base class which holds the data in an unsigned char* buffer. + /// Create a numeric tensor from a pointer in memory. Length of the source data is determined from the shape and type. + /// Data will be copied into the new created tensor. + /// \param[in] shape shape of the output tensor + /// \param[in] type type of the output tensor + /// \param[in] src pointer to the source data + /// \param[out] out Generated tensor + /// \return Status code + static Status CreateFromMemory(const TensorShape &shape, const DataType &type, const uchar *src, TensorPtr *out); - // Construct a scalar string Tensor - explicit Tensor(const std::string &str) : Tensor(std::vector{str}, TensorShape::CreateScalar()) {} + /// Create a tensor from a pointer in memory and length. Data will be copied into the new created tensor. + /// \param[in] shape shape of the output tensor + /// \param[in] type type of the output tensor + /// \param[in] src pointer to the source data + /// \param[in] length length of the src data + /// \param[out] out Generated tensor + /// \return Status code + static Status CreateFromMemory(const TensorShape &shape, const DataType &type, const uchar *src, + const dsize_t &length, TensorPtr *out); - // Construct a tensor from a list of strings. Reshape the tensor with `shape` if given, otherwise assume the shape is - // the size of the vector `strings`. - // The memory layout of a Tensor of strings consists of the Offset_array followed by the strings. - // Thr offset array will store one extra value to find the length of the last string. - // OFFSET1, OFFSET2, ..., OFFSETn+1, STRING1, STRING2, ..., STRINGn - // The value of each offset is the start index of the corresponding string - // Offsets is of type offest_t - // strings will ne null-terminated - // example: Tensor(['abc', 'de'], shape={2}, type=DE_STRING) - // |----------------------------------------------------------------| - // | OFFSET ARRAY | STRINGS | - // | bytes 0-3 | bytes 3-6 | bytes 7-10 | bytes 11-14 | bytes 15-17 | - // | 11 | 15 | 18 | abc\0 | de\0 | - // |----------------------------------------------------------------| - explicit Tensor(const std::vector &strings, - const TensorShape &shape = TensorShape::CreateUnknownRankShape()); - - // Same as Tensor(vector) but the input is protobuf bytelist - explicit Tensor(const dataengine::BytesList &bytes_list, - const TensorShape &shape = TensorShape::CreateUnknownRankShape()); - - // A static factory method to create the given flavour of derived Tensor - // Returns the base class reference for the Tensor. - // @param ptr output argument to hold the created Tensor of given tensor_impl - // @param tensor_impl - which implementation of Tensor - // @param shape - shape of the tensor - // @param type - datatype of the tensor - // @param data - data to be copied to Tensor new allocation - // @return Status Code - static Status CreateTensor(std::shared_ptr *, TensorImpl tensor_impl, const TensorShape &shape, DataType type, - const unsigned char *data = nullptr); - - // Create a copy of the input tensor - // @param out [out] output tensor to be generated - // @param in [in] orginal tensor to be copied - // @return Status - static Status CreateTensor(std::shared_ptr *out, const std::shared_ptr &in) { - const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *out = std::allocate_shared(*alloc, in->shape(), in->type(), in->GetBuffer(), in->SizeInBytes()); - return Status::OK(); + /// Create a copy of the input tensor + /// \param[in] in original tensor to be copied + /// \param[out] out output tensor to be generated + /// \return Status + static Status CreateFromTensor(const TensorPtr &in, TensorPtr *out) { + return CreateFromMemory(in->shape(), in->type(), in->GetBuffer(), in->SizeInBytes(), out); } #ifdef ENABLE_PYTHON - // A static factory method to create a Tensor from a given py::array. - // @param ptr output argument to hold the created Tensor - // @param arr py::array - // @return Status Code - static Status CreateTensor(std::shared_ptr *ptr, py::array arr); - - // Helper function to create a tensor from Numpy of strings - static Status CreateTensorFromNumpyString(std::shared_ptr *ptr, py::array arr); + /// Create a Tensor from a given py::array + /// \param[in] arr py::array + /// \param[out] out Created tensor + /// \return Status Code + static Status CreateFromNpArray(const py::array &arr, TensorPtr *out); #endif - // A static factory method to create a Tensor from a given list of strings. - // @param ptr output argument to hold the created Tensor - // @param strings elements of the tensor - // @param shape shape of the tensor - // @return Status Code - static Status CreateTensor(std::shared_ptr *ptr, const std::vector &strings, - const TensorShape &shape = TensorShape::CreateUnknownRankShape()); + /// Create a tensor of type DE_STRING from a BytesList. + /// \param[in] bytes_list protobuf's Bytelist + /// \param[in] shape shape of the outout tensor + /// \param[out] out created Tensor + /// \return Status Code + static Status CreateFromByteList(const dataengine::BytesList &bytes_list, const TensorShape &shape, TensorPtr *out); - // create tensor from protobuf bytelist with strings - static Status CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, - const TensorShape &shape); + /// Create a tensor of type UINT8 or INT8 from a BytesList. + /// The tensor will be padded with ' ' to reach the required pad_size. + /// \param[in] bytes_list protobuf's Bytelist + /// \param[in] shape shape of the output tensor + /// \param[in] type type of created tensor. Should be DE_UINT8 or INT8 + /// \param[in] pad_size The size of the tensor after padding + /// \param[out] out created Tensor + /// \return Status Code + static Status CreateFromByteList(const dataengine::BytesList &bytes_list, const TensorShape &shape, + const DataType &type, dsize_t pad_size, TensorPtr *out); - // A static factory method to create a Tensor from a given list of numbers. - // @param ptr output argument to hold the created Tensor - // @param items elements of the tensor - // @param shape shape of the tensor - // @return Status Code + /// Create a Tensor from a given list of values. + /// \tparam type of the values to be inserted. + /// \param[in] items elements of the tensor + /// \param[in] shape shape of the output tensor + /// \param[out] out output argument to hold the created Tensor + /// \return Status Code template - static Status CreateTensor(std::shared_ptr *ptr, const std::vector &items, - const TensorShape &shape_req = TensorShape::CreateUnknownRankShape()) { + static Status CreateFromVector(const std::vector &items, const TensorShape &shape, TensorPtr *out) { + CHECK_FAIL_RETURN_UNEXPECTED( + items.size() == shape.NumOfElements(), + "Number of elements in the vector does not match the number of elements of the shape required"); DataType type = DataType::FromCType(); + // if items is empty, items_ptr would be nullptr. CreateFromMemory will handle this case. auto items_ptr = reinterpret_cast(&items[0]); - TensorShape shape = shape_req; - if (!shape.known()) { - shape = TensorShape({static_cast(items.size())}); - } - return CreateTensor(ptr, TensorImpl::kFlexible, shape, type, items_ptr); + return CreateFromMemory(shape, type, items_ptr, out); } - // A static factory method to create a Tensor from a given number. - // @param ptr output argument to hold the created Tensor - // @param item value - // @return Status Code + /// Create a 1D Tensor from a given list of values. + /// \tparam type of the values to be inserted. + /// \param[in] items elements of the tensor + /// \param[out] out output argument to hold the created Tensor + /// \return Status Code template - static Status CreateTensor(std::shared_ptr *ptr, const T &item) { - return CreateTensor(ptr, {item}, TensorShape::CreateScalar()); + static Status CreateFromVector(const std::vector &items, TensorPtr *out) { + return CreateFromVector(items, TensorShape({static_cast(items.size())}), out); } - // Create tensor from protobuf bytelist with uint8 or int8 types - static Status CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, - const TensorShape &shape, const DataType &type, dsize_t pad_size); + /// Create a numeric scalar Tensor from the given value. + /// \tparam T type of value + /// \param[in] item value + /// \param[out] out Created tensor + /// \return Status code + template + static Status CreateScalar(const T &item, TensorPtr *out) { + DataType type = DataType::FromCType(); + auto item_ptr = reinterpret_cast(&item); + return CreateFromMemory(TensorShape::CreateScalar(), type, item_ptr, out); + } - static Status CreateTensor(std::shared_ptr *ptr, const std::string &path); + /// Create a tensor from a binary file on disk. + /// \param[in] path file to be read + /// \param[out] out Created Tensor + /// \return Status code + static Status CreateFromFile(const std::string &path, TensorPtr *out); - // Copy raw data of a array based on shape and strides to the destination pointer - // @param dst Pointer to the destination array where the content is to be copied - // @param src Pointer to the source of strided array to be copied - // @param shape - shape of the source array - // @param strides - strides of the source array - // @param type_size - number of bytes needed to store one array element's type - // @return Status Code - static Status CopyStridedArray(unsigned char *dst, unsigned char *src, std::vector shape, - std::vector strides, uint8_t type_size); - - // Release the memory using the allocator + /// Destruct the tensor and release the memory using the allocator virtual ~Tensor(); - // compare the tensor shape and data + /// Equality operator. compares tensor shape, type and data + /// \param[in] rhs Tensor to be compared with + /// \return bool bool operator==(const Tensor &rhs) const; bool operator!=(const Tensor &rhs) const { return !((*this) == rhs); } - // Get item located at `index`, caller needs to provide the type. - // @tparam T - // @param index vector - // @return return the item specified at index + /// Get item located at `index`, caller needs to provide the type. + /// \tparam T + /// \param[in] index vector + /// \return return the item specified at index template Status GetItemAt(T *o, const std::vector &index) const; - // Get string located at `index`. - // @param index vector - // @return return std::string_view specified at index + /// Get string located at `index`. + /// \param[in] index vector + /// \return return std::string_view specified at index Status GetItemAt(std::string_view *o, const std::vector &index) const; template @@ -225,22 +211,21 @@ class Tensor { template Status GetFloatAt(T *o, const std::vector &index) const; - // set item at location specified by index - // @tparam `T` - // @param index - // @param value of type `T` + /// set item at location specified by index + /// \tparam `T` + /// \param[in] index + /// \param[in] value of type `T` template Status SetItemAt(const std::vector &index, const T &value) { - RETURN_IF_NOT_OK(AllocateBuffer(SizeInBytes())); T *ptr = nullptr; RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); *ptr = value; return Status::OK(); } - // set string item at location specified by index - // @param index - // @param value of type std::string + /// set string item at location specified by index + /// \param[in] index + /// \param[in] value of type std::string Status SetItemAt(const std::vector &index, const std::string &value) { RETURN_UNEXPECTED_IF_NULL(data_); uchar *ptr = nullptr; @@ -253,7 +238,8 @@ class Tensor { return Status::OK(); } - // fill tensor with Zeros. Does not support strings. + + /// fill tensor with Zeros. Does not support strings. Status Zero() { CHECK_FAIL_RETURN_UNEXPECTED(type_ != DataType::DE_STRING, "Cannot use Zero on tensor of strings.."); dsize_t size = SizeInBytes(); @@ -262,13 +248,12 @@ class Tensor { return Status::OK(); } - // Fill all elements in the Tensor with the given value of type `T`. Does not support strings. - // @tparam T - // @param value + /// Fill all elements in the Tensor with the given value of type `T`. Does not support strings. + /// \tparam T + /// \param value[in] template Status Fill(const T &value) { CHECK_FAIL_RETURN_UNEXPECTED(type_ != DataType::DE_STRING, "Cannot use fill on tensor of strings."); - RETURN_IF_NOT_OK(AllocateBuffer(SizeInBytes())); int64_t cellSize = type_.SizeInBytes(); if ((data_ != nullptr) && type_.IsCompatible()) { for (dsize_t i = 0; i < Size(); i++) { @@ -283,91 +268,86 @@ class Tensor { } } - // Getter function for shape - // @return + /// Getter function for shape + /// \return const TensorShape &shape() const { return shape_; } /// Check if tensor has data /// \return bool - true if tensor is empty - bool HasData() const; + bool HasData() const { return data_ != nullptr; } - // Reshape the tensor. The given shape should have the same number of elements in the Tensor - // @param shape + /// Reshape the tensor. The given shape should have the same number of elements in the Tensor + /// \param shape virtual Status Reshape(const TensorShape &shape); - // @return number of elements in this tensor + /// \return number of elements in this tensor dsize_t Size() const { return shape().NumOfElements(); } - // @return the number of bytes this tensor is needs + /// \return the number of bytes this tensor is needs dsize_t SizeInBytes() const { if (data_end_ == nullptr) return type_.SizeInBytes() * shape_.NumOfElements(); return data_end_ - data_; } - // @return the rank of the tensor + /// \return the rank of the tensor dsize_t Rank() const { return shape().Rank(); } - // Get the starting memory address as a constant for the data of the tensor. This potentially - // drives an allocation if the data area. - // @return const unsigned char* - const unsigned char *GetBuffer() const; + /// Get the starting memory address as a constant for the data of the tensor. This potentially + /// drives an allocation if the data area. + /// \return const unsigned char* + const unsigned char *GetBuffer() const { return data_; } - // Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if the - // tensor's type is a string, otherwise undefined address would be returned. - // @return address of the first string of the tensor. - uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; } - - // Getter of the type - // @return + /// Getter of the type + /// \return DataType type() const { return type_; } - // Provide stream operator for displaying it - // @param output stream - // @param so the Tensor object to be printed - // @return output stream + /// Provide stream operator for displaying it + /// \param output stream + /// \param so the Tensor object to be printed + /// \return output stream friend std::ostream &operator<<(std::ostream &out, const Tensor &so) { so.Print(out); return out; } - // Invalidate this Tensor by setting the type and shape to unknown and MData to null. - // Calling this method will make the Tensor and its data inaccessible, use it with caution. + /// Invalidate this Tensor by setting the type and shape to unknown and MData to null. + /// Calling this method will make the Tensor and its data inaccessible, use it with caution. void Invalidate(); - // Copy input tensor into self at the location index. - // Index is a vector of axises which can be incomplete: - // Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell. - // @param index - // @param input - // @return Status code + /// Copy input tensor into self at the location index. + /// Index is a vector of axises which can be incomplete: + /// Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell. + /// \param index + /// \param input + /// \return Status code Status InsertTensor(const std::vector &index, const std::shared_ptr &input); - // Find the address of the given index. Used in InsertTensor. - // Example: - // Tensor t= [[1,2],[3,4]] , StartAddrOfIndex({0}) -> &1 - // @param index incomplete index - // @param output: startAddrofIndex - // @param output: remaining - // @return Status code + /// Find the address of the given index. Used in InsertTensor. + /// Example: + /// Tensor t= [[1,2],[3,4]] , StartAddrOfIndex({0}) -> &1 + /// \param index incomplete index + /// \param output: startAddrofIndex + /// \param output: remaining + /// \return Status code Status StartAddrOfIndex(std::vector ind, uchar **start_addr_of_index, TensorShape *remaining); - // Expand the shape of the Tensor with one extra dimension. - // For example, if the shape is <512,512,3>: - // *- ExpandDim(0) gives: <1,512,512,3> - // *- ExpandDim(1) gives: <512,1,512,3> - // *- ExpandDim(3) gives: <512,512,3,1> - // @param axis location of the dim + /// Expand the shape of the Tensor with one extra dimension. + /// For example, if the shape is <512,512,3>: + /// *- ExpandDim(0) gives: <1,512,512,3> + /// *- ExpandDim(1) gives: <512,1,512,3> + /// *- ExpandDim(3) gives: <512,512,3,1> + /// \param axis location of the dim virtual Status ExpandDim(const dsize_t &axis); virtual void Squeeze(); - // Calculates the strides of the Tensor - // Ex: Tensor of shape <4,2,2> and type DE_UINT8 (1 byte) - // The strides will be {6,2,1}. - // Ex: Tensor of shape <4,2,2> and type DE_UINT32 (4 byte) - // The strides will be {24,8,4}. - // @return vector of integers - std::vector Strides(); + /// Calculates the strides of the Tensor + /// Ex: Tensor of shape <4,2,2> and type DE_UINT8 (1 byte) + /// The strides will be {6,2,1}. + /// Ex: Tensor of shape <4,2,2> and type DE_UINT32 (4 byte) + /// The strides will be {24,8,4}. + /// \return vector of integers + std::vector Strides() const; std::string ToString() { std::stringstream ss; @@ -375,26 +355,26 @@ class Tensor { return ss.str(); } - // Handle negative indices. + /// Handle negative indices. static inline dsize_t HandleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; } - // Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported. - // Based on the type of tensor, SliceNumeric or SliceString will be called - // @param out Tensor - // @param indices vector of indices - // @return Status error code - Status Slice(std::shared_ptr *out, const std::vector &indices); + /// Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported. + /// Based on the type of tensor, SliceNumeric or SliceString will be called + /// \param[out] out Tensor + /// \param[in] indices vector of indices + /// \return Status error code + Status Slice(TensorPtr *out, const std::vector &indices); - // Slice numeric tensors. - Status SliceNumeric(std::shared_ptr *out, const std::vector &indices); + /// Slice numeric tensors. + Status SliceNumeric(TensorPtr *out, const std::vector &indices); - // Slice string tensors - Status SliceString(std::shared_ptr *out, const std::vector &indices); + /// Slice string tensors + Status SliceString(TensorPtr *out, const std::vector &indices); #ifdef ENABLE_PYTHON - // Constructs numpy array from input tensor - // @param data this data is the location of python data - // @return Status code + /// Constructs numpy array from input tensor + /// \param[in] data this data is the location of python data + /// \return Status code Status GetDataAsNumpy(py::array *data); Status GetDataAsNumpyStrings(py::array *data); @@ -402,12 +382,12 @@ class Tensor { static Status GetBufferInfo(Tensor *t, py::buffer_info *out); #endif - // Concatenate based on given tensor, can fill in current tensor with a smaller one, unlike InsertTensor + /// Concatenate based on given tensor, can fill in current tensor with a smaller one, unlike InsertTensor Status Concatenate(const std::vector &index, const std::shared_ptr &input); - // TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor - // The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6 - // @tparam T type of values in the Tensor Iterator + /// TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor + /// The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6 + /// \tparam T type of values in the Tensor Iterator template class TensorIterator { public: @@ -498,7 +478,7 @@ class Tensor { }; // Specialization of TensorIterator for strings. It returns std::string_view for every item. - // @tparam DUMMY, used to mbe able to specialize the inner class + // \tparam DUMMY, used to mbe able to specialize the inner class template class TensorIterator { public: @@ -585,84 +565,192 @@ class Tensor { const char *data_; }; - // Return a TensorIterator that points to the start of the Tensor. - // It's the user responsibility to use the correct type that matches the Tensor type - // @param T The type of values in the Tensor - // @return TensorIterator + /// Return a TensorIterator that points to the start of the Tensor. + /// It's the user responsibility to use the correct type that matches the Tensor type + /// \tparam T The type of values in the Tensor + /// \return TensorIterator template TensorIterator begin() { - AllocateBuffer(SizeInBytes()); return TensorIterator(data_); } - // Return a linear iterator that points to the place after the last element of the Tensor. - // @tparam T The type of values in the Tensor - // @return TensorIterator + /// Return a linear iterator that points to the place after the last element of the Tensor. + /// \tparam T The type of values in the Tensor + /// \return TensorIterator template TensorIterator end() { return TensorIterator(data_end_); } - // Copies the last dimension at `index` from Tensor `src` to this Tensor. - // @param src Tensor - // @param index vector to the start of the dimension. The last dim should be 0 - // @return Status + /// Copies the last dimension at `index` from Tensor `src` to this Tensor. + /// \param[in] src Tensor + /// \param[in] index vector to the start of the dimension. The last dim should be 0 + /// \return Status Status CopyLastDimAt(const std::shared_ptr &src, const std::vector &index); protected: - // Get the starting memory address for the data of the tensor. This potentially - // drives an allocation if the data is null. - // @return unsigned char* - unsigned char *GetMutableBuffer(); + /// Allocate memory for the tensor using the data_allocator + /// \param[in] length number of bytes to be allocated + /// \return Error Status + Status AllocateBuffer(const dsize_t &length); - // A function that prints Tensor recursively, first called by print - // @param out - // @param cur_dim - // @param cur_index + /// Get the starting memory address for the data of the tensor. This potentially + /// drives an allocation if the data is null. + /// \return unsigned char* + unsigned char *GetMutableBuffer() { return data_; } + + /// A function that prints Tensor recursively, first called by print + /// \param[in] out + /// \param[in] cur_dim + /// \param[in] cur_index void PrintRecursive(std::ostream &out, int32_t cur_dim, const std::vector &cur_index) const; - // A function that prints info about the tensor - // @param out output stream + /// A function that prints info about the tensor + /// \param[out] out output stream void Print(std::ostream &out) const; - // A function that print the value as specified by its index - // @param index vector representing the index - // @param out + /// A function that print the value as specified by its index + /// \param[in] index vector representing the index + /// \param[out] out void PrintItemAt(const std::vector &index, std::ostream &out) const; - // Get pointer to item located at `index`, caller needs to provide the type. - // @tparam T - // @param index vector - // @return return a pointer to the item specified at index of type `T` + /// Get pointer to item located at `index`, caller needs to provide the type. + /// \tparam T + /// \param[in] index vector + /// \return return a pointer to the item specified at index of type `T` template Status GetItemPtr(T **, const std::vector &index) const; - // Get pointer to string located at `index` and the length of string - // @param index vector - // @return return a pointer to the string specified at index and the length of the string + /// Get pointer to string located at `index` and the length of string + /// \param[in] index vector + /// \return return a pointer to the string specified at index and the length of the string Status GetItemPtr(uchar **, const std::vector &index, offset_t *length = nullptr) const; - // Given a flat index of an item string, return the start and length of the item - // @param index flat index of the item - // @return start address of the ths string - // @return length of the string + /// Given a flat index of an item string, return the start and length of the item + /// \param[in] index flat index of the item + /// \param[out] start address of the ths string + /// \param[out] length of the string Status GetStringAt(dsize_t index, uchar **string_start, offset_t *length) const; - // all access to shape_ should be via shape + /// Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if + /// the tensor's type is a string, otherwise undefined address would be returned. \return address of the first string + /// of the tensor. + uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; } + + /// all access to shape_ should be via shape TensorShape shape_; - // data type of tensor + /// data type of tensor DataType type_; - // pointer to the start of the physical data + /// pointer to the start of the physical data unsigned char *data_; - // An allocator for data_ + /// An allocator for data_ CharAllocPtr data_allocator_; - // pointer to the end of the physical data + /// pointer to the end of the physical data unsigned char *data_end_ = nullptr; + + private: + /// Helper function to create a tensor from Numpy array of strings + /// \param[in] arr Numpy array + /// \param[out] out Created Tensor + /// \return Status + static Status CreateFromNpString(py::array arr, TensorPtr *out); + + /// Copy raw data of a array based on shape and strides to the destination pointer + /// \param dst [out] Pointer to the destination array where the content is to be copied + /// \param[in] src Pointer to the source of strided array to be copied + /// \param[in] shape shape of the source array + /// \param[in] strides strides of the source array + /// \param[in] type_size number of bytes needed to store one array element's type + /// \return Status Code + static Status CopyStridedArray(unsigned char *dst, unsigned char *src, std::vector shape, + std::vector strides, uint8_t type_size); + + /// const of the size of the offset variable + static constexpr uint8_t kOffsetSize = sizeof(offset_t); }; template <> inline Tensor::TensorIterator Tensor::end() { return TensorIterator(data_, shape_.NumOfElements()); } + +/// Create a Tensor from a given list of strings. +/// @note: The memory layout of a Tensor of strings consists of the Offset_array followed by the strings. +/// The offset array will store one extra value to find the length of the last string. +/// OFFSET_1, OFFSET_2, ..., OFFSET_n+1, STRING_1, STRING_2, ..., STRING_n +/// The value of each offset is the start index of the corresponding string +/// Offsets is of type offset_t +/// strings will ne null-terminated +/// example: Tensor(['abc', 'de'], shape={2}, type=DE_STRING) +/// |----------------------------------------------------------------| +/// | OFFSET ARRAY | STRINGS | +/// | bytes 0-3 | bytes 3-6 | bytes 7-10 | bytes 11-14 | bytes 15-17 | +/// | 11 | 15 | 18 | abc\0 | de\0 | +/// |----------------------------------------------------------------| +/// \param[in] items elements of the tensor +/// \param[in] shape shape of the output tensor +/// \param[out] out output argument to hold the created Tensor +/// \return Status Code +template <> +inline Status Tensor::CreateFromVector(const std::vector &items, const TensorShape &shape, + TensorPtr *out) { + CHECK_FAIL_RETURN_UNEXPECTED( + items.size() == shape.NumOfElements(), + "Number of elements in the vector does not match the number of elements of the shape required"); + const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); + *out = std::allocate_shared(*alloc, TensorShape({static_cast(items.size())}), + DataType(DataType::DE_STRING)); + if (items.size() == 0) { + if (shape.known()) { + return (*out)->Reshape(shape); + } + } + auto length_sum = [](dsize_t sum, const std::string &s) { return s.length() + sum; }; + dsize_t total_length = std::accumulate(items.begin(), items.end(), 0, length_sum); + + // total bytes needed = offset array + strings + // offset array needs to store one offset var per element + 1 extra to get the length of the last string. + // strings will be null-terminated --> need 1 extra byte per element + dsize_t num_bytes = (kOffsetSize + 1) * (*out)->shape_.NumOfElements() + kOffsetSize + total_length; + + (*out)->AllocateBuffer(num_bytes); + auto offset_arr = reinterpret_cast((*out)->data_); + uchar *buf = (*out)->GetStringsBuffer(); + + offset_t offset = buf - (*out)->data_; // the first string will start here + uint32_t i = 0; + for (const auto &str : items) { + // insert the start index of the string. + offset_arr[i++] = offset; + // total bytes are reduced by kOffsetSize + num_bytes -= kOffsetSize; + // insert actual string + int ret_code = memcpy_s((*out)->data_ + offset, num_bytes, common::SafeCStr(str), str.length() + 1); + if (ret_code != 0) MS_LOG(ERROR) << "Cannot copy string into Tensor"; + // next string will be stored right after the current one. + offset = offset + str.length() + 1; + // total bytes are reduced by the length of the string + num_bytes -= str.length() + 1; + } + // store one more offset value so we can get the length of the last string + // length[last_element] = offset_arr[last_element + 1] - offset_arr[last_element] + offset_arr[i] = offset; + + (*out)->data_end_ = (*out)->data_ + offset_arr[i]; + + MS_ASSERT(num_bytes == 0); + if (shape.known()) { + RETURN_IF_NOT_OK((*out)->Reshape(shape)); + } + return Status::OK(); +} +/// Create a string scalar Tensor from the given value. +/// \param[in] item value +/// \param[out] out Created tensor +/// \return Status code +template <> +inline Status Tensor::CreateScalar(const std::string &item, TensorPtr *out) { + return CreateFromVector({item}, TensorShape::CreateScalar(), out); +} } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc index 3b7fc057a2c..a460e43aead 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc @@ -141,8 +141,9 @@ Status BatchFetchRequest::RestoreOneTensor(const TensorMetaMsg *col_ts, const Re #undef CASE DataType type(dest); - std::shared_ptr ts = - std::make_shared(shape, type, static_cast(data.GetPointer()), data.GetSize()); + std::shared_ptr ts; + RETURN_IF_NOT_OK( + Tensor::CreateFromMemory(shape, type, static_cast(data.GetPointer()), data.GetSize(), &ts)); // Next we restore the real data which can be embedded or stored separately. if (ts->SizeInBytes() != data.GetSize()) { MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n" diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc index 844d0543074..63dcd4e9c8e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc @@ -176,12 +176,15 @@ Status BatchOp::BatchRows(const std::unique_ptr *src, const std::u std::shared_ptr new_tensor; if (first_type.IsNumeric()) { // numeric tensor - RETURN_IF_NOT_OK(Tensor::CreateTensor(&new_tensor, TensorImpl::kFlexible, new_shape, first_type)); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(new_shape, first_type, &new_tensor)); dsize_t j = 0; for (auto row : **src) { std::shared_ptr old_tensor = row.at(i); // row j, column i if (old_tensor->shape() == first_shape) { // check the newly popped rows have the same dim as the first - RETURN_IF_NOT_OK(new_tensor->InsertTensor({j++}, old_tensor)); + if (new_shape.NumOfElements() != 0) { + RETURN_IF_NOT_OK(new_tensor->InsertTensor({j++}, old_tensor)); + } + // Don't do anything if the tensor has no data } else { RETURN_STATUS_UNEXPECTED("[Batch ERROR] Inconsistent TensorShapes of Column " + std::to_string(i)); } @@ -194,7 +197,7 @@ Status BatchOp::BatchRows(const std::unique_ptr *src, const std::u strings.emplace_back(*itr); } } - RETURN_IF_NOT_OK(Tensor::CreateTensor(&new_tensor, strings, new_shape)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(strings, new_shape, &new_tensor)); } batched_row.emplace_back(new_tensor); } @@ -352,7 +355,7 @@ Status BatchOp::InvokeBatchMapFunc(TensorBatchTable *input, TensorBatchTable *ou py::list output_list = py::cast(ret_tuple[i]); for (size_t j = 0; j < output_list.size(); j++) { std::shared_ptr out; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, py::cast(output_list[j]))); + RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(py::cast(output_list[j]), &out)); output_batch.push_back(std::move(out)); } output->push_back(std::move(output_batch)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc index 39029918e82..3b450fb04d7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc @@ -226,7 +226,8 @@ void CacheMergeOp::TensorRowRequest::WakeUpAny(TensorRow &&row) { if (GetState() == State::kEmpty) { // We will do a deep copy for (auto &ts : row) { - auto out_ts = std::make_shared(ts->shape(), ts->type(), ts->GetBuffer(), ts->SizeInBytes()); + std::shared_ptr out_ts; + Tensor::CreateFromTensor(ts, &out_ts); cleaner_copy_.push_back(out_ts); } cleaner_copy_.setId(row.getId()); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc index d9cda4d4562..e145237a877 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc @@ -72,6 +72,7 @@ Status DeviceQueueOp::CheckExceptions(const std::unique_ptr &buffer) buffer->GetRow(0, &row); for (const auto &item : row) { CHECK_FAIL_RETURN_UNEXPECTED(item->type().IsNumeric(), "Cannot send tensor of string type to device."); + CHECK_FAIL_RETURN_UNEXPECTED(item->HasData(), "Cannot send tensor with no data."); } } return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc index 9d7d5622a67..7b374c40752 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc @@ -359,7 +359,7 @@ Status CelebAOp::LoadTensorRow(row_id_type row_id, const std::paircolumn(1).tensorImpl(), - TensorShape({1, (uint32_t)image_label.second.size()}), - data_schema_->column(1).type())); + RETURN_IF_NOT_OK( + Tensor::CreateEmpty(TensorShape({1, (uint32_t)image_label.second.size()}), data_schema_->column(1).type(), &label)); RETURN_IF_NOT_OK(label->Zero()); for (uint32_t index = 0; index < image_label.second.size(); index++) { if (image_label.second[index] == 1) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc index 06be682bfd5..b06fcdb55d4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc @@ -190,15 +190,12 @@ Status CifarOp::LoadTensorRow(uint64_t index, TensorRow *trow) { std::shared_ptr label; std::shared_ptr fine_label; std::shared_ptr ori_image = cifar_image_label_pairs_[index].first; - std::shared_ptr copy_image = - std::make_shared(ori_image->shape(), ori_image->type(), ori_image->GetBuffer()); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), data_schema_->column(1).shape(), - data_schema_->column(1).type(), - reinterpret_cast(&cifar_image_label_pairs_[index].second[0]))); + std::shared_ptr copy_image; + RETURN_IF_NOT_OK(Tensor::CreateFromTensor(ori_image, ©_image)); + RETURN_IF_NOT_OK(Tensor::CreateScalar(cifar_image_label_pairs_[index].second[0], &label)); + if (cifar_image_label_pairs_[index].second.size() > 1) { - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &fine_label, data_schema_->column(2).tensorImpl(), data_schema_->column(2).shape(), - data_schema_->column(2).type(), reinterpret_cast(&cifar_image_label_pairs_[index].second[1]))); + RETURN_IF_NOT_OK(Tensor::CreateScalar(cifar_image_label_pairs_[index].second[1], &fine_label)); (*trow) = TensorRow(index, {copy_image, std::move(label), std::move(fine_label)}); } else { (*trow) = TensorRow(index, {copy_image, std::move(label)}); @@ -359,9 +356,8 @@ Status CifarOp::ParseCifarData() { } std::shared_ptr image_tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image_tensor, data_schema_->column(0).tensorImpl(), - TensorShape({kCifarImageHeight, kCifarImageWidth, kCifarImageChannel}), - data_schema_->column(0).type())); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({kCifarImageHeight, kCifarImageWidth, kCifarImageChannel}), + data_schema_->column(0).type(), &image_tensor)); auto itr = image_tensor->begin(); uint32_t total_pix = kCifarImageHeight * kCifarImageWidth; for (int pix = 0; pix < total_pix; ++pix) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc index 958514583ae..239d323043b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc @@ -127,7 +127,7 @@ Status ClueOp::LoadTensor(const std::string &line, std::unique_ptr (*tensor_table)->push_back(std::move(tRow)); std::shared_ptr tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {line}, TensorShape::CreateScalar())); + RETURN_IF_NOT_OK(Tensor::CreateScalar(line, &tensor)); (**tensor_table)[row][0] = std::move(tensor); return Status::OK(); } @@ -144,26 +144,19 @@ Status ClueOp::GetValue(const nlohmann::json &js, std::vector key_c std::string final_str = key_chain.back(); switch (cursor.type()) { case nlohmann::detail::value_t::string: - RETURN_IF_NOT_OK(Tensor::CreateTensor(t, {cursor.get()}, TensorShape::CreateScalar())); + RETURN_IF_NOT_OK(Tensor::CreateScalar(cursor.get(), t)); break; - case nlohmann::detail::value_t::number_integer: - RETURN_IF_NOT_OK( - Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32))); - (*t)->SetItemAt({0}, cursor.get()); + RETURN_IF_NOT_OK(Tensor::CreateScalar(cursor.get(), t)); break; case nlohmann::detail::value_t::number_unsigned: - RETURN_IF_NOT_OK( - Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32))); - (*t)->SetItemAt({0}, cursor.get()); + RETURN_IF_NOT_OK(Tensor::CreateScalar(cursor.get(), t)); break; case nlohmann::detail::value_t::number_float: - RETURN_IF_NOT_OK( - Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32))); - (*t)->SetItemAt({0}, cursor.get()); + RETURN_IF_NOT_OK(Tensor::CreateScalar(cursor.get(), t)); break; case nlohmann::detail::value_t::array: - RETURN_IF_NOT_OK(Tensor::CreateTensor(t, {cursor.get>()}, TensorShape::CreateScalar())); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(cursor.get>(), t)); break; default: break; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc index da298dabf2b..dac2f8f57dd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc @@ -239,9 +239,8 @@ Status CocoOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, Te } std::vector bbox_dim = {bbox_row_num, bbox_column_num}; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&coordinate, data_schema_->column(1).tensorImpl(), TensorShape(bbox_dim), - data_schema_->column(1).type(), - reinterpret_cast(&bbox_row[0]))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(bbox_row, TensorShape(bbox_dim), &coordinate)); + if (task_type_ == TaskType::Detection) { RETURN_IF_NOT_OK(LoadDetectionTensorRow(row_id, image_id, image, coordinate, trow)); } else if (task_type_ == TaskType::Stuff || task_type_ == TaskType::Keypoint) { @@ -278,13 +277,12 @@ Status CocoOp::LoadDetectionTensorRow(row_id_type row_id, const std::string &ima iscrowd_row.push_back(annotation[i]); } } - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &category_id, data_schema_->column(2).tensorImpl(), TensorShape({static_cast(category_id_row.size()), 1}), - data_schema_->column(2).type(), reinterpret_cast(&category_id_row[0]))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector( + category_id_row, TensorShape({static_cast(category_id_row.size()), 1}), &category_id)); + + RETURN_IF_NOT_OK( + Tensor::CreateFromVector(iscrowd_row, TensorShape({static_cast(iscrowd_row.size()), 1}), &iscrowd)); - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &iscrowd, data_schema_->column(3).tensorImpl(), TensorShape({static_cast(iscrowd_row.size()), 1}), - data_schema_->column(3).type(), reinterpret_cast(&iscrowd_row[0]))); (*trow) = TensorRow(row_id, {std::move(image), std::move(coordinate), std::move(category_id), std::move(iscrowd)}); return Status::OK(); } @@ -302,9 +300,8 @@ Status CocoOp::LoadSimpleTensorRow(row_id_type row_id, const std::string &image_ item_queue = itr_item->second; std::vector bbox_dim = {static_cast(item_queue.size()), 1}; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&item, data_schema_->column(2).tensorImpl(), TensorShape(bbox_dim), - data_schema_->column(2).type(), - reinterpret_cast(&item_queue[0]))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(item_queue, TensorShape(bbox_dim), &item)); + (*trow) = TensorRow(row_id, {std::move(image), std::move(coordinate), std::move(item)}); return Status::OK(); } @@ -334,18 +331,14 @@ Status CocoOp::LoadMixTensorRow(row_id_type row_id, const std::string &image_id, area_row.push_back(annotation[i]); } } + RETURN_IF_NOT_OK(Tensor::CreateFromVector( + category_id_row, TensorShape({static_cast(category_id_row.size()), 1}), &category_id)); - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &category_id, data_schema_->column(2).tensorImpl(), TensorShape({static_cast(category_id_row.size()), 1}), - data_schema_->column(2).type(), reinterpret_cast(&category_id_row[0]))); + RETURN_IF_NOT_OK( + Tensor::CreateFromVector(iscrowd_row, TensorShape({static_cast(iscrowd_row.size()), 1}), &iscrowd)); - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &iscrowd, data_schema_->column(3).tensorImpl(), TensorShape({static_cast(iscrowd_row.size()), 1}), - data_schema_->column(3).type(), reinterpret_cast(&iscrowd_row[0]))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(area_row, TensorShape({static_cast(area_row.size()), 1}), &area)); - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &area, data_schema_->column(4).tensorImpl(), TensorShape({static_cast(area_row.size()), 1}), - data_schema_->column(4).type(), reinterpret_cast(&area_row[0]))); (*trow) = TensorRow( row_id, {std::move(image), std::move(coordinate), std::move(category_id), std::move(iscrowd), std::move(area)}); return Status::OK(); @@ -596,7 +589,7 @@ Status CocoOp::LaunchThreadsAndInitOp() { } Status CocoOp::ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, path)); + RETURN_IF_NOT_OK(Tensor::CreateFromFile(path, tensor)); if (decode_ == true) { Status rc = Decode(*tensor, tensor); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc index 0871b3f30cd..1e0347b4b6b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc @@ -102,18 +102,13 @@ int CsvOp::CsvParser::put_record(char c) { std::shared_ptr t; switch (column_default_[cur_col_]->type) { case CsvOp::INT: - Tensor::CreateTensor(&t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32)); - t->SetItemAt({0}, std::stoi(s)); + Tensor::CreateScalar(std::stoi(s), &t); break; case CsvOp::FLOAT: - Tensor::CreateTensor(&t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32)); - t->SetItemAt({0}, std::stof(s)); - break; - case CsvOp::STRING: - Tensor::CreateTensor(&t, {s}, TensorShape::CreateScalar()); + Tensor::CreateScalar(std::stof(s), &t); break; default: - Tensor::CreateTensor(&t, {s}, TensorShape::CreateScalar()); + Tensor::CreateScalar(s, &t); break; } (*tensor_table_)[cur_row_][cur_col_] = std::move(t); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc index 773dfc78b66..4af30428616 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc @@ -129,7 +129,7 @@ Status GeneratorOp::PyRowToTensorRow(py::object py_data, TensorRow *tensor_row) "Generator should return a tuple of numpy arrays."); } std::shared_ptr tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, ret_py_ele.cast())); + RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(ret_py_ele.cast(), &tensor)); if ((!column_types_.empty()) && (column_types_[i] != DataType::DE_UNKNOWN) && (column_types_[i] != tensor->type())) { return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, "Generator type check failed."); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc index 85839303db9..9a3bbccdcf9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc @@ -201,10 +201,8 @@ Status ImageFolderOp::WorkerEntry(int32_t worker_id) { // Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow in a DataBuffer Status ImageFolderOp::LoadTensorRow(row_id_type row_id, ImageLabelPair pairPtr, TensorRow *trow) { std::shared_ptr image, label; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), data_schema_->column(1).shape(), - data_schema_->column(1).type(), - reinterpret_cast(&pairPtr->second))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, folder_path_ + (pairPtr->first))); + RETURN_IF_NOT_OK(Tensor::CreateScalar(pairPtr->second, &label)); + RETURN_IF_NOT_OK(Tensor::CreateFromFile(folder_path_ + (pairPtr->first), &image)); if (decode_ == true) { Status rc = Decode(image, &image); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc index 0476baf56f1..7982c63d106 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc @@ -185,17 +185,14 @@ Status ManifestOp::LoadTensorRow(row_id_type row_id, const std::pair label_index(data.second.size()); (void)std::transform(data.second.begin(), data.second.end(), label_index.begin(), [this](const std::string &label_name) { return label_index_[label_name]; }); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(label_index, &label)); if (label_index.size() == 1) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), TensorShape({}), - data_schema_->column(1).type(), - reinterpret_cast(&label_index[0]))); + label->Reshape(TensorShape({})); } else { - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &label, data_schema_->column(1).tensorImpl(), TensorShape(std::vector(1, label_index.size())), - data_schema_->column(1).type(), reinterpret_cast(&label_index[0]))); + label->Reshape(TensorShape(std::vector(1, label_index.size()))); } - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, data.first)); + RETURN_IF_NOT_OK(Tensor::CreateFromFile(data.first, &image)); if (decode_ == true) { Status rc = Decode(image, &image); if (rc.IsError()) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc index 0886f751424..25327cea657 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc @@ -381,15 +381,15 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector(num_elements), &new_shape)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data)); + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(new_shape, type, data, &tensor)); } else { std::vector shapeDetails = {static_cast(num_elements)}; auto new_shape = TensorShape(shapeDetails); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data)); + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(new_shape, type, data, &tensor)); } tensor_row->push_back(std::move(tensor)); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc index 11ad18865e1..b3c52be60ee 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc @@ -160,12 +160,10 @@ Status MnistOp::WorkerEntry(int32_t worker_id) { // Load 1 TensorRow (image,label) using 1 MnistLabelPair. Status MnistOp::LoadTensorRow(row_id_type row_id, const MnistLabelPair &mnist_pair, TensorRow *trow) { std::shared_ptr image, label; - int32_t l = mnist_pair.second; // make a copy of cached tensor - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, data_schema_->column(0).tensorImpl(), mnist_pair.first->shape(), - mnist_pair.first->type(), mnist_pair.first->GetBuffer())); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), data_schema_->column(1).shape(), - data_schema_->column(1).type(), reinterpret_cast(&l))); + RETURN_IF_NOT_OK(Tensor::CreateFromTensor(mnist_pair.first, &image)); + RETURN_IF_NOT_OK(Tensor::CreateScalar(mnist_pair.second, &label)); + (*trow) = TensorRow(row_id, {std::move(image), std::move(label)}); return Status::OK(); } @@ -325,8 +323,8 @@ Status MnistOp::ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *la pixels[m] = (pixels[m] == 0) ? 0 : 255; } std::shared_ptr image; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, data_schema_->column(0).tensorImpl(), img_tensor_shape, - data_schema_->column(0).type(), reinterpret_cast(pixels))); + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(img_tensor_shape, data_schema_->column(0).type(), + reinterpret_cast(pixels), &image)); image_label_pairs_.emplace_back(std::make_pair(image, labels_buf[j])); } return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h index 071e39a764a..1f2e3dd7309 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h @@ -40,7 +40,7 @@ namespace dataset { template class Queue; -using MnistLabelPair = std::pair, int32_t>; +using MnistLabelPair = std::pair, uint32_t>; class MnistOp : public ParallelOp, public RandomAccessOp { public: diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc index 46f3adfa62e..8e09cc2b6c6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc @@ -361,8 +361,7 @@ Status RandomDataOp::CreateRandomRow(int32_t worker_id, TensorRow *new_row) { return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Failed to set random bytes for a tensor."); } - RETURN_IF_NOT_OK( - Tensor::CreateTensor(&new_tensor, current_col.tensorImpl(), *new_shape, current_col.type(), buf.get())); + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(*new_shape, current_col.type(), buf.get(), &new_tensor)); // Add this tensor to the tensor row for output (*new_row).push_back(std::move(new_tensor)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc index 50c67bca6c9..a501a2dcb02 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc @@ -41,7 +41,7 @@ Status PythonSampler::GetNextSample(std::unique_ptr *out_buffer) { try { py::object py_ret = py_sampler_instance.attr("_get_indices")(); py::array np_sample_ids = py_ret.cast(); - Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor + Tensor::CreateFromNpArray(np_sample_ids, &sample_ids); // copy numpy to tensor if (HasChildSampler()) { for (auto it = sample_ids->begin(); it != sample_ids->end(); ++it) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc index 60d75d2eec7..eb952b1608e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc @@ -73,9 +73,7 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr *sample_ids, int64_t col_desc_ = std::make_unique("sampleIds", DataType(DataType::DE_INT64), TensorImpl::kFlexible, 1); } TensorShape shape(std::vector(1, num_elements)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(sample_ids, col_desc_->tensorImpl(), shape, col_desc_->type())); - RETURN_IF_NOT_OK( - (*sample_ids)->AllocateBuffer((*sample_ids)->SizeInBytes())); // allocate memory in case user forgets! + RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, col_desc_->type(), sample_ids)); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc index e0b32262f5b..104d7919cee 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc @@ -146,7 +146,7 @@ Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptrpush_back(std::move(tRow)); std::shared_ptr tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {line}, TensorShape::CreateScalar())); + RETURN_IF_NOT_OK(Tensor::CreateScalar(line, &tensor)); (**tensor_table)[row][0] = std::move(tensor); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc index ae7907b5ceb..f0d40a9ba8d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc @@ -677,8 +677,7 @@ Status TFReaderOp::LoadFeature(const std::unique_ptr *tensor_table // into the tensor TensorShape current_shape = TensorShape::CreateUnknownRankShape(); RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(num_elements, ¤t_shape)); - RETURN_IF_NOT_OK( - Tensor::CreateTensor(&ts, current_col.tensorImpl(), current_shape, current_col.type(), data_ptr)); + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(current_shape, current_col.type(), data_ptr, &ts)); break; } case dataengine::Feature::KindCase::kInt64List: { @@ -735,7 +734,7 @@ Status TFReaderOp::LoadBytesList(const ColDescriptor ¤t_col, const dataeng if (current_col.type() == DataType::DE_STRING) { TensorShape shape = TensorShape::CreateScalar(); RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, &shape)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, bytes_list, shape)); + RETURN_IF_NOT_OK(Tensor::CreateFromByteList(bytes_list, shape, tensor)); return Status::OK(); } @@ -763,7 +762,7 @@ Status TFReaderOp::LoadBytesList(const ColDescriptor ¤t_col, const dataeng // know how many elements there are and the total bytes, create tensor here: TensorShape current_shape = TensorShape::CreateScalar(); RETURN_IF_NOT_OK(current_col.MaterializeTensorShape((*num_elements) * pad_size, ¤t_shape)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, bytes_list, current_shape, current_col.type(), pad_size)); + RETURN_IF_NOT_OK(Tensor::CreateFromByteList(bytes_list, current_shape, current_col.type(), pad_size, tensor)); return Status::OK(); } @@ -836,10 +835,7 @@ Status TFReaderOp::LoadIntList(const ColDescriptor ¤t_col, const dataengin // know how many elements there are, create tensor here: TensorShape current_shape = TensorShape::CreateUnknownRankShape(); RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, ¤t_shape)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, current_col.tensorImpl(), current_shape, current_col.type())); - - // Tensors are lazily allocated, this eagerly allocates memory for the tensor. - RETURN_IF_NOT_OK((*tensor)->AllocateBuffer((*tensor)->SizeInBytes())); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(current_shape, current_col.type(), tensor)); int64_t i = 0; auto it = (*tensor)->begin(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc index fcc529b6bd9..bb48d5e418d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc @@ -375,7 +375,7 @@ Status VOCOp::LaunchThreadsAndInitOp() { } Status VOCOp::ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, path)); + RETURN_IF_NOT_OK(Tensor::CreateFromFile(path, tensor)); if (decode_ == true) { Status rc = Decode(*tensor, tensor); if (rc.IsError()) { @@ -412,18 +412,10 @@ Status VOCOp::ReadAnnotationToTensor(const std::string &path, TensorRow *row) { bbox_num++; } } - RETURN_IF_NOT_OK(Tensor::CreateTensor(&bbox, data_schema_->column(1).tensorImpl(), TensorShape({bbox_num, 4}), - data_schema_->column(1).type(), - reinterpret_cast(&bbox_data[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(2).tensorImpl(), TensorShape({bbox_num, 1}), - data_schema_->column(2).type(), - reinterpret_cast(&label_data[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&difficult, data_schema_->column(3).tensorImpl(), TensorShape({bbox_num, 1}), - data_schema_->column(3).type(), - reinterpret_cast(&difficult_data[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&truncate, data_schema_->column(4).tensorImpl(), TensorShape({bbox_num, 1}), - data_schema_->column(4).type(), - reinterpret_cast(&truncate_data[0]))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(bbox_data, TensorShape({bbox_num, 4}), &bbox)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(label_data, TensorShape({bbox_num, 1}), &label)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(difficult_data, TensorShape({bbox_num, 1}), &difficult)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(truncate_data, TensorShape({bbox_num, 1}), &truncate)); (*row) = TensorRow({std::move(bbox), std::move(label), std::move(difficult), std::move(truncate)}); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.cc index 9083eb4c4b3..7cbfedcf465 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.cc @@ -57,8 +57,7 @@ Status Graph::CreateTensorByVector(const std::vector> &data, Data std::shared_ptr tensor; size_t m = data.size(); size_t n = data[0].size(); - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &tensor, TensorImpl::kFlexible, TensorShape({static_cast(m), static_cast(n)}), type, nullptr)); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({static_cast(m), static_cast(n)}), type, &tensor)); auto ptr = tensor->begin(); for (const auto &id_m : data) { CHECK_FAIL_RETURN_UNEXPECTED(id_m.size() == n, "Each member of the vector has a different size"); @@ -310,8 +309,7 @@ Status Graph::GetNodeFeature(const std::shared_ptr &nodes, const std::ve dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies()); shape = shape.PrependDim(size); std::shared_ptr fea_tensor; - RETURN_IF_NOT_OK( - Tensor::CreateTensor(&fea_tensor, TensorImpl::kFlexible, shape, default_feature->Value()->type(), nullptr)); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, default_feature->Value()->type(), &fea_tensor)); dsize_t index = 0; for (auto node_itr = nodes->begin(); node_itr != nodes->end(); ++node_itr) { @@ -358,8 +356,7 @@ Status Graph::GetEdgeFeature(const std::shared_ptr &edges, const std::ve dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies()); shape = shape.PrependDim(size); std::shared_ptr fea_tensor; - RETURN_IF_NOT_OK( - Tensor::CreateTensor(&fea_tensor, TensorImpl::kFlexible, shape, default_feature->Value()->type(), nullptr)); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, default_feature->Value()->type(), &fea_tensor)); dsize_t index = 0; for (auto edge_itr = edges->begin(); edge_itr != edges->end(); ++edge_itr) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc index 9d2c6211f40..2339b02de21 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc @@ -125,7 +125,7 @@ Status GraphLoader::LoadNode(const std::vector &col_blob, const mindrec (*feature_map)[node_type].insert(ind); if ((*default_feature)[ind] == nullptr) { std::shared_ptr zero_tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type())); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); RETURN_IF_NOT_OK(zero_tensor->Zero()); (*default_feature)[ind] = std::make_shared(ind, zero_tensor); } @@ -151,7 +151,7 @@ Status GraphLoader::LoadEdge(const std::vector &col_blob, const mindrec (*feature_map)[edge_type].insert(ind); if ((*default_feature)[ind] == nullptr) { std::shared_ptr zero_tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type())); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); RETURN_IF_NOT_OK(zero_tensor->Zero()); (*default_feature)[ind] = std::make_shared(ind, zero_tensor); } @@ -170,9 +170,9 @@ Status GraphLoader::LoadFeatureTensor(const std::string &key, const std::vector< key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape); CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); if (data == nullptr) data = reinterpret_cast(&data_ptr[0]); - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, TensorImpl::kFlexible, - std::move(TensorShape({static_cast(n_bytes / col_type_size)})), - std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), data)); + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast(n_bytes / col_type_size)})), + std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), + data, tensor)); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/include/tensor.h b/mindspore/ccsrc/minddata/dataset/include/tensor.h index 9996e8bfcf2..888c542cf99 100644 --- a/mindspore/ccsrc/minddata/dataset/include/tensor.h +++ b/mindspore/ccsrc/minddata/dataset/include/tensor.h @@ -33,6 +33,7 @@ #include "pybind11/stl.h" #endif +#include "common/utils.h" #include "minddata/dataset/core/constants.h" #include "minddata/dataset/core/data_type.h" #include "minddata/dataset/core/tensor_shape.h" @@ -50,170 +51,155 @@ class Allocator; using CharAllocPtr = std::unique_ptr>; using TensorAllocPtr = std::shared_ptr>; // An allocator shared_ptr for Tensors +using offset_t = uint32_t; // type of offset values to store strings locations +using TensorPtr = std::shared_ptr; class Tensor { public: Tensor() = delete; - - // Create a new tensor, does not internally allocate storage. This constructor is protected, use CreateTensor. - // @note The shape and type information should be known and valid. - // @param shape TensorShape - // @param type DataType - Tensor(const TensorShape &shape, const DataType &type); - - // Create a new tensor, allocates storage and copies in data. This constructor is protected, use CreateTensor. - // @note The buffer should be valid and the shape and type information should be known and valid. - // @param shape TensorShape - // @param type DataType - // @param data unsigned char*, pointer to the data. - Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data); - - Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data, const dsize_t &length); - Tensor(const Tensor &other) = delete; - Tensor &operator=(const Tensor &other) = delete; + /// Create a tensor using shape and type. This constructor should not be used directly, use CreateFromTensor instead + /// \note The shape and type information should be known and valid + /// \note The constructor does not allocate data + /// \param shape TensorShape + /// \param type DataType + Tensor(const TensorShape &shape, const DataType &type); + + /// Move constructor + /// \param other Tensor to be moved Tensor(Tensor &&other) noexcept; + /// Move assigment operator + /// \param other Tensor to be moved Tensor &operator=(Tensor &&other) noexcept; - Status AllocateBuffer(const dsize_t &length); + /// Create a numeric tensor with type and shape. Items of the tensor would be uninitialized. + /// \param[in] shape shape of the output tensor + /// \param[in] type type of the output tensor + /// \param[out] out Generated tensor + /// \return Status code + static Status CreateEmpty(const TensorShape &shape, const DataType &type, TensorPtr *out); - // type of offest values to store strings information - using offset_t = uint32_t; - // const of the size of the offset variable - static constexpr uint8_t kOffsetSize = sizeof(offset_t); - // Tensor base class which holds the data in an unsigned char* buffer. + /// Create a numeric tensor from a pointer in memory. Length of the source data is determined from the shape and type. + /// Data will be copied into the new created tensor. + /// \param[in] shape shape of the output tensor + /// \param[in] type type of the output tensor + /// \param[in] src pointer to the source data + /// \param[out] out Generated tensor + /// \return Status code + static Status CreateFromMemory(const TensorShape &shape, const DataType &type, const uchar *src, TensorPtr *out); - // Construct a scalar string Tensor - explicit Tensor(const std::string &str) : Tensor(std::vector{str}, TensorShape::CreateScalar()) {} + /// Create a tensor from a pointer in memory and length. Data will be copied into the new created tensor. + /// \param[in] shape shape of the output tensor + /// \param[in] type type of the output tensor + /// \param[in] src pointer to the source data + /// \param[in] length length of the src data + /// \param[out] out Generated tensor + /// \return Status code + static Status CreateFromMemory(const TensorShape &shape, const DataType &type, const uchar *src, + const dsize_t &length, TensorPtr *out); - // Construct a tensor from a list of strings. Reshape the tensor with `shape` if given, otherwise assume the shape is - // the size of the vector `strings`. - // The memory layout of a Tensor of strings consists of the Offset_array followed by the strings. - // Thr offset array will store one extra value to find the length of the last string. - // OFFSET1, OFFSET2, ..., OFFSETn+1, STRING1, STRING2, ..., STRINGn - // The value of each offset is the start index of the corresponding string - // Offsets is of type offest_t - // strings will ne null-terminated - // example: Tensor(['abc', 'de'], shape={2}, type=DE_STRING) - // |----------------------------------------------------------------| - // | OFFSET ARRAY | STRINGS | - // | bytes 0-3 | bytes 3-6 | bytes 7-10 | bytes 11-14 | bytes 15-17 | - // | 11 | 15 | 18 | abc\0 | de\0 | - // |----------------------------------------------------------------| - explicit Tensor(const std::vector &strings, - const TensorShape &shape = TensorShape::CreateUnknownRankShape()); - - // Same as Tensor(vector) but the input is protobuf bytelist - explicit Tensor(const dataengine::BytesList &bytes_list, - const TensorShape &shape = TensorShape::CreateUnknownRankShape()); - - // A static factory method to create the given flavour of derived Tensor - // Returns the base class reference for the Tensor. - // @param ptr output argument to hold the created Tensor of given tensor_impl - // @param tensor_impl - which implementation of Tensor - // @param shape - shape of the tensor - // @param type - datatype of the tensor - // @param data - data to be copied to Tensor new allocation - // @return Status Code - static Status CreateTensor(std::shared_ptr *, TensorImpl tensor_impl, const TensorShape &shape, DataType type, - const unsigned char *data = nullptr); - - // Create a copy of the input tensor - // @param out [out] output tensor to be generated - // @param in [in] orginal tensor to be copied - // @return Status - static Status CreateTensor(std::shared_ptr *out, const std::shared_ptr &in) { - const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *out = std::allocate_shared(*alloc, in->shape(), in->type(), in->GetBuffer(), in->SizeInBytes()); - return Status::OK(); + /// Create a copy of the input tensor + /// \param[in] in original tensor to be copied + /// \param[out] out output tensor to be generated + /// \return Status + static Status CreateFromTensor(const TensorPtr &in, TensorPtr *out) { + return CreateFromMemory(in->shape(), in->type(), in->GetBuffer(), in->SizeInBytes(), out); } #ifdef ENABLE_PYTHON - // A static factory method to create a Tensor from a given py::array. - // @param ptr output argument to hold the created Tensor - // @param arr py::array - // @return Status Code - static Status CreateTensor(std::shared_ptr *ptr, py::array arr); - - // Helper function to create a tensor from Numpy of strings - static Status CreateTensorFromNumpyString(std::shared_ptr *ptr, py::array arr); + /// Create a Tensor from a given py::array + /// \param[in] arr py::array + /// \param[out] out Created tensor + /// \return Status Code + static Status CreateFromNpArray(const py::array &arr, TensorPtr *out); #endif - // A static factory method to create a Tensor from a given list of strings. - // @param ptr output argument to hold the created Tensor - // @param strings elements of the tensor - // @param shape shape of the tensor - // @return Status Code - static Status CreateTensor(std::shared_ptr *ptr, const std::vector &strings, - const TensorShape &shape = TensorShape::CreateUnknownRankShape()); + /// Create a tensor of type DE_STRING from a BytesList. + /// \param[in] bytes_list protobuf's Bytelist + /// \param[in] shape shape of the outout tensor + /// \param[out] out created Tensor + /// \return Status Code + static Status CreateFromByteList(const dataengine::BytesList &bytes_list, const TensorShape &shape, TensorPtr *out); - // create tensor from protobuf bytelist with strings - static Status CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, - const TensorShape &shape); + /// Create a tensor of type UINT8 or INT8 from a BytesList. + /// The tensor will be padded with ' ' to reach the required pad_size. + /// \param[in] bytes_list protobuf's Bytelist + /// \param[in] shape shape of the output tensor + /// \param[in] type type of created tensor. Should be DE_UINT8 or INT8 + /// \param[in] pad_size The size of the tensor after padding + /// \param[out] out created Tensor + /// \return Status Code + static Status CreateFromByteList(const dataengine::BytesList &bytes_list, const TensorShape &shape, + const DataType &type, dsize_t pad_size, TensorPtr *out); - // A static factory method to create a Tensor from a given list of numbers. - // @param ptr output argument to hold the created Tensor - // @param items elements of the tensor - // @param shape shape of the tensor - // @return Status Code + /// Create a Tensor from a given list of values. + /// \tparam type of the values to be inserted. + /// \param[in] items elements of the tensor + /// \param[in] shape shape of the output tensor + /// \param[out] out output argument to hold the created Tensor + /// \return Status Code template - static Status CreateTensor(std::shared_ptr *ptr, const std::vector &items, - const TensorShape &shape_req = TensorShape::CreateUnknownRankShape()) { + static Status CreateFromVector(const std::vector &items, const TensorShape &shape, TensorPtr *out) { + CHECK_FAIL_RETURN_UNEXPECTED( + items.size() == shape.NumOfElements(), + "Number of elements in the vector does not match the number of elements of the shape required"); DataType type = DataType::FromCType(); + // if items is empty, items_ptr would be nullptr. CreateFromMemory will handle this case. auto items_ptr = reinterpret_cast(&items[0]); - TensorShape shape = shape_req; - if (!shape.known()) { - shape = TensorShape({static_cast(items.size())}); - } - return CreateTensor(ptr, TensorImpl::kFlexible, shape, type, items_ptr); + return CreateFromMemory(shape, type, items_ptr, out); } - // A static factory method to create a Tensor from a given number. - // @param ptr output argument to hold the created Tensor - // @param item value - // @return Status Code + /// Create a 1D Tensor from a given list of values. + /// \tparam type of the values to be inserted. + /// \param[in] items elements of the tensor + /// \param[out] out output argument to hold the created Tensor + /// \return Status Code template - static Status CreateTensor(std::shared_ptr *ptr, const T &item) { - return CreateTensor(ptr, {item}, TensorShape::CreateScalar()); + static Status CreateFromVector(const std::vector &items, TensorPtr *out) { + return CreateFromVector(items, TensorShape({static_cast(items.size())}), out); } - // Create tensor from protobuf bytelist with uint8 or int8 types - static Status CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, - const TensorShape &shape, const DataType &type, dsize_t pad_size); + /// Create a numeric scalar Tensor from the given value. + /// \tparam T type of value + /// \param[in] item value + /// \param[out] out Created tensor + /// \return Status code + template + static Status CreateScalar(const T &item, TensorPtr *out) { + DataType type = DataType::FromCType(); + auto item_ptr = reinterpret_cast(&item); + return CreateFromMemory(TensorShape::CreateScalar(), type, item_ptr, out); + } - static Status CreateTensor(std::shared_ptr *ptr, const std::string &path); + /// Create a tensor from a binary file on disk. + /// \param[in] path file to be read + /// \param[out] out Created Tensor + /// \return Status code + static Status CreateFromFile(const std::string &path, TensorPtr *out); - // Copy raw data of a array based on shape and strides to the destination pointer - // @param dst Pointer to the destination array where the content is to be copied - // @param src Pointer to the source of strided array to be copied - // @param shape - shape of the source array - // @param strides - strides of the source array - // @param type_size - number of bytes needed to store one array element's type - // @return Status Code - static Status CopyStridedArray(unsigned char *dst, unsigned char *src, std::vector shape, - std::vector strides, uint8_t type_size); - - // Release the memory using the allocator + /// Destruct the tensor and release the memory using the allocator virtual ~Tensor(); - // compare the tensor shape and data + /// Equality operator. compares tensor shape, type and data + /// \param[in] rhs Tensor to be compared with + /// \return bool bool operator==(const Tensor &rhs) const; bool operator!=(const Tensor &rhs) const { return !((*this) == rhs); } - // Get item located at `index`, caller needs to provide the type. - // @tparam T - // @param index vector - // @return return the item specified at index + /// Get item located at `index`, caller needs to provide the type. + /// \tparam T + /// \param[in] index vector + /// \return return the item specified at index template Status GetItemAt(T *o, const std::vector &index) const; - // Get string located at `index`. - // @param index vector - // @return return std::string_view specified at index + /// Get string located at `index`. + /// \param[in] index vector + /// \return return std::string_view specified at index Status GetItemAt(std::string_view *o, const std::vector &index) const; template @@ -225,22 +211,21 @@ class Tensor { template Status GetFloatAt(T *o, const std::vector &index) const; - // set item at location specified by index - // @tparam `T` - // @param index - // @param value of type `T` + /// set item at location specified by index + /// \tparam `T` + /// \param[in] index + /// \param[in] value of type `T` template Status SetItemAt(const std::vector &index, const T &value) { - RETURN_IF_NOT_OK(AllocateBuffer(SizeInBytes())); T *ptr = nullptr; RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); *ptr = value; return Status::OK(); } - // set string item at location specified by index - // @param index - // @param value of type std::string + /// set string item at location specified by index + /// \param[in] index + /// \param[in] value of type std::string Status SetItemAt(const std::vector &index, const std::string &value) { RETURN_UNEXPECTED_IF_NULL(data_); uchar *ptr = nullptr; @@ -253,7 +238,8 @@ class Tensor { return Status::OK(); } - // fill tensor with Zeros. Does not support strings. + + /// fill tensor with Zeros. Does not support strings. Status Zero() { CHECK_FAIL_RETURN_UNEXPECTED(type_ != DataType::DE_STRING, "Cannot use Zero on tensor of strings.."); dsize_t size = SizeInBytes(); @@ -262,13 +248,12 @@ class Tensor { return Status::OK(); } - // Fill all elements in the Tensor with the given value of type `T`. Does not support strings. - // @tparam T - // @param value + /// Fill all elements in the Tensor with the given value of type `T`. Does not support strings. + /// \tparam T + /// \param value[in] template Status Fill(const T &value) { CHECK_FAIL_RETURN_UNEXPECTED(type_ != DataType::DE_STRING, "Cannot use fill on tensor of strings."); - RETURN_IF_NOT_OK(AllocateBuffer(SizeInBytes())); int64_t cellSize = type_.SizeInBytes(); if ((data_ != nullptr) && type_.IsCompatible()) { for (dsize_t i = 0; i < Size(); i++) { @@ -283,91 +268,86 @@ class Tensor { } } - // Getter function for shape - // @return + /// Getter function for shape + /// \return const TensorShape &shape() const { return shape_; } /// Check if tensor has data /// \return bool - true if tensor is empty - bool HasData() const; + bool HasData() const { return data_ != nullptr; } - // Reshape the tensor. The given shape should have the same number of elements in the Tensor - // @param shape + /// Reshape the tensor. The given shape should have the same number of elements in the Tensor + /// \param shape virtual Status Reshape(const TensorShape &shape); - // @return number of elements in this tensor + /// \return number of elements in this tensor dsize_t Size() const { return shape().NumOfElements(); } - // @return the number of bytes this tensor is needs + /// \return the number of bytes this tensor is needs dsize_t SizeInBytes() const { if (data_end_ == nullptr) return type_.SizeInBytes() * shape_.NumOfElements(); return data_end_ - data_; } - // @return the rank of the tensor + /// \return the rank of the tensor dsize_t Rank() const { return shape().Rank(); } - // Get the starting memory address as a constant for the data of the tensor. This potentially - // drives an allocation if the data area. - // @return const unsigned char* - const unsigned char *GetBuffer() const; + /// Get the starting memory address as a constant for the data of the tensor. This potentially + /// drives an allocation if the data area. + /// \return const unsigned char* + const unsigned char *GetBuffer() const { return data_; } - // Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if the - // tensor's type is a string, otherwise undefined address would be returned. - // @return address of the first string of the tensor. - uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; } - - // Getter of the type - // @return + /// Getter of the type + /// \return DataType type() const { return type_; } - // Provide stream operator for displaying it - // @param output stream - // @param so the Tensor object to be printed - // @return output stream + /// Provide stream operator for displaying it + /// \param output stream + /// \param so the Tensor object to be printed + /// \return output stream friend std::ostream &operator<<(std::ostream &out, const Tensor &so) { so.Print(out); return out; } - // Invalidate this Tensor by setting the type and shape to unknown and MData to null. - // Calling this method will make the Tensor and its data inaccessible, use it with caution. + /// Invalidate this Tensor by setting the type and shape to unknown and MData to null. + /// Calling this method will make the Tensor and its data inaccessible, use it with caution. void Invalidate(); - // Copy input tensor into self at the location index. - // Index is a vector of axises which can be incomplete: - // Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell. - // @param index - // @param input - // @return Status code + /// Copy input tensor into self at the location index. + /// Index is a vector of axises which can be incomplete: + /// Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell. + /// \param index + /// \param input + /// \return Status code Status InsertTensor(const std::vector &index, const std::shared_ptr &input); - // Find the address of the given index. Used in InsertTensor. - // Example: - // Tensor t= [[1,2],[3,4]] , StartAddrOfIndex({0}) -> &1 - // @param index incomplete index - // @param output: startAddrofIndex - // @param output: remaining - // @return Status code + /// Find the address of the given index. Used in InsertTensor. + /// Example: + /// Tensor t= [[1,2],[3,4]] , StartAddrOfIndex({0}) -> &1 + /// \param index incomplete index + /// \param output: startAddrofIndex + /// \param output: remaining + /// \return Status code Status StartAddrOfIndex(std::vector ind, uchar **start_addr_of_index, TensorShape *remaining); - // Expand the shape of the Tensor with one extra dimension. - // For example, if the shape is <512,512,3>: - // *- ExpandDim(0) gives: <1,512,512,3> - // *- ExpandDim(1) gives: <512,1,512,3> - // *- ExpandDim(3) gives: <512,512,3,1> - // @param axis location of the dim + /// Expand the shape of the Tensor with one extra dimension. + /// For example, if the shape is <512,512,3>: + /// *- ExpandDim(0) gives: <1,512,512,3> + /// *- ExpandDim(1) gives: <512,1,512,3> + /// *- ExpandDim(3) gives: <512,512,3,1> + /// \param axis location of the dim virtual Status ExpandDim(const dsize_t &axis); virtual void Squeeze(); - // Calculates the strides of the Tensor - // Ex: Tensor of shape <4,2,2> and type DE_UINT8 (1 byte) - // The strides will be {6,2,1}. - // Ex: Tensor of shape <4,2,2> and type DE_UINT32 (4 byte) - // The strides will be {24,8,4}. - // @return vector of integers - std::vector Strides(); + /// Calculates the strides of the Tensor + /// Ex: Tensor of shape <4,2,2> and type DE_UINT8 (1 byte) + /// The strides will be {6,2,1}. + /// Ex: Tensor of shape <4,2,2> and type DE_UINT32 (4 byte) + /// The strides will be {24,8,4}. + /// \return vector of integers + std::vector Strides() const; std::string ToString() { std::stringstream ss; @@ -375,26 +355,26 @@ class Tensor { return ss.str(); } - // Handle negative indices. + /// Handle negative indices. static inline dsize_t HandleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; } - // Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported. - // Based on the type of tensor, SliceNumeric or SliceString will be called - // @param out Tensor - // @param indices vector of indices - // @return Status error code - Status Slice(std::shared_ptr *out, const std::vector &indices); + /// Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported. + /// Based on the type of tensor, SliceNumeric or SliceString will be called + /// \param[out] out Tensor + /// \param[in] indices vector of indices + /// \return Status error code + Status Slice(TensorPtr *out, const std::vector &indices); - // Slice numeric tensors. - Status SliceNumeric(std::shared_ptr *out, const std::vector &indices); + /// Slice numeric tensors. + Status SliceNumeric(TensorPtr *out, const std::vector &indices); - // Slice string tensors - Status SliceString(std::shared_ptr *out, const std::vector &indices); + /// Slice string tensors + Status SliceString(TensorPtr *out, const std::vector &indices); #ifdef ENABLE_PYTHON - // Constructs numpy array from input tensor - // @param data this data is the location of python data - // @return Status code + /// Constructs numpy array from input tensor + /// \param[in] data this data is the location of python data + /// \return Status code Status GetDataAsNumpy(py::array *data); Status GetDataAsNumpyStrings(py::array *data); @@ -402,12 +382,12 @@ class Tensor { static Status GetBufferInfo(Tensor *t, py::buffer_info *out); #endif - // Concatenate based on given tensor, can fill in current tensor with a smaller one, unlike InsertTensor + /// Concatenate based on given tensor, can fill in current tensor with a smaller one, unlike InsertTensor Status Concatenate(const std::vector &index, const std::shared_ptr &input); - // TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor - // The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6 - // @tparam T type of values in the Tensor Iterator + /// TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor + /// The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6 + /// \tparam T type of values in the Tensor Iterator template class TensorIterator { public: @@ -498,7 +478,7 @@ class Tensor { }; // Specialization of TensorIterator for strings. It returns std::string_view for every item. - // @tparam DUMMY, used to mbe able to specialize the inner class + // \tparam DUMMY, used to mbe able to specialize the inner class template class TensorIterator { public: @@ -585,84 +565,192 @@ class Tensor { const char *data_; }; - // Return a TensorIterator that points to the start of the Tensor. - // It's the user responsibility to use the correct type that matches the Tensor type - // @param T The type of values in the Tensor - // @return TensorIterator + /// Return a TensorIterator that points to the start of the Tensor. + /// It's the user responsibility to use the correct type that matches the Tensor type + /// \tparam T The type of values in the Tensor + /// \return TensorIterator template TensorIterator begin() { - AllocateBuffer(SizeInBytes()); return TensorIterator(data_); } - // Return a linear iterator that points to the place after the last element of the Tensor. - // @tparam T The type of values in the Tensor - // @return TensorIterator + /// Return a linear iterator that points to the place after the last element of the Tensor. + /// \tparam T The type of values in the Tensor + /// \return TensorIterator template TensorIterator end() { return TensorIterator(data_end_); } - // Copies the last dimension at `index` from Tensor `src` to this Tensor. - // @param src Tensor - // @param index vector to the start of the dimension. The last dim should be 0 - // @return Status + /// Copies the last dimension at `index` from Tensor `src` to this Tensor. + /// \param[in] src Tensor + /// \param[in] index vector to the start of the dimension. The last dim should be 0 + /// \return Status Status CopyLastDimAt(const std::shared_ptr &src, const std::vector &index); protected: - // Get the starting memory address for the data of the tensor. This potentially - // drives an allocation if the data is null. - // @return unsigned char* - unsigned char *GetMutableBuffer(); + /// Allocate memory for the tensor using the data_allocator + /// \param[in] length number of bytes to be allocated + /// \return Error Status + Status AllocateBuffer(const dsize_t &length); - // A function that prints Tensor recursively, first called by print - // @param out - // @param cur_dim - // @param cur_index + /// Get the starting memory address for the data of the tensor. This potentially + /// drives an allocation if the data is null. + /// \return unsigned char* + unsigned char *GetMutableBuffer() { return data_; } + + /// A function that prints Tensor recursively, first called by print + /// \param[in] out + /// \param[in] cur_dim + /// \param[in] cur_index void PrintRecursive(std::ostream &out, int32_t cur_dim, const std::vector &cur_index) const; - // A function that prints info about the tensor - // @param out output stream + /// A function that prints info about the tensor + /// \param[out] out output stream void Print(std::ostream &out) const; - // A function that print the value as specified by its index - // @param index vector representing the index - // @param out + /// A function that print the value as specified by its index + /// \param[in] index vector representing the index + /// \param[out] out void PrintItemAt(const std::vector &index, std::ostream &out) const; - // Get pointer to item located at `index`, caller needs to provide the type. - // @tparam T - // @param index vector - // @return return a pointer to the item specified at index of type `T` + /// Get pointer to item located at `index`, caller needs to provide the type. + /// \tparam T + /// \param[in] index vector + /// \return return a pointer to the item specified at index of type `T` template Status GetItemPtr(T **, const std::vector &index) const; - // Get pointer to string located at `index` and the length of string - // @param index vector - // @return return a pointer to the string specified at index and the length of the string + /// Get pointer to string located at `index` and the length of string + /// \param[in] index vector + /// \return return a pointer to the string specified at index and the length of the string Status GetItemPtr(uchar **, const std::vector &index, offset_t *length = nullptr) const; - // Given a flat index of an item string, return the start and length of the item - // @param index flat index of the item - // @return start address of the ths string - // @return length of the string + /// Given a flat index of an item string, return the start and length of the item + /// \param[in] index flat index of the item + /// \param[out] start address of the ths string + /// \param[out] length of the string Status GetStringAt(dsize_t index, uchar **string_start, offset_t *length) const; - // all access to shape_ should be via shape + /// Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if + /// the tensor's type is a string, otherwise undefined address would be returned. \return address of the first string + /// of the tensor. + uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; } + + /// all access to shape_ should be via shape TensorShape shape_; - // data type of tensor + /// data type of tensor DataType type_; - // pointer to the start of the physical data + /// pointer to the start of the physical data unsigned char *data_; - // An allocator for data_ + /// An allocator for data_ CharAllocPtr data_allocator_; - // pointer to the end of the physical data + /// pointer to the end of the physical data unsigned char *data_end_ = nullptr; + + private: + /// Helper function to create a tensor from Numpy array of strings + /// \param[in] arr Numpy array + /// \param[out] out Created Tensor + /// \return Status + static Status CreateFromNpString(py::array arr, TensorPtr *out); + + /// Copy raw data of a array based on shape and strides to the destination pointer + /// \param dst [out] Pointer to the destination array where the content is to be copied + /// \param[in] src Pointer to the source of strided array to be copied + /// \param[in] shape shape of the source array + /// \param[in] strides strides of the source array + /// \param[in] type_size number of bytes needed to store one array element's type + /// \return Status Code + static Status CopyStridedArray(unsigned char *dst, unsigned char *src, std::vector shape, + std::vector strides, uint8_t type_size); + + /// const of the size of the offset variable + static constexpr uint8_t kOffsetSize = sizeof(offset_t); }; template <> inline Tensor::TensorIterator Tensor::end() { return TensorIterator(data_, shape_.NumOfElements()); } + +/// Create a Tensor from a given list of strings. +/// @note: The memory layout of a Tensor of strings consists of the Offset_array followed by the strings. +/// The offset array will store one extra value to find the length of the last string. +/// OFFSET_1, OFFSET_2, ..., OFFSET_n+1, STRING_1, STRING_2, ..., STRING_n +/// The value of each offset is the start index of the corresponding string +/// Offsets is of type offset_t +/// strings will ne null-terminated +/// example: Tensor(['abc', 'de'], shape={2}, type=DE_STRING) +/// |----------------------------------------------------------------| +/// | OFFSET ARRAY | STRINGS | +/// | bytes 0-3 | bytes 3-6 | bytes 7-10 | bytes 11-14 | bytes 15-17 | +/// | 11 | 15 | 18 | abc\0 | de\0 | +/// |----------------------------------------------------------------| +/// \param[in] items elements of the tensor +/// \param[in] shape shape of the output tensor +/// \param[out] out output argument to hold the created Tensor +/// \return Status Code +template <> +inline Status Tensor::CreateFromVector(const std::vector &items, const TensorShape &shape, + TensorPtr *out) { + CHECK_FAIL_RETURN_UNEXPECTED( + items.size() == shape.NumOfElements(), + "Number of elements in the vector does not match the number of elements of the shape required"); + const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); + *out = std::allocate_shared(*alloc, TensorShape({static_cast(items.size())}), + DataType(DataType::DE_STRING)); + if (items.size() == 0) { + if (shape.known()) { + return (*out)->Reshape(shape); + } + } + auto length_sum = [](dsize_t sum, const std::string &s) { return s.length() + sum; }; + dsize_t total_length = std::accumulate(items.begin(), items.end(), 0, length_sum); + + // total bytes needed = offset array + strings + // offset array needs to store one offset var per element + 1 extra to get the length of the last string. + // strings will be null-terminated --> need 1 extra byte per element + dsize_t num_bytes = (kOffsetSize + 1) * (*out)->shape_.NumOfElements() + kOffsetSize + total_length; + + (*out)->AllocateBuffer(num_bytes); + auto offset_arr = reinterpret_cast((*out)->data_); + uchar *buf = (*out)->GetStringsBuffer(); + + offset_t offset = buf - (*out)->data_; // the first string will start here + uint32_t i = 0; + for (const auto &str : items) { + // insert the start index of the string. + offset_arr[i++] = offset; + // total bytes are reduced by kOffsetSize + num_bytes -= kOffsetSize; + // insert actual string + int ret_code = memcpy_s((*out)->data_ + offset, num_bytes, common::SafeCStr(str), str.length() + 1); + if (ret_code != 0) MS_LOG(ERROR) << "Cannot copy string into Tensor"; + // next string will be stored right after the current one. + offset = offset + str.length() + 1; + // total bytes are reduced by the length of the string + num_bytes -= str.length() + 1; + } + // store one more offset value so we can get the length of the last string + // length[last_element] = offset_arr[last_element + 1] - offset_arr[last_element] + offset_arr[i] = offset; + + (*out)->data_end_ = (*out)->data_ + offset_arr[i]; + + MS_ASSERT(num_bytes == 0); + if (shape.known()) { + RETURN_IF_NOT_OK((*out)->Reshape(shape)); + } + return Status::OK(); +} +/// Create a string scalar Tensor from the given value. +/// \param[in] item value +/// \param[out] out Created tensor +/// \return Status code +template <> +inline Status Tensor::CreateScalar(const std::string &item, TensorPtr *out) { + return CreateFromVector({item}, TensorShape::CreateScalar(), out); +} } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc index b1d51a6c081..267120851b1 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc @@ -97,7 +97,7 @@ Status OneHotEncoding(std::shared_ptr input, std::shared_ptr *ou if (input->Rank() == 1) num_elements = input->shape()[0]; TensorShape out_shape({num_elements, num_classes}); std::shared_ptr out; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, out_shape, input->type())); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(out_shape, input->type(), &out)); RETURN_IF_NOT_OK(out->Zero()); for (dsize_t i = 0; i < num_elements; ++i) { if (input->type().IsUnsignedInt()) { @@ -133,7 +133,9 @@ Status Fill(const std::shared_ptr input, std::shared_ptr *output fill_output = fill_value; } - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input_shape, input_type)); + if (input_type.IsNumeric()) { + RETURN_IF_NOT_OK(Tensor::CreateEmpty(input_shape, input_type, &out)); + } switch (input_type.value()) { case DataType::DE_BOOL: { @@ -216,7 +218,7 @@ Status Fill(const std::shared_ptr input, std::shared_ptr *output for (int i = 0; i < input_shape.NumOfElements(); i++) { strings.emplace_back(fill_string); } - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input_shape)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(strings, input_shape, &out)); break; } case DataType::DE_UNKNOWN: { @@ -285,9 +287,8 @@ void CastFrom(const std::shared_ptr &input, std::shared_ptr *out // Type cast operator Status TypeCast(const std::shared_ptr &input, std::shared_ptr *output, const DataType &data_type) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), data_type)); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), data_type, output)); - RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes())); switch (input->type().value()) { case DataType::DE_BOOL: CastFrom(input, output); @@ -335,8 +336,7 @@ Status TypeCast(const std::shared_ptr &input, std::shared_ptr *o Status ToFloat16(const std::shared_ptr &input, std::shared_ptr *output) { // initiate new tensor for type cast DataType new_type = DataType("float16"); - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), new_type)); - RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes())); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), new_type, output)); auto in_itr = input->begin(); auto out_itr = (*output)->begin(); @@ -387,7 +387,7 @@ Status PadEndNumeric(const std::shared_ptr &src, std::shared_ptr (*dst) = src; // if no padding, copy the pointer } else { CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed"); - RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, TensorImpl::kFlexible, TensorShape(pad_shape), src->type())); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape(pad_shape), src->type(), dst)); auto tensor_type = src->type().value(); if (pad_val == 0) { // if pad with zero, don't care what type it is RETURN_IF_NOT_OK((*dst)->Zero()); @@ -447,7 +447,7 @@ Status PadEndString(const std::shared_ptr &src, std::shared_ptr std::vector cur_ind(src->Rank(), 0); std::vector strings; RETURN_IF_NOT_OK(PadEndStringHelper(src, &strings, TensorShape(pad_shape), cur_ind, 0, pad_val)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, strings, TensorShape(pad_shape))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(strings, TensorShape(pad_shape), dst)); } return Status::OK(); } @@ -521,7 +521,7 @@ Status Mask(const std::shared_ptr &input, std::shared_ptr *outpu "Cannot convert constant value to the type of the input tensor."); CHECK_FAIL_RETURN_UNEXPECTED(value->shape() == TensorShape::CreateScalar(), "Value is not a scalar"); - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), DataType(DataType::DE_BOOL))); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), DataType(DataType::DE_BOOL), output)); std::unique_ptr value_cast_op(new TypeCastOp(input->type())); std::shared_ptr casted_value; @@ -629,7 +629,7 @@ Status ConcatenateHelper(const std::shared_ptr &input, std::shared_ptr out; if (input->type().IsNumeric()) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, t, input->type())); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(t, input->type(), &out)); RETURN_IF_NOT_OK(out->Concatenate({0}, input)); RETURN_IF_NOT_OK(out->Concatenate({input->shape()[0]}, append)); @@ -645,7 +645,7 @@ Status ConcatenateHelper(const std::shared_ptr &input, std::shared_ptrend(); itr++) { strings.emplace_back(*itr); } - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, t)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(strings, t, &out)); *output = out; } diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.cc index 57a424704f0..c7fc6c1d7ec 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.cc @@ -26,7 +26,7 @@ Status DuplicateOp::Compute(const TensorRow &input, TensorRow *output) { IO_CHECK_VECTOR(input, output); CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); std::shared_ptr out; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, input[0])); + RETURN_IF_NOT_OK(Tensor::CreateFromTensor(input[0], &out)); output->push_back(input[0]); output->push_back(out); return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc index f0f2fcb8523..acddc765d90 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc @@ -63,9 +63,8 @@ int GetCVBorderType(BorderType type) { Status Flip(std::shared_ptr input, std::shared_ptr *output, int flip_code) { std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input)); - std::shared_ptr output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); - RETURN_IF_NOT_OK(output_cv->AllocateBuffer(output_cv->SizeInBytes())); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); if (input_cv->mat().data) { try { @@ -110,8 +109,9 @@ Status Resize(const std::shared_ptr &input, std::shared_ptr *out TensorShape shape{output_height, output_width}; int num_channels = input_cv->shape()[2]; if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); - std::shared_ptr output_cv = std::make_shared(shape, input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(shape, input_cv->type(), &output_cv)); + auto cv_mode = GetCVInterpolationMode(mode); cv::resize(in_image, output_cv->mat(), cv::Size(output_width, output_height), fx, fy, cv_mode); *output = std::static_pointer_cast(output_cv); @@ -147,8 +147,8 @@ Status DecodeCv(const std::shared_ptr &input, std::shared_ptr *o RETURN_STATUS_UNEXPECTED(err); } cv::cvtColor(img_mat, img_mat, static_cast(cv::COLOR_BGR2RGB)); - std::shared_ptr output_cv = std::make_shared(img_mat); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateFromMat(img_mat, &output_cv)); *output = std::static_pointer_cast(output_cv); return Status::OK(); } catch (const cv::Exception &e) { @@ -309,7 +309,8 @@ Status JpegCropAndDecode(const std::shared_ptr &input, std::shared_ptr(ts, DataType(DataType::DE_UINT8)); + std::shared_ptr output_tensor; + RETURN_IF_NOT_OK(Tensor::CreateEmpty(ts, DataType(DataType::DE_UINT8), &output_tensor)); const int buffer_size = output_tensor->SizeInBytes(); JSAMPLE *buffer = reinterpret_cast(&(*output_tensor->begin())); const int max_scanlines_to_read = skipped_scanlines + crop_h; @@ -331,8 +332,8 @@ Status Rescale(const std::shared_ptr &input, std::shared_ptr *ou RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); } cv::Mat input_image = input_cv->mat(); - std::shared_ptr output_cv = std::make_shared(input_cv->shape(), DataType(DataType::DE_FLOAT32)); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), DataType(DataType::DE_FLOAT32), &output_cv)); try { input_image.convertTo(output_cv->mat(), CV_32F, rescale, shift); *output = std::static_pointer_cast(output_cv); @@ -354,8 +355,8 @@ Status Crop(const std::shared_ptr &input, std::shared_ptr *outpu TensorShape shape{h, w}; int num_channels = input_cv->shape()[2]; if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); - std::shared_ptr output_cv = std::make_shared(shape, input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(shape, input_cv->type(), &output_cv)); cv::Rect roi(x, y, w, h); (input_cv->mat())(roi).copyTo(output_cv->mat()); *output = std::static_pointer_cast(output_cv); @@ -386,10 +387,11 @@ Status HwcToChw(std::shared_ptr input, std::shared_ptr *output) int height = input_cv->shape()[0]; int width = input_cv->shape()[1]; - auto output_cv = std::make_unique(TensorShape{num_channels, height, width}, input_cv->type()); + std::shared_ptr output_cv; + CVTensor::CreateEmpty(TensorShape{num_channels, height, width}, input_cv->type(), &output_cv); for (int i = 0; i < num_channels; ++i) { cv::Mat mat; - RETURN_IF_NOT_OK(output_cv->Mat({i}, &mat)); + RETURN_IF_NOT_OK(output_cv->MatAtIndex({i}, &mat)); cv::extractChannel(input_cv->mat(), mat, i); } *output = std::move(output_cv); @@ -406,8 +408,9 @@ Status SwapRedAndBlue(std::shared_ptr input, std::shared_ptr *ou if (input_cv->shape().Size() != 3 || num_channels != 3) { RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); } - auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); + cv::cvtColor(input_cv->mat(), output_cv->mat(), static_cast(cv::COLOR_BGR2RGB)); *output = std::static_pointer_cast(output_cv); return Status::OK(); @@ -440,8 +443,8 @@ Status CropAndResize(const std::shared_ptr &input, std::shared_ptrshape()[2]; if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); - std::shared_ptr cvt_out = std::make_shared(shape, input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(cvt_out); + std::shared_ptr cvt_out; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(shape, input_cv->type(), &cvt_out)); cv::resize(cv_in(roi), cvt_out->mat(), cv::Size(target_width, target_height), 0, 0, cv_mode); *output = std::static_pointer_cast(cvt_out); return Status::OK(); @@ -475,8 +478,7 @@ Status Rotate(const std::shared_ptr &input, std::shared_ptr *out if (!expand) { // this case means that the shape doesn't change, size stays the same // We may not need this memcpy if it is in place. - output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); // using inter_nearest to comply with python default cv::warpAffine(input_img, output_cv->mat(), rot, input_img.size(), GetCVInterpolationMode(interpolation), cv::BORDER_CONSTANT, fill_color); @@ -489,7 +491,7 @@ Status Rotate(const std::shared_ptr &input, std::shared_ptr *out // use memcpy and don't compute the new shape since openCV has a rounding problem cv::warpAffine(input_img, output_img, rot, bbox.size(), GetCVInterpolationMode(interpolation), cv::BORDER_CONSTANT, fill_color); - output_cv = std::make_shared(output_img); + RETURN_IF_NOT_OK(CVTensor::CreateFromMat(output_img, &output_cv)); RETURN_UNEXPECTED_IF_NULL(output_cv); } *output = std::static_pointer_cast(output_cv); @@ -506,8 +508,8 @@ Status Normalize(const std::shared_ptr &input, std::shared_ptr * RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); } cv::Mat in_image = input_cv->mat(); - std::shared_ptr output_cv = std::make_shared(input_cv->shape(), DataType(DataType::DE_FLOAT32)); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), DataType(DataType::DE_FLOAT32), &output_cv)); mean->Squeeze(); if (mean->type() != DataType::DE_FLOAT32 || mean->Rank() != 1 || mean->shape()[0] != 3) { std::string err_msg = "Mean tensor should be of size 3 and type float."; @@ -548,8 +550,8 @@ Status AdjustBrightness(const std::shared_ptr &input, std::shared_ptrRank() != 3 || num_channels != 3) { RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); } - auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); output_cv->mat() = input_img * alpha; *output = std::static_pointer_cast(output_cv); } catch (const cv::Exception &e) { @@ -572,8 +574,8 @@ Status AdjustContrast(const std::shared_ptr &input, std::shared_ptr(cv::mean(gray).val[0] + 0.5); - std::shared_ptr output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); output_img = cv::Mat::zeros(input_img.rows, input_img.cols, CV_8UC1); output_img = output_img + mean_img; cv::cvtColor(output_img, output_img, CV_GRAY2RGB); @@ -680,7 +682,9 @@ Status AutoContrast(const std::shared_ptr &input, std::shared_ptrmat().type()); - std::shared_ptr output_cv = std::make_shared(result); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateFromMat(result, &output_cv)); + (*output) = std::static_pointer_cast(output_cv); (*output) = std::static_pointer_cast(output_cv); (*output)->Reshape(input->shape()); } catch (const cv::Exception &e) { @@ -700,8 +704,8 @@ Status AdjustSaturation(const std::shared_ptr &input, std::shared_ptrRank() != 3 || num_channels != 3) { RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); } - auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); cv::Mat output_img = output_cv->mat(); cv::Mat gray; cv::cvtColor(input_img, gray, CV_RGB2GRAY); @@ -729,8 +733,8 @@ Status AdjustHue(const std::shared_ptr &input, std::shared_ptr * if (input_cv->Rank() != 3 || num_channels != 3) { RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); } - auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); cv::Mat output_img; cv::cvtColor(input_img, output_img, CV_RGB2HSV_FULL); for (int y = 0; y < output_img.cols; y++) { @@ -781,7 +785,8 @@ Status Equalize(const std::shared_ptr &input, std::shared_ptr *o } cv::Mat result; cv::merge(image_result, result); - std::shared_ptr output_cv = std::make_shared(result); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateFromMat(result, &output_cv)); (*output) = std::static_pointer_cast(output_cv); (*output)->Reshape(input->shape()); } catch (const cv::Exception &e) { @@ -867,8 +872,8 @@ Status Pad(const std::shared_ptr &input, std::shared_ptr *output } else { cv::copyMakeBorder(input_cv->mat(), out_image, pad_top, pad_bottom, pad_left, pad_right, b_type); } - std::shared_ptr output_cv = std::make_shared(out_image); - RETURN_UNEXPECTED_IF_NULL(output_cv); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateFromMat(out_image, &output_cv)); // pad the dimension if shape information is only 2 dimensional, this is grayscale int num_channels = input_cv->shape()[2]; if (input_cv->Rank() == 3 && num_channels == 1 && output_cv->Rank() == 2) output_cv->ExpandDim(2); @@ -932,7 +937,7 @@ Status UpdateBBoxesForCrop(std::shared_ptr *bboxList, size_t *bboxCount, } } std::shared_ptr retV; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&retV, copyVals, TensorShape({static_cast(*bboxCount), bboxDim}))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(copyVals, TensorShape({static_cast(*bboxCount), bboxDim}), &retV)); (*bboxList) = retV; // reset pointer return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/invert_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/invert_op.cc index 44a7f1f5b49..ed46194baa3 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/invert_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/invert_op.cc @@ -40,8 +40,8 @@ Status InvertOp::Compute(const std::shared_ptr &input, std::shared_ptr(input_cv->shape(), input_cv->type()); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); RETURN_UNEXPECTED_IF_NULL(output_cv); output_cv->mat() = cv::Scalar::all(255) - input_img; diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.cc index de5deb31efa..56593e33ca4 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.cc @@ -24,20 +24,14 @@ namespace mindspore { namespace dataset { NormalizeOp::NormalizeOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b) { - int size[] = {3}; - cv::Mat mean_cv(1, size, CV_32F); - mean_cv.at(0) = mean_r; - mean_cv.at(1) = mean_g; - mean_cv.at(2) = mean_b; - mean_ = std::make_shared(mean_cv); - mean_->Squeeze(); - - cv::Mat std_cv(1, size, CV_32F); - std_cv.at(0) = std_r; - std_cv.at(1) = std_g; - std_cv.at(2) = std_b; - std_ = std::make_shared(std_cv); - std_->Squeeze(); + Status s = Tensor::CreateFromVector({mean_r, mean_g, mean_b}, &mean_); + if (s.IsError()) { + MS_LOG(ERROR) << "Could not create mean tensor."; + } + s = Tensor::CreateFromVector({std_r, std_g, std_b}, &std_); + if (s.IsError()) { + MS_LOG(ERROR) << "Could not create std tensor."; + } } Status NormalizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { @@ -47,9 +41,7 @@ Status NormalizeOp::Compute(const std::shared_ptr &input, std::shared_pt } void NormalizeOp::Print(std::ostream &out) const { - out << "NormalizeOp, mean: " << mean_->mat().at(0) << ", " << mean_->mat().at(1) << ", " - << mean_->mat().at(2) << "std: " << std_->mat().at(0) << ", " << std_->mat().at(1) << ", " - << std_->mat().at(2) << std::endl; + out << "NormalizeOp, mean: " << mean_ << std::endl << "std: " << std_ << std::endl; } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.h index 2884776de4c..4e4b760abd5 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.h @@ -39,8 +39,8 @@ class NormalizeOp : public TensorOp { std::string Name() const override { return kNormalizeOp; } private: - std::shared_ptr mean_; - std::shared_ptr std_; + std::shared_ptr mean_; + std::shared_ptr std_; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc index f501dd4b4f0..dbf2dfe73e3 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc @@ -49,7 +49,7 @@ Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) { if (py::isinstance(ret_py_obj)) { // In case of a n-1 mapping, the return value will be a numpy array std::shared_ptr out; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, ret_py_obj.cast())); + RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(ret_py_obj.cast(), &out)); output->push_back(out); } else if (py::isinstance(ret_py_obj)) { // In case of a n-m mapping, the return value will be a tuple of numpy arrays @@ -61,7 +61,7 @@ Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) { goto ShapeMisMatch; } std::shared_ptr out; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, ret_py_ele.cast())); + RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(ret_py_ele.cast(), &out)); output->push_back(out); } } else { diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.cc index 6195572944e..f530edf779d 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.cc @@ -136,8 +136,7 @@ Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::shared_ptrbegin(); iter != input->end(); iter++) { RETURN_IF_NOT_OK(CaseFoldWithoutUnusedWords(*iter, kUnusedWords, &strs[i++])); } - *output = std::make_shared(std::move(strs), input->shape()); - return Status::OK(); + return Tensor::CreateFromVector(strs, input->shape(), output); } Status BasicTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.cc index 0ea5cadedb6..b38df2f0f6a 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.cc @@ -39,8 +39,7 @@ Status CaseFoldOp::Compute(const std::shared_ptr &input, std::shared_ptr nfkc_case_fold->normalizeUTF8(0, icu::StringPiece((*iter).data(), (*iter).size()), sink, nullptr, error); CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "normalizeUTF8 failed."); } - *output = std::make_shared(std::move(strs), input->shape()); - return Status::OK(); + return Tensor::CreateFromVector(strs, input->shape(), output); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/data_utils.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/data_utils.cc index 74b1d930775..17b4c613ba5 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/data_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/data_utils.cc @@ -33,12 +33,7 @@ Status SlidingWindowHelper(const std::shared_ptr &input, std::shared_ptr // if the data row has fewer items than width, the corresponding result row will be empty if (out_shape.Size() == 0) { MS_LOG(WARNING) << "The data row has fewer items than width, the result will be empty."; - if (input->type().value() == DataType::DE_STRING) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, std::vector{}, TensorShape({0}))); - } else { - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, TensorShape({0}), input->type())); - } - return Status::OK(); + return Tensor::CreateEmpty(TensorShape({0}), input->type(), output); } axis = Tensor::HandleNeg(axis, input->shape().Size()); diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.cc index 0a1ae92d144..abcf72c9daa 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.cc @@ -68,15 +68,12 @@ Status JiebaTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { offsets_limit.push_back(static_cast(item.offset + item.word.length())); } } - token_tensor = std::make_shared(words, TensorShape({(dsize_t)words.size()})); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(words, &token_tensor)); output->push_back(token_tensor); if (with_offsets_) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_start[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_limit[0]))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor)); + output->push_back(offsets_start_tensor); output->push_back(offsets_limit_tensor); } diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc index d1b4ad24b86..03178044160 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc @@ -36,9 +36,7 @@ Status LookupOp::Compute(const std::shared_ptr &input, std::shared_ptrshape(), type_, - reinterpret_cast(word_ids.data()))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(word_ids, input->shape(), output)); return Status::OK(); } Status LookupOp::OutputType(const std::vector &inputs, std::vector &outputs) { diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.cc index 36781b9b4d6..27b8cb60653 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.cc @@ -67,7 +67,7 @@ Status NgramOp::Compute(const std::shared_ptr &input, std::shared_ptr(res.size())}))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(res, TensorShape({static_cast(res.size())}), output)); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.cc index 0c0aa5fa2da..b669ca9a8a4 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.cc @@ -68,8 +68,7 @@ Status NormalizeUTF8Op::Compute(const std::shared_ptr &input, std::share normalize->normalizeUTF8(0, icu::StringPiece((*iter).data(), (*iter).size()), sink, nullptr, error); CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "normalizeUTF8 failed."); } - *output = std::make_shared(std::move(strs), input->shape()); - return Status::OK(); + return Tensor::CreateFromVector(strs, input->shape(), output); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.cc index c370393e768..b36afba8fc5 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.cc @@ -50,8 +50,7 @@ Status RegexReplaceOp::Compute(const std::shared_ptr &input, std::shared for (auto iter = input->begin(); iter != input->end(); iter++) { RETURN_IF_NOT_OK(RegexReplace(&matcher, *iter, &strs[i])); } - *output = std::make_shared(std::move(strs), input->shape()); - return Status::OK(); + return Tensor::CreateFromVector(strs, input->shape(), output); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.cc index 7ff1d994bed..95cb4552761 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.cc @@ -120,15 +120,11 @@ Status RegexTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; RETURN_IF_NOT_OK(input[0]->GetItemAt(&text, {})); RETURN_IF_NOT_OK(GetRegexTokens(std::string(text.data(), text.size()), &tokens, &offsets_start, &offsets_limit)); - token_tensor = std::make_shared(std::move(tokens), TensorShape({(dsize_t)tokens.size()})); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(std::move(tokens), &token_tensor)); output->push_back(token_tensor); if (with_offsets_) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_start[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_limit[0]))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor)); output->push_back(offsets_start_tensor); output->push_back(offsets_limit_tensor); } diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.cc index 42fefa20068..e972fd53268 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.cc @@ -64,14 +64,14 @@ Status SentencePieceTokenizerOp::Compute(const std::shared_ptr &input, s if (!status.ok()) { RETURN_STATUS_UNEXPECTED("sentence piece tokenizer error"); } - *output = std::make_unique(pieces, TensorShape({(dsize_t)pieces.size()})); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(pieces, output)); } else { std::vector ids; auto status = processor_.Encode(sentence, &ids); if (!status.ok()) { RETURN_STATUS_UNEXPECTED("sentence piece tokenizer error"); } - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, ids, TensorShape({(dsize_t)ids.size()}))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(ids, output)); } return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.cc index a6685a2d643..3fda769ea23 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.cc @@ -114,7 +114,7 @@ Status ToNumberOp::ToSignedIntegral(const std::shared_ptr &input, std::s casted.push_back(casted_result); } - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(casted, input->shape(), output)); return Status::OK(); } @@ -157,7 +157,7 @@ Status ToNumberOp::ToUnsignedIntegral(const std::shared_ptr &input, std: casted.push_back(casted_result); } - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(casted, input->shape(), output)); return Status::OK(); } @@ -165,7 +165,7 @@ Status ToNumberOp::ToFloat16(const std::shared_ptr &input, std::shared_p // special case, float16 does not exist in c++, no native support for // casting, so cast to float first then use this method, which use Eigen. std::shared_ptr temp; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&temp, TensorImpl::kFlexible, input->shape(), DataType("float32"))); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), DataType("float32"), &temp)); RETURN_IF_NOT_OK(ToFloat(input, &temp)); RETURN_IF_NOT_OK(mindspore::dataset::ToFloat16(temp, output)); return Status::OK(); @@ -200,7 +200,7 @@ Status ToNumberOp::ToFloat(const std::shared_ptr &input, std::shared_ptr casted.push_back(casted_result); } - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(casted, input->shape(), output)); return Status::OK(); } @@ -233,7 +233,7 @@ Status ToNumberOp::ToDouble(const std::shared_ptr &input, std::shared_pt casted.push_back(casted_result); } - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(casted, input->shape(), output)); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.cc index e08f61100b1..c8b33d0ce4a 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.cc @@ -55,15 +55,13 @@ Status UnicodeCharTokenizerOp::Compute(const TensorRow &input, TensorRow *output offsets_start.push_back(0); offsets_limit.push_back(0); } - token_tensor = std::make_shared(splits, TensorShape({(dsize_t)splits.size()})); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(splits, &token_tensor)); + output->push_back(token_tensor); if (with_offsets_) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_start[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_limit[0]))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor)); + output->push_back(offsets_start_tensor); output->push_back(offsets_limit_tensor); } diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.cc index 60fe8dd0e41..43ebbda42f9 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.cc @@ -96,15 +96,12 @@ Status UnicodeScriptTokenizerOp::Compute(const TensorRow &input, TensorRow *outp offsets_start.push_back(0); offsets_limit.push_back(0); } - token_tensor = std::make_shared(splits, TensorShape({(dsize_t)splits.size()})); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(splits, &token_tensor)); output->push_back(token_tensor); if (with_offsets_) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_start[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_limit[0]))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor)); + output->push_back(offsets_start_tensor); output->push_back(offsets_limit_tensor); } diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.cc index d3bb32081e5..c8727778138 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.cc @@ -79,15 +79,12 @@ Status WhitespaceTokenizerOp::Compute(const TensorRow &input, TensorRow *output) offsets_start.push_back(0); offsets_limit.push_back(0); } - token_tensor = std::make_shared(splits, TensorShape({(dsize_t)splits.size()})); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(splits, &token_tensor)); output->push_back(token_tensor); if (with_offsets_) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_start[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_limit[0]))); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor)); + output->push_back(offsets_start_tensor); output->push_back(offsets_limit_tensor); } diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.cc index f0bd448e398..04a1274b030 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.cc @@ -1,157 +1,154 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h" -#include -#include - -namespace mindspore { -namespace dataset { - -const char WordpieceTokenizerOp::kDefSuffixIndicator[] = "##"; -const int WordpieceTokenizerOp::kDefMaxBytesPerToken = 100; -const char WordpieceTokenizerOp::kDefUnknownToken[] = "[UNK]"; -const bool WordpieceTokenizerOp::kDefWithOffsets = false; - -WordpieceTokenizerOp::WordpieceTokenizerOp(const std::shared_ptr &vocab, const std::string &suffix_indicator, - const int &max_bytes_per_token, const std::string &unknown_token, - const bool &with_offsets) - : vocab_(vocab), - suffix_indicator_(suffix_indicator), - max_bytes_per_token_(max_bytes_per_token), - unknown_token_(unknown_token), - with_offsets_(with_offsets) {} - -Status WordpieceTokenizerOp::LookupWord(const std::string &input_token, const RuneStrArray &runes, const int start, - bool *out_found, int *out_end) const { - CHECK_FAIL_RETURN_UNEXPECTED(start >= 0 && start < input_token.size(), "Out of range"); - *out_found = false; - for (int i = runes.size() - 1; i >= 0; i--) { - *out_end = runes[i].offset + runes[i].len; - int len = *out_end - start; - std::string word = input_token.substr(start, len); - if (start > 0) { - word = suffix_indicator_ + word; - } - if (vocab_->Lookup(word) != Vocab::kNoTokenExists) { - *out_found = true; - break; - } - } - return Status::OK(); -} - -Status WordpieceTokenizerOp::FoundNoToken(const std::string &input_token, const uint32_t &basic_start, - std::vector *out_tokens, std::vector *offsets_start, - std::vector *offsets_limit) const { - out_tokens->clear(); - offsets_start->push_back(basic_start); - if (unknown_token_.empty()) { - out_tokens->emplace_back(input_token); - offsets_limit->push_back(basic_start + input_token.length()); - } else { - out_tokens->emplace_back(unknown_token_); - offsets_limit->push_back(basic_start + input_token.length()); - } - return Status::OK(); -} - -Status WordpieceTokenizerOp::AddSubword(const std::string &input_token, const int &start, const int &end, - std::vector *out_tokens) const { - CHECK_FAIL_RETURN_UNEXPECTED(start >= 0 && end > start && end <= input_token.size(), "Out of range"); - std::string subword = input_token.substr(start, end - start); - if (start > 0) { - subword = suffix_indicator_ + subword; - } - out_tokens->emplace_back(subword); - return Status::OK(); -} - -Status WordpieceTokenizerOp::GetTokens(const std::string &input_token, const uint32_t &basic_start, - std::vector *out_tokens, std::vector *offsets_start, - std::vector *offsets_limit) const { - if (input_token.size() > max_bytes_per_token_) { - offsets_start->push_back(basic_start); - if (!unknown_token_.empty()) { - offsets_limit->push_back(basic_start + unknown_token_.size()); - out_tokens->emplace_back(unknown_token_); - } else { - out_tokens->emplace_back(input_token); - offsets_limit->push_back(basic_start + input_token.size()); - } - return Status::OK(); - } - RuneStrArray runes; - if (!DecodeRunesInString(input_token.data(), input_token.size(), runes)) { - RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); - } - int end = 0; - for (int start = 0; start < input_token.size();) { - bool found = false; - RETURN_IF_NOT_OK(LookupWord(input_token, runes, start, &found, &end)); - if (found) { - RETURN_IF_NOT_OK(AddSubword(input_token, start, end, out_tokens)); - offsets_start->push_back(static_cast(basic_start + start)); - offsets_limit->push_back(static_cast(basic_start + end)); - start = end; - } else { - return FoundNoToken(input_token, basic_start, out_tokens, offsets_start, offsets_limit); - } - } - return Status::OK(); -} - -Status WordpieceTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - if (input[0]->Rank() > 1 || input[0]->type() != DataType::DE_STRING) { - RETURN_STATUS_UNEXPECTED("The input tensor should be scalar or 1-D string tensor"); - } - dsize_t count = 0; - std::vector out_tokens; - std::vector offsets_start, offsets_limit; - std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; - for (auto iter = input[0]->begin(); iter != input[0]->end(); iter++) { - uint32_t basic_start = 0; - std::vector temp_tokens; - if (with_offsets_ && input.size() == 3) { - RETURN_IF_NOT_OK(input[1]->GetItemAt(&basic_start, {count, 0})); - } - RETURN_IF_NOT_OK(GetTokens(std::string(*iter), basic_start, &temp_tokens, &offsets_start, &offsets_limit)); - out_tokens.insert(out_tokens.end(), temp_tokens.begin(), temp_tokens.end()); - count++; - } - if (out_tokens.empty()) { - out_tokens.emplace_back(""); - offsets_start.push_back(0); - offsets_limit.push_back(0); - } - token_tensor = std::make_shared(out_tokens, TensorShape({(dsize_t)out_tokens.size()})); - output->push_back(token_tensor); - if (with_offsets_) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_start[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_limit[0]))); - output->push_back(offsets_start_tensor); - output->push_back(offsets_limit_tensor); - } - return Status::OK(); -} - -} // namespace dataset -} // namespace mindspore +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h" +#include +#include + +namespace mindspore { +namespace dataset { + +const char WordpieceTokenizerOp::kDefSuffixIndicator[] = "##"; +const int WordpieceTokenizerOp::kDefMaxBytesPerToken = 100; +const char WordpieceTokenizerOp::kDefUnknownToken[] = "[UNK]"; +const bool WordpieceTokenizerOp::kDefWithOffsets = false; + +WordpieceTokenizerOp::WordpieceTokenizerOp(const std::shared_ptr &vocab, const std::string &suffix_indicator, + const int &max_bytes_per_token, const std::string &unknown_token, + const bool &with_offsets) + : vocab_(vocab), + suffix_indicator_(suffix_indicator), + max_bytes_per_token_(max_bytes_per_token), + unknown_token_(unknown_token), + with_offsets_(with_offsets) {} + +Status WordpieceTokenizerOp::LookupWord(const std::string &input_token, const RuneStrArray &runes, const int start, + bool *out_found, int *out_end) const { + CHECK_FAIL_RETURN_UNEXPECTED(start >= 0 && start < input_token.size(), "Out of range"); + *out_found = false; + for (int i = runes.size() - 1; i >= 0; i--) { + *out_end = runes[i].offset + runes[i].len; + int len = *out_end - start; + std::string word = input_token.substr(start, len); + if (start > 0) { + word = suffix_indicator_ + word; + } + if (vocab_->Lookup(word) != Vocab::kNoTokenExists) { + *out_found = true; + break; + } + } + return Status::OK(); +} + +Status WordpieceTokenizerOp::FoundNoToken(const std::string &input_token, const uint32_t &basic_start, + std::vector *out_tokens, std::vector *offsets_start, + std::vector *offsets_limit) const { + out_tokens->clear(); + offsets_start->push_back(basic_start); + if (unknown_token_.empty()) { + out_tokens->emplace_back(input_token); + offsets_limit->push_back(basic_start + input_token.length()); + } else { + out_tokens->emplace_back(unknown_token_); + offsets_limit->push_back(basic_start + input_token.length()); + } + return Status::OK(); +} + +Status WordpieceTokenizerOp::AddSubword(const std::string &input_token, const int &start, const int &end, + std::vector *out_tokens) const { + CHECK_FAIL_RETURN_UNEXPECTED(start >= 0 && end > start && end <= input_token.size(), "Out of range"); + std::string subword = input_token.substr(start, end - start); + if (start > 0) { + subword = suffix_indicator_ + subword; + } + out_tokens->emplace_back(subword); + return Status::OK(); +} + +Status WordpieceTokenizerOp::GetTokens(const std::string &input_token, const uint32_t &basic_start, + std::vector *out_tokens, std::vector *offsets_start, + std::vector *offsets_limit) const { + if (input_token.size() > max_bytes_per_token_) { + offsets_start->push_back(basic_start); + if (!unknown_token_.empty()) { + offsets_limit->push_back(basic_start + unknown_token_.size()); + out_tokens->emplace_back(unknown_token_); + } else { + out_tokens->emplace_back(input_token); + offsets_limit->push_back(basic_start + input_token.size()); + } + return Status::OK(); + } + RuneStrArray runes; + if (!DecodeRunesInString(input_token.data(), input_token.size(), runes)) { + RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); + } + int end = 0; + for (int start = 0; start < input_token.size();) { + bool found = false; + RETURN_IF_NOT_OK(LookupWord(input_token, runes, start, &found, &end)); + if (found) { + RETURN_IF_NOT_OK(AddSubword(input_token, start, end, out_tokens)); + offsets_start->push_back(static_cast(basic_start + start)); + offsets_limit->push_back(static_cast(basic_start + end)); + start = end; + } else { + return FoundNoToken(input_token, basic_start, out_tokens, offsets_start, offsets_limit); + } + } + return Status::OK(); +} + +Status WordpieceTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + if (input[0]->Rank() > 1 || input[0]->type() != DataType::DE_STRING) { + RETURN_STATUS_UNEXPECTED("The input tensor should be scalar or 1-D string tensor"); + } + dsize_t count = 0; + std::vector out_tokens; + std::vector offsets_start, offsets_limit; + std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; + for (auto iter = input[0]->begin(); iter != input[0]->end(); iter++) { + uint32_t basic_start = 0; + std::vector temp_tokens; + if (with_offsets_ && input.size() == 3) { + RETURN_IF_NOT_OK(input[1]->GetItemAt(&basic_start, {count, 0})); + } + RETURN_IF_NOT_OK(GetTokens(std::string(*iter), basic_start, &temp_tokens, &offsets_start, &offsets_limit)); + out_tokens.insert(out_tokens.end(), temp_tokens.begin(), temp_tokens.end()); + count++; + } + if (out_tokens.empty()) { + out_tokens.emplace_back(""); + offsets_start.push_back(0); + offsets_limit.push_back(0); + } + Tensor::CreateFromVector(out_tokens, &token_tensor); + output->push_back(token_tensor); + if (with_offsets_) { + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor)); + RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor)); + + output->push_back(offsets_start_tensor); + output->push_back(offsets_limit_tensor); + } + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/tests/ut/cpp/dataset/batch_op_test.cc b/tests/ut/cpp/dataset/batch_op_test.cc index 3e1f3c0b320..05686d55139 100644 --- a/tests/ut/cpp/dataset/batch_op_test.cc +++ b/tests/ut/cpp/dataset/batch_op_test.cc @@ -90,8 +90,8 @@ TEST_F(MindDataTestBatchOp, TestSimpleBatch) { rc = di.GetNextAsMap(&tensor_map); EXPECT_TRUE(rc.IsOk()); std::shared_ptr t; - rc = de::Tensor::CreateTensor(&t, TensorImpl::kFlexible, de::TensorShape({12, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)payload); + rc = de::Tensor::CreateFromMemory(de::TensorShape({12, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)payload, &t); EXPECT_TRUE(rc.IsOk()); // verify the actual data in Tensor is correct EXPECT_EQ(*t == *tensor_map["col_sint64"], true); @@ -119,14 +119,14 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) { -9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807}; de::DatasetIterator di(tree); std::shared_ptr t1, t2, t3; - rc = de::Tensor::CreateTensor(&t1, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)payload); + rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)payload, &t1); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t2, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)(payload + 7)); + rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 7), &t2); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t3, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)(payload + 2)); + rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 2), &t3); EXPECT_TRUE(rc.IsOk()); TensorMap tensor_map; @@ -164,17 +164,17 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) { -9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807}; de::DatasetIterator di(tree); std::shared_ptr t1, t2, t3, t4; - rc = de::Tensor::CreateTensor(&t1, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)payload); + rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)payload, &t1); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t2, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)(payload + 7)); + rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 7), &t2); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t3, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)(payload + 2)); + rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 2), &t3); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t4, TensorImpl::kFlexible, de::TensorShape({3, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)(payload + 9)); + rc = de::Tensor::CreateFromMemory(de::TensorShape({3, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 9), &t4); EXPECT_TRUE(rc.IsOk()); TensorMap tensor_map; @@ -216,11 +216,11 @@ TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) { -9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807}; de::DatasetIterator di(tree); std::shared_ptr t1, t2; - rc = de::Tensor::CreateTensor(&t1, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)payload); + rc = de::Tensor::CreateFromMemory(de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)payload, &t1); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t2, TensorImpl::kFlexible, de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)(payload + 7)); + rc = de::Tensor::CreateFromMemory(de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 7), &t2); EXPECT_TRUE(rc.IsOk()); TensorMap tensor_map; @@ -262,11 +262,11 @@ TEST_F(MindDataTestBatchOp, TestBatchDropTrueRepeat) { -9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807}; de::DatasetIterator di(tree); std::shared_ptr t1, t2; - rc = de::Tensor::CreateTensor(&t1, TensorImpl::kFlexible, de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)payload); + rc = de::Tensor::CreateFromMemory(de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)payload, &t1); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t2, TensorImpl::kFlexible, de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64), - (unsigned char *)(payload + 5)); + rc = de::Tensor::CreateFromMemory(de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 5), &t2); EXPECT_TRUE(rc.IsOk()); TensorMap tensor_map; @@ -300,7 +300,7 @@ TEST_F(MindDataTestBatchOp, TestSimpleBatchPadding) { std::shared_ptr op; PadInfo m; std::shared_ptr pad_value; - Tensor::CreateTensor(&pad_value, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32)); + Tensor::CreateEmpty(TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32), &pad_value); pad_value->SetItemAt({}, -1); m.insert({"col_1d", std::make_pair(TensorShape({4}), pad_value)}); de::BatchOp::Builder(12).SetDrop(false).SetPaddingMap(m, true).Build(&op); @@ -359,8 +359,8 @@ TEST_F(MindDataTestBatchOp, TestSimpleBatchPadding) { -1, -1}; std::shared_ptr t; - rc = de::Tensor::CreateTensor(&t, TensorImpl::kFlexible, de::TensorShape({12, 4}), de::DataType(DataType::DE_INT64), - (unsigned char *)payload); + rc = de::Tensor::CreateFromMemory(de::TensorShape({12, 4}), de::DataType(DataType::DE_INT64), + (unsigned char *)payload, &t); de::DatasetIterator di(tree); TensorMap tensor_map; rc = di.GetNextAsMap(&tensor_map); diff --git a/tests/ut/cpp/dataset/cache_op_test.cc b/tests/ut/cpp/dataset/cache_op_test.cc index 1d07a6e0c98..26db41ef66b 100644 --- a/tests/ut/cpp/dataset/cache_op_test.cc +++ b/tests/ut/cpp/dataset/cache_op_test.cc @@ -75,7 +75,8 @@ TEST_F(MindDataTestCacheOp, TestCacheServer) { EXPECT_TRUE(rc.IsOk()); // Create a tensor, take a snapshot and restore it back, and compare. - std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_UINT64)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_UINT64), &t); t->SetItemAt({0, 0}, 1); t->SetItemAt({0, 1}, 2); t->SetItemAt({0, 2}, 3); @@ -129,7 +130,8 @@ TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) { rc = myClient.CreateCache(1, true); EXPECT_TRUE(rc.IsOk()); std::cout << myClient << std::endl; - std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_UINT64)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_UINT64), &t); t->SetItemAt({0, 0}, 1); t->SetItemAt({0, 1}, 2); t->SetItemAt({0, 2}, 3); @@ -403,11 +405,7 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) { // replace it with the required tree structures for cache lookup op and cache merge op. std::shared_ptr myCacheOp; - rc = CacheOp::Builder() - .SetNumWorkers(4) - .SetClient(myClient) - .SetRowsPerBuffer(3) - .Build(&myCacheOp); + rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp); std::shared_ptr so; ImageFolderOp::Builder builder; diff --git a/tests/ut/cpp/dataset/channel_swap_test.cc b/tests/ut/cpp/dataset/channel_swap_test.cc index 2000de15b2d..0fae8417e1c 100644 --- a/tests/ut/cpp/dataset/channel_swap_test.cc +++ b/tests/ut/cpp/dataset/channel_swap_test.cc @@ -36,7 +36,7 @@ TEST_F(MindDataTestChannelSwap, TestOp) { int size_buffer = s[0] * s[1] * s[2]; std::unique_ptr output_buffer(new uchar[size_buffer]); - std::shared_ptr output_tensor(new Tensor(s, DataType(DataType::DE_UINT8))); + std::shared_ptr output_tensor; // Decoding std::unique_ptr op(new HwcToChwOp()); diff --git a/tests/ut/cpp/dataset/common/bboxop_common.cc b/tests/ut/cpp/dataset/common/bboxop_common.cc index 62c9f853488..29324c928ce 100644 --- a/tests/ut/cpp/dataset/common/bboxop_common.cc +++ b/tests/ut/cpp/dataset/common/bboxop_common.cc @@ -163,8 +163,11 @@ void BBoxOpCommon::CompareActualAndExpected(const std::string &op_name) { // after comparison is done remove temporary file EXPECT_TRUE(remove(actual_path.c_str()) == 0); // compare using ==operator by Tensor + std::shared_ptr expect_img_t, actual_img_t; + CVTensor::CreateFromMat(expect_img, &expect_img_t); + CVTensor::CreateFromMat(actual_img, &actual_img_t); if (actual_img.data) { - EXPECT_EQ(CVTensor(expect_img) == CVTensor(actual_img), true); + EXPECT_EQ(*expect_img_t == *actual_img_t, true); } else { MS_LOG(ERROR) << "Not pass verification! Image data is null."; EXPECT_EQ(0, 1); @@ -223,7 +226,7 @@ bool BBoxOpCommon::LoadAnnotationFile(const std::string &path, std::shared_ptrNextSiblingElement("object"); // Read next BBox if exists } std::shared_ptr ret_value; - Status s = Tensor::CreateTensor(&ret_value, return_value_list, TensorShape({bbox_count, bbox_val_count})); + Status s = Tensor::CreateFromVector(return_value_list, TensorShape({bbox_count, bbox_val_count}), &ret_value); EXPECT_TRUE(s.IsOk()); (*target_BBox) = ret_value; // load bbox from file into return return true; diff --git a/tests/ut/cpp/dataset/common/cvop_common.cc b/tests/ut/cpp/dataset/common/cvop_common.cc index 48d69564fd3..9b5d7606720 100644 --- a/tests/ut/cpp/dataset/common/cvop_common.cc +++ b/tests/ut/cpp/dataset/common/cvop_common.cc @@ -52,9 +52,11 @@ std::string CVOpCommon::GetFilename() { void CVOpCommon::GetInputImage(std::string filename) { try { - Tensor::CreateTensor(&raw_input_tensor_, filename); + Tensor::CreateFromFile(filename, &raw_input_tensor_); raw_cv_image_ = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR); - input_tensor_ = std::dynamic_pointer_cast(std::make_shared(raw_cv_image_)); + std::shared_ptr input_cv_tensor; + CVTensor::CreateFromMat(raw_cv_image_, &input_cv_tensor); + input_tensor_ = std::dynamic_pointer_cast(input_cv_tensor); SwapRedAndBlue(input_tensor_, &input_tensor_); if (raw_cv_image_.data) { MS_LOG(INFO) << "Reading was successful. Height:" << raw_cv_image_.rows << " Width: " << raw_cv_image_.cols diff --git a/tests/ut/cpp/dataset/concatenate_op_test.cc b/tests/ut/cpp/dataset/concatenate_op_test.cc index dc2fc69266c..bdad5bbd6de 100644 --- a/tests/ut/cpp/dataset/concatenate_op_test.cc +++ b/tests/ut/cpp/dataset/concatenate_op_test.cc @@ -29,14 +29,14 @@ class MindDataTestConcatenateOp : public UT::Common { TEST_F(MindDataTestConcatenateOp, TestOp) { MS_LOG(INFO) << "Doing MindDataTestConcatenate-TestOp."; - uint64_t labels[3] = {1, 1, 2}; + std::vector labels = {1, 1, 2}; TensorShape shape({3}); - std::shared_ptr input = - std::make_shared(shape, DataType(DataType::DE_UINT64), reinterpret_cast(labels)); + std::shared_ptr input; + Tensor::CreateFromVector(labels, &input); - uint64_t append_labels[3] = {4, 4, 4}; - std::shared_ptr append = - std::make_shared(shape, DataType(DataType::DE_UINT64), reinterpret_cast(append_labels)); + std::vector append_labels = {4, 4, 4}; + std::shared_ptr append; + Tensor::CreateFromVector(append_labels, &append); std::shared_ptr output; std::unique_ptr op(new ConcatenateOp(0, nullptr, append)); @@ -44,10 +44,11 @@ TEST_F(MindDataTestConcatenateOp, TestOp) { in.push_back(input); TensorRow out_row; Status s = op->Compute(in, &out_row); - uint64_t out[6] = {1, 1, 2, 4, 4, 4}; + std::vector out = {1, 1, 2, 4, 4, 4}; + + std::shared_ptr expected; + Tensor::CreateFromVector(out, &expected); - std::shared_ptr expected = - std::make_shared(TensorShape{6}, DataType(DataType::DE_UINT64), reinterpret_cast(out)); output = out_row[0]; EXPECT_TRUE(s.IsOk()); ASSERT_TRUE(output->shape() == expected->shape()); diff --git a/tests/ut/cpp/dataset/duplicate_op_test.cc b/tests/ut/cpp/dataset/duplicate_op_test.cc index 93779b084d6..afad66f6201 100644 --- a/tests/ut/cpp/dataset/duplicate_op_test.cc +++ b/tests/ut/cpp/dataset/duplicate_op_test.cc @@ -32,9 +32,9 @@ class MindDataTestDuplicateOp : public UT::Common { TEST_F(MindDataTestDuplicateOp, Basics) { std::shared_ptr t; - Tensor::CreateTensor(&t, std::vector({1, 2, 3, 4, 5, 6})); + Tensor::CreateFromVector(std::vector({1, 2, 3, 4, 5, 6}), &t); std::shared_ptr v; - Tensor::CreateTensor(&v, std::vector({3}), TensorShape::CreateScalar()); + Tensor::CreateFromVector(std::vector({3}), TensorShape::CreateScalar(), &v); std::shared_ptr op = std::make_shared(); TensorRow in; in.push_back(t); diff --git a/tests/ut/cpp/dataset/fill_op_test.cc b/tests/ut/cpp/dataset/fill_op_test.cc index 20e323cc8d8..795db705af7 100644 --- a/tests/ut/cpp/dataset/fill_op_test.cc +++ b/tests/ut/cpp/dataset/fill_op_test.cc @@ -29,23 +29,20 @@ class MindDataTestFillOp : public UT::Common { TEST_F(MindDataTestFillOp, TestOp) { MS_LOG(INFO) << "Doing MindDataTestFillOp-TestOp."; - uint64_t labels[3] = {1, 1, 2}; - TensorShape shape({3}); - std::shared_ptr input = - std::make_shared(shape, DataType(DataType::DE_UINT64), reinterpret_cast(labels)); + std::vector labels = {1, 1, 2}; + std::shared_ptr input; + Tensor::CreateFromVector(labels, &input); - TensorShape fill_shape({}); - std::shared_ptr fill_tensor = std::make_shared(fill_shape, DataType(DataType::DE_UINT64)); - fill_tensor->SetItemAt({}, 4); + std::shared_ptr fill_tensor; + Tensor::CreateScalar(4, &fill_tensor); std::shared_ptr output; std::unique_ptr op(new FillOp(fill_tensor)); Status s = op->Compute(input, &output); - uint64_t out[3] = {4, 4, 4}; - - std::shared_ptr expected = - std::make_shared(TensorShape{3}, DataType(DataType::DE_UINT64), reinterpret_cast(out)); + std::vector out = {4, 4, 4}; + std::shared_ptr expected; + Tensor::CreateFromVector(out, &expected); EXPECT_TRUE(s.IsOk()); ASSERT_TRUE(output->shape() == expected->shape()); @@ -59,23 +56,20 @@ TEST_F(MindDataTestFillOp, TestOp) { TEST_F(MindDataTestFillOp, TestCasting) { MS_LOG(INFO) << "Doing MindDataTestFillOp-TestCasting."; - uint64_t labels[3] = {0, 1, 2}; - TensorShape shape({3}); - std::shared_ptr input = - std::make_shared(shape, DataType(DataType::DE_UINT64), reinterpret_cast(labels)); + std::vector labels = {0, 1, 2}; + std::shared_ptr input; + Tensor::CreateFromVector(labels, &input); - TensorShape fill_shape({}); - std::shared_ptr fill_tensor = std::make_shared(fill_shape, DataType(DataType::DE_FLOAT32)); - fill_tensor->SetItemAt({}, 2.0); + std::shared_ptr fill_tensor; + Tensor::CreateScalar(2.0, &fill_tensor); std::shared_ptr output; std::unique_ptr op(new FillOp(fill_tensor)); Status s = op->Compute(input, &output); - uint64_t out[3] = {2, 2, 2}; - - std::shared_ptr expected = - std::make_shared(TensorShape{3}, DataType(DataType::DE_UINT64), reinterpret_cast(out)); + std::vector out = {2, 2, 2}; + std::shared_ptr expected; + Tensor::CreateFromVector(out, &expected); ASSERT_TRUE(output->shape() == expected->shape()); ASSERT_TRUE(output->type() == expected->type()); @@ -90,15 +84,15 @@ TEST_F(MindDataTestFillOp, TestCasting) { TEST_F(MindDataTestFillOp, ScalarFill) { MS_LOG(INFO) << "Doing MindDataTestFillOp-ScalarFill."; - uint64_t labels[3] = {0, 1, 2}; - TensorShape shape({3}); - std::shared_ptr input = - std::make_shared(shape, DataType(DataType::DE_UINT64), reinterpret_cast(labels)); + std::vector labels = {0, 1, 2}; + std::shared_ptr input; + Tensor::CreateFromVector(labels, &input); TensorShape fill_shape({2}); - uint64_t fill_labels[3] = {0, 1}; - std::shared_ptr fill_tensor = - std::make_shared(fill_shape, DataType(DataType::DE_UINT64), reinterpret_cast(fill_labels)); + std::vector fill_labels = {0, 1}; + std::shared_ptr fill_tensor; + Tensor::CreateFromVector(fill_labels, &fill_tensor); + std::shared_ptr output; std::unique_ptr op(new FillOp(fill_tensor)); Status s = op->Compute(input, &output); @@ -112,12 +106,11 @@ TEST_F(MindDataTestFillOp, ScalarFill) { TEST_F(MindDataTestFillOp, StringFill) { MS_LOG(INFO) << "Doing MindDataTestFillOp-StringFill."; std::vector strings = {"xyzzy", "plugh", "abracadabra"}; - TensorShape shape({3}); - std::shared_ptr input = std::make_shared(strings, shape); + std::shared_ptr input; + Tensor::CreateFromVector(strings, &input); - TensorShape fill_shape({}); - std::string fill_string = "hello"; - std::shared_ptr fill_tensor = std::make_shared(fill_string); + std::shared_ptr fill_tensor; + Tensor::CreateScalar("hello", &fill_tensor); std::shared_ptr output; @@ -125,8 +118,8 @@ TEST_F(MindDataTestFillOp, StringFill) { Status s = op->Compute(input, &output); std::vector expected_strings = {"hello", "hello", "hello"}; - TensorShape expected_shape({3}); - std::shared_ptr expected = std::make_shared(expected_strings, expected_shape); + std::shared_ptr expected; + Tensor::CreateFromVector(expected_strings, &expected); EXPECT_TRUE(s.IsOk()); ASSERT_TRUE(output->shape() == expected->shape()); @@ -142,12 +135,11 @@ TEST_F(MindDataTestFillOp, StringFill) { TEST_F(MindDataTestFillOp, NumericToString) { MS_LOG(INFO) << "Doing MindDataTestFillOp-NumericToString."; std::vector strings = {"xyzzy", "plugh", "abracadabra"}; - TensorShape shape({3}); - std::shared_ptr input = std::make_shared(strings, shape); + std::shared_ptr input; + Tensor::CreateFromVector(strings, &input); - TensorShape fill_shape({}); - std::shared_ptr fill_tensor = std::make_shared(fill_shape, DataType(DataType::DE_FLOAT32)); - fill_tensor->SetItemAt({}, 2.0); + std::shared_ptr fill_tensor; + Tensor::CreateScalar(2.0, &fill_tensor); std::shared_ptr output; @@ -162,14 +154,12 @@ TEST_F(MindDataTestFillOp, NumericToString) { TEST_F(MindDataTestFillOp, StringToNumeric) { MS_LOG(INFO) << "Doing MindDataTestFillOp-StringToNumeric."; - uint64_t labels[3] = {0, 1, 2}; - TensorShape shape({3}); - std::shared_ptr input = - std::make_shared(shape, DataType(DataType::DE_UINT64), reinterpret_cast(labels)); + std::vector labels = {0, 1, 2}; + std::shared_ptr input; + Tensor::CreateFromVector(labels, &input); - TensorShape fill_shape({}); - std::string fill_string = "hello"; - std::shared_ptr fill_tensor = std::make_shared(fill_string); + std::shared_ptr fill_tensor; + Tensor::CreateScalar("hello", &fill_tensor); std::shared_ptr output; diff --git a/tests/ut/cpp/dataset/image_folder_op_test.cc b/tests/ut/cpp/dataset/image_folder_op_test.cc index 3168efa1965..768d0d834e2 100644 --- a/tests/ut/cpp/dataset/image_folder_op_test.cc +++ b/tests/ut/cpp/dataset/image_folder_op_test.cc @@ -68,8 +68,7 @@ std::shared_ptr ImageFolder(int64_t num_works, int64_t rows, int6 Status Create1DTensor(std::shared_ptr *sample_ids, int64_t num_elements, unsigned char *data = nullptr, DataType::Type data_type = DataType::DE_UINT32) { TensorShape shape(std::vector(1, num_elements)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(sample_ids, TensorImpl::kFlexible, shape, DataType(data_type), data)); - (*sample_ids)->AllocateBuffer((*sample_ids)->SizeInBytes()); // allocate memory in case user forgets! + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(shape, DataType(data_type), data, sample_ids)); return Status::OK(); } diff --git a/tests/ut/cpp/dataset/jieba_tokenizer_op_test.cc b/tests/ut/cpp/dataset/jieba_tokenizer_op_test.cc index 85b3384d364..21f2483ffba 100644 --- a/tests/ut/cpp/dataset/jieba_tokenizer_op_test.cc +++ b/tests/ut/cpp/dataset/jieba_tokenizer_op_test.cc @@ -42,7 +42,8 @@ TEST_F(MindDataTestJiebaTokenizerOp, TestJieba_opFuntions) { TensorRow input, output; std::unique_ptr op(new JiebaTokenizerOp(hmm_path, mp_path)); - std::shared_ptr input_tensor = std::make_shared("今天天气太好了我们一起去外面玩吧"); + std::shared_ptr input_tensor; + Tensor::CreateScalar("今天天气太好了我们一起去外面玩吧", &input_tensor); input.push_back(input_tensor); Status s = op->Compute(input, &output); EXPECT_TRUE(s.IsOk()); @@ -66,7 +67,8 @@ TEST_F(MindDataTestJiebaTokenizerOp, TestJieba_opAdd) { std::unique_ptr op(new JiebaTokenizerOp(hmm_path, mp_path)); op->AddWord("男默女泪"); - std::shared_ptr input_tensor = std::make_shared("男默女泪"); + std::shared_ptr input_tensor; + Tensor::CreateScalar("男默女泪", &input_tensor); input.push_back(input_tensor); Status s = op->Compute(input, &output); EXPECT_TRUE(s.IsOk()); @@ -84,7 +86,8 @@ TEST_F(MindDataTestJiebaTokenizerOp, TestJieba_opEmpty) { std::unique_ptr op(new JiebaTokenizerOp(hmm_path, mp_path)); op->AddWord("男默女泪"); - std::shared_ptr input_tensor = std::make_shared(""); + std::shared_ptr input_tensor; + Tensor::CreateScalar("", &input_tensor); input.push_back(input_tensor); Status s = op->Compute(input, &output); EXPECT_TRUE(s.IsOk()); diff --git a/tests/ut/cpp/dataset/manifest_op_test.cc b/tests/ut/cpp/dataset/manifest_op_test.cc index a6eef4aaa24..63ad4f44c2b 100644 --- a/tests/ut/cpp/dataset/manifest_op_test.cc +++ b/tests/ut/cpp/dataset/manifest_op_test.cc @@ -71,9 +71,9 @@ TEST_F(MindDataTestManifest, TestSequentialManifestWithRepeat) { di.GetNextAsMap(&tensor_map); EXPECT_TRUE(rc.IsOk()); uint64_t i = 0; - uint32_t label = 0; + int32_t label = 0; while (tensor_map.size() != 0) { - tensor_map["label"]->GetItemAt(&label, {}); + tensor_map["label"]->GetItemAt(&label, {}); EXPECT_TRUE(res[i] == label); MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n"; i++; @@ -101,9 +101,9 @@ TEST_F(MindDataTestManifest, TestSubsetRandomSamplerManifest) { rc = di.GetNextAsMap(&tensor_map); EXPECT_TRUE(rc.IsOk()); uint64_t i = 0; - uint32_t label = 0; + int32_t label = 0; while (tensor_map.size() != 0) { - tensor_map["label"]->GetItemAt(&label, {}); + tensor_map["label"]->GetItemAt(&label, {}); i++; di.GetNextAsMap(&tensor_map); EXPECT_EQ(label, 1); @@ -131,9 +131,9 @@ TEST_F(MindDataTestManifest, MindDataTestManifestClassIndex) { di.GetNextAsMap(&tensor_map); EXPECT_TRUE(rc.IsOk()); uint64_t i = 0; - uint32_t label = 0; + int32_t label = 0; while (tensor_map.size() != 0) { - tensor_map["label"]->GetItemAt(&label, {}); + tensor_map["label"]->GetItemAt(&label, {}); EXPECT_TRUE(label == res[i]); MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n"; i++; @@ -160,9 +160,9 @@ TEST_F(MindDataTestManifest, MindDataTestManifestNumSamples) { di.GetNextAsMap(&tensor_map); EXPECT_TRUE(rc.IsOk()); uint64_t i = 0; - uint32_t label = 0; + int32_t label = 0; while (tensor_map.size() != 0) { - tensor_map["label"]->GetItemAt(&label, {}); + tensor_map["label"]->GetItemAt(&label, {}); EXPECT_TRUE(0 == label); MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n"; i++; @@ -176,7 +176,7 @@ TEST_F(MindDataTestManifest, MindDataTestManifestEval) { std::string file = datasets_root_path_ + "/testManifestData/cpp.json"; int64_t num_samples = 1; int64_t start_index = 0; - auto seq_sampler = std::make_shared(num_samples, start_index); + auto seq_sampler = std::make_shared(num_samples, start_index); auto tree = Build({Manifest(16, 2, 32, file, "eval", std::move(seq_sampler), {})}); tree->Prepare(); Status rc = tree->Launch(); @@ -189,9 +189,9 @@ TEST_F(MindDataTestManifest, MindDataTestManifestEval) { di.GetNextAsMap(&tensor_map); EXPECT_TRUE(rc.IsOk()); uint64_t i = 0; - uint32_t label = 0; + int32_t label = 0; while (tensor_map.size() != 0) { - tensor_map["label"]->GetItemAt(&label, {}); + tensor_map["label"]->GetItemAt(&label, {}); EXPECT_TRUE(0 == label); MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n"; i++; diff --git a/tests/ut/cpp/dataset/mask_test.cc b/tests/ut/cpp/dataset/mask_test.cc index 609d5bf4477..c6279acdb9a 100644 --- a/tests/ut/cpp/dataset/mask_test.cc +++ b/tests/ut/cpp/dataset/mask_test.cc @@ -38,9 +38,9 @@ class MindDataTestMaskOp : public UT::Common { TEST_F(MindDataTestMaskOp, Basics) { std::shared_ptr t; - Tensor::CreateTensor(&t, std::vector({1, 2, 3, 4, 5, 6})); + Tensor::CreateFromVector(std::vector({1, 2, 3, 4, 5, 6}), &t); std::shared_ptr v; - Tensor::CreateTensor(&v, std::vector({3}), TensorShape::CreateScalar()); + Tensor::CreateFromVector(std::vector({3}), TensorShape::CreateScalar(), &v); std::shared_ptr op = std::make_shared(RelationalOp::kEqual, v, DataType(DataType::DE_UINT16)); std::shared_ptr out; ASSERT_TRUE(op->Compute(t, &out).IsOk()); diff --git a/tests/ut/cpp/dataset/one_hot_op_test.cc b/tests/ut/cpp/dataset/one_hot_op_test.cc index 2617ae4536f..9dd5139dac1 100644 --- a/tests/ut/cpp/dataset/one_hot_op_test.cc +++ b/tests/ut/cpp/dataset/one_hot_op_test.cc @@ -29,19 +29,17 @@ class MindDataTestOneHotOp : public UT::Common { TEST_F(MindDataTestOneHotOp, TestOp) { MS_LOG(INFO) << "Doing MindDataTestOneHotOp."; - uint64_t labels[3] = {0, 1, 2}; - TensorShape shape({3}); - std::shared_ptr input = std::make_shared(shape, DataType(DataType::DE_UINT64), - reinterpret_cast (labels)); + std::vector labels = {0, 1, 2}; + std::shared_ptr input; + Tensor::CreateFromVector(labels, &input); std::shared_ptr output; std::unique_ptr op(new OneHotOp(5)); Status s = op->Compute(input, &output); - uint64_t out[15] = {1, 0, 0, 0, 0, - 0, 1, 0, 0, 0, - 0, 0, 1, 0, 0}; - std::shared_ptr expected = std::make_shared(TensorShape{3, 5}, DataType(DataType::DE_UINT64), - reinterpret_cast (out)); + std::vector out = {1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0}; + std::shared_ptr expected; + Tensor::CreateFromVector(out, TensorShape{3, 5}, &expected); + EXPECT_TRUE(s.IsOk()); ASSERT_TRUE(output->shape() == expected->shape()); ASSERT_TRUE(output->type() == expected->type()); diff --git a/tests/ut/cpp/dataset/pad_end_op_test.cc b/tests/ut/cpp/dataset/pad_end_op_test.cc index 1c838da8e86..b4bd993f387 100644 --- a/tests/ut/cpp/dataset/pad_end_op_test.cc +++ b/tests/ut/cpp/dataset/pad_end_op_test.cc @@ -35,44 +35,40 @@ TEST_F(MindDataTestPadEndOp, TestOp) { TensorShape pad_data_shape({1}); // prepare input tensor - float_t orig1[4] = {1, 1, 1, 1}; + std::vector orig1 = {1, 1, 1, 1}; TensorShape input_shape1({2, 2}); std::vector input_shape1_vector = {input_shape1}; - std::shared_ptr input1 = - std::make_shared(input_shape1, DataType(DataType::DE_FLOAT32), reinterpret_cast(orig1)); + std::shared_ptr input1; + Tensor::CreateFromVector(orig1, input_shape1, &input1); // pad_shape TensorShape pad_shape1[3] = {TensorShape({3, 3}), TensorShape({2, 4}), TensorShape({4, 2})}; // value to pad - float_t pad_data1[3][1] = {0, 3.5, 3.5}; + std::vector> pad_data1 = {{0}, {3.5}, {3.5}}; std::shared_ptr expected1[3]; // expected tensor output for testunit 1 - float_t out1[9] = {1, 1, 0, 1, 1, 0, 0, 0, 0}; - - expected1[0] = - std::make_shared(pad_shape1[0], DataType(DataType::DE_FLOAT32), reinterpret_cast(out1)); + std::vector out1 = {1, 1, 0, 1, 1, 0, 0, 0, 0}; + Tensor::CreateFromVector(out1, pad_shape1[0], &(expected1[0])); // expected tensor output for testunit 2 - float_t out2[8] = {1, 1, 3.5, 3.5, 1, 1, 3.5, 3.5}; - - expected1[1] = - std::make_shared(pad_shape1[1], DataType(DataType::DE_FLOAT32), reinterpret_cast(out2)); + std::vector out2 = {1, 1, 3.5, 3.5, 1, 1, 3.5, 3.5}; + Tensor::CreateFromVector(out2, pad_shape1[1], &(expected1[1])); // expected tensor output for testunit 3 - float_t out3[8] = {1, 1, 1, 1, 3.5, 3.5, 3.5, 3.5}; - - expected1[2] = - std::make_shared(pad_shape1[2], DataType(DataType::DE_FLOAT32), reinterpret_cast(out3)); + std::vector out3 = {1, 1, 1, 1, 3.5, 3.5, 3.5, 3.5}; + Tensor::CreateFromVector(out3, pad_shape1[2], &(expected1[2])); // run the PadEndOp for (auto i = 0; i < 3; i++) { std::shared_ptr output; std::vector output_shape = {TensorShape({})}; - std::shared_ptr pad_value1 = std::make_shared(pad_data_shape, DataType(DataType::DE_FLOAT32), - reinterpret_cast(pad_data1[i])); + + std::shared_ptr pad_value1; + Tensor::CreateFromVector(pad_data1[i], pad_data_shape, &pad_value1); + std::unique_ptr op(new PadEndOp(pad_shape1[i], pad_value1)); Status s = op->Compute(input1, &output); @@ -96,7 +92,7 @@ TEST_F(MindDataTestPadEndOp, TestOp) { TensorShape input_shape2({2}); std::vector input_shape2_vector = {input_shape2}; std::shared_ptr input2; - Tensor::CreateTensor(&input2, orig2, input_shape2); + Tensor::CreateFromVector(orig2, input_shape2, &input2); // pad_shape TensorShape pad_shape2[3] = {TensorShape({5}), TensorShape({2}), TensorShape({10})}; @@ -112,7 +108,7 @@ TEST_F(MindDataTestPadEndOp, TestOp) { for (auto i = 0; i < 3; i++) { // pad value - Tensor::CreateTensor(&pad_value2[i], pad_data2[i], pad_data_shape); + Tensor::CreateFromVector(pad_data2[i], pad_data_shape, &pad_value2[i]); std::shared_ptr output; std::vector output_shape = {TensorShape({})}; @@ -121,7 +117,7 @@ TEST_F(MindDataTestPadEndOp, TestOp) { Status s = op->Compute(input2, &output); - Tensor::CreateTensor(&expected2[i], outstring[i], pad_shape2[i]); + Tensor::CreateFromVector(outstring[i], pad_shape2[i], &expected2[i]); EXPECT_TRUE(s.IsOk()); ASSERT_TRUE(output->shape() == expected2[i]->shape()); diff --git a/tests/ut/cpp/dataset/sentence_piece_vocab_op_test.cc b/tests/ut/cpp/dataset/sentence_piece_vocab_op_test.cc index 4261950ec00..19f7291079d 100644 --- a/tests/ut/cpp/dataset/sentence_piece_vocab_op_test.cc +++ b/tests/ut/cpp/dataset/sentence_piece_vocab_op_test.cc @@ -93,7 +93,6 @@ TEST_F(MindDataTestSentencePieceVocabOp, TestSentencePieceFromDatasetFuntions) { rc = di.FetchNextTensorRow(&tensor_list); } ASSERT_TRUE(rc.IsOk()); - } TEST_F(MindDataTestSentencePieceVocabOp, TestSentencePieceFromFileFuntions) { @@ -166,9 +165,10 @@ TEST_F(MindDataTestSentencePieceVocabOp, TestSentencePieceTokenizerFuntions) { rc = di.FetchNextTensorRow(&tensor_list); } std::shared_ptr output_tensor; - std::unique_ptr op(new SentencePieceTokenizerOp(spm, - SPieceTokenizerLoadType::kModel, SPieceTokenizerOutType::kString)); - std::shared_ptr input_tensor = std::make_shared("I saw a girl with a telescope."); + std::unique_ptr op( + new SentencePieceTokenizerOp(spm, SPieceTokenizerLoadType::kModel, SPieceTokenizerOutType::kString)); + std::shared_ptr input_tensor; + Tensor::CreateScalar("I saw a girl with a telescope.", &input_tensor); Status s = op->Compute(input_tensor, &output_tensor); std::vector expect; diff --git a/tests/ut/cpp/dataset/sliding_window_op_test.cc b/tests/ut/cpp/dataset/sliding_window_op_test.cc index 7020229d9af..b39a131460f 100644 --- a/tests/ut/cpp/dataset/sliding_window_op_test.cc +++ b/tests/ut/cpp/dataset/sliding_window_op_test.cc @@ -31,15 +31,17 @@ TEST_F(MindDataTestSlidingWindowOp, Compute) { MS_LOG(INFO) << "Doing MindDataTestSlidingWindowOp->Compute."; std::vector strings = {"one", "two", "three", "four", "five", "six", "seven", "eight"}; TensorShape shape({static_cast(strings.size())}); - std::shared_ptr input = std::make_shared(strings, shape); + std::shared_ptr input; + Tensor::CreateFromVector(strings, shape, &input); std::shared_ptr output; std::unique_ptr op(new SlidingWindowOp(3, 0)); Status s = op->Compute(input, &output); - std::vector out = {"one", "two", "three", "two", "three", "four", "three", "four", "five", - "four", "five", "six", "five", "six", "seven", "six", "seven", "eight"}; - std::shared_ptr expected = std::make_shared(out, TensorShape({6, 3})); + std::vector out = {"one", "two", "three", "two", "three", "four", "three", "four", "five", + "four", "five", "six", "five", "six", "seven", "six", "seven", "eight"}; + std::shared_ptr expected; + Tensor::CreateFromVector(out, TensorShape({6, 3}), &expected); ASSERT_TRUE(output->shape() == expected->shape()); ASSERT_TRUE(output->type() == expected->type()); @@ -54,7 +56,8 @@ TEST_F(MindDataTestSlidingWindowOp, OutputShape) { MS_LOG(INFO) << "Doing MindDataTestSlidingWindowOp->OutputShape."; std::vector strings = {"one", "two", "three", "four", "five", "six", "seven", "eight"}; TensorShape shape({static_cast(strings.size())}); - std::shared_ptr input = std::make_shared(strings, shape); + std::shared_ptr input; + Tensor::CreateFromVector(strings, shape, &input); std::vector input_shape = {input->shape()}; std::vector output_shape = {TensorShape({})}; diff --git a/tests/ut/cpp/dataset/stand_alone_samplers_test.cc b/tests/ut/cpp/dataset/stand_alone_samplers_test.cc index 96e9652bbc5..79464b732b9 100644 --- a/tests/ut/cpp/dataset/stand_alone_samplers_test.cc +++ b/tests/ut/cpp/dataset/stand_alone_samplers_test.cc @@ -30,8 +30,7 @@ using namespace mindspore::dataset; Status CreateINT64Tensor(std::shared_ptr *sample_ids, int64_t num_elements, unsigned char *data = nullptr) { TensorShape shape(std::vector(1, num_elements)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(sample_ids, TensorImpl::kFlexible, shape, DataType(DataType::DE_INT64), data)); - (*sample_ids)->AllocateBuffer((*sample_ids)->SizeInBytes()); // allocate memory in case user forgets! + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(shape, DataType(DataType::DE_INT64), data, sample_ids)); return Status::OK(); } @@ -54,8 +53,7 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) { {0, 17, 4, 10, 14, 8, 15}, {13, 9, 16, 3, 2, 19, 12}, {1, 11, 6, 18, 7, 5, 0}}; for (int i = 0; i < 6; i++) { std::shared_ptr t; - Tensor::CreateTensor(&t, TensorImpl::kFlexible, TensorShape({7}), - DataType(DataType::DE_INT64), (unsigned char *)(res[i])); + Tensor::CreateFromMemory(TensorShape({7}), DataType(DataType::DE_INT64), (unsigned char *)(res[i]), &t); row.push_back(t); } MockStorageOp mock(20); diff --git a/tests/ut/cpp/dataset/tensor_string_test.cc b/tests/ut/cpp/dataset/tensor_string_test.cc index fe336a34c52..232fefc2ae9 100644 --- a/tests/ut/cpp/dataset/tensor_string_test.cc +++ b/tests/ut/cpp/dataset/tensor_string_test.cc @@ -35,13 +35,15 @@ class MindDataTestStringTensorDE : public UT::Common { }; TEST_F(MindDataTestStringTensorDE, Basics) { - std::shared_ptr t = std::make_shared("Hi"); + std::shared_ptr t; + Tensor::CreateScalar("Hi", &t); ASSERT_TRUE(t->shape() == TensorShape({})); std::string_view s = ""; t->GetItemAt(&s, {}); ASSERT_TRUE(s == "Hi"); - std::shared_ptr t2 = std::make_shared(std::vector{"Hi", "Bye"}); + std::shared_ptr t2; + Tensor::CreateFromVector(std::vector{"Hi", "Bye"}, &t2); ASSERT_TRUE(t2->shape() == TensorShape({2})); t2->GetItemAt(&s, {0}); ASSERT_TRUE(s == "Hi"); @@ -49,7 +51,9 @@ TEST_F(MindDataTestStringTensorDE, Basics) { ASSERT_TRUE(s == "Bye"); std::vector strings{"abc", "defg", "hi", "klmno", "123", "789"}; - std::shared_ptr t3 = std::make_shared(strings, TensorShape({2, 3})); + std::shared_ptr t3; + Tensor::CreateFromVector(strings, TensorShape({2, 3}), &t3); + ASSERT_TRUE(t3->shape() == TensorShape({2, 3})); uint32_t index = 0; for (uint32_t i = 0; i < 2; i++) { @@ -62,8 +66,10 @@ TEST_F(MindDataTestStringTensorDE, Basics) { } TEST_F(MindDataTestStringTensorDE, Basics2) { - std::shared_ptr t = - std::make_shared(std::vector{"abc", "defg", "hi", "klmno", "123", "789"}, TensorShape({2, 3})); + std::shared_ptr t; + Tensor::CreateFromVector(std::vector{"abc", "defg", "hi", "klmno", "123", "789"}, TensorShape({2, 3}), + &t); + ASSERT_TRUE(t->SizeInBytes() == 6 * 5 + 20 + 4); std::vector offsets = {0, 4, 9, 12, 18, 22, 26}; uint32_t ctr = 0; @@ -86,7 +92,8 @@ TEST_F(MindDataTestStringTensorDE, Basics2) { TEST_F(MindDataTestStringTensorDE, Empty) { std::vector strings{"abc", "defg", "", "", "123", ""}; - std::shared_ptr t = std::make_shared(strings, TensorShape({2, 3})); + std::shared_ptr t; + Tensor::CreateFromVector(strings, TensorShape({2, 3}), &t); // abc_defg___123__ // 0123456789012345 ASSERT_TRUE(t->SizeInBytes() == 6 * 5 + 10 + 4); @@ -112,7 +119,9 @@ TEST_F(MindDataTestStringTensorDE, Empty) { TEST_F(MindDataTestStringTensorDE, SetItem) { std::vector strings{"abc", "defg", "hi", "klmno", "123", "789"}; - std::shared_ptr t3 = std::make_shared(strings, TensorShape({2, 3})); + std::shared_ptr t3; + Tensor::CreateFromVector(strings, TensorShape({2, 3}), &t3); + ASSERT_TRUE(t3->shape() == TensorShape({2, 3})); t3->SetItemAt({0, 1}, std::string{"xyzz"}); @@ -136,7 +145,8 @@ TEST_F(MindDataTestStringTensorDE, SetItem) { TEST_F(MindDataTestStringTensorDE, Iterator) { std::vector strings{"abc", "defg", "hi", "klmno", "123", "789"}; - std::shared_ptr t = std::make_shared(strings, TensorShape({2, 3})); + std::shared_ptr t; + Tensor::CreateFromVector(strings, TensorShape({2, 3}), &t); uint32_t index = 0; auto itr = t->begin(); for (; itr != t->end(); itr++) { diff --git a/tests/ut/cpp/dataset/tensor_test.cc b/tests/ut/cpp/dataset/tensor_test.cc index fce4652b47a..47279874252 100644 --- a/tests/ut/cpp/dataset/tensor_test.cc +++ b/tests/ut/cpp/dataset/tensor_test.cc @@ -35,8 +35,9 @@ class MindDataTestTensorDE : public UT::Common { }; TEST_F(MindDataTestTensorDE, Basics) { - std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_UINT64)); - ASSERT_TRUE((t->AllocateBuffer(t->SizeInBytes())).IsOk()); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_UINT64), &t); + ASSERT_EQ(t->shape(), TensorShape({2, 3})); ASSERT_EQ(t->type(), DataType::DE_UINT64); ASSERT_EQ(t->SizeInBytes(), 2 * 3 * 8); @@ -67,28 +68,30 @@ TEST_F(MindDataTestTensorDE, Basics) { ASSERT_EQ(t->ToString(), "Tensor (shape: <2,3>, Type: uint64)\n[[1,2,3],[4,5,6]]"); std::vector x = {1, 2, 3, 4, 5, 6}; std::shared_ptr t2; - Tensor::CreateTensor(&t2, x, TensorShape({2, 3})); + Tensor::CreateFromVector(x, TensorShape({2, 3}), &t2); ASSERT_EQ(*t == *t2, true); ASSERT_EQ(*t != *t2, false); } TEST_F(MindDataTestTensorDE, Fill) { - std::shared_ptr t = std::make_shared(TensorShape({2, 2}), DataType(DataType::DE_FLOAT32)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_FLOAT32), &t); t->Fill(2.5); std::vector x = {2.5, 2.5, 2.5, 2.5}; std::shared_ptr t2; - Tensor::CreateTensor(&t2, x, TensorShape({2, 2})); + Tensor::CreateFromVector(x, TensorShape({2, 2}), &t2); ASSERT_EQ(*t == *t2, true); } TEST_F(MindDataTestTensorDE, Reshape) { - std::shared_ptr t = std::make_shared(TensorShape({2, 2}), DataType(DataType::DE_UINT8)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_UINT8), &t); t->Fill(254); t->Reshape(TensorShape({4})); std::vector x = {254, 254, 254, 254}; std::shared_ptr t2; - Tensor::CreateTensor(&t2, x); + Tensor::CreateFromVector(x, &t2); ASSERT_EQ(*t == *t2, true); Status rc = t->Reshape(TensorShape({5})); @@ -102,7 +105,8 @@ TEST_F(MindDataTestTensorDE, Reshape) { } TEST_F(MindDataTestTensorDE, CopyTensor) { - std::shared_ptr t = std::make_shared(TensorShape({}), DataType(DataType::DE_INT16)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({}), DataType(DataType::DE_INT16), &t); t->SetItemAt({}, -66); ASSERT_EQ(t->shape(), TensorShape({})); ASSERT_EQ(t->type(), DataType::DE_INT16); @@ -125,30 +129,31 @@ TEST_F(MindDataTestTensorDE, CopyTensor) { } TEST_F(MindDataTestTensorDE, InsertTensor) { - std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_FLOAT64)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_FLOAT64), &t); std::vector x = {1.1, 2.1, 3.1}; std::shared_ptr t2; - Tensor::CreateTensor(&t2, x); + Tensor::CreateFromVector(x, &t2); std::vector y = {1.2, 2.2, 3.2}; std::shared_ptr t3; - Tensor::CreateTensor(&t3, y); + Tensor::CreateFromVector(y, &t3); ASSERT_TRUE(t->InsertTensor({0}, t2).OK()); ASSERT_TRUE(t->InsertTensor({1}, t3).OK()); std::vector z = {1.1, 2.1, 3.1, 1.2, 2.2, 3.2}; std::shared_ptr t4; - Tensor::CreateTensor(&t4, z, TensorShape({2, 3})); + Tensor::CreateFromVector(z, TensorShape({2, 3}), &t4); ASSERT_EQ(*t == *t4, true); std::shared_ptr t5; - Tensor::CreateTensor(&t5, 0); + Tensor::CreateScalar(0, &t5); ASSERT_TRUE(t->InsertTensor({1, 2}, t5).OK()); z[5] = 0; std::shared_ptr t6; - Tensor::CreateTensor(&t6, z, TensorShape({2, 3})); + Tensor::CreateFromVector(z, TensorShape({2, 3}), &t6); ASSERT_EQ(*t == *t6, true); ASSERT_EQ(t->InsertTensor({2}, t5).get_code(), StatusCode::kUnexpectedError); @@ -161,7 +166,8 @@ TEST_F(MindDataTestTensorDE, InsertTensor) { // Test the bug of Tensor::ToString will exec failed for Tensor which store bool values TEST_F(MindDataTestTensorDE, BoolTensor) { - std::shared_ptr t = std::make_shared(TensorShape({2}), DataType(DataType::DE_BOOL)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({2}), DataType(DataType::DE_BOOL), &t); t->SetItemAt({0}, true); t->SetItemAt({1}, true); std::string out = t->ToString(); @@ -169,7 +175,8 @@ TEST_F(MindDataTestTensorDE, BoolTensor) { } TEST_F(MindDataTestTensorDE, GetItemAt) { - std::shared_ptr t = std::make_shared(TensorShape({2, 2}), DataType(DataType::DE_UINT8)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_UINT8), &t); t->Fill(254); uint64_t o1; t->GetItemAt(&o1, {0, 0}); @@ -183,7 +190,8 @@ TEST_F(MindDataTestTensorDE, GetItemAt) { uint8_t o4; t->GetItemAt(&o4, {1, 1}); ASSERT_EQ(o4, 254); - std::shared_ptr t2 = std::make_shared(TensorShape({2, 2}), DataType(DataType::DE_INT8)); + std::shared_ptr t2; + Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_INT8), &t2); t2->Fill(-10); int64_t o5; t2->GetItemAt(&o5, {0, 0}); @@ -197,7 +205,8 @@ TEST_F(MindDataTestTensorDE, GetItemAt) { int8_t o8; t2->GetItemAt(&o8, {1, 1}); ASSERT_EQ(o8, -10); - std::shared_ptr t3 = std::make_shared(TensorShape({2, 2}), DataType(DataType::DE_FLOAT32)); + std::shared_ptr t3; + Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_FLOAT32), &t3); t3->Fill(1.1); double o9; t3->GetItemAt(&o9, {0, 0}); @@ -208,9 +217,11 @@ TEST_F(MindDataTestTensorDE, GetItemAt) { } TEST_F(MindDataTestTensorDE, OperatorAssign) { - std::shared_ptr t = std::make_shared(TensorShape({2, 2}), DataType(DataType::DE_UINT8)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_UINT8), &t); t->Fill(1); - std::shared_ptr t2 = std::make_shared(TensorShape({2, 2}), DataType(DataType::DE_UINT8)); + std::shared_ptr t2; + Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_UINT8), &t2); *t2 = std::move(*t); uint8_t o; t2->GetItemAt(&o, {0, 0}); @@ -224,18 +235,20 @@ TEST_F(MindDataTestTensorDE, OperatorAssign) { } TEST_F(MindDataTestTensorDE, Strides) { - std::shared_ptr t = std::make_shared(TensorShape({4, 2, 2}), DataType(DataType::DE_UINT8)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({4, 2, 2}), DataType(DataType::DE_UINT8), &t); std::vector x1 = t->Strides(); std::vector x2 = {4, 2, 1}; ASSERT_EQ(x1, x2); - t = std::make_shared(TensorShape({4, 2, 2}), DataType(DataType::DE_UINT32)); + Tensor::CreateEmpty(TensorShape({4, 2, 2}), DataType(DataType::DE_UINT32), &t); x1 = t->Strides(); x2 = {16, 8, 4}; ASSERT_EQ(x1, x2); } void checkCvMat(TensorShape shape, DataType type) { - std::shared_ptr t = std::make_shared(shape, type); + std::shared_ptr t; + CVTensor::CreateEmpty(shape, type, &t); cv::Mat m = t->mat(); ASSERT_EQ(m.data, t->GetBuffer()); ASSERT_EQ(static_cast(m.type()) & static_cast(CV_MAT_DEPTH_MASK), type.AsCVType()); @@ -289,8 +302,10 @@ TEST_F(MindDataTestTensorDE, CVTensorFromMat) { m.at(0, 1) = 20; m.at(1, 0) = 30; m.at(1, 1) = 40; - std::shared_ptr cvt = std::make_shared(m); - std::shared_ptr t = std::make_shared(TensorShape({2, 2}), DataType(DataType::DE_UINT8)); + std::shared_ptr cvt; + CVTensor::CreateFromMat(m, &cvt); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_UINT8), &t); t->SetItemAt({0, 0}, 10); t->SetItemAt({0, 1}, 20); t->SetItemAt({1, 0}, 30); @@ -302,8 +317,10 @@ TEST_F(MindDataTestTensorDE, CVTensorFromMat) { m2.at(1) = 20; m2.at(2) = 30; m2.at(3) = 40; - std::shared_ptr cvt2 = std::make_shared(m2); - std::shared_ptr t2 = std::make_shared(TensorShape({4}), DataType(DataType::DE_UINT8)); + std::shared_ptr cvt2; + CVTensor::CreateFromMat(m2, &cvt2); + std::shared_ptr t2; + Tensor::CreateEmpty(TensorShape({4}), DataType(DataType::DE_UINT8), &t2); t2->SetItemAt({0}, 10); t2->SetItemAt({1}, 20); t2->SetItemAt({2}, 30); @@ -313,10 +330,12 @@ TEST_F(MindDataTestTensorDE, CVTensorFromMat) { } TEST_F(MindDataTestTensorDE, CVTensorAs) { - std::shared_ptr t = std::make_shared(TensorShape({3, 2}), DataType(DataType::DE_FLOAT64)); + std::shared_ptr t; + Tensor::CreateEmpty(TensorShape({3, 2}), DataType(DataType::DE_FLOAT64), &t); t->Fill(2.2); const unsigned char *addr = t->GetBuffer(); - std::shared_ptr t2 = std::make_shared(TensorShape({3, 2}), DataType(DataType::DE_FLOAT64)); + std::shared_ptr t2; + Tensor::CreateEmpty(TensorShape({3, 2}), DataType(DataType::DE_FLOAT64), &t2); t2->Fill(4.4); std::shared_ptr ctv = CVTensor::AsCVTensor(t); ASSERT_EQ(t->GetBuffer(), nullptr); @@ -326,6 +345,10 @@ TEST_F(MindDataTestTensorDE, CVTensorAs) { ASSERT_EQ(ctv->GetBuffer(), addr); ASSERT_TRUE(*t2 == *ctv); MS_LOG(DEBUG) << *t2 << std::endl << *ctv; + cv::Mat m2 = ctv->matCopy(); + m2 = 2 * m2; + ASSERT_EQ(ctv->GetBuffer(), addr); + ASSERT_TRUE(*t2 == *ctv); } TEST_F(MindDataTestTensorDE, CVTensorMatSlice) { @@ -336,23 +359,26 @@ TEST_F(MindDataTestTensorDE, CVTensorMatSlice) { m.at(1, 0) = 40; m.at(1, 1) = 50; m.at(1, 2) = 60; - std::shared_ptr cvt = std::make_shared(m); + std::shared_ptr cvt; + CVTensor::CreateFromMat(m, &cvt); cv::Mat mat; - cvt->Mat({1}, &mat); + cvt->MatAtIndex({1}, &mat); cv::Mat m2(3, 1, CV_32S); m2.at(0) = 40; m2.at(1) = 50; m2.at(2) = 60; - std::shared_ptr cvt2 = std::make_shared(mat); - std::shared_ptr cvt3 = std::make_shared(m2); + std::shared_ptr cvt2; + CVTensor::CreateFromMat(mat, &cvt2); + std::shared_ptr cvt3; + CVTensor::CreateFromMat(m2, &cvt3); ASSERT_TRUE(*cvt2 == *cvt3); - cvt->Mat({0}, &mat); + cvt->MatAtIndex({0}, &mat); m2.at(0) = 10; m2.at(1) = 20; m2.at(2) = 30; - cvt2 = std::make_shared(mat); - cvt3 = std::make_shared(m2); + CVTensor::CreateFromMat(mat, &cvt2); + CVTensor::CreateFromMat(m2, &cvt3); ASSERT_TRUE(*cvt2 == *cvt3); } @@ -361,7 +387,7 @@ TEST_F(MindDataTestTensorDE, TensorIterator) { std::vector values2 = {2, 3, 4, 5, 6, 7}; std::shared_ptr t; - Tensor::CreateTensor(&t, values); + Tensor::CreateFromVector(values, &t); auto i = t->begin(); auto j = values.begin(); @@ -395,11 +421,11 @@ TEST_F(MindDataTestTensorDE, TensorIterator) { TEST_F(MindDataTestTensorDE, TensorSlice) { std::shared_ptr t; - Tensor::CreateTensor(&t, std::vector{0, 1, 2, 3, 4}); + Tensor::CreateFromVector(std::vector{0, 1, 2, 3, 4}, &t); std::shared_ptr t2; auto x = std::vector{0, 3, 4}; std::shared_ptr expected; - Tensor::CreateTensor(&expected, x); + Tensor::CreateFromVector(x, &expected); t->Slice(&t2, x); ASSERT_EQ(*t2, *expected); t->Slice(&t2, std::vector{0, 1, 2, 3, 4}); @@ -412,13 +438,13 @@ TEST_F(MindDataTestTensorDE, TensorConcatenate) { std::vector expected = {1, 2, 3, 4, 5, 6}; std::shared_ptr t1; - Tensor::CreateTensor(&t1, values1); + Tensor::CreateFromVector(values1, &t1); std::shared_ptr t2; - Tensor::CreateTensor(&t2, values2); + Tensor::CreateFromVector(values2, &t2); std::shared_ptr out; - Tensor::CreateTensor(&out, expected); + Tensor::CreateFromVector(expected, &out); Status s = t1->Concatenate({3}, t2); EXPECT_TRUE(s.IsOk()); @@ -434,15 +460,80 @@ TEST_F(MindDataTestTensorDE, TensorConcatenate) { } TEST_F(MindDataTestTensorDE, TensorEmpty) { - std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_UINT64)); - ASSERT_TRUE(t->HasData()); -} + TensorPtr t; + Status rc = Tensor::CreateEmpty(TensorShape({0}), DataType(DataType::DE_UINT64), &t); + ASSERT_TRUE(rc.IsOk()); -TEST_F(MindDataTestTensorDE, TensorEmptyInvalidate) { - std::vector values1 = {1, 2, 3, 0, 0, 0}; - std::shared_ptr t; - Tensor::CreateTensor(&t, values1); - t->Invalidate(); - ASSERT_TRUE(t->HasData()); -} + ASSERT_EQ(t->shape(), TensorShape({0})); + ASSERT_EQ(t->type(), DataType::DE_UINT64); + ASSERT_EQ(t->SizeInBytes(), 0); + ASSERT_EQ(t->GetBuffer(), nullptr); + ASSERT_TRUE(!t->HasData()); + rc = t->SetItemAt({0}, 7); + ASSERT_TRUE(rc.IsError()); + + rc = Tensor::CreateEmpty(TensorShape({1, 0}), DataType(DataType::DE_STRING), &t); + ASSERT_TRUE(rc.IsOk()); + ASSERT_EQ(t->shape(), TensorShape({1, 0})); + ASSERT_EQ(t->type(), DataType::DE_STRING); + ASSERT_EQ(t->SizeInBytes(), 0); + ASSERT_EQ(t->GetBuffer(), nullptr); + ASSERT_TRUE(!t->HasData()); + + std::vector data; + rc = Tensor::CreateFromVector(data, &t); + ASSERT_TRUE(rc.IsOk()); + ASSERT_EQ(t->shape(), TensorShape({0})); + ASSERT_EQ(t->type(), DataType::DE_UINT16); + ASSERT_EQ(t->SizeInBytes(), 0); + ASSERT_EQ(t->GetBuffer(), nullptr); + ASSERT_TRUE(!t->HasData()); + + std::vector data2; + rc = Tensor::CreateFromVector(data2, &t); + ASSERT_TRUE(rc.IsOk()); + ASSERT_EQ(t->shape(), TensorShape({0})); + ASSERT_EQ(t->type(), DataType::DE_STRING); + ASSERT_EQ(t->SizeInBytes(), 0); + ASSERT_EQ(t->GetBuffer(), nullptr); + ASSERT_TRUE(!t->HasData()); + + rc = Tensor::CreateFromVector(data, TensorShape({0, 2}), &t); + ASSERT_TRUE(rc.IsOk()); + ASSERT_EQ(t->shape(), TensorShape({0, 2})); + ASSERT_EQ(t->type(), DataType::DE_UINT16); + ASSERT_EQ(t->SizeInBytes(), 0); + ASSERT_EQ(t->GetBuffer(), nullptr); + ASSERT_TRUE(!t->HasData()); + + rc = Tensor::CreateFromVector(data2, TensorShape({0, 0, 6}), &t); + ASSERT_TRUE(rc.IsOk()); + ASSERT_EQ(t->shape(), TensorShape({0, 0, 6})); + ASSERT_EQ(t->type(), DataType::DE_STRING); + ASSERT_EQ(t->SizeInBytes(), 0); + ASSERT_EQ(t->GetBuffer(), nullptr); + ASSERT_TRUE(!t->HasData()); + + rc = Tensor::CreateFromMemory(TensorShape({0}), DataType(DataType::DE_INT8), nullptr, &t); + ASSERT_TRUE(rc.IsOk()); + ASSERT_EQ(t->shape(), TensorShape({0})); + ASSERT_EQ(t->type(), DataType::DE_INT8); + ASSERT_EQ(t->SizeInBytes(), 0); + ASSERT_EQ(t->GetBuffer(), nullptr); + ASSERT_TRUE(!t->HasData()); + + rc = Tensor::CreateFromMemory(TensorShape({0}), DataType(DataType::DE_STRING), nullptr, &t); + ASSERT_TRUE(rc.IsOk()); + ASSERT_EQ(t->shape(), TensorShape({0})); + ASSERT_EQ(t->type(), DataType::DE_STRING); + ASSERT_EQ(t->SizeInBytes(), 0); + ASSERT_EQ(t->GetBuffer(), nullptr); + + std::vector values = {1, 2, 3, 0, 0, 0}; + std::shared_ptr t2; + Tensor::CreateFromVector(values, &t2); + ASSERT_TRUE(t2->HasData()); + t2->Invalidate(); + ASSERT_TRUE(!t2->HasData()); +} diff --git a/tests/ut/cpp/dataset/tokenizer_op_test.cc b/tests/ut/cpp/dataset/tokenizer_op_test.cc index cc2d7473ff8..df3a5435de9 100644 --- a/tests/ut/cpp/dataset/tokenizer_op_test.cc +++ b/tests/ut/cpp/dataset/tokenizer_op_test.cc @@ -46,8 +46,8 @@ class MindDataTestTokenizerOp : public UT::Common { TEST_F(MindDataTestTokenizerOp, TestUnicodeCharTokenizerOp) { MS_LOG(INFO) << "Doing TestUnicodeCharTokenizerOp."; std::unique_ptr op(new UnicodeCharTokenizerOp(true)); - std::shared_ptr input = std::make_shared("Hello World!"); - TensorRow output; +std::shared_ptr input; + Tensor::CreateScalar("Hello World!", &input); TensorRow output; Status s = op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 12); @@ -66,7 +66,7 @@ TEST_F(MindDataTestTokenizerOp, TestUnicodeCharTokenizerOp) { CheckEqual(output[0], {10}, "d"); CheckEqual(output[0], {11}, "!"); - input = std::make_shared("中国 你好!"); + Tensor::CreateScalar("中国 你好!", &input); output.clear(); s = op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); @@ -80,38 +80,38 @@ TEST_F(MindDataTestTokenizerOp, TestUnicodeCharTokenizerOp) { CheckEqual(output[0], {4}, "好"); CheckEqual(output[0], {5}, "!"); - input = std::make_shared("中"); - output.clear(); + Tensor::CreateScalar("中", &input); +output.clear(); s = op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 1); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor3: " << output[0]->ToString(); CheckEqual(output[0], {0}, "中"); - input = std::make_shared("H"); - output.clear(); + Tensor::CreateScalar("H", &input); +output.clear(); s = op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 1); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor4: " << output[0]->ToString(); CheckEqual(output[0], {0}, "H"); - input = std::make_shared(" "); - output.clear(); + Tensor::CreateScalar(" ", &input); +output.clear(); s = op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 2); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor5: " << output[0]->ToString(); CheckEqual(output[0], {0}, " "); CheckEqual(output[0], {1}, " "); - input = std::make_shared(""); - output.clear(); + Tensor::CreateScalar("", &input); +output.clear(); s = op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 1); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor6: " << output[0]->ToString(); @@ -121,10 +121,10 @@ TEST_F(MindDataTestTokenizerOp, TestUnicodeCharTokenizerOp) { TEST_F(MindDataTestTokenizerOp, TestWhitespaceTokenizerOp) { MS_LOG(INFO) << "Doing TestWhitespaceTokenizerOp."; std::unique_ptr op(new WhitespaceTokenizerOp(true)); - std::shared_ptr input = std::make_shared("Welcome to China."); - TensorRow output; +std::shared_ptr input; + Tensor::CreateScalar("Welcome to China.", &input); TensorRow output; Status s = op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 3); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor1: " << output[0]->ToString(); @@ -132,37 +132,37 @@ TEST_F(MindDataTestTokenizerOp, TestWhitespaceTokenizerOp) { CheckEqual(output[0], {1}, "to"); CheckEqual(output[0], {2}, "China."); - input = std::make_shared(" hello"); - output.clear(); + Tensor::CreateScalar(" hello", &input); +output.clear(); s = op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 1); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor2: " << output[0]->ToString(); CheckEqual(output[0], {0}, "hello"); - input = std::make_shared("hello"); - output.clear(); + Tensor::CreateScalar("hello", &input); +output.clear(); s = op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 1); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor3: " << output[0]->ToString(); CheckEqual(output[0], {0}, "hello"); - input = std::make_shared("hello "); - output.clear(); + Tensor::CreateScalar("hello ", &input); +output.clear(); s = op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 1); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor4: " << output[0]->ToString(); CheckEqual(output[0], {0}, "hello"); - input = std::make_shared(" "); - output.clear(); + Tensor::CreateScalar(" ", &input); +output.clear(); s = op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 1); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor5: " << output[0]->ToString(); @@ -174,8 +174,9 @@ TEST_F(MindDataTestTokenizerOp, TestUnicodeScriptTokenizer) { std::unique_ptr keep_whitespace_op(new UnicodeScriptTokenizerOp(true, true)); std::unique_ptr skip_whitespace_op(new UnicodeScriptTokenizerOp(false, true)); - std::shared_ptr input = std::make_shared("Welcome to China. \n 中国\t北京"); - TensorRow output; + std::shared_ptr input; + Tensor::CreateScalar("Welcome to China. \n 中国\t北京", &input); + TensorRow output; Status s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 10); @@ -204,10 +205,9 @@ TEST_F(MindDataTestTokenizerOp, TestUnicodeScriptTokenizer) { CheckEqual(output[0], {4}, "中国"); CheckEqual(output[0], {5}, "北京"); - input = std::make_shared(" Welcome to 中国. "); - output.clear(); - s = skip_whitespace_op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + Tensor::CreateScalar(" Welcome to 中国. ", &input); + output.clear(); + s = skip_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 4); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor3: " << output[0]->ToString(); @@ -230,25 +230,23 @@ TEST_F(MindDataTestTokenizerOp, TestUnicodeScriptTokenizer) { CheckEqual(output[0], {6}, "."); CheckEqual(output[0], {7}, " "); - input = std::make_shared("Hello"); - output.clear(); - s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + Tensor::CreateScalar("Hello", &input); +output.clear(); + s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 1); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor5: " << output[0]->ToString(); CheckEqual(output[0], {0}, "Hello"); - input = std::make_shared("H"); - output.clear(); - s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + Tensor::CreateScalar("H", &input); +output.clear(); + s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 1); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor6: " << output[0]->ToString(); CheckEqual(output[0], {0}, "H"); - input = std::make_shared(""); + Tensor::CreateScalar("", &input); output.clear(); s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); @@ -257,10 +255,9 @@ TEST_F(MindDataTestTokenizerOp, TestUnicodeScriptTokenizer) { MS_LOG(INFO) << "Out tensor7: " << output[0]->ToString(); CheckEqual(output[0], {0}, ""); - input = std::make_shared("Hello中国Hello世界"); - output.clear(); - s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output[0]->Size(), 4); + Tensor::CreateScalar("Hello中国Hello世界", &input); + output.clear(); + s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 4); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor8: " << output[0]->ToString(); CheckEqual(output[0], {0}, "Hello"); @@ -268,15 +265,15 @@ TEST_F(MindDataTestTokenizerOp, TestUnicodeScriptTokenizer) { CheckEqual(output[0], {2}, "Hello"); CheckEqual(output[0], {3}, "世界"); - input = std::make_shared(" "); - output.clear(); + Tensor::CreateScalar(" ", &input); + output.clear(); s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); - EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(s.IsOk()); EXPECT_EQ(output[0]->Size(), 1); EXPECT_EQ(output[0]->Rank(), 1); MS_LOG(INFO) << "Out tensor10: " << output[0]->ToString(); CheckEqual(output[0], {0}, " "); - input = std::make_shared(" "); + Tensor::CreateScalar(" ", &input); output.clear(); s = skip_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); @@ -289,7 +286,9 @@ TEST_F(MindDataTestTokenizerOp, TestUnicodeScriptTokenizer) { TEST_F(MindDataTestTokenizerOp, TestCaseFold) { MS_LOG(INFO) << "Doing TestCaseFold."; std::unique_ptr case_fold_op(new CaseFoldOp()); - std::shared_ptr input = std::make_shared("Welcome to China. \n 中国\t北京"); + std::shared_ptr input; + Tensor::CreateScalar("Welcome to China. \n 中国\t北京", &input); + std::shared_ptr output; Status s = case_fold_op->Compute(input, &output); EXPECT_TRUE(s.IsOk()); @@ -305,7 +304,8 @@ TEST_F(MindDataTestTokenizerOp, TestNormalize) { std::unique_ptr nfkc_normalize_op(new NormalizeUTF8Op(NormalizeForm::kNfkc)); std::unique_ptr nfd_normalize_op(new NormalizeUTF8Op(NormalizeForm::kNfd)); std::unique_ptr nfkd_normalize_op(new NormalizeUTF8Op(NormalizeForm::kNfkd)); - std::shared_ptr input = std::make_shared("ṩ"); + std::shared_ptr input; + Tensor::CreateScalar("ṩ", &input); std::shared_ptr output; Status s = nfc_normalize_op->Compute(input, &output); EXPECT_TRUE(s.IsOk()); @@ -327,7 +327,8 @@ TEST_F(MindDataTestTokenizerOp, TestNormalize) { TEST_F(MindDataTestTokenizerOp, TestRegexReplace) { MS_LOG(INFO) << "Doing TestRegexReplace."; std::unique_ptr regex_replace_op(new RegexReplaceOp("\\s+", "_", true)); - std::shared_ptr input = std::make_shared("Welcome to China. \n 中国\t北京"); + std::shared_ptr input; + Tensor::CreateScalar("Welcome to China. \n 中国\t北京", &input); std::shared_ptr output; Status s = regex_replace_op->Compute(input, &output); EXPECT_TRUE(s.IsOk()); @@ -340,19 +341,20 @@ TEST_F(MindDataTestTokenizerOp, TestRegexReplace) { TEST_F(MindDataTestTokenizerOp, TestRegexTokenizer) { MS_LOG(INFO) << "Doing TestRegexTokenizerOp."; std::unique_ptr regex_tokenizer_op(new RegexTokenizerOp("\\p{Cc}|\\p{Cf}|\\s+", "", true)); - std::shared_ptr input = std::make_shared("Welcome to China. \n 中国\t北京"); - TensorRow output; +std::shared_ptr input; + Tensor::CreateScalar("Welcome to China. \n 中国\t北京", &input); + TensorRow output; Status s = regex_tokenizer_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); } TEST_F(MindDataTestTokenizerOp, TestBasicTokenizer) { MS_LOG(INFO) << "Doing TestBasicTokenizer."; - //bool lower_case, bool keep_whitespace, + // bool lower_case, bool keep_whitespace, // NormalizeForm normalization_form, bool preserve_unused_token - std::unique_ptr basic_tokenizer(new BasicTokenizerOp(true, true, NormalizeForm::kNone, false, - true)); - std::shared_ptr input = std::make_shared("Welcome to China. 中国\t北京"); + std::unique_ptr basic_tokenizer(new BasicTokenizerOp(true, true, NormalizeForm::kNone, false,true)); +std::shared_ptr input; + Tensor::CreateScalar("Welcome to China. 中国\t北京", &input); TensorRow output; Status s = basic_tokenizer->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); diff --git a/tests/ut/cpp/dataset/trucate_pair_test.cc b/tests/ut/cpp/dataset/trucate_pair_test.cc index af7e61c16aa..48d30cf2f5a 100644 --- a/tests/ut/cpp/dataset/trucate_pair_test.cc +++ b/tests/ut/cpp/dataset/trucate_pair_test.cc @@ -35,17 +35,17 @@ class MindDataTestTruncatePairOp : public UT::Common { TEST_F(MindDataTestTruncatePairOp, Basics) { std::shared_ptr t1; - Tensor::CreateTensor(&t1, std::vector({1, 2, 3})); + Tensor::CreateFromVector(std::vector({1, 2, 3}), &t1); std::shared_ptr t2; - Tensor::CreateTensor(&t2, std::vector({4, 5})); + Tensor::CreateFromVector(std::vector({4, 5}), &t2); TensorRow in({t1, t2}); std::shared_ptr op = std::make_shared(4); TensorRow out; ASSERT_TRUE(op->Compute(in, &out).IsOk()); std::shared_ptr out1; - Tensor::CreateTensor(&out1, std::vector({1, 2})); + Tensor::CreateFromVector(std::vector({1, 2}), &out1); std::shared_ptr out2; - Tensor::CreateTensor(&out2, std::vector({4, 5})); + Tensor::CreateFromVector(std::vector({4, 5}), &out2); ASSERT_EQ(*out1, *out[0]); ASSERT_EQ(*out2, *out[1]); } diff --git a/tests/ut/cpp/dataset/type_cast_op_test.cc b/tests/ut/cpp/dataset/type_cast_op_test.cc index a94a7fedbab..371b589c570 100644 --- a/tests/ut/cpp/dataset/type_cast_op_test.cc +++ b/tests/ut/cpp/dataset/type_cast_op_test.cc @@ -43,16 +43,15 @@ class MindDataTestTypeCast : public UT::Common { template void testCast(std::vector values, const DataType &from, const DataType &to) { - std::shared_ptr t = std::make_shared(TensorShape({static_cast(values.size())}), - DataType(from), - reinterpret_cast(&values[0])); + std::shared_ptr t; + Tensor::CreateFromVector(values, &t); std::unique_ptr op(new TypeCastOp(to)); EXPECT_TRUE(op->OneToOne()); std::shared_ptr output; EXPECT_TRUE(op->Compute(t, &output)); ASSERT_TRUE(t->shape() == output->shape()); - ASSERT_TRUE(DataType(to)==output->type()); + ASSERT_TRUE(DataType(to) == output->type()); MS_LOG(DEBUG) << *output << std::endl; auto out = output->begin(); auto v = values.begin(); diff --git a/tests/ut/python/dataset/test_pair_truncate.py b/tests/ut/python/dataset/test_pair_truncate.py index 6b1138e5a9c..8cc40ee1264 100644 --- a/tests/ut/python/dataset/test_pair_truncate.py +++ b/tests/ut/python/dataset/test_pair_truncate.py @@ -16,7 +16,6 @@ Testing Mask op in DE """ import numpy as np -import pytest import mindspore.dataset as ds import mindspore.dataset.text as text @@ -55,9 +54,7 @@ def test_basics_str(): def test_exceptions(): - with pytest.raises(RuntimeError) as info: - compare(in1=[1, 2, 3, 4], in2=[5, 6, 7, 8], length=1, out1=[1, 2], out2=[5]) - assert "Indices are empty, generated tensor would be empty" in str(info.value) + compare(in1=[1, 2, 3, 4], in2=[5, 6, 7, 8], length=1, out1=[1], out2=[]) if __name__ == "__main__": diff --git a/tests/ut/python/dataset/test_slice_op.py b/tests/ut/python/dataset/test_slice_op.py index 6e81133a2a2..72417bff710 100644 --- a/tests/ut/python/dataset/test_slice_op.py +++ b/tests/ut/python/dataset/test_slice_op.py @@ -121,21 +121,10 @@ def test_slice_exceptions(): slice_compare([1, 2, 3, 4, 5], 5) assert "Index 5 is out of bounds [0,5)" in str(info.value) - with pytest.raises(RuntimeError) as info: - slice_compare([1, 2, 3, 4, 5], slice(0)) - assert "Indices are empty, generated tensor would be empty." in str(info.value) - - with pytest.raises(RuntimeError) as info: - slice_compare([1, 2, 3, 4, 5], slice(3, 1, 1)) - assert "Indices are empty, generated tensor would be empty." in str(info.value) - - with pytest.raises(RuntimeError) as info: - slice_compare([1, 2, 3, 4, 5], slice(5, 10, 1)) - assert "Indices are empty, generated tensor would be empty." in str(info.value) - - with pytest.raises(RuntimeError) as info: - slice_compare([1, 2, 3, 4, 5], slice(-1, -5, 1)) - assert "Indices are empty, generated tensor would be empty." in str(info.value) + slice_compare([1, 2, 3, 4, 5], slice(0)) + slice_compare([1, 2, 3, 4, 5], slice(3, 1, 1)) + slice_compare([1, 2, 3, 4, 5], slice(5, 10, 1)) + slice_compare([1, 2, 3, 4, 5], slice(-1, -5, 1)) def test_slice_all_str(): @@ -198,21 +187,10 @@ def test_slice_exceptions_str(): slice_compare([b"1", b"2", b"3", b"4", b"5"], 5) assert "Index 5 is out of bounds [0,5)" in str(info.value) - with pytest.raises(RuntimeError) as info: - slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0)) - assert "Indices are empty, generated tensor would be empty." in str(info.value) - - with pytest.raises(RuntimeError) as info: - slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(3, 1, 1)) - assert "Indices are empty, generated tensor would be empty." in str(info.value) - - with pytest.raises(RuntimeError) as info: - slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(5, 10, 1)) - assert "Indices are empty, generated tensor would be empty." in str(info.value) - - with pytest.raises(RuntimeError) as info: - slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1, -5, 1)) - assert "Indices are empty, generated tensor would be empty." in str(info.value) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(3, 1, 1)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(5, 10, 1)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1, -5, 1)) if __name__ == "__main__": diff --git a/tests/ut/python/dataset/test_tensor_empty.py b/tests/ut/python/dataset/test_tensor_empty.py new file mode 100644 index 00000000000..f6810555443 --- /dev/null +++ b/tests/ut/python/dataset/test_tensor_empty.py @@ -0,0 +1,72 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np + +import mindspore.dataset as ds + + +def test_tensor_empty(): + def gen(): + for _ in range(4): + (yield np.array([], dtype=np.int64), np.array([], dtype='S').reshape([0, 4]), np.array([1], + dtype=np.float64)) + + data = ds.GeneratorDataset(gen, column_names=["col1", "col2", "col3"]) + + for d in data: + np.testing.assert_array_equal(np.array([], dtype=np.int64), d[0]) + np.testing.assert_array_equal(np.array([], dtype='S').reshape([0, 4]), d[1]) + np.testing.assert_array_equal(np.array([1], dtype=np.float64), d[2]) + + +def test_tensor_empty_map(): + def gen(): + for _ in range(4): + (yield np.array([], dtype=np.int64), np.array([], dtype='S'), np.array([1], dtype=np.float64)) + + data = ds.GeneratorDataset(gen, column_names=["col1", "col2", "col3"]) + + def func(x, y, z): + x = np.array([1], dtype=np.int64) + y = np.array(["Hi"], dtype='S') + z = np.array([], dtype=np.float64) + return x, y, z + + data = data.map(input_columns=["col1", "col2", "col3"], operations=func) + + for d in data: + np.testing.assert_array_equal(np.array([1], dtype=np.int64), d[0]) + np.testing.assert_array_equal(np.array(["Hi"], dtype='S'), d[1]) + np.testing.assert_array_equal(np.array([], dtype=np.float64), d[2]) + + +def test_tensor_empty_batch(): + def gen(): + for _ in range(4): + (yield np.array([], dtype=np.int64), np.array([], dtype='S').reshape([0, 4]), np.array([1], + dtype=np.float64)) + + data = ds.GeneratorDataset(gen, column_names=["col1", "col2", "col3"]).batch(2) + + for d in data: + np.testing.assert_array_equal(np.array([], dtype=np.int64).reshape([2, 0]), d[0]) + np.testing.assert_array_equal(np.array([], dtype='S').reshape([2, 0, 4]), d[1]) + np.testing.assert_array_equal(np.array([[1], [1]], dtype=np.float64), d[2]) + + +if __name__ == '__main__': + test_tensor_empty() + test_tensor_empty_map() + test_tensor_empty_batch()