2020-05-09 03:36:06 +08:00
|
|
|
/**
|
|
|
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
|
|
*
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
* You may obtain a copy of the License at
|
|
|
|
*
|
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
* limitations under the License.
|
|
|
|
*/
|
|
|
|
#include <memory>
|
|
|
|
#include <string>
|
2020-07-14 12:56:33 +08:00
|
|
|
#include "minddata/dataset/core/client.h"
|
2020-05-09 03:36:06 +08:00
|
|
|
#include "common/common.h"
|
|
|
|
#include "gtest/gtest.h"
|
|
|
|
#include "securec.h"
|
2020-07-14 12:56:33 +08:00
|
|
|
#include "minddata/dataset/core/tensor.h"
|
|
|
|
#include "minddata/dataset/core/cv_tensor.h"
|
|
|
|
#include "minddata/dataset/core/data_type.h"
|
2020-05-09 03:36:06 +08:00
|
|
|
|
|
|
|
using namespace mindspore::dataset;
|
|
|
|
|
|
|
|
namespace py = pybind11;
|
|
|
|
|
|
|
|
class MindDataTestStringTensorDE : public UT::Common {
|
|
|
|
public:
|
|
|
|
MindDataTestStringTensorDE() = default;
|
|
|
|
|
|
|
|
void SetUp() override { GlobalInit(); }
|
|
|
|
};
|
|
|
|
|
|
|
|
TEST_F(MindDataTestStringTensorDE, Basics) {
|
|
|
|
std::shared_ptr<Tensor> t = std::make_shared<Tensor>("Hi");
|
|
|
|
ASSERT_TRUE(t->shape() == TensorShape({}));
|
|
|
|
std::string_view s = "";
|
|
|
|
t->GetItemAt(&s, {});
|
|
|
|
ASSERT_TRUE(s == "Hi");
|
|
|
|
|
|
|
|
std::shared_ptr<Tensor> t2 = std::make_shared<Tensor>(std::vector<std::string>{"Hi", "Bye"});
|
|
|
|
ASSERT_TRUE(t2->shape() == TensorShape({2}));
|
|
|
|
t2->GetItemAt(&s, {0});
|
|
|
|
ASSERT_TRUE(s == "Hi");
|
|
|
|
t2->GetItemAt(&s, {1});
|
|
|
|
ASSERT_TRUE(s == "Bye");
|
|
|
|
|
|
|
|
std::vector<std::string> strings{"abc", "defg", "hi", "klmno", "123", "789"};
|
|
|
|
std::shared_ptr<Tensor> t3 = std::make_shared<Tensor>(strings, TensorShape({2, 3}));
|
|
|
|
ASSERT_TRUE(t3->shape() == TensorShape({2, 3}));
|
|
|
|
uint32_t index = 0;
|
|
|
|
for (uint32_t i = 0; i < 2; i++) {
|
|
|
|
for (uint32_t j = 0; j < 3; j++) {
|
|
|
|
std::string_view s = "";
|
|
|
|
t3->GetItemAt(&s, {i, j});
|
|
|
|
ASSERT_TRUE(s == strings[index++]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(MindDataTestStringTensorDE, Basics2) {
|
|
|
|
std::shared_ptr<Tensor> t =
|
|
|
|
std::make_shared<Tensor>(std::vector<std::string>{"abc", "defg", "hi", "klmno", "123", "789"}, TensorShape({2, 3}));
|
2020-05-16 17:21:37 +08:00
|
|
|
ASSERT_TRUE(t->SizeInBytes() == 6 * 5 + 20 + 4);
|
|
|
|
std::vector<uint32_t> offsets = {0, 4, 9, 12, 18, 22, 26};
|
2020-05-09 03:36:06 +08:00
|
|
|
uint32_t ctr = 0;
|
|
|
|
for (auto i : offsets) {
|
2020-06-11 12:34:22 +08:00
|
|
|
ASSERT_TRUE(*(reinterpret_cast<const uint32_t *>(t->GetBuffer() + ctr)) == i + 28);
|
2020-05-09 03:36:06 +08:00
|
|
|
ctr += 4;
|
|
|
|
}
|
2020-06-11 12:34:22 +08:00
|
|
|
const char *buf = reinterpret_cast<const char *>(t->GetBuffer()) + 6 * 4 + 4;
|
2020-05-09 03:36:06 +08:00
|
|
|
std::vector<uint32_t> starts = {0, 4, 9, 12, 18, 22};
|
|
|
|
|
|
|
|
uint32_t index = 0;
|
|
|
|
for (uint32_t i = 0; i < 2; i++) {
|
|
|
|
for (uint32_t j = 0; j < 3; j++) {
|
|
|
|
std::string_view s = "";
|
|
|
|
t->GetItemAt(&s, {i, j});
|
|
|
|
ASSERT_TRUE(s.data() == buf + starts[index++]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(MindDataTestStringTensorDE, Empty) {
|
|
|
|
std::vector<std::string> strings{"abc", "defg", "", "", "123", ""};
|
|
|
|
std::shared_ptr<Tensor> t = std::make_shared<Tensor>(strings, TensorShape({2, 3}));
|
|
|
|
// abc_defg___123__
|
|
|
|
// 0123456789012345
|
2020-05-16 17:21:37 +08:00
|
|
|
ASSERT_TRUE(t->SizeInBytes() == 6 * 5 + 10 + 4);
|
|
|
|
std::vector<uint32_t> offsets = {0, 4, 9, 10, 11, 15, 16};
|
2020-05-09 03:36:06 +08:00
|
|
|
uint32_t ctr = 0;
|
|
|
|
for (auto i : offsets) {
|
2020-06-11 12:34:22 +08:00
|
|
|
ASSERT_TRUE(*(reinterpret_cast<const uint32_t *>(t->GetBuffer() + ctr)) == i + 28);
|
2020-05-09 03:36:06 +08:00
|
|
|
ctr += 4;
|
|
|
|
}
|
2020-06-11 12:34:22 +08:00
|
|
|
const char *buf = reinterpret_cast<const char *>(t->GetBuffer()) + 6 * 4 + 4;
|
2020-05-09 03:36:06 +08:00
|
|
|
std::vector<uint32_t> starts = {0, 4, 9, 10, 11, 15};
|
|
|
|
|
|
|
|
uint32_t index = 0;
|
|
|
|
for (uint32_t i = 0; i < 2; i++) {
|
|
|
|
for (uint32_t j = 0; j < 3; j++) {
|
|
|
|
std::string_view s = "";
|
|
|
|
t->GetItemAt(&s, {i, j});
|
|
|
|
ASSERT_TRUE(s.data() == buf + starts[index]);
|
|
|
|
ASSERT_TRUE(s == strings[index++]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(MindDataTestStringTensorDE, SetItem) {
|
|
|
|
std::vector<std::string> strings{"abc", "defg", "hi", "klmno", "123", "789"};
|
|
|
|
std::shared_ptr<Tensor> t3 = std::make_shared<Tensor>(strings, TensorShape({2, 3}));
|
|
|
|
ASSERT_TRUE(t3->shape() == TensorShape({2, 3}));
|
|
|
|
|
|
|
|
t3->SetItemAt({0, 1}, std::string{"xyzz"});
|
|
|
|
strings[1] = "xyzz";
|
|
|
|
|
|
|
|
t3->SetItemAt({0, 2}, std::string{"07"});
|
|
|
|
strings[2] = "07";
|
|
|
|
|
|
|
|
t3->SetItemAt({1, 2}, std::string{"987"});
|
|
|
|
strings[5] = "987";
|
|
|
|
|
|
|
|
uint32_t index = 0;
|
|
|
|
for (uint32_t i = 0; i < 2; i++) {
|
|
|
|
for (uint32_t j = 0; j < 3; j++) {
|
|
|
|
std::string_view s = "";
|
|
|
|
t3->GetItemAt(&s, {i, j});
|
|
|
|
ASSERT_TRUE(s == strings[index++]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(MindDataTestStringTensorDE, Iterator) {
|
|
|
|
std::vector<std::string> strings{"abc", "defg", "hi", "klmno", "123", "789"};
|
|
|
|
std::shared_ptr<Tensor> t = std::make_shared<Tensor>(strings, TensorShape({2, 3}));
|
|
|
|
uint32_t index = 0;
|
|
|
|
auto itr = t->begin<std::string_view>();
|
|
|
|
for (; itr != t->end<std::string_view>(); itr++) {
|
|
|
|
ASSERT_TRUE(*itr == strings[index++]);
|
|
|
|
}
|
|
|
|
|
|
|
|
index = 0;
|
|
|
|
itr = t->begin<std::string_view>();
|
|
|
|
for (; itr != t->end<std::string_view>(); itr += 2) {
|
|
|
|
ASSERT_TRUE(*itr == strings[index]);
|
|
|
|
index += 2;
|
|
|
|
}
|
|
|
|
}
|