forked from mindspore-Ecosystem/mindspore
153 lines
4.9 KiB
C++
153 lines
4.9 KiB
C++
|
/**
|
||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
#include <memory>
|
||
|
#include <string>
|
||
|
#include "dataset/core/client.h"
|
||
|
#include "common/common.h"
|
||
|
#include "gtest/gtest.h"
|
||
|
#include "securec.h"
|
||
|
#include "dataset/core/tensor.h"
|
||
|
#include "dataset/core/cv_tensor.h"
|
||
|
#include "dataset/core/data_type.h"
|
||
|
#include "dataset/util/de_error.h"
|
||
|
|
||
|
using namespace mindspore::dataset;
|
||
|
|
||
|
namespace py = pybind11;
|
||
|
|
||
|
class MindDataTestStringTensorDE : public UT::Common {
|
||
|
public:
|
||
|
MindDataTestStringTensorDE() = default;
|
||
|
|
||
|
void SetUp() override { GlobalInit(); }
|
||
|
};
|
||
|
|
||
|
TEST_F(MindDataTestStringTensorDE, Basics) {
|
||
|
std::shared_ptr<Tensor> t = std::make_shared<Tensor>("Hi");
|
||
|
ASSERT_TRUE(t->shape() == TensorShape({}));
|
||
|
std::string_view s = "";
|
||
|
t->GetItemAt(&s, {});
|
||
|
ASSERT_TRUE(s == "Hi");
|
||
|
|
||
|
std::shared_ptr<Tensor> t2 = std::make_shared<Tensor>(std::vector<std::string>{"Hi", "Bye"});
|
||
|
ASSERT_TRUE(t2->shape() == TensorShape({2}));
|
||
|
t2->GetItemAt(&s, {0});
|
||
|
ASSERT_TRUE(s == "Hi");
|
||
|
t2->GetItemAt(&s, {1});
|
||
|
ASSERT_TRUE(s == "Bye");
|
||
|
|
||
|
std::vector<std::string> strings{"abc", "defg", "hi", "klmno", "123", "789"};
|
||
|
std::shared_ptr<Tensor> t3 = std::make_shared<Tensor>(strings, TensorShape({2, 3}));
|
||
|
ASSERT_TRUE(t3->shape() == TensorShape({2, 3}));
|
||
|
uint32_t index = 0;
|
||
|
for (uint32_t i = 0; i < 2; i++) {
|
||
|
for (uint32_t j = 0; j < 3; j++) {
|
||
|
std::string_view s = "";
|
||
|
t3->GetItemAt(&s, {i, j});
|
||
|
ASSERT_TRUE(s == strings[index++]);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
TEST_F(MindDataTestStringTensorDE, Basics2) {
|
||
|
std::shared_ptr<Tensor> t =
|
||
|
std::make_shared<Tensor>(std::vector<std::string>{"abc", "defg", "hi", "klmno", "123", "789"}, TensorShape({2, 3}));
|
||
|
ASSERT_TRUE(t->SizeInBytes() == 6 * 5 + 20);
|
||
|
std::vector<uint32_t> offsets = {3, 8, 11, 17, 21, 25};
|
||
|
uint32_t ctr = 0;
|
||
|
for (auto i : offsets) {
|
||
|
ASSERT_TRUE(*(reinterpret_cast<uint32_t *>(t->GetMutableBuffer() + ctr)) == i);
|
||
|
ctr += 4;
|
||
|
}
|
||
|
const char *buf = reinterpret_cast<char *>(t->GetMutableBuffer()) + 6 * 4;
|
||
|
std::vector<uint32_t> starts = {0, 4, 9, 12, 18, 22};
|
||
|
|
||
|
uint32_t index = 0;
|
||
|
for (uint32_t i = 0; i < 2; i++) {
|
||
|
for (uint32_t j = 0; j < 3; j++) {
|
||
|
std::string_view s = "";
|
||
|
t->GetItemAt(&s, {i, j});
|
||
|
ASSERT_TRUE(s.data() == buf + starts[index++]);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
TEST_F(MindDataTestStringTensorDE, Empty) {
|
||
|
std::vector<std::string> strings{"abc", "defg", "", "", "123", ""};
|
||
|
std::shared_ptr<Tensor> t = std::make_shared<Tensor>(strings, TensorShape({2, 3}));
|
||
|
// abc_defg___123__
|
||
|
// 0123456789012345
|
||
|
ASSERT_TRUE(t->SizeInBytes() == 6 * 5 + 10);
|
||
|
std::vector<uint32_t> offsets = {3, 8, 9, 10, 14, 15};
|
||
|
uint32_t ctr = 0;
|
||
|
for (auto i : offsets) {
|
||
|
ASSERT_TRUE(*(reinterpret_cast<uint32_t *>(t->GetMutableBuffer() + ctr)) == i);
|
||
|
ctr += 4;
|
||
|
}
|
||
|
const char *buf = reinterpret_cast<char *>(t->GetMutableBuffer()) + 6 * 4;
|
||
|
std::vector<uint32_t> starts = {0, 4, 9, 10, 11, 15};
|
||
|
|
||
|
uint32_t index = 0;
|
||
|
for (uint32_t i = 0; i < 2; i++) {
|
||
|
for (uint32_t j = 0; j < 3; j++) {
|
||
|
std::string_view s = "";
|
||
|
t->GetItemAt(&s, {i, j});
|
||
|
ASSERT_TRUE(s.data() == buf + starts[index]);
|
||
|
ASSERT_TRUE(s == strings[index++]);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
TEST_F(MindDataTestStringTensorDE, SetItem) {
|
||
|
std::vector<std::string> strings{"abc", "defg", "hi", "klmno", "123", "789"};
|
||
|
std::shared_ptr<Tensor> t3 = std::make_shared<Tensor>(strings, TensorShape({2, 3}));
|
||
|
ASSERT_TRUE(t3->shape() == TensorShape({2, 3}));
|
||
|
|
||
|
t3->SetItemAt({0, 1}, std::string{"xyzz"});
|
||
|
strings[1] = "xyzz";
|
||
|
|
||
|
t3->SetItemAt({0, 2}, std::string{"07"});
|
||
|
strings[2] = "07";
|
||
|
|
||
|
t3->SetItemAt({1, 2}, std::string{"987"});
|
||
|
strings[5] = "987";
|
||
|
|
||
|
uint32_t index = 0;
|
||
|
for (uint32_t i = 0; i < 2; i++) {
|
||
|
for (uint32_t j = 0; j < 3; j++) {
|
||
|
std::string_view s = "";
|
||
|
t3->GetItemAt(&s, {i, j});
|
||
|
ASSERT_TRUE(s == strings[index++]);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
TEST_F(MindDataTestStringTensorDE, Iterator) {
|
||
|
std::vector<std::string> strings{"abc", "defg", "hi", "klmno", "123", "789"};
|
||
|
std::shared_ptr<Tensor> t = std::make_shared<Tensor>(strings, TensorShape({2, 3}));
|
||
|
uint32_t index = 0;
|
||
|
auto itr = t->begin<std::string_view>();
|
||
|
for (; itr != t->end<std::string_view>(); itr++) {
|
||
|
ASSERT_TRUE(*itr == strings[index++]);
|
||
|
}
|
||
|
|
||
|
index = 0;
|
||
|
itr = t->begin<std::string_view>();
|
||
|
for (; itr != t->end<std::string_view>(); itr += 2) {
|
||
|
ASSERT_TRUE(*itr == strings[index]);
|
||
|
index += 2;
|
||
|
}
|
||
|
}
|