forked from mindspore-Ecosystem/mindspore
!19047 [MD] Remove Builder
Merge pull request !19047 from harshvardhangupta/remove_builder
This commit is contained in:
commit
0a6732f6dc
|
@ -28,52 +28,14 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
CsvOp::Builder::Builder()
|
||||
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
|
||||
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
|
||||
builder_num_workers_ = config_manager->num_parallel_workers();
|
||||
builder_op_connector_size_ = config_manager->op_connector_size();
|
||||
builder_worker_connector_size_ = config_manager->worker_connector_size();
|
||||
}
|
||||
|
||||
Status CsvOp::Builder::ValidateInputs() const {
|
||||
std::string err;
|
||||
err += builder_num_workers_ <= 0 ? "Invalid parameter, num_parallel_workers must be greater than 0, but got " +
|
||||
std::to_string(builder_num_workers_) + ".\n"
|
||||
: "";
|
||||
err += (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1)
|
||||
? "Invalid parameter, num_shard must be greater than shard_id and greater than 0, got num_shard: " +
|
||||
std::to_string(builder_num_devices_) + ", shard_id: " + std::to_string(builder_device_id_) + ".\n"
|
||||
: "";
|
||||
return err.empty() ? Status::OK() : Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, err);
|
||||
}
|
||||
|
||||
Status CsvOp::Builder::Build(std::shared_ptr<CsvOp> *op) {
|
||||
RETURN_IF_NOT_OK(ValidateInputs());
|
||||
|
||||
// Throttle the number of workers if we have more workers than files!
|
||||
if (static_cast<size_t>(builder_num_workers_) > builder_csv_files_list_.size()) {
|
||||
builder_num_workers_ = builder_csv_files_list_.size();
|
||||
MS_LOG(WARNING) << "CsvOp operator parallelism reduced to " << builder_num_workers_ << " workers.";
|
||||
}
|
||||
|
||||
std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>(
|
||||
builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_,
|
||||
builder_num_workers_, builder_num_samples_, builder_worker_connector_size_, builder_op_connector_size_,
|
||||
builder_shuffle_files_, builder_num_devices_, builder_device_id_);
|
||||
RETURN_IF_NOT_OK(csv_op->Init());
|
||||
*op = std::move(csv_op);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
|
||||
const std::vector<std::shared_ptr<BaseRecord>> &column_default,
|
||||
const std::vector<std::string> &column_name, int32_t num_workers, int64_t num_samples,
|
||||
int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, int32_t num_devices,
|
||||
int32_t device_id)
|
||||
: NonMappableLeafOp(num_workers, worker_connector_size, num_samples, op_connector_size, shuffle_files, num_devices,
|
||||
device_id),
|
||||
: NonMappableLeafOp(std::min(num_workers, static_cast<int32_t>(csv_files_list.size())), worker_connector_size,
|
||||
num_samples, op_connector_size, shuffle_files, num_devices, device_id),
|
||||
csv_files_list_(std::move(csv_files_list)),
|
||||
field_delim_(field_delim),
|
||||
column_default_list_(column_default),
|
||||
|
@ -654,13 +616,24 @@ int64_t CsvOp::CountTotalRows(const std::string &file) {
|
|||
}
|
||||
|
||||
Status CsvOp::CountAllFileRows(const std::vector<std::string> &files, bool csv_header, int64_t *count) {
|
||||
int32_t num_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||
int32_t op_connector_size = GlobalContext::config_manager()->op_connector_size();
|
||||
int32_t worker_connector_size = GlobalContext::config_manager()->worker_connector_size();
|
||||
int32_t device_id = 0;
|
||||
int32_t num_devices = 1;
|
||||
int32_t num_samples = 0;
|
||||
bool shuffle_files = false;
|
||||
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_list;
|
||||
std::vector<std::string> column_name_list;
|
||||
char field_delim = ',';
|
||||
std::shared_ptr<CsvOp> op;
|
||||
*count = 0;
|
||||
if (csv_header) {
|
||||
RETURN_IF_NOT_OK(Builder().SetCsvFilesList(files).Build(&op));
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(Builder().SetCsvFilesList(files).SetColumName({""}).Build(&op));
|
||||
if (!csv_header) {
|
||||
column_name_list.push_back("");
|
||||
}
|
||||
op = std::make_shared<CsvOp>(files, field_delim, column_list, column_name_list, num_workers, num_samples,
|
||||
worker_connector_size, op_connector_size, shuffle_files, num_devices, device_id);
|
||||
RETURN_IF_NOT_OK(op->Init());
|
||||
for (auto file : files) {
|
||||
*count += op->CountTotalRows(file);
|
||||
}
|
||||
|
|
|
@ -144,117 +144,6 @@ class CsvOp : public NonMappableLeafOp {
|
|||
std::string file_path_;
|
||||
};
|
||||
|
||||
class Builder {
|
||||
public:
|
||||
/// Builder constructor. Creates the builder object.
|
||||
/// @note No default args
|
||||
/// @return This is a constructor.
|
||||
Builder();
|
||||
|
||||
/// Default destructor
|
||||
~Builder() = default;
|
||||
|
||||
/// Checks if the inputs of the builder is valid.
|
||||
/// @return Status - the error code returned.
|
||||
Status ValidateInputs() const;
|
||||
|
||||
/// Create the final object.
|
||||
/// @param op - dataset op.
|
||||
/// @return - the error code return.
|
||||
Status Build(std::shared_ptr<CsvOp> *op);
|
||||
|
||||
/// Setter method.
|
||||
/// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetNumWorkers(int32_t num_workers) {
|
||||
builder_num_workers_ = num_workers;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Setter method.
|
||||
/// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetOpConnectorSize(int32_t op_connector_size) {
|
||||
builder_op_connector_size_ = op_connector_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Setter method.
|
||||
/// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetRowsPerBuffer(int64_t rows_per_buffer) {
|
||||
builder_rows_per_buffer_ = rows_per_buffer;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Setter method.
|
||||
/// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetNumDevices(int64_t num_dev) {
|
||||
builder_num_devices_ = num_dev;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Setter method.
|
||||
/// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetDeviceId(int64_t dev_id) {
|
||||
builder_device_id_ = dev_id;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Setter method.
|
||||
/// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetCsvFilesList(const std::vector<std::string> &files_list) {
|
||||
builder_csv_files_list_ = files_list;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Setter method.
|
||||
/// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetShuffleFiles(bool shuffle_files) {
|
||||
builder_shuffle_files_ = shuffle_files;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Setter method.
|
||||
/// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetNumSamples(int64_t num_samples) {
|
||||
builder_num_samples_ = num_samples;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Setter method.
|
||||
/// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetFieldDelim(char field_delim) {
|
||||
builder_field_delim_ = field_delim;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Setter method.
|
||||
/// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetColumDefault(std::vector<std::shared_ptr<CsvOp::BaseRecord>> record_list) {
|
||||
builder_column_default_list_ = record_list;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Setter method.
|
||||
/// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetColumName(std::vector<std::string> col_name_list) {
|
||||
builder_column_name_list_ = col_name_list;
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
int32_t builder_device_id_;
|
||||
int32_t builder_num_devices_;
|
||||
int32_t builder_num_workers_;
|
||||
int32_t builder_op_connector_size_;
|
||||
int64_t builder_rows_per_buffer_;
|
||||
int64_t builder_num_samples_;
|
||||
int32_t builder_worker_connector_size_;
|
||||
std::vector<std::string> builder_csv_files_list_;
|
||||
bool builder_shuffle_files_;
|
||||
char builder_field_delim_;
|
||||
std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_;
|
||||
std::vector<std::string> builder_column_name_list_;
|
||||
};
|
||||
|
||||
/// Constructor of CsvOp
|
||||
CsvOp() = delete;
|
||||
|
||||
|
|
|
@ -52,100 +52,6 @@ const int32_t LOG_INTERVAL = 19;
|
|||
|
||||
class MindRecordOp : public MappableLeafOp {
|
||||
public:
|
||||
// The nested builder class inside of the MindRecordOp is used to help manage all of the arguments
|
||||
// for constructing it. Use the builder by setting each argument with the provided set methods,
|
||||
// and then finally call the build method to execute the actual construction.
|
||||
class Builder {
|
||||
public:
|
||||
Builder();
|
||||
|
||||
~Builder() = default;
|
||||
|
||||
Status Build(std::shared_ptr<MindRecordOp> *);
|
||||
|
||||
Builder &SetRowsPerBuffer(int rows_per_buffer) {
|
||||
build_rows_per_buffer_ = rows_per_buffer;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Builder &SetNumMindRecordWorkers(int32_t num_mind_record_workers) {
|
||||
build_num_mind_record_workers_ = num_mind_record_workers;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Builder &SetOpConnectorQueueSize(int32_t queue_size) {
|
||||
build_op_connector_queue_size_ = queue_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Builder &SetDatasetFile(const std::vector<std::string> &files) {
|
||||
build_dataset_file_ = files;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Builder &SetColumnsToLoad(const std::vector<std::string> &columns) {
|
||||
build_columns_to_load_ = columns;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Builder &SetOperators(const std::vector<std::shared_ptr<ShardOperator>> &operators) {
|
||||
build_operators_ = operators;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Builder &SetLoadDataset(bool load_dataset) {
|
||||
build_load_dataset_ = load_dataset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Builder &SetNumToPadSamples(int64_t num_padded) {
|
||||
build_num_padded_ = num_padded;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Builder &SetPaddedSample(const py::handle &sample) {
|
||||
build_sample_ = sample;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param std::shared_ptr<Sampler> sampler
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
|
||||
builder_sampler_ = std::move(sampler);
|
||||
return *this;
|
||||
}
|
||||
|
||||
Builder &SetShuffleMode(const ShuffleMode shuffle_mode) {
|
||||
build_shuffle_mode_ = shuffle_mode;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Status SanityCheck() const;
|
||||
|
||||
static int32_t num_mind_record_workers() { return kDefaultMindRecordWorkers; }
|
||||
|
||||
mindrecord::json ToJson(const py::handle &obj);
|
||||
|
||||
private:
|
||||
static constexpr int32_t kDefaultMindRecordWorkers = 4;
|
||||
// The builder saves all MindRecordOp construction arguments internally.
|
||||
// The following are the arguments.
|
||||
int32_t build_num_mind_record_workers_;
|
||||
int32_t builder_num_workers_;
|
||||
int32_t build_rows_per_buffer_;
|
||||
int32_t build_op_connector_queue_size_;
|
||||
std::vector<std::string> build_dataset_file_;
|
||||
bool build_load_dataset_;
|
||||
std::vector<std::string> build_columns_to_load_;
|
||||
std::vector<std::shared_ptr<ShardOperator>> build_operators_;
|
||||
int64_t build_num_padded_;
|
||||
py::handle build_sample_;
|
||||
std::map<std::string, std::string> build_sample_bytes_;
|
||||
std::shared_ptr<SamplerRT> builder_sampler_;
|
||||
ShuffleMode build_shuffle_mode_;
|
||||
};
|
||||
|
||||
// Constructor of the MindRecordOp.
|
||||
// @note The builder class should be used to call it
|
||||
// @param num_mind_record_workers - The number of workers for the op (run by ShardReader)
|
||||
|
|
|
@ -32,46 +32,6 @@ const int32_t kMnistLabelFileMagicNumber = 2049;
|
|||
const int32_t kMnistImageRows = 28;
|
||||
const int32_t kMnistImageCols = 28;
|
||||
|
||||
MnistOp::Builder::Builder() : builder_sampler_(nullptr), builder_usage_("") {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
builder_num_workers_ = cfg->num_parallel_workers();
|
||||
builder_op_connector_size_ = cfg->op_connector_size();
|
||||
}
|
||||
|
||||
Status MnistOp::Builder::Build(std::shared_ptr<MnistOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
if (builder_sampler_ == nullptr) {
|
||||
const int64_t num_samples = 0;
|
||||
const int64_t start_index = 0;
|
||||
builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
|
||||
}
|
||||
builder_schema_ = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(
|
||||
builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(builder_schema_->AddColumn(
|
||||
ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
*ptr = std::make_shared<MnistOp>(builder_usage_, builder_num_workers_, builder_dir_, builder_op_connector_size_,
|
||||
std::move(builder_schema_), std::move(builder_sampler_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MnistOp::Builder::SanityCheck() {
|
||||
const std::set<std::string> valid = {"test", "train", "all", ""};
|
||||
Path dir(builder_dir_);
|
||||
std::string err_msg;
|
||||
err_msg += dir.IsDirectory() == false
|
||||
? "Invalid parameter, MNIST path is invalid or not set, path: " + builder_dir_ + ".\n"
|
||||
: "";
|
||||
err_msg += builder_num_workers_ <= 0 ? "Invalid parameter, num_parallel_workers must be greater than 0, but got " +
|
||||
std::to_string(builder_num_workers_) + ".\n"
|
||||
: "";
|
||||
err_msg += valid.find(builder_usage_) == valid.end()
|
||||
? "Invalid parameter, usage must be 'train','test' or 'all', but got " + builder_usage_ + ".\n"
|
||||
: "";
|
||||
return err_msg.empty() ? Status::OK() : Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
|
||||
MnistOp::MnistOp(const std::string &usage, int32_t num_workers, std::string folder_path, int32_t queue_size,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
|
||||
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
|
||||
|
@ -306,10 +266,19 @@ Status MnistOp::LaunchThreadsAndInitOp() {
|
|||
|
||||
Status MnistOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) {
|
||||
// the logic of counting the number of samples is copied from ParseMnistData() and uses CheckReader()
|
||||
std::shared_ptr<MnistOp> op;
|
||||
*count = 0;
|
||||
RETURN_IF_NOT_OK(Builder().SetDir(dir).SetUsage(usage).Build(&op));
|
||||
|
||||
const int64_t num_samples = 0;
|
||||
const int64_t start_index = 0;
|
||||
auto sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
int32_t num_workers = cfg->num_parallel_workers();
|
||||
int32_t op_connect_size = cfg->op_connector_size();
|
||||
auto op = std::make_shared<MnistOp>(usage, num_workers, dir, op_connect_size, std::move(schema), std::move(sampler));
|
||||
RETURN_IF_NOT_OK(op->WalkAllFiles());
|
||||
|
||||
for (size_t i = 0; i < op->image_names_.size(); ++i) {
|
||||
|
|
|
@ -44,72 +44,6 @@ using MnistLabelPair = std::pair<std::shared_ptr<Tensor>, uint32_t>;
|
|||
|
||||
class MnistOp : public MappableLeafOp {
|
||||
public:
|
||||
class Builder {
|
||||
public:
|
||||
// Constructor for Builder class of MnistOp
|
||||
Builder();
|
||||
|
||||
// Destructor.
|
||||
~Builder() = default;
|
||||
|
||||
// Setter method
|
||||
// @param int32_t op_connector_size
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetOpConnectorSize(int32_t op_connector_size) {
|
||||
builder_op_connector_size_ = op_connector_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param int32_t num_workers
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetNumWorkers(int32_t num_workers) {
|
||||
builder_num_workers_ = num_workers;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param std::shared_ptr<Sampler> sampler
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
|
||||
builder_sampler_ = std::move(sampler);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param const std::string &dir
|
||||
// @return
|
||||
Builder &SetDir(const std::string &dir) {
|
||||
builder_dir_ = dir;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param const std::string &usage
|
||||
// @return
|
||||
Builder &SetUsage(const std::string &usage) {
|
||||
builder_usage_ = usage;
|
||||
return *this;
|
||||
}
|
||||
// Check validity of input args
|
||||
// @return Status The status code returned
|
||||
Status SanityCheck();
|
||||
|
||||
// The builder "Build" method creates the final object.
|
||||
// @param std::shared_ptr<MnistOp> *op - DatasetOp
|
||||
// @return Status The status code returned
|
||||
Status Build(std::shared_ptr<MnistOp> *op);
|
||||
|
||||
private:
|
||||
std::string builder_dir_;
|
||||
std::string builder_usage_;
|
||||
int32_t builder_num_workers_;
|
||||
int32_t builder_rows_per_buffer_;
|
||||
int32_t builder_op_connector_size_;
|
||||
std::shared_ptr<SamplerRT> builder_sampler_;
|
||||
std::unique_ptr<DataSchema> builder_schema_;
|
||||
};
|
||||
|
||||
// Constructor
|
||||
// @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'
|
||||
// @param int32_t num_workers - number of workers reading images in parallel
|
||||
|
|
|
@ -37,67 +37,6 @@ class MindDataTestCSVOp : public UT::DatasetOpTesting {
|
|||
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestCSVOp, TestCSVBasic) {
|
||||
// Start with an empty execution tree
|
||||
auto tree = std::make_shared<ExecutionTree>();
|
||||
|
||||
std::string dataset_path;
|
||||
dataset_path = datasets_root_path_ + "/testCSV/1.csv";
|
||||
|
||||
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list;
|
||||
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0));
|
||||
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0));
|
||||
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0));
|
||||
column_default_list.push_back(std::make_shared<CsvOp::Record<int>>(CsvOp::INT, 0));
|
||||
std::shared_ptr<CsvOp> op;
|
||||
CsvOp::Builder builder;
|
||||
builder.SetCsvFilesList({dataset_path})
|
||||
|
||||
.SetShuffleFiles(false)
|
||||
.SetOpConnectorSize(2)
|
||||
.SetFieldDelim(',')
|
||||
.SetColumDefault(column_default_list)
|
||||
.SetColumName({"col1", "col2", "col3", "col4"});
|
||||
|
||||
Status rc = builder.Build(&op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
rc = tree->AssociateNode(op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
rc = tree->AssignRoot(op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
MS_LOG(INFO) << "Launching tree and begin iteration.";
|
||||
rc = tree->Prepare();
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
rc = tree->Launch();
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
DatasetIterator di(tree);
|
||||
TensorRow tensor_list;
|
||||
rc = di.FetchNextTensorRow(&tensor_list);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
int row_count = 0;
|
||||
while (!tensor_list.empty()) {
|
||||
// Display the tensor by calling the printer on it
|
||||
for (int i = 0; i < tensor_list.size(); i++) {
|
||||
std::ostringstream ss;
|
||||
ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl;
|
||||
MS_LOG(INFO) << "Tensor print: " << ss.str() << ".";
|
||||
}
|
||||
|
||||
rc = di.FetchNextTensorRow(&tensor_list);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
row_count++;
|
||||
}
|
||||
|
||||
ASSERT_EQ(row_count, 3);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestCSVOp, TestTotalRows) {
|
||||
std::string csv_file1 = datasets_root_path_ + "/testCSV/1.csv";
|
||||
std::string csv_file2 = datasets_root_path_ + "/testCSV/size.csv";
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
@ -46,96 +47,74 @@ std::shared_ptr<RepeatOp> Repeat(int repeat_cnt);
|
|||
|
||||
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
|
||||
|
||||
Status Create1DTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements, unsigned char *data = nullptr,
|
||||
DataType::Type data_type = DataType::DE_UINT32);
|
||||
|
||||
std::shared_ptr<MnistOp> CreateMnist(int64_t num_wrks, int64_t rows, int64_t conns, std::string path, bool shuf = false,
|
||||
std::shared_ptr<SamplerRT> sampler = nullptr) {
|
||||
std::shared_ptr<MnistOp> so;
|
||||
MnistOp::Builder builder;
|
||||
Status rc = builder.SetNumWorkers(num_wrks)
|
||||
.SetDir(path)
|
||||
|
||||
.SetOpConnectorSize(conns)
|
||||
.SetSampler(std::move(sampler))
|
||||
.Build(&so);
|
||||
return so;
|
||||
}
|
||||
|
||||
class MindDataTestMnistSampler : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestMnistSampler, TestSequentialMnistWithRepeat) {
|
||||
// Note: Mnist datasets are not included
|
||||
// as part of the build tree.
|
||||
// Download datasets and rebuild if data doesn't
|
||||
// appear in this dataset
|
||||
// Example: python tests/dataset/data/prep_data.py
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
int64_t num_samples = 10;
|
||||
int64_t start_index = 0;
|
||||
auto seq_sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
|
||||
auto op1 = CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler));
|
||||
auto op2 = Repeat(2);
|
||||
op1->set_total_repeats(2);
|
||||
op1->set_num_repeats_per_epoch(2);
|
||||
auto tree = Build({op1, op2});
|
||||
tree->Prepare();
|
||||
std::shared_ptr<Dataset> ds =
|
||||
Mnist(folder_path, "all", std::make_shared<SequentialSampler>(start_index, num_samples));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
ds = ds->Repeat(2);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
uint32_t res[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "Return code error detected during tree launch: " << rc.ToString() << ".";
|
||||
EXPECT_TRUE(false);
|
||||
} else {
|
||||
DatasetIterator di(tree);
|
||||
TensorMap tensor_map;
|
||||
ASSERT_OK(di.GetNextAsMap(&tensor_map));
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
uint64_t i = 0;
|
||||
uint32_t label = 0;
|
||||
while (tensor_map.size() != 0) {
|
||||
tensor_map["label"]->GetItemAt<uint32_t>(&label, {});
|
||||
EXPECT_TRUE(res[i % 10] == label);
|
||||
MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n";
|
||||
i++;
|
||||
ASSERT_OK(di.GetNextAsMap(&tensor_map));
|
||||
}
|
||||
EXPECT_TRUE(i == 20);
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
uint32_t label_idx;
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto image = row["image"];
|
||||
auto label = row["label"];
|
||||
// EXPECT_EQ(label, res[i % 10]);
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
std::shared_ptr<Tensor> de_label;
|
||||
ASSERT_OK(Tensor::CreateFromMSTensor(label, &de_label));
|
||||
ASSERT_OK(de_label->GetItemAt<uint32_t>(&label_idx, {}));
|
||||
MS_LOG(INFO) << "Tensor label value: " << label_idx;
|
||||
EXPECT_EQ(label_idx, res[i % 10]);
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 20);
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestMnistSampler, TestSequentialImageFolderWithRepeatBatch) {
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
int64_t num_samples = 10;
|
||||
int64_t start_index = 0;
|
||||
auto seq_sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
|
||||
auto op1 = CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler));
|
||||
auto op2 = Repeat(2);
|
||||
auto op3 = Batch(5);
|
||||
op1->set_total_repeats(2);
|
||||
op1->set_num_repeats_per_epoch(2);
|
||||
auto tree = Build({op1, op2, op3});
|
||||
tree->Prepare();
|
||||
uint32_t res[4][5] = {{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}};
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "Return code error detected during tree launch: " << rc.ToString() << ".";
|
||||
EXPECT_TRUE(false);
|
||||
} else {
|
||||
DatasetIterator di(tree);
|
||||
TensorMap tensor_map;
|
||||
ASSERT_OK(di.GetNextAsMap(&tensor_map));
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
uint64_t i = 0;
|
||||
while (tensor_map.size() != 0) {
|
||||
std::shared_ptr<Tensor> label;
|
||||
Create1DTensor(&label, 5, reinterpret_cast<unsigned char *>(res[i % 4]));
|
||||
EXPECT_TRUE((*label) == (*tensor_map["label"]));
|
||||
MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << *tensor_map["label"] << "\n";
|
||||
i++;
|
||||
ASSERT_OK(di.GetNextAsMap(&tensor_map));
|
||||
}
|
||||
EXPECT_TRUE(i == 4);
|
||||
std::shared_ptr<Dataset> ds =
|
||||
Mnist(folder_path, "all", std::make_shared<SequentialSampler>(start_index, num_samples));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
ds = ds->Repeat(2);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
ds = ds->Batch(5);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
std::vector<std::vector<uint32_t>> expected = {{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}};
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto image = row["image"];
|
||||
auto label = row["label"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
TEST_MS_LOG_MSTENSOR(INFO, "Tensor label: ", label);
|
||||
std::shared_ptr<Tensor> de_expected_label;
|
||||
ASSERT_OK(Tensor::CreateFromVector(expected[i % 4], &de_expected_label));
|
||||
mindspore::MSTensor expected_label =
|
||||
mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_expected_label));
|
||||
EXPECT_MSTENSOR_EQ(label, expected_label);
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
EXPECT_EQ(i, 4);
|
||||
iter->Stop();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue