forked from mindspore-Ecosystem/mindspore
!31268 Move UT CPP Helper Functions
Merge pull request !31268 from zetongzhao/ut_cpp_helper
This commit is contained in:
commit
8e1f6d6b9b
|
@ -16,7 +16,6 @@
|
|||
#include "common/common.h"
|
||||
#include "include/api/types.h"
|
||||
#include "minddata/dataset/core/tensor_row.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
#include "minddata/dataset/include/dataset/vision.h"
|
||||
#include "minddata/dataset/kernels/ir/data/transforms_ir.h"
|
||||
|
@ -142,68 +141,15 @@ class MindDataTestPipeline : public UT::DatasetOpTesting {
|
|||
protected:
|
||||
};
|
||||
|
||||
TensorRow VecToRow(const MSTensorVec &v) {
|
||||
TensorRow row;
|
||||
for (const mindspore::MSTensor &t : v) {
|
||||
std::shared_ptr<Tensor> rt;
|
||||
(void)Tensor::CreateFromMemory(TensorShape(t.Shape()), MSTypeToDEType(static_cast<mindspore::TypeId>(t.DataType())),
|
||||
(const uchar *)(t.Data().get()), t.DataSize(), &rt);
|
||||
row.emplace_back(rt);
|
||||
}
|
||||
return row;
|
||||
}
|
||||
MSTensorVec RowToVec(const TensorRow &v) {
|
||||
MSTensorVec rv; // std::make_shared<DETensor>(de_tensor)
|
||||
std::transform(v.begin(), v.end(), std::back_inserter(rv), [](std::shared_ptr<Tensor> t) -> mindspore::MSTensor {
|
||||
return mindspore::MSTensor(std::make_shared<DETensor>(t));
|
||||
});
|
||||
return rv;
|
||||
}
|
||||
|
||||
MSTensorVec BucketBatchTestFunction(MSTensorVec input) {
|
||||
mindspore::dataset::TensorRow output;
|
||||
std::shared_ptr<Tensor> out;
|
||||
(void)Tensor::CreateEmpty(mindspore::dataset::TensorShape({1}),
|
||||
mindspore::dataset::DataType(mindspore::dataset::DataType::Type::DE_INT32), &out);
|
||||
(void)out->SetItemAt({0}, 2);
|
||||
output.push_back(out);
|
||||
return RowToVec(output);
|
||||
}
|
||||
|
||||
MSTensorVec Predicate1(MSTensorVec in) {
|
||||
// Return true if input is equal to 3
|
||||
uint64_t input_value;
|
||||
TensorRow input = VecToRow(in);
|
||||
(void)input.at(0)->GetItemAt(&input_value, {0});
|
||||
bool result = (input_value == 3);
|
||||
|
||||
// Convert from boolean to TensorRow
|
||||
TensorRow output;
|
||||
std::shared_ptr<Tensor> out;
|
||||
(void)Tensor::CreateEmpty(mindspore::dataset::TensorShape({}),
|
||||
mindspore::dataset::DataType(mindspore::dataset::DataType::Type::DE_BOOL), &out);
|
||||
(void)out->SetItemAt({}, result);
|
||||
(void)Tensor::CreateEmpty(
|
||||
TensorShape({1}), DataType(DataType::Type::DE_INT32),
|
||||
&out);
|
||||
constexpr int value = 2;
|
||||
(void)out->SetItemAt({0}, value);
|
||||
output.push_back(out);
|
||||
|
||||
return RowToVec(output);
|
||||
}
|
||||
|
||||
MSTensorVec Predicate2(MSTensorVec in) {
|
||||
// Return true if label is more than 1
|
||||
// The index of label in input is 1
|
||||
uint64_t input_value;
|
||||
TensorRow input = VecToRow(in);
|
||||
(void)input.at(1)->GetItemAt(&input_value, {0});
|
||||
bool result = (input_value > 1);
|
||||
|
||||
// Convert from boolean to TensorRow
|
||||
TensorRow output;
|
||||
std::shared_ptr<Tensor> out;
|
||||
(void)Tensor::CreateEmpty(mindspore::dataset::TensorShape({}),
|
||||
mindspore::dataset::DataType(mindspore::dataset::DataType::Type::DE_BOOL), &out);
|
||||
(void)out->SetItemAt({}, result);
|
||||
output.push_back(out);
|
||||
|
||||
return RowToVec(output);
|
||||
}
|
||||
|
||||
|
|
|
@ -150,3 +150,42 @@ std::shared_ptr<mindspore::dataset::ExecutionTree> DatasetOpTesting::Build(
|
|||
#endif
|
||||
#endif
|
||||
} // namespace UT
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
MSTensorVec Predicate1(MSTensorVec in) {
|
||||
// Return true if input is equal to 3
|
||||
uint64_t input_value;
|
||||
TensorRow input = VecToRow(in);
|
||||
(void)input.at(0)->GetItemAt(&input_value, {0});
|
||||
bool result = (input_value == 3);
|
||||
|
||||
// Convert from boolean to TensorRow
|
||||
TensorRow output;
|
||||
std::shared_ptr<Tensor> out;
|
||||
(void)Tensor::CreateEmpty(TensorShape({}), DataType(DataType::Type::DE_BOOL), &out);
|
||||
(void)out->SetItemAt({}, result);
|
||||
output.push_back(out);
|
||||
|
||||
return RowToVec(output);
|
||||
}
|
||||
|
||||
MSTensorVec Predicate2(MSTensorVec in) {
|
||||
// Return true if label is more than 1
|
||||
// The index of label in input is 1
|
||||
uint64_t input_value;
|
||||
TensorRow input = VecToRow(in);
|
||||
(void)input.at(1)->GetItemAt(&input_value, {0});
|
||||
bool result = (input_value > 1);
|
||||
|
||||
// Convert from boolean to TensorRow
|
||||
TensorRow output;
|
||||
std::shared_ptr<Tensor> out;
|
||||
(void)Tensor::CreateEmpty(TensorShape({}), DataType(mindspore::dataset::DataType::Type::DE_BOOL), &out);
|
||||
(void)out->SetItemAt({}, result);
|
||||
output.push_back(out);
|
||||
|
||||
return RowToVec(output);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -27,6 +27,7 @@
|
|||
#include "minddata/dataset/engine/datasetops/batch_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/repeat_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
using mindspore::Status;
|
||||
using mindspore::StatusCode;
|
||||
|
@ -118,4 +119,20 @@ class DatasetOpTesting : public Common {
|
|||
void SetUp() override;
|
||||
};
|
||||
} // namespace UT
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// defined in datasets.cc code, and function prototypes added here for UT purposes
|
||||
// convert MSTensorVec to DE TensorRow, return empty if fails
|
||||
TensorRow VecToRow(const MSTensorVec &v);
|
||||
|
||||
// defined in datasets.cc code, and function prototypes added here for UT purposes
|
||||
// convert DE TensorRow to MSTensorVec, won't fail
|
||||
MSTensorVec RowToVec(const TensorRow &v);
|
||||
|
||||
MSTensorVec Predicate1(MSTensorVec in);
|
||||
|
||||
MSTensorVec Predicate2(MSTensorVec in);
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // TESTS_UT_CPP_DATASET_COMMON_COMMON_H_
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
#include <string>
|
||||
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
#include "minddata/dataset/engine/opt/pre/skip_pushdown_pass.h"
|
||||
#include "minddata/dataset/include/dataset/samplers.h"
|
||||
#include "minddata/dataset/include/dataset/vision.h"
|
||||
|
@ -107,12 +106,6 @@ class MindDataSkipPushdownTestOptimizationPass : public UT::DatasetOpTesting {
|
|||
}
|
||||
};
|
||||
|
||||
TensorRow VecToRow(const MSTensorVec &v);
|
||||
|
||||
MSTensorVec RowToVec(const TensorRow &v);
|
||||
|
||||
MSTensorVec Predicate1(MSTensorVec in);
|
||||
|
||||
/// Feature: MindData Skip Pushdown Optimization Pass Test
|
||||
/// Description: Test MindData Skip Pushdown Optimization Pass with Sampler in MappableSourceNode
|
||||
/// Expectation: Skip node is pushed down and removed after optimization pass
|
||||
|
|
Loading…
Reference in New Issue