!19047 [MD] Remove Builder

Merge pull request !19047 from harshvardhangupta/remove_builder
This commit is contained in:
i-robot 2021-07-06 17:34:35 +00:00 committed by Gitee
commit 0a6732f6dc
7 changed files with 83 additions and 494 deletions

View File

@ -28,52 +28,14 @@
namespace mindspore {
namespace dataset {
CsvOp::Builder::Builder()
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers();
builder_op_connector_size_ = config_manager->op_connector_size();
builder_worker_connector_size_ = config_manager->worker_connector_size();
}
Status CsvOp::Builder::ValidateInputs() const {
std::string err;
err += builder_num_workers_ <= 0 ? "Invalid parameter, num_parallel_workers must be greater than 0, but got " +
std::to_string(builder_num_workers_) + ".\n"
: "";
err += (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1)
? "Invalid parameter, num_shard must be greater than shard_id and greater than 0, got num_shard: " +
std::to_string(builder_num_devices_) + ", shard_id: " + std::to_string(builder_device_id_) + ".\n"
: "";
return err.empty() ? Status::OK() : Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, err);
}
Status CsvOp::Builder::Build(std::shared_ptr<CsvOp> *op) {
RETURN_IF_NOT_OK(ValidateInputs());
// Throttle the number of workers if we have more workers than files!
if (static_cast<size_t>(builder_num_workers_) > builder_csv_files_list_.size()) {
builder_num_workers_ = builder_csv_files_list_.size();
MS_LOG(WARNING) << "CsvOp operator parallelism reduced to " << builder_num_workers_ << " workers.";
}
std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>(
builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_,
builder_num_workers_, builder_num_samples_, builder_worker_connector_size_, builder_op_connector_size_,
builder_shuffle_files_, builder_num_devices_, builder_device_id_);
RETURN_IF_NOT_OK(csv_op->Init());
*op = std::move(csv_op);
return Status::OK();
}
CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default,
const std::vector<std::string> &column_name, int32_t num_workers, int64_t num_samples,
int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, int32_t num_devices,
int32_t device_id)
: NonMappableLeafOp(num_workers, worker_connector_size, num_samples, op_connector_size, shuffle_files, num_devices,
device_id),
: NonMappableLeafOp(std::min(num_workers, static_cast<int32_t>(csv_files_list.size())), worker_connector_size,
num_samples, op_connector_size, shuffle_files, num_devices, device_id),
csv_files_list_(std::move(csv_files_list)),
field_delim_(field_delim),
column_default_list_(column_default),
@ -654,13 +616,24 @@ int64_t CsvOp::CountTotalRows(const std::string &file) {
}
Status CsvOp::CountAllFileRows(const std::vector<std::string> &files, bool csv_header, int64_t *count) {
int32_t num_workers = GlobalContext::config_manager()->num_parallel_workers();
int32_t op_connector_size = GlobalContext::config_manager()->op_connector_size();
int32_t worker_connector_size = GlobalContext::config_manager()->worker_connector_size();
int32_t device_id = 0;
int32_t num_devices = 1;
int32_t num_samples = 0;
bool shuffle_files = false;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_list;
std::vector<std::string> column_name_list;
char field_delim = ',';
std::shared_ptr<CsvOp> op;
*count = 0;
if (csv_header) {
RETURN_IF_NOT_OK(Builder().SetCsvFilesList(files).Build(&op));
} else {
RETURN_IF_NOT_OK(Builder().SetCsvFilesList(files).SetColumName({""}).Build(&op));
if (!csv_header) {
column_name_list.push_back("");
}
op = std::make_shared<CsvOp>(files, field_delim, column_list, column_name_list, num_workers, num_samples,
worker_connector_size, op_connector_size, shuffle_files, num_devices, device_id);
RETURN_IF_NOT_OK(op->Init());
for (auto file : files) {
*count += op->CountTotalRows(file);
}

View File

@ -144,117 +144,6 @@ class CsvOp : public NonMappableLeafOp {
std::string file_path_;
};
class Builder {
public:
/// Builder constructor. Creates the builder object.
/// @note No default args
/// @return This is a constructor.
Builder();
/// Default destructor
~Builder() = default;
/// Checks if the inputs of the builder is valid.
/// @return Status - the error code returned.
Status ValidateInputs() const;
/// Create the final object.
/// @param op - dataset op.
/// @return - the error code return.
Status Build(std::shared_ptr<CsvOp> *op);
/// Setter method.
/// @return Builder - setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
builder_num_workers_ = num_workers;
return *this;
}
/// Setter method.
/// @return Builder - setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t op_connector_size) {
builder_op_connector_size_ = op_connector_size;
return *this;
}
/// Setter method.
/// @return Builder - setter method returns reference to the builder.
Builder &SetRowsPerBuffer(int64_t rows_per_buffer) {
builder_rows_per_buffer_ = rows_per_buffer;
return *this;
}
/// Setter method.
/// @return Builder - setter method returns reference to the builder.
Builder &SetNumDevices(int64_t num_dev) {
builder_num_devices_ = num_dev;
return *this;
}
/// Setter method.
/// @return Builder - setter method returns reference to the builder.
Builder &SetDeviceId(int64_t dev_id) {
builder_device_id_ = dev_id;
return *this;
}
/// Setter method.
/// @return Builder - setter method returns reference to the builder.
Builder &SetCsvFilesList(const std::vector<std::string> &files_list) {
builder_csv_files_list_ = files_list;
return *this;
}
/// Setter method.
/// @return Builder - setter method returns reference to the builder.
Builder &SetShuffleFiles(bool shuffle_files) {
builder_shuffle_files_ = shuffle_files;
return *this;
}
/// Setter method.
/// @return Builder - setter method returns reference to the builder.
Builder &SetNumSamples(int64_t num_samples) {
builder_num_samples_ = num_samples;
return *this;
}
/// Setter method.
/// @return Builder - setter method returns reference to the builder.
Builder &SetFieldDelim(char field_delim) {
builder_field_delim_ = field_delim;
return *this;
}
/// Setter method.
/// @return Builder - setter method returns reference to the builder.
Builder &SetColumDefault(std::vector<std::shared_ptr<CsvOp::BaseRecord>> record_list) {
builder_column_default_list_ = record_list;
return *this;
}
/// Setter method.
/// @return Builder - setter method returns reference to the builder.
Builder &SetColumName(std::vector<std::string> col_name_list) {
builder_column_name_list_ = col_name_list;
return *this;
}
private:
int32_t builder_device_id_;
int32_t builder_num_devices_;
int32_t builder_num_workers_;
int32_t builder_op_connector_size_;
int64_t builder_rows_per_buffer_;
int64_t builder_num_samples_;
int32_t builder_worker_connector_size_;
std::vector<std::string> builder_csv_files_list_;
bool builder_shuffle_files_;
char builder_field_delim_;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_;
std::vector<std::string> builder_column_name_list_;
};
/// Constructor of CsvOp
CsvOp() = delete;

View File

@ -52,100 +52,6 @@ const int32_t LOG_INTERVAL = 19;
class MindRecordOp : public MappableLeafOp {
public:
// The nested builder class inside of the MindRecordOp is used to help manage all of the arguments
// for constructing it. Use the builder by setting each argument with the provided set methods,
// and then finally call the build method to execute the actual construction.
class Builder {
public:
Builder();
~Builder() = default;
Status Build(std::shared_ptr<MindRecordOp> *);
Builder &SetRowsPerBuffer(int rows_per_buffer) {
build_rows_per_buffer_ = rows_per_buffer;
return *this;
}
Builder &SetNumMindRecordWorkers(int32_t num_mind_record_workers) {
build_num_mind_record_workers_ = num_mind_record_workers;
return *this;
}
Builder &SetOpConnectorQueueSize(int32_t queue_size) {
build_op_connector_queue_size_ = queue_size;
return *this;
}
Builder &SetDatasetFile(const std::vector<std::string> &files) {
build_dataset_file_ = files;
return *this;
}
Builder &SetColumnsToLoad(const std::vector<std::string> &columns) {
build_columns_to_load_ = columns;
return *this;
}
Builder &SetOperators(const std::vector<std::shared_ptr<ShardOperator>> &operators) {
build_operators_ = operators;
return *this;
}
Builder &SetLoadDataset(bool load_dataset) {
build_load_dataset_ = load_dataset;
return *this;
}
Builder &SetNumToPadSamples(int64_t num_padded) {
build_num_padded_ = num_padded;
return *this;
}
Builder &SetPaddedSample(const py::handle &sample) {
build_sample_ = sample;
return *this;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
Builder &SetShuffleMode(const ShuffleMode shuffle_mode) {
build_shuffle_mode_ = shuffle_mode;
return *this;
}
Status SanityCheck() const;
static int32_t num_mind_record_workers() { return kDefaultMindRecordWorkers; }
mindrecord::json ToJson(const py::handle &obj);
private:
static constexpr int32_t kDefaultMindRecordWorkers = 4;
// The builder saves all MindRecordOp construction arguments internally.
// The following are the arguments.
int32_t build_num_mind_record_workers_;
int32_t builder_num_workers_;
int32_t build_rows_per_buffer_;
int32_t build_op_connector_queue_size_;
std::vector<std::string> build_dataset_file_;
bool build_load_dataset_;
std::vector<std::string> build_columns_to_load_;
std::vector<std::shared_ptr<ShardOperator>> build_operators_;
int64_t build_num_padded_;
py::handle build_sample_;
std::map<std::string, std::string> build_sample_bytes_;
std::shared_ptr<SamplerRT> builder_sampler_;
ShuffleMode build_shuffle_mode_;
};
// Constructor of the MindRecordOp.
// @note The builder class should be used to call it
// @param num_mind_record_workers - The number of workers for the op (run by ShardReader)

View File

@ -32,46 +32,6 @@ const int32_t kMnistLabelFileMagicNumber = 2049;
const int32_t kMnistImageRows = 28;
const int32_t kMnistImageCols = 28;
MnistOp::Builder::Builder() : builder_sampler_(nullptr), builder_usage_("") {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_num_workers_ = cfg->num_parallel_workers();
builder_op_connector_size_ = cfg->op_connector_size();
}
Status MnistOp::Builder::Build(std::shared_ptr<MnistOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
if (builder_sampler_ == nullptr) {
const int64_t num_samples = 0;
const int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
}
builder_schema_ = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(
builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(builder_schema_->AddColumn(
ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
*ptr = std::make_shared<MnistOp>(builder_usage_, builder_num_workers_, builder_dir_, builder_op_connector_size_,
std::move(builder_schema_), std::move(builder_sampler_));
return Status::OK();
}
Status MnistOp::Builder::SanityCheck() {
const std::set<std::string> valid = {"test", "train", "all", ""};
Path dir(builder_dir_);
std::string err_msg;
err_msg += dir.IsDirectory() == false
? "Invalid parameter, MNIST path is invalid or not set, path: " + builder_dir_ + ".\n"
: "";
err_msg += builder_num_workers_ <= 0 ? "Invalid parameter, num_parallel_workers must be greater than 0, but got " +
std::to_string(builder_num_workers_) + ".\n"
: "";
err_msg += valid.find(builder_usage_) == valid.end()
? "Invalid parameter, usage must be 'train','test' or 'all', but got " + builder_usage_ + ".\n"
: "";
return err_msg.empty() ? Status::OK() : Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, err_msg);
}
MnistOp::MnistOp(const std::string &usage, int32_t num_workers, std::string folder_path, int32_t queue_size,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
@ -306,10 +266,19 @@ Status MnistOp::LaunchThreadsAndInitOp() {
Status MnistOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) {
// the logic of counting the number of samples is copied from ParseMnistData() and uses CheckReader()
std::shared_ptr<MnistOp> op;
*count = 0;
RETURN_IF_NOT_OK(Builder().SetDir(dir).SetUsage(usage).Build(&op));
const int64_t num_samples = 0;
const int64_t start_index = 0;
auto sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
int32_t num_workers = cfg->num_parallel_workers();
int32_t op_connect_size = cfg->op_connector_size();
auto op = std::make_shared<MnistOp>(usage, num_workers, dir, op_connect_size, std::move(schema), std::move(sampler));
RETURN_IF_NOT_OK(op->WalkAllFiles());
for (size_t i = 0; i < op->image_names_.size(); ++i) {

View File

@ -44,72 +44,6 @@ using MnistLabelPair = std::pair<std::shared_ptr<Tensor>, uint32_t>;
class MnistOp : public MappableLeafOp {
public:
class Builder {
public:
// Constructor for Builder class of MnistOp
Builder();
// Destructor.
~Builder() = default;
// Setter method
// @param int32_t op_connector_size
// @return Builder setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t op_connector_size) {
builder_op_connector_size_ = op_connector_size;
return *this;
}
// Setter method
// @param int32_t num_workers
// @return Builder setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
builder_num_workers_ = num_workers;
return *this;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
// Setter method
// @param const std::string &dir
// @return
Builder &SetDir(const std::string &dir) {
builder_dir_ = dir;
return *this;
}
// Setter method
// @param const std::string &usage
// @return
Builder &SetUsage(const std::string &usage) {
builder_usage_ = usage;
return *this;
}
// Check validity of input args
// @return Status The status code returned
Status SanityCheck();
// The builder "Build" method creates the final object.
// @param std::shared_ptr<MnistOp> *op - DatasetOp
// @return Status The status code returned
Status Build(std::shared_ptr<MnistOp> *op);
private:
std::string builder_dir_;
std::string builder_usage_;
int32_t builder_num_workers_;
int32_t builder_rows_per_buffer_;
int32_t builder_op_connector_size_;
std::shared_ptr<SamplerRT> builder_sampler_;
std::unique_ptr<DataSchema> builder_schema_;
};
// Constructor
// @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'
// @param int32_t num_workers - number of workers reading images in parallel

View File

@ -37,67 +37,6 @@ class MindDataTestCSVOp : public UT::DatasetOpTesting {
};
TEST_F(MindDataTestCSVOp, TestCSVBasic) {
// Start with an empty execution tree
auto tree = std::make_shared<ExecutionTree>();
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testCSV/1.csv";
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list;
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0));
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0));
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0));
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0));
std::shared_ptr<CsvOp> op;
CsvOp::Builder builder;
builder.SetCsvFilesList({dataset_path})
.SetShuffleFiles(false)
.SetOpConnectorSize(2)
.SetFieldDelim(',')
.SetColumDefault(column_default_list)
.SetColumName({"col1", "col2", "col3", "col4"});
Status rc = builder.Build(&op);
ASSERT_TRUE(rc.IsOk());
rc = tree->AssociateNode(op);
ASSERT_TRUE(rc.IsOk());
rc = tree->AssignRoot(op);
ASSERT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration.";
rc = tree->Prepare();
ASSERT_TRUE(rc.IsOk());
rc = tree->Launch();
ASSERT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(tree);
TensorRow tensor_list;
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
int row_count = 0;
while (!tensor_list.empty()) {
// Display the tensor by calling the printer on it
for (int i = 0; i < tensor_list.size(); i++) {
std::ostringstream ss;
ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl;
MS_LOG(INFO) << "Tensor print: " << ss.str() << ".";
}
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
row_count++;
}
ASSERT_EQ(row_count, 3);
}
TEST_F(MindDataTestCSVOp, TestTotalRows) {
std::string csv_file1 = datasets_root_path_ + "/testCSV/1.csv";
std::string csv_file2 = datasets_root_path_ + "/testCSV/size.csv";

View File

@ -30,6 +30,7 @@
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "minddata/dataset/include/dataset/datasets.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/status.h"
#include "gtest/gtest.h"
@ -46,96 +47,74 @@ std::shared_ptr<RepeatOp> Repeat(int repeat_cnt);
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
Status Create1DTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements, unsigned char *data = nullptr,
DataType::Type data_type = DataType::DE_UINT32);
std::shared_ptr<MnistOp> CreateMnist(int64_t num_wrks, int64_t rows, int64_t conns, std::string path, bool shuf = false,
std::shared_ptr<SamplerRT> sampler = nullptr) {
std::shared_ptr<MnistOp> so;
MnistOp::Builder builder;
Status rc = builder.SetNumWorkers(num_wrks)
.SetDir(path)
.SetOpConnectorSize(conns)
.SetSampler(std::move(sampler))
.Build(&so);
return so;
}
class MindDataTestMnistSampler : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestMnistSampler, TestSequentialMnistWithRepeat) {
// Note: Mnist datasets are not included
// as part of the build tree.
// Download datasets and rebuild if data doesn't
// appear in this dataset
// Example: python tests/dataset/data/prep_data.py
std::string folder_path = datasets_root_path_ + "/testMnistData/";
int64_t num_samples = 10;
int64_t start_index = 0;
auto seq_sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
auto op1 = CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler));
auto op2 = Repeat(2);
op1->set_total_repeats(2);
op1->set_num_repeats_per_epoch(2);
auto tree = Build({op1, op2});
tree->Prepare();
std::shared_ptr<Dataset> ds =
Mnist(folder_path, "all", std::make_shared<SequentialSampler>(start_index, num_samples));
EXPECT_NE(ds, nullptr);
ds = ds->Repeat(2);
EXPECT_NE(ds, nullptr);
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
uint32_t res[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
Status rc = tree->Launch();
if (rc.IsError()) {
MS_LOG(ERROR) << "Return code error detected during tree launch: " << rc.ToString() << ".";
EXPECT_TRUE(false);
} else {
DatasetIterator di(tree);
TensorMap tensor_map;
ASSERT_OK(di.GetNextAsMap(&tensor_map));
EXPECT_TRUE(rc.IsOk());
uint64_t i = 0;
uint32_t label = 0;
while (tensor_map.size() != 0) {
tensor_map["label"]->GetItemAt<uint32_t>(&label, {});
EXPECT_TRUE(res[i % 10] == label);
MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n";
i++;
ASSERT_OK(di.GetNextAsMap(&tensor_map));
}
EXPECT_TRUE(i == 20);
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint32_t label_idx;
uint64_t i = 0;
while (row.size() != 0) {
auto image = row["image"];
auto label = row["label"];
// EXPECT_EQ(label, res[i % 10]);
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
std::shared_ptr<Tensor> de_label;
ASSERT_OK(Tensor::CreateFromMSTensor(label, &de_label));
ASSERT_OK(de_label->GetItemAt<uint32_t>(&label_idx, {}));
MS_LOG(INFO) << "Tensor label value: " << label_idx;
EXPECT_EQ(label_idx, res[i % 10]);
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 20);
iter->Stop();
}
TEST_F(MindDataTestMnistSampler, TestSequentialImageFolderWithRepeatBatch) {
std::string folder_path = datasets_root_path_ + "/testMnistData/";
int64_t num_samples = 10;
int64_t start_index = 0;
auto seq_sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
auto op1 = CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler));
auto op2 = Repeat(2);
auto op3 = Batch(5);
op1->set_total_repeats(2);
op1->set_num_repeats_per_epoch(2);
auto tree = Build({op1, op2, op3});
tree->Prepare();
uint32_t res[4][5] = {{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}};
Status rc = tree->Launch();
if (rc.IsError()) {
MS_LOG(ERROR) << "Return code error detected during tree launch: " << rc.ToString() << ".";
EXPECT_TRUE(false);
} else {
DatasetIterator di(tree);
TensorMap tensor_map;
ASSERT_OK(di.GetNextAsMap(&tensor_map));
EXPECT_TRUE(rc.IsOk());
uint64_t i = 0;
while (tensor_map.size() != 0) {
std::shared_ptr<Tensor> label;
Create1DTensor(&label, 5, reinterpret_cast<unsigned char *>(res[i % 4]));
EXPECT_TRUE((*label) == (*tensor_map["label"]));
MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << *tensor_map["label"] << "\n";
i++;
ASSERT_OK(di.GetNextAsMap(&tensor_map));
}
EXPECT_TRUE(i == 4);
std::shared_ptr<Dataset> ds =
Mnist(folder_path, "all", std::make_shared<SequentialSampler>(start_index, num_samples));
EXPECT_NE(ds, nullptr);
ds = ds->Repeat(2);
EXPECT_NE(ds, nullptr);
ds = ds->Batch(5);
EXPECT_NE(ds, nullptr);
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
std::vector<std::vector<uint32_t>> expected = {{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}};
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
auto image = row["image"];
auto label = row["label"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
TEST_MS_LOG_MSTENSOR(INFO, "Tensor label: ", label);
std::shared_ptr<Tensor> de_expected_label;
ASSERT_OK(Tensor::CreateFromVector(expected[i % 4], &de_expected_label));
mindspore::MSTensor expected_label =
mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_expected_label));
EXPECT_MSTENSOR_EQ(label, expected_label);
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 4);
iter->Stop();
}