!31268 Move UT CPP Helper Functions

Merge pull request !31268 from zetongzhao/ut_cpp_helper
This commit is contained in:
i-robot 2022-03-16 18:20:24 +00:00 committed by Gitee
commit 8e1f6d6b9b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 61 additions and 66 deletions

View File

@ -16,7 +16,6 @@
#include "common/common.h"
#include "include/api/types.h"
#include "minddata/dataset/core/tensor_row.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/include/dataset/datasets.h"
#include "minddata/dataset/include/dataset/vision.h"
#include "minddata/dataset/kernels/ir/data/transforms_ir.h"
@ -142,68 +141,15 @@ class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
TensorRow VecToRow(const MSTensorVec &v) {
TensorRow row;
for (const mindspore::MSTensor &t : v) {
std::shared_ptr<Tensor> rt;
(void)Tensor::CreateFromMemory(TensorShape(t.Shape()), MSTypeToDEType(static_cast<mindspore::TypeId>(t.DataType())),
(const uchar *)(t.Data().get()), t.DataSize(), &rt);
row.emplace_back(rt);
}
return row;
}
MSTensorVec RowToVec(const TensorRow &v) {
MSTensorVec rv; // std::make_shared<DETensor>(de_tensor)
std::transform(v.begin(), v.end(), std::back_inserter(rv), [](std::shared_ptr<Tensor> t) -> mindspore::MSTensor {
return mindspore::MSTensor(std::make_shared<DETensor>(t));
});
return rv;
}
MSTensorVec BucketBatchTestFunction(MSTensorVec input) {
mindspore::dataset::TensorRow output;
std::shared_ptr<Tensor> out;
(void)Tensor::CreateEmpty(mindspore::dataset::TensorShape({1}),
mindspore::dataset::DataType(mindspore::dataset::DataType::Type::DE_INT32), &out);
(void)out->SetItemAt({0}, 2);
output.push_back(out);
return RowToVec(output);
}
MSTensorVec Predicate1(MSTensorVec in) {
// Return true if input is equal to 3
uint64_t input_value;
TensorRow input = VecToRow(in);
(void)input.at(0)->GetItemAt(&input_value, {0});
bool result = (input_value == 3);
// Convert from boolean to TensorRow
TensorRow output;
std::shared_ptr<Tensor> out;
(void)Tensor::CreateEmpty(mindspore::dataset::TensorShape({}),
mindspore::dataset::DataType(mindspore::dataset::DataType::Type::DE_BOOL), &out);
(void)out->SetItemAt({}, result);
(void)Tensor::CreateEmpty(
TensorShape({1}), DataType(DataType::Type::DE_INT32),
&out);
constexpr int value = 2;
(void)out->SetItemAt({0}, value);
output.push_back(out);
return RowToVec(output);
}
MSTensorVec Predicate2(MSTensorVec in) {
// Return true if label is more than 1
// The index of label in input is 1
uint64_t input_value;
TensorRow input = VecToRow(in);
(void)input.at(1)->GetItemAt(&input_value, {0});
bool result = (input_value > 1);
// Convert from boolean to TensorRow
TensorRow output;
std::shared_ptr<Tensor> out;
(void)Tensor::CreateEmpty(mindspore::dataset::TensorShape({}),
mindspore::dataset::DataType(mindspore::dataset::DataType::Type::DE_BOOL), &out);
(void)out->SetItemAt({}, result);
output.push_back(out);
return RowToVec(output);
}

View File

@ -150,3 +150,42 @@ std::shared_ptr<mindspore::dataset::ExecutionTree> DatasetOpTesting::Build(
#endif
#endif
} // namespace UT
namespace mindspore {
namespace dataset {
MSTensorVec Predicate1(MSTensorVec in) {
// Return true if input is equal to 3
uint64_t input_value;
TensorRow input = VecToRow(in);
(void)input.at(0)->GetItemAt(&input_value, {0});
bool result = (input_value == 3);
// Convert from boolean to TensorRow
TensorRow output;
std::shared_ptr<Tensor> out;
(void)Tensor::CreateEmpty(TensorShape({}), DataType(DataType::Type::DE_BOOL), &out);
(void)out->SetItemAt({}, result);
output.push_back(out);
return RowToVec(output);
}
MSTensorVec Predicate2(MSTensorVec in) {
// Return true if label is more than 1
// The index of label in input is 1
uint64_t input_value;
TensorRow input = VecToRow(in);
(void)input.at(1)->GetItemAt(&input_value, {0});
bool result = (input_value > 1);
// Convert from boolean to TensorRow
TensorRow output;
std::shared_ptr<Tensor> out;
(void)Tensor::CreateEmpty(TensorShape({}), DataType(mindspore::dataset::DataType::Type::DE_BOOL), &out);
(void)out->SetItemAt({}, result);
output.push_back(out);
return RowToVec(output);
}
} // namespace dataset
} // namespace mindspore

View File

@ -27,6 +27,7 @@
#include "minddata/dataset/engine/datasetops/batch_op.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
using mindspore::Status;
using mindspore::StatusCode;
@ -118,4 +119,20 @@ class DatasetOpTesting : public Common {
void SetUp() override;
};
} // namespace UT
namespace mindspore {
namespace dataset {
// defined in datasets.cc code, and function prototypes added here for UT purposes
// convert MSTensorVec to DE TensorRow, return empty if fails
TensorRow VecToRow(const MSTensorVec &v);
// defined in datasets.cc code, and function prototypes added here for UT purposes
// convert DE TensorRow to MSTensorVec, won't fail
MSTensorVec RowToVec(const TensorRow &v);
MSTensorVec Predicate1(MSTensorVec in);
MSTensorVec Predicate2(MSTensorVec in);
} // namespace dataset
} // namespace mindspore
#endif // TESTS_UT_CPP_DATASET_COMMON_COMMON_H_

View File

@ -18,7 +18,6 @@
#include <string>
#include "common/common.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/engine/opt/pre/skip_pushdown_pass.h"
#include "minddata/dataset/include/dataset/samplers.h"
#include "minddata/dataset/include/dataset/vision.h"
@ -107,12 +106,6 @@ class MindDataSkipPushdownTestOptimizationPass : public UT::DatasetOpTesting {
}
};
TensorRow VecToRow(const MSTensorVec &v);
MSTensorVec RowToVec(const TensorRow &v);
MSTensorVec Predicate1(MSTensorVec in);
/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with Sampler in MappableSourceNode
/// Expectation: Skip node is pushed down and removed after optimization pass