forked from mindspore-Ecosystem/mindspore
!2143 dataset: remove storage_op c++ part
Merge pull request !2143 from ms_yan/del_storage_c++
This commit is contained in:
commit
3401e1c80b
|
@ -48,7 +48,6 @@ namespace dataset {
|
|||
using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr<DatasetOp> *);
|
||||
|
||||
static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {
|
||||
{kStorage, &DEPipeline::ParseStorageOp},
|
||||
{kShuffle, &DEPipeline::ParseShuffleOp},
|
||||
{kMindrecord, &DEPipeline::ParseMindRecordOp},
|
||||
{kMap, &DEPipeline::ParseMapOp},
|
||||
|
@ -301,70 +300,6 @@ Status DEPipeline::SetBatchParameters(const py::dict &args) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ValidateArgStorageOp(const py::dict &args) {
|
||||
// Required arguments
|
||||
if (((args.contains("dataset_files") && args["dataset_files"].is_none()) || args["schema"].is_none()) &&
|
||||
((args.contains("dataset_dir") && args["dataset_dir"].is_none()) ||
|
||||
(args["schema"].is_none() && args["schema_json_string"].is_none()))) {
|
||||
std::string err_msg = "Error: at least one of dataset_files or schema_file is missing";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParseStorageOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(ValidateArgStorageOp(args));
|
||||
std::shared_ptr<StorageOp::Builder> builder;
|
||||
if (args.contains("dataset_files") && !args["dataset_files"].is_none()) {
|
||||
builder = std::make_shared<StorageOp::Builder>();
|
||||
(void)builder->SetDatasetFileList(ToStringVector(args["dataset_files"]));
|
||||
(void)builder->SetSchemaFile(ToString(args["schema"]));
|
||||
} else if (args.contains("dataset_dir") && !args["dataset_dir"].is_none()) {
|
||||
builder = std::make_shared<StorageOp::Builder>();
|
||||
(void)builder->SetDatasetFilesDir(ToString(args["dataset_dir"]));
|
||||
if (!args["schema"].is_none()) {
|
||||
(void)builder->SetSchemaFile(ToString(args["schema"]));
|
||||
} else if (!args["schema_json_string"].is_none()) {
|
||||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
std::string s = ToString(args["schema_json_string"]);
|
||||
RETURN_IF_NOT_OK(schema->LoadSchemaString(s, std::vector<std::string>()));
|
||||
(void)builder->SetNumRows(schema->num_rows());
|
||||
(void)builder->SetSchema(std::move(schema));
|
||||
}
|
||||
}
|
||||
|
||||
// Optional arguments
|
||||
for (auto arg : args) {
|
||||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "prefetch_size") {
|
||||
(void)builder->SetOpConnectorSize(ToInt(value));
|
||||
} else if (key == "columns_list") {
|
||||
(void)builder->SetColumnsToLoad(ToStringVector(value));
|
||||
} else if (key == "distribution") {
|
||||
(void)builder->SetDataDistributionFile(ToString(value));
|
||||
} else if (key == "labels_filename") {
|
||||
(void)builder->setLabelsFileName(ToString(value));
|
||||
} else if (key == "dataset_usage") {
|
||||
(void)builder->SetDatasetUsage(ToString(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
(void)builder->SetBatchSize(temp_batch_size_);
|
||||
(void)builder->SetDropRemainder(temp_drop_remainder_);
|
||||
|
||||
std::shared_ptr<StorageOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
num_rows_ = op->num_rows();
|
||||
num_classes_ = op->num_classes();
|
||||
*ptr = op;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
||||
std::shared_ptr<ShuffleOp::Builder> builder = std::make_shared<ShuffleOp::Builder>();
|
||||
if (!args["buffer_size"].is_none()) {
|
||||
|
|
|
@ -37,7 +37,6 @@ using DsOpPtr = std::shared_ptr<DatasetOp>;
|
|||
|
||||
// enum for the dataset operator names
|
||||
enum OpName {
|
||||
kStorage = 0,
|
||||
kShuffle,
|
||||
kMindrecord,
|
||||
kBatch,
|
||||
|
@ -105,8 +104,6 @@ class DEPipeline {
|
|||
|
||||
int GetRepeatCount() const;
|
||||
|
||||
Status ParseStorageOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
@ -181,9 +178,6 @@ class DEPipeline {
|
|||
|
||||
std::unique_ptr<DatasetIterator> iterator_;
|
||||
|
||||
// Validate required args passed to storage op.
|
||||
Status ValidateArgStorageOp(const py::dict &args);
|
||||
|
||||
static Status ParsePadInfo(py::handle value, PadInfo *pad_info);
|
||||
|
||||
int batch_size_;
|
||||
|
|
|
@ -826,7 +826,6 @@ PYBIND11_MODULE(_c_dataengine, m) {
|
|||
(void)py::class_<DatasetOp, std::shared_ptr<DatasetOp>>(m, "DatasetOp");
|
||||
|
||||
(void)py::enum_<OpName>(m, "OpName", py::arithmetic())
|
||||
.value("STORAGE", OpName::kStorage)
|
||||
.value("SHUFFLE", OpName::kShuffle)
|
||||
.value("BATCH", OpName::kBatch)
|
||||
.value("BUCKETBATCH", OpName::kBucketBatch)
|
||||
|
|
|
@ -39,7 +39,6 @@
|
|||
#include "dataset/engine/datasetops/shuffle_op.h"
|
||||
#include "dataset/engine/datasetops/source/generator_op.h"
|
||||
#include "dataset/engine/datasetops/source/mindrecord_op.h"
|
||||
#include "dataset/engine/datasetops/source/storage_op.h"
|
||||
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "dataset/engine/datasetops/take_op.h"
|
||||
#include "dataset/engine/datasetops/zip_op.h"
|
||||
|
|
|
@ -17,8 +17,6 @@
|
|||
#include "dataset/util/allocator.h"
|
||||
#include "dataset/core/global_context.h"
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/engine/datasetops/source/storage_client.h"
|
||||
#include "dataset/engine/datasetops/source/tf_buffer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -26,37 +24,6 @@ namespace dataset {
|
|||
// Description: This is the main constructor that is used for making a buffer
|
||||
DataBuffer::DataBuffer(int32_t id, BufferFlags flags) : buffer_id_(id), tensor_table_(nullptr), buffer_flags_(flags) {}
|
||||
|
||||
// Name: CreateDataBuffer()
|
||||
// Description: A static factory method to create the appropriate type of derived class
|
||||
// buffer. Returns the base class reference for DataBuffer.
|
||||
Status DataBuffer::CreateDataBuffer(
|
||||
int32_t id, // In: The id for the new buffer
|
||||
std::shared_ptr<StorageClient> storage_client, // In: The storage client that is related to this buffer type
|
||||
std::unique_ptr<DataBuffer> *ptr) {
|
||||
std::unique_ptr<DataBuffer> new_data_buffer;
|
||||
try {
|
||||
DatasetType ds_type = storage_client->schema()->dataset_type();
|
||||
switch (ds_type) {
|
||||
case DatasetType::kTf: {
|
||||
// This type of buffer is for TF record data.
|
||||
// Allocate derived class version for a TF buffers
|
||||
new_data_buffer = std::make_unique<TFBuffer>(id, kDeBFlagNone, storage_client);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
std::string errMsg("Invalid buffer type");
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
}
|
||||
} catch (std::bad_alloc &e) {
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, e.what());
|
||||
} catch (std::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED(e.what());
|
||||
}
|
||||
*ptr = std::move(new_data_buffer);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Name: print()
|
||||
// Description: A function that prints info about the DataBuffer (base class version)
|
||||
void DataBuffer::Print(std::ostream &out, // In: The output stream to print to
|
||||
|
|
|
@ -29,9 +29,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Forward declares
|
||||
class StorageClient;
|
||||
|
||||
// The DataBuffer class is a base class that will represent the data for n values based
|
||||
// on a unique row id for each row of data.
|
||||
// There can be different types of DataBuffers to abstract over how the data is stored
|
||||
|
@ -53,14 +50,6 @@ class DataBuffer {
|
|||
// Destructor
|
||||
virtual ~DataBuffer();
|
||||
|
||||
// Name: CreateDataBuffer()
|
||||
// Description: A factory method to create the appropriate type of derived class
|
||||
// buffer. Returns the base class reference for DataBuffer.
|
||||
static Status CreateDataBuffer(
|
||||
int32_t id, // In: The id for the new buffer
|
||||
std::shared_ptr<StorageClient>, // In: The StorageClient is used to choose the buffer type to create
|
||||
std::unique_ptr<DataBuffer> *);
|
||||
|
||||
// Name: print()
|
||||
// Description: A function that prints info about the DataBuffer (base class version)
|
||||
virtual void Print(std::ostream &out, // In: The output stream to print to
|
||||
|
|
|
@ -53,7 +53,7 @@ class IteratorBase {
|
|||
// messages are encountered (such as eoe or eof), then an empty TensorRow is returned back.
|
||||
// @return Status - The error code return
|
||||
// @note The position of a Tensor/column might be different from the initial column order
|
||||
// in the storageOp. User must be aware that MapOp, ZipOps, and others might change
|
||||
// in corresponding Dataset Op. User must be aware that MapOp, ZipOps, and others might change
|
||||
// the column ordering.
|
||||
virtual Status FetchNextTensorRow(TensorRow *out_row);
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ class ConcatOp : public PipelineOp {
|
|||
~Builder() = default;
|
||||
|
||||
// The builder "build" method creates the final object.
|
||||
// @return shared_ptr to the new StorageOp object
|
||||
// @return shared_ptr to the new ConcatOp object
|
||||
Status Build(std::shared_ptr<ConcatOp> *);
|
||||
|
||||
private:
|
||||
|
|
|
@ -40,7 +40,7 @@ class ProjectOp : public PipelineOp {
|
|||
~Builder() = default;
|
||||
|
||||
// The builder "build" method creates the final object.
|
||||
// @return shared_ptr to the new StorageOp object.
|
||||
// @return shared_ptr to the new ProjectOp object.
|
||||
Status Build(std::shared_ptr<ProjectOp> *);
|
||||
|
||||
private:
|
||||
|
|
|
@ -67,7 +67,7 @@ class RenameOp : public PipelineOp {
|
|||
}
|
||||
|
||||
// The builder "build" method creates the ZipOp dataset Operator.
|
||||
// @return shared_ptr to the new StorageOp object
|
||||
// @return shared_ptr to the new RenameOp object
|
||||
Status Build(std::shared_ptr<RenameOp> *);
|
||||
|
||||
private:
|
||||
|
|
|
@ -42,7 +42,7 @@ class RepeatOp : public PipelineOp {
|
|||
~Builder() = default;
|
||||
|
||||
// The builder "build" method creates the final object.
|
||||
// @return shared_ptr to the new StorageOp object
|
||||
// @return shared_ptr to the new RepeatOp object
|
||||
Status Build(std::shared_ptr<RepeatOp> *);
|
||||
|
||||
private:
|
||||
|
|
|
@ -101,7 +101,7 @@ class ShuffleOp : public PipelineOp {
|
|||
}
|
||||
|
||||
// The builder "build" method creates the final object.
|
||||
// @return shared_ptr to the new StorageOp object
|
||||
// @return shared_ptr to the new ShuffleOp object
|
||||
Status Build(std::shared_ptr<ShuffleOp> *);
|
||||
|
||||
private:
|
||||
|
|
|
@ -37,7 +37,7 @@ class SkipOp : public PipelineOp {
|
|||
~Builder() = default;
|
||||
|
||||
// The builder "build" method creates the final object.
|
||||
// @return shared_ptr to the new StorageOp object
|
||||
// @return shared_ptr to the new SkipOp object
|
||||
Status Build(std::shared_ptr<SkipOp> *);
|
||||
|
||||
private:
|
||||
|
|
|
@ -5,10 +5,6 @@ add_library(engine-datasetops-source OBJECT
|
|||
generator_op.cc
|
||||
io_block.cc
|
||||
mindrecord_op.cc
|
||||
storage_client.cc
|
||||
storage_op.cc
|
||||
tf_buffer.cc
|
||||
tf_client.cc
|
||||
tf_reader_op.cc
|
||||
image_folder_op.cc
|
||||
mnist_op.cc
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
GeneratorOp::Builder::Builder() {
|
||||
// Some arguments to the StorageOp constructor have a default argument that is taken
|
||||
// Some arguments to the GeneratorOp constructor have a default argument that is taken
|
||||
// from the client config.
|
||||
build_buffer_size_ = kCfgRowsPerBuffer;
|
||||
build_op_connector_size_ = kCfgOpConnectorSize;
|
||||
|
|
|
@ -72,7 +72,7 @@ class GeneratorOp : public PipelineOp {
|
|||
}
|
||||
|
||||
// The builder "build" method creates the final object.
|
||||
// @return shared_ptr to the new StorageOp object
|
||||
// @return shared_ptr to the new GeneratorOp object
|
||||
Status Build(std::shared_ptr<GeneratorOp> *);
|
||||
|
||||
private:
|
||||
|
|
|
@ -198,7 +198,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
|
|||
// @param show_all
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
// This function is a hack! It is to return the num_class and num_rows the old storageOp does. The result
|
||||
// This function is a hack! It is to return the num_class and num_rows. The result
|
||||
// returned by this function may not be consistent with what image_folder_op is going to return
|
||||
// user this at your own risk!
|
||||
static Status CountRowsAndClasses(const std::string &path, const std::set<std::string> &exts, int64_t *num_rows,
|
||||
|
|
|
@ -44,7 +44,7 @@ using mindrecord::ShardReader;
|
|||
MindRecordOp::Builder::Builder() : build_dataset_file_({}) {
|
||||
// Some arguments to the MindRecordOp constructor have a default argument that is taken
|
||||
// from the client config.
|
||||
// The user may choose to change these values for the construction of the StorageOp by
|
||||
// The user may choose to change these values for the construction of the MindRecordOp by
|
||||
// using the various builder set methods.
|
||||
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
|
|
|
@ -45,7 +45,7 @@ class PythonSampler : public Sampler {
|
|||
Status ResetSampler() override;
|
||||
|
||||
// Op calls this to get next Buffer that contains all the sampleIds
|
||||
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
|
||||
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op
|
||||
// @param int32_t workerId - not meant to be used
|
||||
// @return - The error code return
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
|
|
@ -38,7 +38,7 @@ class RandomAccessOp {
|
|||
// @return - The error code return
|
||||
Status GetNumRowsInDataset(int64_t *num_rows) const;
|
||||
|
||||
// sampler gets label , imageIds from storageOp, this function is unique to PK
|
||||
// sampler gets label , imageIds from corresponding Dataset Op, this function is unique to PK
|
||||
// @param std::map<int64_t, std::vector<int64_t>> * map
|
||||
// @return - The error code return
|
||||
virtual Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *map) const {
|
||||
|
|
|
@ -44,7 +44,7 @@ class SequentialSampler : public Sampler {
|
|||
Status ResetSampler() override;
|
||||
|
||||
// Op calls this to get next Buffer that contains all the sampleIds
|
||||
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
|
||||
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op
|
||||
// @param int32_t workerId - not meant to be used
|
||||
// @return - The error code return
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
|
|
@ -1,190 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#define MAX_INTEGER_INT32 2147483647
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include "dataset/core/constants.h"
|
||||
#include "dataset/engine/datasetops/source/storage_client.h"
|
||||
#include "dataset/engine/datasetops/source/storage_op.h"
|
||||
#include "dataset/engine/datasetops/source/tf_client.h"
|
||||
#include "dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Name: Constructor
|
||||
// Description:
|
||||
StorageClient::StorageClient(std::unique_ptr<DataSchema> schema, // In: The schema for this storage client.
|
||||
StorageOp *store_op) // In: The StorageOp that's using this client
|
||||
: data_schema_(std::move(schema)), num_rows_in_dataset_(0), storage_op_(store_op), num_classes_(0) {}
|
||||
|
||||
// Name: Print()
|
||||
// Description: A function that prints info about the StorageClient
|
||||
// In: The output stream to print to
|
||||
void StorageClient::Print(std::ostream &out) const {
|
||||
// not much to show here folks!
|
||||
// out << "Storage client:\n";
|
||||
}
|
||||
|
||||
// This is a local-only static function to drive the switch statement for creating
|
||||
// the storage client (not a static member function)
|
||||
static Status CreateStorageClientSwitch(
|
||||
std::unique_ptr<DataSchema> schema, // In: The schema to set into the client
|
||||
StorageOp *store_op, // In: The StorageOp we are operating on
|
||||
std::shared_ptr<StorageClient> *out_client) { // Out: the created storage client
|
||||
switch (schema->dataset_type()) {
|
||||
case DatasetType::kArrow: {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
"Storage client not implemented yet for arrow dataset type.");
|
||||
}
|
||||
case DatasetType::kTf: {
|
||||
// Construct the derived class TFClient, stored as base class StorageClient
|
||||
store_op->set_rows_per_buffer(32);
|
||||
*out_client = std::make_unique<TFClient>(std::move(schema), store_op);
|
||||
break;
|
||||
}
|
||||
case DatasetType::kUnknown:
|
||||
default: {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid dataset type.");
|
||||
}
|
||||
}
|
||||
if (*out_client) {
|
||||
RETURN_IF_NOT_OK((*out_client)->Init());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Name: CreateStorageClient()
|
||||
// Description: A factory method to create the derived storage client.
|
||||
// Every dataset has a required field for the dataset type in a config
|
||||
// file. This type will determine the child class to return for the
|
||||
// type of storage client. It also creates the schema and sticks it
|
||||
// into the cache.
|
||||
Status StorageClient::CreateStorageClient(
|
||||
StorageOp *store_op, // In: A backpointer to the owning cache for this client.
|
||||
std::string dataset_schema_path, // In: The path to the schema
|
||||
std::shared_ptr<StorageClient> *out_client) { // Out: the created storage client
|
||||
// Make a new schema first. This only assigns the dataset type. It does not
|
||||
// create the columns yet.
|
||||
auto new_schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(new_schema->LoadDatasetType(dataset_schema_path));
|
||||
RETURN_IF_NOT_OK(CreateStorageClientSwitch(std::move(new_schema), store_op, out_client));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Name: CreateStorageClient()
|
||||
// Description: A factory method to create the derived storage client.
|
||||
// This creator is a user-override for the schema properties where
|
||||
// the user has input the layout of the data (typically used in testcases)
|
||||
Status StorageClient::CreateStorageClient(
|
||||
StorageOp *store_op, // In: A backpointer to the owning cache for this client.
|
||||
DatasetType in_type, // In: The type of dataset
|
||||
std::shared_ptr<StorageClient> *out_client) { // Out: the created storage client
|
||||
// The dataset type is passed in by the user. Create an empty schema with only
|
||||
// only the dataset type filled in and then create the client with it.
|
||||
auto new_schema = std::make_unique<DataSchema>();
|
||||
new_schema->set_dataset_type(in_type);
|
||||
RETURN_IF_NOT_OK(CreateStorageClientSwitch(std::move(new_schema), store_op, out_client));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Name: LoadDatasetLayout()
|
||||
// Description: There are 2 ways to define the properties of the data in the storage
|
||||
// layer: LoadDatasetLayout() and AssignDatasetLayout().
|
||||
// LoadDatasetLayout() will parse the json config file that comes with
|
||||
// the dataset.
|
||||
Status StorageClient::LoadDatasetLayout() {
|
||||
// Access the json file to populate our schema, assume the json file is accessible
|
||||
// locally.
|
||||
RETURN_IF_NOT_OK(data_schema_->LoadSchemaFile(storage_op_->schema_file(), storage_op_->columns_to_load()));
|
||||
|
||||
// The number of rows in the schema file is an optional config. For example,
|
||||
// maybe the derived storage client will know how to determine the total number
|
||||
// of rows a different way rather than having it in the schema config json file.
|
||||
// Thus, mNumRowsInDataset can still be zero and force the derived class override
|
||||
// to determine it another way.
|
||||
uint32_t num_rows = 0;
|
||||
RETURN_IF_NOT_OK(this->numRowsFromFile(num_rows));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows <= MAX_INTEGER_INT32, "numRows exceeds the boundary numRows>2147483647");
|
||||
if (num_rows_in_dataset_ == 0 || num_rows < num_rows_in_dataset_) {
|
||||
num_rows_in_dataset_ = num_rows;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Name: AssignDatasetLayout()
|
||||
// Description: There are 2 ways to define the properties of the data in the storage
|
||||
// layer: LoadDatasetLayout() and AssignDatasetLayout().
|
||||
// AssignDatasetLayout() will take input from the caller and assign that
|
||||
// info into the storage client.
|
||||
Status StorageClient::AssignDatasetLayout(uint32_t num_rows, // In: The number of rows in the dataset
|
||||
const DataSchema &schema) { // In: The schema for the dataset
|
||||
// Since this is just an assignment into the storage client, you probably won't need
|
||||
// to override this one in a derived class. First some sanity checks
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(data_schema_->dataset_type() == schema.dataset_type(),
|
||||
"Assigning a schema into StorageClient with mismatched dataset types!");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(data_schema_->NumColumns() == 0,
|
||||
"Assigning a schema into StorageClient that already has non-empty schema!");
|
||||
|
||||
// The current schema was just an empty one with only the dataset field populated.
|
||||
// Let's copy construct a new one that will be a copy of the input schema (releasing the old
|
||||
// one) and then set the number of rows that the user requested.
|
||||
data_schema_ = std::make_unique<DataSchema>(schema);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows <= MAX_INTEGER_INT32, "numRows exceeds the boundary numRows>2147483647");
|
||||
num_rows_in_dataset_ = num_rows;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Name: numRowsFromFile()
|
||||
// Description: Reads the schema json file to see if the optional numRows field has
|
||||
// been set and returns it.
|
||||
Status StorageClient::numRowsFromFile(uint32_t &num_rows) const {
|
||||
std::string schemaFile = storage_op_->schema_file();
|
||||
try {
|
||||
std::ifstream in(schemaFile);
|
||||
nlohmann::json js;
|
||||
in >> js;
|
||||
if (js.find("numRows") == js.end()) {
|
||||
num_rows = MAX_INTEGER_INT32;
|
||||
} else {
|
||||
num_rows = js.value("numRows", 0);
|
||||
}
|
||||
if (num_rows == 0) {
|
||||
std::string err_msg =
|
||||
"Storage client has not properly done dataset "
|
||||
"handshake to initialize schema and number of rows.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
}
|
||||
// Catch any exception and rethrow it as our own
|
||||
catch (const std::exception &err) {
|
||||
std::ostringstream ss;
|
||||
ss << "Schema file failed to load:\n" << err.what();
|
||||
std::string err_msg = ss.str();
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get'r function
|
||||
DataSchema *StorageClient::schema() const { return data_schema_.get(); }
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -1,128 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_STORAGE_CLIENT_H_
|
||||
#define DATASET_ENGINE_DATASETOPS_SOURCE_STORAGE_CLIENT_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "dataset/engine/data_schema.h"
|
||||
#include "dataset/engine/datasetops/source/storage_op.h"
|
||||
#include "dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// The Storage Client is the interface and base class that the StorageOp
|
||||
// will use to perform any interactions with the storage layer.
|
||||
// The different types of datasets will have different derived classes
|
||||
// under that storage client super class.
|
||||
class StorageClient {
|
||||
public:
|
||||
// Name: Constructor
|
||||
// Description:
|
||||
StorageClient(std::unique_ptr<DataSchema> schema, // In: The schema for this storage client.
|
||||
StorageOp *store_op); // In: The StorageOp that's using this client
|
||||
|
||||
// Destructor
|
||||
virtual ~StorageClient() { storage_op_ = nullptr; }
|
||||
|
||||
virtual Status Init() { return Status::OK(); }
|
||||
|
||||
// Name: CreateStorageClient()
|
||||
// Description: A factory method to create the derived storage client.
|
||||
// Every dataset has a required field for the dataset type in a config
|
||||
// file. This type will determine the child class to return for the
|
||||
// type of storage client.
|
||||
static Status CreateStorageClient(StorageOp *store_op, // In: A backpointer to the owning storage op for this client.
|
||||
std::string dataset_schema_path, // In: The path to the dataset
|
||||
std::shared_ptr<StorageClient> *out_client); // Out: the created storage client
|
||||
|
||||
// Name: CreateStorageClient()
|
||||
// Description: A factory method to create the derived storage client.
|
||||
// This creator is a user-override for the schema properties where
|
||||
// the user has input the layout of the data (typically used in testcases)
|
||||
static Status CreateStorageClient(StorageOp *store_op, // In: A backpointer to the owning cache for this client.
|
||||
DatasetType in_type, // In: The type of dataset
|
||||
std::shared_ptr<StorageClient> *out_client); // Out: the created storage client
|
||||
|
||||
// Name: Print()
|
||||
// Description: A function that prints info about the StorageClient
|
||||
virtual void Print(std::ostream &out) const; // In: The output stream to print to
|
||||
|
||||
// Provide stream operator for displaying
|
||||
friend std::ostream &operator<<(std::ostream &out, const StorageClient &storage_client) {
|
||||
storage_client.Print(out);
|
||||
return out;
|
||||
}
|
||||
|
||||
// Name: LoadDatasetLayout()
|
||||
// Description: There are 2 ways to define the properties of the data in the storage
|
||||
// layer: LoadDatasetLayout() and AssignDatasetLayout().
|
||||
// LoadDatasetLayout() will parse the json config file that comes with
|
||||
// the dataset and internally populate row counts and schema.
|
||||
virtual Status LoadDatasetLayout();
|
||||
|
||||
// Name: AssignDatasetLayout()
|
||||
// Description: There are 2 ways to define the properties of the data in the storage
|
||||
// layer: LoadDatasetLayout() and AssignDatasetLayout().
|
||||
// AssignDatasetLayout() will take input from the caller and assign that
|
||||
virtual Status AssignDatasetLayout(uint32_t num_rows, // In: The number of rows in the dataset
|
||||
const DataSchema &schema); // In: The schema for the dataset
|
||||
|
||||
// Name: Reset()
|
||||
// Description: Resets any state info inside the client back to it's initialized
|
||||
// state.
|
||||
virtual Status Reset() = 0;
|
||||
|
||||
// Name: IsMoreData
|
||||
// Description: General routine to ask if more data exists in the storage side for
|
||||
// a given buffer id.
|
||||
virtual bool IsMoreData(uint32_t id) { return true; }
|
||||
|
||||
// Name: numRowsFromFile()
|
||||
// Description: Reads the schema json file to see if the optional numRows field has
|
||||
// been set and returns it.
|
||||
Status numRowsFromFile(uint32_t &num_rows) const;
|
||||
|
||||
// Get'r functions
|
||||
DataSchema *schema() const;
|
||||
|
||||
uint32_t num_rows() const { return num_rows_in_dataset_; }
|
||||
|
||||
// Name: rows_per_buffer()
|
||||
// Description: This default version simply gives you the count of the requested
|
||||
// rows per buffer that the user defined in the storage op.
|
||||
// However, if some condition down in the storage client layers
|
||||
// could result in a buffer that has a different number of rows,
|
||||
// then the derived class can override this method to provide their
|
||||
// own implementation.
|
||||
virtual uint32_t rows_per_buffer() { return storage_op_->rows_per_buffer(); }
|
||||
|
||||
// Description: Get the label classes num. Only manifest and Imagenet dataset support this parameter
|
||||
virtual uint32_t num_classes() const { return 0; }
|
||||
|
||||
protected:
|
||||
std::unique_ptr<DataSchema> data_schema_; // The schema for the data
|
||||
uint32_t num_rows_in_dataset_; // The number of rows in the dataset
|
||||
StorageOp *storage_op_; // Back pointer to the owning storage operator.
|
||||
std::vector<std::string> col_names_;
|
||||
uint32_t num_classes_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_STORAGE_CLIENT_H_
|
|
@ -1,607 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#define MAX_INTEGER_UINT32 4294967295
|
||||
#define MAX_INTEGER_INT32 2147483647
|
||||
|
||||
#include "dataset/engine/datasetops/source/storage_client.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cstdint>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include "common/utils.h"
|
||||
#include "dataset/core/config_manager.h"
|
||||
#include "dataset/core/constants.h"
|
||||
#include "dataset/core/global_context.h"
|
||||
#include "dataset/engine/data_buffer.h"
|
||||
#include "dataset/engine/datasetops/dataset_op.h"
|
||||
#include "dataset/engine/datasetops/parallel_op.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
#include "dataset/engine/data_schema.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/util/queue.h"
|
||||
#include "dataset/engine/datasetops/source/storage_op.h"
|
||||
#include "dataset/util/task_manager.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Builder constructor. Creates the builder object.
|
||||
StorageOp::Builder::Builder()
|
||||
: build_dataset_files_dir_(""),
|
||||
build_schema_file_(""),
|
||||
build_num_rows_(0),
|
||||
build_data_distribution_file_(""),
|
||||
build_batch_size_(1),
|
||||
build_drop_remainder_(false) {
|
||||
// Some arguments to the StorageOp constructor have a default argument that is taken
|
||||
// from the client config.
|
||||
// The user may choose to change these values for the construction of the StorageOp by
|
||||
// using the various builder set methods.
|
||||
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
build_rows_per_buffer_ = cfg->rows_per_buffer();
|
||||
build_worker_connector_size_ = cfg->worker_connector_size();
|
||||
build_num_workers_ = cfg->num_parallel_workers();
|
||||
build_op_connector_size_ = cfg->op_connector_size();
|
||||
}
|
||||
|
||||
// The builder "build" method creates the final object.
|
||||
Status StorageOp::Builder::Build(std::shared_ptr<StorageOp> *ptr) {
|
||||
// There are 2 "flavours" of construction for a StorageOp:
|
||||
//
|
||||
// 1) Does a handshake with the dataset to identify row ranges and to identify
|
||||
// the schema (internally the handshake does lookup against a json file in the dataset)
|
||||
//
|
||||
// 2) The user manually creates a schema and defines the row ranges, so there is no real
|
||||
// dataset handshake.
|
||||
//
|
||||
// The decision about which style is called will depend on if the user supplied the
|
||||
// schema and row range fields.
|
||||
|
||||
const std::string dataset_schema_file("datasetSchema.json");
|
||||
if (build_schema_ != nullptr && build_num_rows_ == 0) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
"Building a StorageOp with a given schema, but the number of rows not specified!");
|
||||
}
|
||||
if (build_schema_ == nullptr && build_num_rows_ != 0) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
"Building a StorageOp with a given number of rows but schema not specified!");
|
||||
}
|
||||
if (build_dataset_files_dir_.empty() && build_dataset_file_list_.empty()) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
"Building a StorageOp that has not provided the location of the data files.");
|
||||
}
|
||||
if (!build_dataset_files_dir_.empty() && !build_dataset_file_list_.empty()) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
"Building a StorageOp that has provided conflicting location of the data files.");
|
||||
}
|
||||
|
||||
std::shared_ptr<StorageOp> new_storage_op = std::make_shared<StorageOp>(
|
||||
build_num_workers_, build_worker_connector_size_, build_rows_per_buffer_, build_op_connector_size_,
|
||||
build_columns_to_load_, build_data_distribution_file_, build_batch_size_, build_drop_remainder_);
|
||||
|
||||
// If there is no schema or number of rows given, then we go with construction method 1
|
||||
// where we need to handshake with storage client to find out what the schema (and
|
||||
// number of rows) are based on schema file.
|
||||
if (build_schema_ == nullptr && build_num_rows_ == 0) {
|
||||
if (!build_dataset_files_dir_.empty()) {
|
||||
// We have a dataset files dir, but do not have a schema file.
|
||||
// Set the default schema file to be inside the same path as the dataset files dir.
|
||||
if (build_schema_file_.empty()) {
|
||||
build_schema_file_ = build_dataset_files_dir_ + "/" + dataset_schema_file;
|
||||
}
|
||||
RETURN_IF_NOT_OK(new_storage_op->InitOp(build_dataset_files_dir_, build_schema_file_, build_labels_file_name_,
|
||||
build_dataset_usage_));
|
||||
} else {
|
||||
// dataset is provided by list of files not dir_path
|
||||
RETURN_IF_NOT_OK(new_storage_op->InitOp(build_dataset_file_list_, build_schema_file_));
|
||||
}
|
||||
} else {
|
||||
// else, the user gave us a schema and a row range, go with construction method 2, where we use
|
||||
// the user-provided schema, but we still need to identify our data files.
|
||||
RETURN_IF_NOT_OK(new_storage_op->InitOp(build_num_rows_, build_dataset_files_dir_, std::move(build_schema_),
|
||||
build_labels_file_name_, build_dataset_usage_));
|
||||
}
|
||||
|
||||
// Call the actual workhorse of the constructor
|
||||
RETURN_IF_NOT_OK(new_storage_op->init());
|
||||
*ptr = std::move(new_storage_op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StorageOp::StorageOp(int32_t num_workers, int32_t worker_connector_size, int32_t rows_per_buffer,
|
||||
int32_t op_connector_size, std::vector<std::string> columns_to_load,
|
||||
std::string data_distribution_file, int32_t batch_size, bool drop_remainder)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
worker_conn_size_(worker_connector_size),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
num_rows_(0),
|
||||
buffers_fetched_(0),
|
||||
columns_to_load_(columns_to_load),
|
||||
data_distribution_file_(data_distribution_file),
|
||||
device_num_(1),
|
||||
device_id_(0),
|
||||
shard_config_("ALL"),
|
||||
seed_(0),
|
||||
shuffle_config_(false),
|
||||
num_classes_(0),
|
||||
batch_size_(batch_size),
|
||||
drop_remainder_(drop_remainder) {}
|
||||
|
||||
// Init of the StorageOp. This is 1 of 3 init.
|
||||
// This version of the init does not take the schema in it's arguments. It must perform an
|
||||
// internal handshake with the dataset to produce the schema.
|
||||
Status StorageOp::InitOp(const std::string &dataset_files_dir, const std::string &schema_file,
|
||||
const std::string &labels_file_name, const std::string &dataset_usage) {
|
||||
dataset_files_dir_ = dataset_files_dir;
|
||||
schema_file_ = schema_file;
|
||||
labels_file_name_ = labels_file_name;
|
||||
dataset_usage_ = dataset_usage;
|
||||
|
||||
// Storage ops require the internal master/worker connector. create it here
|
||||
RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_conn_size_));
|
||||
|
||||
// Get parameter for distribution.
|
||||
RETURN_IF_NOT_OK(LoadParallelConfig());
|
||||
|
||||
// Create the storage client. This will read the json file to determine what
|
||||
// type of client we're creating.
|
||||
RETURN_IF_NOT_OK(StorageClient::CreateStorageClient(this, schema_file_, &store_client_));
|
||||
|
||||
// Perform the initial handshake with the storage client to further read the
|
||||
// dataset info to populate schema info and the number of rows in the client.
|
||||
RETURN_IF_NOT_OK(store_client_->LoadDatasetLayout());
|
||||
|
||||
// Pull out the number of rows from the client and save into the op.
|
||||
num_rows_ = store_client_->num_rows();
|
||||
num_classes_ = store_client_->num_classes();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Init of the StorageOp. This is 2 of 3 init.
|
||||
// This version of the init allows the user to input the schema and other dataset properties rather
|
||||
// than get it from the dataset itself.
|
||||
Status StorageOp::InitOp(int32_t num_rows, const std::string &dataset_files_dir,
|
||||
std::unique_ptr<DataSchema> data_schema, const std::string &labels_file_name,
|
||||
const std::string &dataset_usage) {
|
||||
num_rows_ = num_rows;
|
||||
dataset_files_dir_ = dataset_files_dir;
|
||||
labels_file_name_ = labels_file_name;
|
||||
dataset_usage_ = dataset_usage;
|
||||
|
||||
// Storage ops require the internal master/worker connector. create it here
|
||||
RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_conn_size_));
|
||||
|
||||
// Get parameter for distribution.
|
||||
RETURN_IF_NOT_OK(LoadParallelConfig());
|
||||
|
||||
// Create the storage client based on the dataset type given from the input schema.
|
||||
RETURN_IF_NOT_OK(StorageClient::CreateStorageClient(this, data_schema->dataset_type(), &store_client_));
|
||||
|
||||
// Perform the initial handshake with the storage client to initialize the schema
|
||||
// and the number of rows in the set. In this case, since the schema and the number
|
||||
// of rows is input by the user directly, it's not much of a "handshake", it's more
|
||||
// like an assign.
|
||||
RETURN_IF_NOT_OK(store_client_->AssignDatasetLayout(num_rows_, *data_schema));
|
||||
num_classes_ = store_client_->num_classes();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Init of the StorageOp. This is 3 of 3 init.
|
||||
// This version of the init does not take the schema in it's arguments. It must perform an
|
||||
// internal handshake with the dataset to produce the schema. Unlike constructor 1, it takes a
|
||||
// list of files rather than a directory.
|
||||
Status StorageOp::InitOp(const std::vector<std::string> &files_list, const std::string &schema_file) {
|
||||
dataset_file_list_ = files_list;
|
||||
schema_file_ = schema_file;
|
||||
|
||||
// Storage ops require the internal master/worker connector. create it here
|
||||
RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_conn_size_));
|
||||
|
||||
// Get parameter for distribution.
|
||||
RETURN_IF_NOT_OK(LoadParallelConfig());
|
||||
|
||||
// Create the storage client. This will read the json file to determine what
|
||||
// type of client we're creating.
|
||||
RETURN_IF_NOT_OK(StorageClient::CreateStorageClient(this, schema_file_, &store_client_));
|
||||
|
||||
// Perform the initial handshake with the storage client to further read the
|
||||
// dataset info to populate schema info and the number of rows in the client.
|
||||
RETURN_IF_NOT_OK(store_client_->LoadDatasetLayout());
|
||||
|
||||
// Pull out the number of rows from the client and save into the op.
|
||||
num_rows_ = store_client_->num_rows();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Private helper method. This one encapsulates some common construction/reset tasks and is
|
||||
// designed to be re-entrant so that you can re-init a previously used StorageOp without needing
|
||||
// to redo the storage client handshake.
|
||||
Status StorageOp::init() {
|
||||
// First a sanity check to make sure the StorageClient initialization has done the proper
|
||||
// handshake and initialized both the schema and the number of rows for the dataset.
|
||||
const DataSchema *the_schema = store_client_->schema();
|
||||
if (the_schema->NumColumns() == 0 || num_rows_ == 0) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
"Storage client did not run handshake to init schema and number of rows.");
|
||||
}
|
||||
|
||||
// Now that we have schema, generate the column name map (base class field)
|
||||
for (int32_t i = 0; i < the_schema->NumColumns(); ++i) {
|
||||
column_name_id_map_[the_schema->column(i).name()] = i;
|
||||
}
|
||||
|
||||
// If the data buffer vector is not empty, then we may be redoing a scan again after a repeat.
|
||||
// In such a case, we have vector of nullptrs that used to hold the buffers. get rid of this
|
||||
// so we can reuse the vector.
|
||||
if (!data_buffers_.empty()) {
|
||||
data_buffers_.clear();
|
||||
}
|
||||
int32_t buffers_needed;
|
||||
|
||||
// We have our range of row id's, but we must carve this up into buffers now so that
|
||||
// each buffer holds a subset of the overall range.
|
||||
// Instantiate the buffers now, but this does not actually drive a load of actual
|
||||
// data at this point.
|
||||
|
||||
// First, compute how many buffers we would need to accomplish rowsPerBuffer
|
||||
buffers_needed = this->num_rows() / rows_per_buffer_;
|
||||
|
||||
// If an extra partial buffer is needed, adjust for that.
|
||||
if (this->num_rows() % rows_per_buffer_ != 0) {
|
||||
buffers_needed++;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Master: Initializing StorageOp. Dataset files dir: " << dataset_files_dir_ << " Dataset type: "
|
||||
<< static_cast<std::underlying_type<DatasetType>::type>(store_client_->schema()->dataset_type())
|
||||
<< " Dataset schema file: " << schema_file_ << " Number of rows: " << num_rows_
|
||||
<< " Rows per buffer: " << rows_per_buffer_ << " Num buffers (computed): " << buffers_needed
|
||||
<< " Number of workers: " << num_workers_ << ".";
|
||||
|
||||
// Next, create each buffer in a loop.
|
||||
int32_t buff_id = 0;
|
||||
for (buff_id = 0; buff_id < buffers_needed; buff_id++) {
|
||||
// Create a new data buffer as a base class pointer, using the factory method from
|
||||
// DataBuffer class
|
||||
std::unique_ptr<DataBuffer> new_data_buffer;
|
||||
RETURN_IF_NOT_OK(DataBuffer::CreateDataBuffer(buff_id, store_client_, &new_data_buffer));
|
||||
|
||||
// Insert the buffer into our vector
|
||||
data_buffers_.push_back(std::move(new_data_buffer));
|
||||
}
|
||||
|
||||
// Instantiate the action queues. If this was a re-entrant call then these already exist.
|
||||
// We cannot drop and recreate them because there are threads waiting on them currently.
|
||||
// They should be empty anyway in a reset codepath
|
||||
if (action_queue_.empty()) {
|
||||
// The max size of these queues should ensure they will never get full and they support
|
||||
// precisely the amount of data that we know they will hold (the total number of buffers).
|
||||
// There needs to be one queue for each worker, to support the Connector design for how
|
||||
// data will be fetched and pushed into a Connector in parallel.
|
||||
//
|
||||
// Say the total buffers is 5, and we have 2 workers.
|
||||
// To support this, we'd need 1 queue of size 2 and the other of size 3.
|
||||
// For simplicity, we'll make both of them 3 so they are the same size.
|
||||
int32_t action_queue_size = (buffers_needed / num_workers_) + 1;
|
||||
for (int32_t i = 0; i < num_workers_; ++i) {
|
||||
auto new_queue = std::make_unique<Queue<int32_t>>(action_queue_size);
|
||||
action_queue_.push_back(std::move(new_queue));
|
||||
}
|
||||
}
|
||||
|
||||
// Extract the list of buffer id's from the vector and use this as our starting action
|
||||
// queue of buffers.
|
||||
RETURN_IF_NOT_OK(this->FillActionQueue(false));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Destructor
|
||||
StorageOp::~StorageOp() {}
|
||||
|
||||
// A print method typically used for debugging
|
||||
void StorageOp::Print(std::ostream &out, bool show_all) const {
|
||||
// Always show the id and name as first line regardless if this summary or detailed print
|
||||
out << "(" << std::setw(2) << operator_id_ << ") <StorageOp>:";
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op
|
||||
out << "\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nDetailed operator printing has not been implemented for this op.\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Private helper method. This one posts a control indicator for each worker thread to consume
|
||||
// from the action queue. When the worker pops this msg, it will shut itself down gracefully.
|
||||
Status StorageOp::PostEndOfData() {
|
||||
MS_LOG(DEBUG) << "Master: Processed all of the buffers. Send end-of-data message to workers.";
|
||||
|
||||
// For each worker we add the message so that they can all get the memo
|
||||
for (int32_t i = 0; i < num_workers_; ++i) {
|
||||
RETURN_IF_NOT_OK(action_queue_[i]->Add(kEndOfActions));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Private helper method. This one populates the action queue with the list of buffer ids.
|
||||
Status StorageOp::FillActionQueue(bool randomize) {
|
||||
// We only support adding the new list of id's to the queue if we are sure the old list
|
||||
// of actions is already done. This might change in the future though
|
||||
for (int32_t i = 0; i < num_workers_; ++i) {
|
||||
if (!(action_queue_[i]->empty())) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
"Attempt to get buffer id's into a queue, but the queue not empty!");
|
||||
}
|
||||
}
|
||||
if (!data_buffers_.empty()) {
|
||||
// Add buffer id's to the queue. Buffer id's in our vector are just numbers from 0 up, so
|
||||
// basically just a list of consecutive numbers starting from 0 (incremented by 1).
|
||||
// If randomize is requested, the list of id's will be jumbled up (so not consecutive
|
||||
// order)
|
||||
if (!randomize) {
|
||||
// Round robin of filling each worker with the buffer id's
|
||||
int32_t curr_worker = 0;
|
||||
for (int32_t i = 0; i < data_buffers_.size(); ++i) {
|
||||
RETURN_IF_NOT_OK(action_queue_[curr_worker]->Add(i));
|
||||
curr_worker++;
|
||||
if (curr_worker == num_workers_) {
|
||||
curr_worker = 0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::vector<int32_t> random_ids;
|
||||
int32_t i;
|
||||
for (i = 0; i < data_buffers_.size(); ++i) {
|
||||
random_ids.push_back(i);
|
||||
}
|
||||
uint32_t seed = std::chrono::system_clock::now().time_since_epoch().count();
|
||||
std::shuffle(random_ids.begin(), random_ids.end(), std::default_random_engine(seed));
|
||||
|
||||
// Round robin of filling each worker with the buffer id's from randomized list
|
||||
int32_t curr_worker = 0;
|
||||
for (i = 0; i < random_ids.size(); ++i) {
|
||||
RETURN_IF_NOT_OK(action_queue_[curr_worker]->Add(random_ids[i]));
|
||||
curr_worker++;
|
||||
if (curr_worker == num_workers_) {
|
||||
curr_worker = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The entry point code for when workers are launched.
|
||||
// Given the input bufferId, it returns a shared_ptr to that buffer back to you by driving a
|
||||
// load operation. This function is intended to be run by worker threads, when they are
|
||||
// populating the memory with the actual data of the buffer.
|
||||
Status StorageOp::GetBuffer(int32_t buffer_id, std::unique_ptr<DataBuffer> *ptr) {
|
||||
if (!data_buffers_.empty()) {
|
||||
if (static_cast<size_t>(buffer_id) >= data_buffers_.size()) {
|
||||
std::ostringstream ss;
|
||||
ss << "Error. Buffer id " << buffer_id << " is out of range.";
|
||||
std::string err_msg = ss.str();
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
// execute a load operation to fill this buffer (may result in call to storage layers)
|
||||
RETURN_IF_NOT_OK(data_buffers_[buffer_id]->Load());
|
||||
|
||||
// Return the buffer
|
||||
// Important: The share pointer remains counted for the caller as well as locally in the
|
||||
// mDataBuffers array. Later when the buffer is sent on it's way up the pipeline, the
|
||||
// shared_ptr in the array will be reset so that the StorageOp will not hang on to old
|
||||
// buffers that it has already passed up the pipeline.
|
||||
*ptr = std::move(data_buffers_[buffer_id]);
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Requested to get a buffer from an empty cache.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Class functor operator () override.
|
||||
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work
|
||||
Status StorageOp::operator()() {
|
||||
// Before we enter our master loop, kick off our workers and assign them to
|
||||
// use the StorageOp worker entry code.
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&StorageOp::WorkerEntry, this, std::placeholders::_1)));
|
||||
// Handshake with TaskManager to synchronize thread creation
|
||||
TaskManager::FindMe()->Post();
|
||||
int32_t num_buffers_to_fetch = data_buffers_.size();
|
||||
|
||||
// The storage op is the bottom node in the tree, so it does not listen to an input
|
||||
// queue from an operator below us. Instead, we'll will read from the internal queue
|
||||
// that our workers produce into, and then push that into output queue.
|
||||
bool done = false;
|
||||
std::unique_ptr<DataBuffer> fetched_buffer;
|
||||
while (!done) {
|
||||
// Get the next buffer. We are single thread master so thread id hard coded to 0
|
||||
// on the connector pop. Count this buffer towards our count, and then push
|
||||
// it up to the output connector.
|
||||
RETURN_IF_NOT_OK(worker_connector_->PopWithRetry(0, &fetched_buffer));
|
||||
buffers_fetched_++;
|
||||
int32_t buffer_id = fetched_buffer->id();
|
||||
|
||||
if (buffers_fetched_ == 1) {
|
||||
num_buffers_to_fetch = static_cast<int32_t>(data_buffers_.size());
|
||||
}
|
||||
|
||||
// There should be 2 holders of this buffer currently. We have one in the mDataBuffers
|
||||
// table, and then ourselves right now with fetchedBuffer.
|
||||
// Reduce the shared_ptr ref count of this buffer by removing it from the mDataBuffers
|
||||
// table first before we push the buffer to output connector.
|
||||
data_buffers_[buffer_id].reset();
|
||||
MS_LOG(DEBUG) << "StorageOp master: Consumed buffer " << buffer_id << " from internal worker connector.";
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer)));
|
||||
MS_LOG(DEBUG) << "StorageOp master: pushed buffer " << buffer_id << " to output connector.";
|
||||
|
||||
// Now, check our loop exit conditions and perform appropriate end of data handling if
|
||||
// we've reached the end of our scan.
|
||||
if (buffers_fetched_ == num_buffers_to_fetch) {
|
||||
MS_LOG(DEBUG) << "StorageOp master: Reached end of data.";
|
||||
|
||||
// If we are not inside of a Repeat path in the tree, or we are in a repeat path but
|
||||
// this was our last repeat, then we do a full quit here with eof control message.
|
||||
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
|
||||
// Post the control message to tell the workers to stop waiting on action queue
|
||||
// because we are done!
|
||||
RETURN_IF_NOT_OK(this->PostEndOfData());
|
||||
std::unique_ptr<DataBuffer> eoeBuffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoeBuffer)));
|
||||
MS_LOG(DEBUG) << "StorageOp master: Flow end-of-data eof message.";
|
||||
std::unique_ptr<DataBuffer> eofBuffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eofBuffer)));
|
||||
MS_LOG(DEBUG) << "StorageOp master: Main execution loop complete.";
|
||||
done = true; // while loop exit
|
||||
} else {
|
||||
// We are in a repeat path and it's not the last repeat.
|
||||
// Flow an end-of-epoch control message up the pipeline.
|
||||
// RepeatOp above us somewhere in the tree will re-init us with the data to fetch again
|
||||
// once it gets the end-of-epoch message.
|
||||
MS_LOG(DEBUG) << "StorageOp master: Flow end-of-epoch eoe message.";
|
||||
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
|
||||
|
||||
// reset our buffer count and go to loop again.
|
||||
buffers_fetched_ = 0;
|
||||
|
||||
// This is a bit of a cheat. Only the repeat op should perform resetting actions
|
||||
// against us (currently). However, if we go to block/wait on the worker_connector_
|
||||
// right now before the reset is done (driven from the repeat op), then we end
|
||||
// up using stale connector index info and blocking on the wrong thing, causing
|
||||
// invalid order during the next epoch.
|
||||
// For now then, do a quick reset of just the connector queue so that we block
|
||||
// at a safe starting point in the connector.
|
||||
worker_connector_->Reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The entry point code for when workers are launched.
|
||||
Status StorageOp::WorkerEntry(int32_t worker_id) {
|
||||
int32_t next_action_id = 0;
|
||||
MS_LOG(DEBUG) << "Worker: StorageOp worker entry point.";
|
||||
|
||||
// Handshake with TaskManager to synchronize the creation
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
// While there is still some actions to perform
|
||||
RETURN_IF_NOT_OK(action_queue_[worker_id]->PopFront(&next_action_id));
|
||||
while (next_action_id != kEndOfActions) {
|
||||
// Drive a load of this buffer and get a pointer to the buffer after it's loaded in
|
||||
std::unique_ptr<DataBuffer> dB;
|
||||
RETURN_IF_NOT_OK(this->GetBuffer(next_action_id, &dB));
|
||||
MS_LOG(DEBUG) << "Worker: Loaded buffer " << next_action_id << ".";
|
||||
|
||||
// Add the buffer to the internal queue for master to consume from later.
|
||||
// This could end up blocking if the queue is full in which case it waits here
|
||||
// until the master can drain a buffer off the queue.
|
||||
RETURN_IF_NOT_OK(worker_connector_->Add(worker_id, std::move(dB)));
|
||||
MS_LOG(DEBUG) << "Worker: Pushed buffer " << next_action_id << " to internal worker connector.";
|
||||
|
||||
// Get the next action id and loop
|
||||
RETURN_IF_NOT_OK(action_queue_[worker_id]->PopFront(&next_action_id));
|
||||
}
|
||||
MS_LOG(DEBUG) << "Worker: Received end-of-data message. Worker complete.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const DataSchema *StorageOp::schema() const { return store_client_->schema(); }
|
||||
|
||||
// Overrides base class reset method. When an operator does a reset, it cleans up any state
|
||||
// info from it's previous execution and then initializes itself so that it can be executed
|
||||
// again.
|
||||
Status StorageOp::Reset() {
|
||||
RETURN_IF_NOT_OK(ParallelOp::Reset()); // Call our super class reset first.
|
||||
|
||||
// We do not need to redo the handshake with the storage client, since that
|
||||
// info should be the same as the last time. However there may be stale
|
||||
// state info in the client from the last execution. The client provides
|
||||
// a reset method as well to re-initialize.
|
||||
RETURN_IF_NOT_OK(store_client_->Reset());
|
||||
|
||||
// init method is re-entrant and will refresh everything.
|
||||
RETURN_IF_NOT_OK(this->init());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Name: LoadParallelConfig
|
||||
// Description: Load parallel config info from a specific config file. In multi-P cases (or single-P cases), we
|
||||
// need to know deviceID, rank, device number, shard mode
|
||||
// , shuffle (or not) and seed to prepare to scatter files.
|
||||
Status StorageOp::LoadParallelConfig() {
|
||||
if (data_distribution_file_ == "") {
|
||||
return Status::OK();
|
||||
}
|
||||
try {
|
||||
std::ifstream in(data_distribution_file_);
|
||||
nlohmann::json js;
|
||||
in >> js;
|
||||
device_num_ = js.value("deviceNum", 0);
|
||||
device_id_ = js.value("deviceId", 0);
|
||||
if (device_num_ == 0 || device_num_ > MAX_INTEGER_INT32) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid deviceNum");
|
||||
}
|
||||
if (device_id_ > MAX_INTEGER_INT32 || device_id_ >= device_num_) {
|
||||
MS_LOG(DEBUG) << "In parallel config file " << data_distribution_file_ << ", wrong deviceID provided.";
|
||||
RETURN_STATUS_UNEXPECTED("Invalid deviceId");
|
||||
}
|
||||
shard_config_ = js.value("shardConfig", "");
|
||||
if (shard_config_ != "ALL" && shard_config_ != "UNIQUE" && shard_config_ != "RANDOM") {
|
||||
MS_LOG(DEBUG) << "In parallel config file " << data_distribution_file_ << " wrong mShardConfig provided.";
|
||||
RETURN_STATUS_UNEXPECTED("Invalid shardConfig");
|
||||
}
|
||||
std::string shuffle_str = js.value("shuffle", "");
|
||||
if (shuffle_str == "ON") {
|
||||
shuffle_config_ = true;
|
||||
} else if (shuffle_str == "OFF") {
|
||||
shuffle_config_ = false;
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "In parallel config file " << data_distribution_file_
|
||||
<< ", shuffle config is wrong: it's not ON or OFF";
|
||||
RETURN_STATUS_UNEXPECTED("Invalid shuffle option");
|
||||
}
|
||||
seed_ = js.value("seed", 0);
|
||||
if (seed_ > MAX_INTEGER_UINT32) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid seed");
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Load parallel config failed");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -1,389 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_STORAGE_OP_H_
|
||||
#define DATASET_ENGINE_DATASETOPS_SOURCE_STORAGE_OP_H_
|
||||
|
||||
#include <condition_variable>
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "dataset/engine/data_schema.h"
|
||||
#include "dataset/engine/datasetops/parallel_op.h"
|
||||
#include "dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Forward declares
|
||||
template <typename T>
|
||||
class Queue;
|
||||
|
||||
// A type for a container of DataBuffer shared_ptr's
|
||||
using DataBuffers = std::vector<std::unique_ptr<DataBuffer>>;
|
||||
|
||||
// A type for the queue of buffer id's for workers to fetch.
|
||||
using ActionQueue = std::vector<std::unique_ptr<Queue<int32_t>>>;
|
||||
|
||||
// Forward declare
|
||||
class DataBuffer;
|
||||
|
||||
class StorageClient;
|
||||
|
||||
class StorageOp : public ParallelOp {
|
||||
public:
|
||||
// The nested builder class inside of the StorageOp is used to help manage all of the arguments
|
||||
// for constructing it. Use the builder by setting each argument with the provided set methods,
|
||||
// and then finally call the build method to execute the actual construction.
|
||||
class Builder {
|
||||
public:
|
||||
// Builder constructor. Creates the builder object.
|
||||
// @note No default args
|
||||
// @return This is a constructor.
|
||||
Builder();
|
||||
|
||||
// Default destructor
|
||||
~Builder() = default;
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetNumRows(int num_rows) {
|
||||
build_num_rows_ = num_rows;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetRowsPerBuffer(int rows_per_buffer) {
|
||||
build_rows_per_buffer_ = rows_per_buffer;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetSchema(std::unique_ptr<DataSchema> schema) {
|
||||
build_schema_ = std::move(schema);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetNumWorkers(int32_t num_workers) {
|
||||
build_num_workers_ = num_workers;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetWorkerConnectorSize(int32_t connector_size) {
|
||||
build_worker_connector_size_ = connector_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetOpConnectorSize(int32_t connector_size) {
|
||||
build_op_connector_size_ = connector_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetSchemaDir(const std::string &schema_dir) {
|
||||
build_schema_file_ = schema_dir + "/datasetSchema.json";
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetSchemaFile(const std::string &schema_file) {
|
||||
build_schema_file_ = schema_file;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetDatasetFilesDir(const std::string &files_dir) {
|
||||
build_dataset_files_dir_ = files_dir;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetDatasetFileList(const std::vector<std::string> &file_list) {
|
||||
build_dataset_file_list_ = file_list;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetColumnsToLoad(const std::vector<std::string> &columns) {
|
||||
build_columns_to_load_ = columns;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetDataDistributionFile(const std::string &data_distribution_file) {
|
||||
build_data_distribution_file_ = data_distribution_file;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &setLabelsFileName(const std::string &labels_file_name) {
|
||||
build_labels_file_name_ = labels_file_name;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetDatasetUsage(const std::string &dataset_usage) {
|
||||
build_dataset_usage_ = dataset_usage;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetBatchSize(int32_t batch_size) {
|
||||
build_batch_size_ = batch_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetDropRemainder(bool drop_remainder) {
|
||||
build_drop_remainder_ = drop_remainder;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// The builder "build" method creates the final object.
|
||||
// @param shared_ptr to the new StorageOp object
|
||||
// @return Status - The error code return
|
||||
Status Build(std::shared_ptr<StorageOp> *);
|
||||
|
||||
private:
|
||||
// The builder saves all StorageOp construction arguments internally.
|
||||
// The following are the arguments.
|
||||
std::string build_dataset_files_dir_;
|
||||
std::string build_schema_file_;
|
||||
int32_t build_num_rows_;
|
||||
std::string build_data_distribution_file_;
|
||||
int32_t build_rows_per_buffer_;
|
||||
int32_t build_worker_connector_size_;
|
||||
int32_t build_num_workers_;
|
||||
int32_t build_op_connector_size_;
|
||||
std::unique_ptr<DataSchema> build_schema_;
|
||||
std::vector<std::string> build_dataset_file_list_;
|
||||
std::vector<std::string> build_columns_to_load_;
|
||||
std::string build_labels_file_name_;
|
||||
std::string build_dataset_usage_;
|
||||
int32_t build_batch_size_;
|
||||
bool build_drop_remainder_;
|
||||
};
|
||||
|
||||
// Constructor of the StorageOp.
|
||||
// @note The builder class should be used to call it
|
||||
// @param num_workers - The number of workers for the op
|
||||
// @param worker_connector_size - The internal connector size between workers and master
|
||||
// @param rows_per_buffer - The requested number of rows per buffer
|
||||
// @param op_connector_size - The output connector queue size
|
||||
// @param columns_to_load - The list of columns to use (column name)
|
||||
StorageOp(int32_t num_workers, int32_t worker_connector_size, int32_t rows_per_buffer, int32_t op_connector_size,
|
||||
std::vector<std::string> columns_to_load, std::string data_distribution_file, int32_t batch_size,
|
||||
bool drop_remainder);
|
||||
|
||||
// Init the StorageOp. This is 1 of 3 init.
|
||||
// This version of the init does not take the schema in it's arguments. It must perform an
|
||||
// internal handshake with the dataset to produce the schema.
|
||||
// @note The builder class should be used to call it
|
||||
// @param dataset_files_dir - The directory that has the dataset files
|
||||
// @param schema_file - The schema file for providing column info
|
||||
Status InitOp(const std::string &dataset_files_dir, const std::string &schema_file,
|
||||
const std::string &labels_file_name, const std::string &dataset_usage);
|
||||
|
||||
// Init the StorageOp. This is 2 of 3 init.
|
||||
// This version of the init allows the user to input the schema and other dataset properties rather
|
||||
// than get it from the dataset itself.
|
||||
// @note The builder class should be used to call it
|
||||
// @param num_rows - The number of rows in the dataset
|
||||
// @param dataset_files_dir - The directory that has the dataset files
|
||||
// @param data_schema - The schema to use
|
||||
Status InitOp(int32_t num_rows, const std::string &dataset_files_dir, std::unique_ptr<DataSchema> data_schema,
|
||||
const std::string &labels_file_name, const std::string &dataset_usage);
|
||||
|
||||
// Init the StorageOp. This is 3 of 3 init.
|
||||
// This version of the init does not take the schema in it's arguments. It must perform an
|
||||
// internal handshake with the dataset to produce the schema. Unlike constructor 1, it takes a
|
||||
// list of files rather than a directory.
|
||||
// @note The builder class should be used to call it
|
||||
// @param files_list - The list of files to use for the dataset
|
||||
// @param schema_file - The schema file for providing column info
|
||||
Status InitOp(const std::vector<std::string> &files_list, const std::string &schema_file);
|
||||
|
||||
// Destructor
|
||||
~StorageOp();
|
||||
|
||||
// A print method typically used for debugging
|
||||
// @param out - The output stream to write output to
|
||||
// @param show_all - A bool to control if you want to show all info or just a summary
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
// << Stream output operator overload
|
||||
// @notes This allows you to write the debug print info using stream operators
|
||||
// @param out - reference to the output stream being overloaded
|
||||
// @param storage_op - reference to the StorageOp to display
|
||||
// @return - the output stream must be returned
|
||||
friend std::ostream &operator<<(std::ostream &out, const StorageOp &storage_op) {
|
||||
storage_op.Print(out, false);
|
||||
return out;
|
||||
}
|
||||
|
||||
// Class functor operator () override.
|
||||
// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work.
|
||||
// @return Status - The error code return
|
||||
Status operator()() override;
|
||||
|
||||
// The entry point code for when workers are launched.
|
||||
// @param worker_id - The worker id
|
||||
// @return Status - The error code return
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
// The entry point code for when workers are launched.
|
||||
// Given the input bufferId, it returns a shared_ptr to that buffer back to you by driving a
|
||||
// load operation. This function is intended to be run by worker threads, when they are
|
||||
// populating the memory with the actual data of the buffer.
|
||||
// @param buffer_id - The buffer id to get.
|
||||
// @param ptr - Pointer to shared_ptr to the buffer that was loaded in.
|
||||
// @return Status - The error code return
|
||||
Status GetBuffer(int32_t buffer_id, std::unique_ptr<DataBuffer> *ptr);
|
||||
|
||||
// Overrides base class reset method. When an operator does a reset, it cleans up any state
|
||||
// info from it's previous execution and then initializes itself so that it can be executed
|
||||
// again.
|
||||
// @return Status - The error code return
|
||||
Status Reset() override;
|
||||
|
||||
// Getter method
|
||||
int32_t num_rows() const { return num_rows_; }
|
||||
|
||||
// Setter method
|
||||
void set_num_rows(int32_t num_rows) { num_rows_ = num_rows; }
|
||||
|
||||
// Getter method
|
||||
int32_t rows_per_buffer() const { return rows_per_buffer_; }
|
||||
|
||||
// Setter method
|
||||
void set_rows_per_buffer(int32_t rows_per_buffer) { rows_per_buffer_ = rows_per_buffer; }
|
||||
|
||||
// Getter method
|
||||
std::string dataset_files_dir() const { return dataset_files_dir_; }
|
||||
|
||||
// Getter method
|
||||
std::vector<std::string> dataset_file_list() const { return dataset_file_list_; }
|
||||
|
||||
// Getter method
|
||||
std::string schema_file() const { return schema_file_; }
|
||||
|
||||
// Getter method
|
||||
const DataSchema *schema() const;
|
||||
|
||||
// Getter method
|
||||
const std::vector<std::string> columns_to_load() const { return columns_to_load_; }
|
||||
|
||||
// Getter method
|
||||
std::string data_distribution_file() const { return data_distribution_file_; }
|
||||
|
||||
// Getter method
|
||||
int32_t device_num() const { return device_num_; }
|
||||
|
||||
// Getter method
|
||||
int32_t device_id() const { return device_id_; }
|
||||
|
||||
// Getter method
|
||||
std::string shard_config() const { return shard_config_; }
|
||||
|
||||
// Getter method
|
||||
uint32_t seed() const { return seed_; }
|
||||
|
||||
// Getter method
|
||||
bool shuffle_config() const { return shuffle_config_; }
|
||||
|
||||
// Getter method
|
||||
int32_t num_classes() const { return num_classes_; }
|
||||
|
||||
// Getter method
|
||||
std::string labels_file_name() const { return labels_file_name_; }
|
||||
|
||||
// Getter method
|
||||
std::string dataset_usage() const { return dataset_usage_; }
|
||||
|
||||
// Getter method
|
||||
int32_t batch_size() const { return batch_size_; }
|
||||
|
||||
// Getter method
|
||||
bool drop_remainder() const { return drop_remainder_; }
|
||||
|
||||
private:
|
||||
// Private helper method. This one populates the action queue with the list of buffer ids.
|
||||
// @param randomize - T/F if the id's in the action queue should be randomized or sequential.
|
||||
Status FillActionQueue(bool randomize);
|
||||
|
||||
// Private helper method. This one encapsulates some common construction/reset tasks and is
|
||||
// designed to be re-entrant so that you can re-init a previously used StorageOp without needing
|
||||
// to redo the storage client handshake.
|
||||
// @return Status - The error code return
|
||||
Status init();
|
||||
|
||||
// Private helper method. This one posts a control indicator for each worker thread to consume
|
||||
// from the action queue. When the worker pops this msg, it will shut itself down gracefully.
|
||||
// @return Status - The error code return
|
||||
Status PostEndOfData();
|
||||
|
||||
Status LoadParallelConfig();
|
||||
|
||||
DataBuffers data_buffers_; // A vector of pointers to buffers
|
||||
std::shared_ptr<StorageClient> store_client_; // The client for interacting with storage
|
||||
ActionQueue action_queue_; // The queues of buffer id's for workers to fetch.
|
||||
int32_t worker_conn_size_; // connector size for internal worker queue
|
||||
int32_t rows_per_buffer_; // The number of requested rows per buffer.
|
||||
int32_t num_rows_; // One more than the last row id in the range for this cache
|
||||
std::string dataset_files_dir_; // The path for the dataset files
|
||||
std::vector<std::string> dataset_file_list_; // List of paths to files for the dataset
|
||||
int32_t buffers_fetched_; // Counter for the buffers that were fetched
|
||||
std::string schema_file_; // Path to the schema json file
|
||||
std::vector<std::string> columns_to_load_; // Columns to load from dataset
|
||||
std::string data_distribution_file_; // Distribution configuration file
|
||||
int32_t device_num_; // All device number
|
||||
int32_t device_id_; // Device id
|
||||
std::string shard_config_; // ALL UNIQUE RANDOM
|
||||
uint32_t seed_; // Used for shuffle
|
||||
bool shuffle_config_; // True or false
|
||||
std::string labels_file_name_; // File name of labels
|
||||
int32_t num_classes_; // Label class number
|
||||
std::string dataset_usage_; // train/eval/inference
|
||||
int32_t batch_size_;
|
||||
bool drop_remainder_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_STORAGE_OP_H_
|
|
@ -1,326 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/engine/datasetops/source/tf_buffer.h"
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "common/utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
#include "dataset/engine/datasetops/source/tf_client.h"
|
||||
#include "dataset/core/data_type.h"
|
||||
#include "dataset/engine/datasetops/source/storage_client.h"
|
||||
#include "dataset/engine/data_schema.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// constructor
|
||||
TFBuffer::TFBuffer(
|
||||
uint32_t id, // In: The id for this buffer
|
||||
BufferFlags flags, // In: The flags for this buffer
|
||||
const std::shared_ptr<StorageClient> &storage_client) // In: Storage client that is related to this buffer type
|
||||
: DataBuffer(id, flags), storage_client_(storage_client) {}
|
||||
|
||||
// destructor
|
||||
TFBuffer::~TFBuffer() {}
|
||||
|
||||
// Name: print()
|
||||
// Description: A function that prints info
|
||||
void TFBuffer::Print(std::ostream &out, // In: The output stream to print to
|
||||
bool show_all) const { // In: T/F if it should print everything
|
||||
out << "TFBuffer print\n";
|
||||
|
||||
// Call base class printer
|
||||
DataBuffer::Print(out, show_all);
|
||||
}
|
||||
|
||||
// Name: load()
|
||||
// Description: populates the DataBuffer with data
|
||||
// Overrides base-class method.
|
||||
Status TFBuffer::Load() {
|
||||
const DataSchema *the_schema = storage_client_->schema();
|
||||
uint32_t num_columns = the_schema->NumColumns();
|
||||
uint32_t num_rows_requested = storage_client_->rows_per_buffer();
|
||||
uint32_t remaining_rows = storage_client_->num_rows() > buffer_id_ * storage_client_->rows_per_buffer()
|
||||
? storage_client_->num_rows() - buffer_id_ * storage_client_->rows_per_buffer()
|
||||
: 0;
|
||||
if (remaining_rows < num_rows_requested) {
|
||||
num_rows_requested = remaining_rows;
|
||||
}
|
||||
|
||||
// Construct the Tensor table for this buffer.
|
||||
tensor_table_ = std::make_unique<TensorQTable>();
|
||||
|
||||
// At each position in the tensor table, instantiate the shared pointer to it's Tensor.
|
||||
uint32_t row = 0;
|
||||
while (row < num_rows_requested && (cur_reader_.peek() != EOF || storage_client_->IsMoreData(buffer_id_))) {
|
||||
TensorRow new_row;
|
||||
|
||||
// Read the data from storage into a tf_file format
|
||||
dataengine::Example tf_file;
|
||||
RETURN_IF_NOT_OK(ParseSingleExample(&tf_file));
|
||||
for (uint32_t col = 0; col < num_columns; ++col) {
|
||||
std::shared_ptr<Tensor> new_t;
|
||||
const ColDescriptor current_col = the_schema->column(col);
|
||||
const dataengine::Features &example_features = tf_file.features();
|
||||
const google::protobuf::Map<std::string, dataengine::Feature> &feature_map = example_features.feature();
|
||||
const dataengine::Feature &column_values_list = feature_map.at(current_col.name());
|
||||
const dataengine::Feature::KindCase column_list_type = column_values_list.kind_case();
|
||||
RETURN_IF_NOT_OK(LoadFeature(column_list_type, column_values_list, current_col, &new_t));
|
||||
|
||||
// Add the column to the current tensor row
|
||||
new_row.push_back(std::move(new_t));
|
||||
}
|
||||
|
||||
// Add the new row of tensors to the end of our tensor table
|
||||
tensor_table_->push_back(new_row);
|
||||
row++;
|
||||
}
|
||||
cur_reader_.close();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Name: ParseSingleExample()
|
||||
// Description: Drives the calls to TFClient for fetching the tf_file info from
|
||||
// the tf_file files. Returns a single row of data from the tf_file
|
||||
// files.
|
||||
Status TFBuffer::ParseSingleExample(dataengine::Example *ptr) {
|
||||
if (cur_reader_.peek() == EOF) {
|
||||
auto client = std::dynamic_pointer_cast<TFClient>(storage_client_);
|
||||
if (client == nullptr) {
|
||||
std::string errMsg = "Unexpected storage client type for TFBuffer";
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
RETURN_IF_NOT_OK(client->NextFileInfo(buffer_id_, &cur_f_info_));
|
||||
cur_reader_.close();
|
||||
cur_reader_.open(cur_f_info_.fileName);
|
||||
// Seek to the offset
|
||||
(void)cur_reader_.seekg(static_cast<std::streamsize>(cur_f_info_.startOffset));
|
||||
MS_LOG(DEBUG) << "got new file " << cur_f_info_.fileName << ".";
|
||||
}
|
||||
|
||||
// one record in tf_file looks like:
|
||||
// Format of a single record:
|
||||
// uint64 length
|
||||
// uint32 masked crc of length
|
||||
// byte data[length]
|
||||
// uint32 masked crc of data
|
||||
// read length
|
||||
if (cur_reader_.peek() == EOF) {
|
||||
MS_LOG(ERROR) << "ParseSingleExample failed";
|
||||
}
|
||||
|
||||
dataengine::Example tf_file;
|
||||
try {
|
||||
uint64_t record_length = 0;
|
||||
(void)cur_reader_.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(uint64_t)));
|
||||
|
||||
// ignore crc header
|
||||
(void)cur_reader_.ignore(static_cast<std::streamsize>(sizeof(uint32_t)));
|
||||
|
||||
// read serialized Example
|
||||
std::string serialized_example;
|
||||
serialized_example.resize(record_length);
|
||||
(void)cur_reader_.read(&serialized_example[0], static_cast<std::streamsize>(record_length));
|
||||
|
||||
// ignore crc footer
|
||||
(void)cur_reader_.ignore(static_cast<std::streamsize>(sizeof(uint32_t)));
|
||||
|
||||
if (!tf_file.ParseFromString(serialized_example)) {
|
||||
std::string err_msg = "parse tf_file failed";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
} catch (const std::exception &err) {
|
||||
std::string err_msg = "Please check if the data file is complete!";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
*ptr = tf_file;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Name: LoadFeature()
|
||||
// Description: Given the column type of the tf record and the values list,
|
||||
// constructs the tensor and returns it.
|
||||
Status TFBuffer::LoadFeature(const dataengine::Feature::KindCase &column_list_type,
|
||||
const dataengine::Feature &column_values_list, const ColDescriptor ¤t_col,
|
||||
std::shared_ptr<Tensor> *out_tensor) {
|
||||
std::string element_str; // For staging data from protobuf deserialization
|
||||
std::unique_ptr<int64_t[]> int_array; // For staging data from protobuf deserialization
|
||||
std::unique_ptr<float[]> float_array; // For staging data from protobuf deserialization
|
||||
const unsigned char *data_ptr = nullptr; // Generic pointer used for populating the Tensor
|
||||
// This variable will point into the above staging
|
||||
// variables.
|
||||
uint32_t num_elements = 0; // Generic counter used for setting shape attributes
|
||||
|
||||
// Depending on the type of data from the tf_file, we want to extract 2 things:
|
||||
// 1) A pointer to the data as a const unsigned char *
|
||||
// 2) The number of elements of the data
|
||||
// After those are determined, we can then build the tensor to represent this data.
|
||||
|
||||
switch (column_list_type) {
|
||||
// CASE : TF record type: kBytesList
|
||||
case dataengine::Feature::KindCase::kBytesList: {
|
||||
RETURN_IF_NOT_OK(LoadBytesList(current_col, column_values_list, &element_str));
|
||||
|
||||
// Get the const pointer representation of this data, and the number of elements
|
||||
// (number of bytes) for this tensor.
|
||||
data_ptr = reinterpret_cast<const unsigned char *>(common::SafeCStr(element_str));
|
||||
num_elements = element_str.length();
|
||||
break;
|
||||
}
|
||||
|
||||
// CASE : TF record type: kFloatList
|
||||
case dataengine::Feature::KindCase::kFloatList: {
|
||||
RETURN_IF_NOT_OK(LoadFloatList(current_col, column_values_list, &num_elements, &float_array));
|
||||
|
||||
data_ptr = reinterpret_cast<const unsigned char *>(float_array.get());
|
||||
break;
|
||||
}
|
||||
|
||||
// CASE : TF record type: kInt64List
|
||||
case dataengine::Feature::KindCase::kInt64List: {
|
||||
RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, &num_elements, &int_array));
|
||||
|
||||
data_ptr = reinterpret_cast<const unsigned char *>(int_array.get());
|
||||
break;
|
||||
}
|
||||
case dataengine::Feature::KindCase::KIND_NOT_SET: {
|
||||
std::string errMsg = "tf_file column list type enum is KIND_NOT_SET";
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
default: {
|
||||
std::string errMsg = "tf_file column list type enum does not match any known DE type";
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
}
|
||||
|
||||
// At this point we have a raw pointer to the data, and we have the number of elements.
|
||||
// Along with the tensor implementation type and the data type from the schema, we
|
||||
// enough info to construct the Tensor for it.
|
||||
TensorShape current_shape = TensorShape::CreateUnknownRankShape();
|
||||
RETURN_IF_NOT_OK(CreateTensorShapeForColumn(current_col, num_elements, ¤t_shape));
|
||||
|
||||
// Now, create this tensor directly into the appropriate slot in our tensor
|
||||
// table.
|
||||
RETURN_IF_NOT_OK(
|
||||
Tensor::CreateTensor(out_tensor, current_col.tensorImpl(), current_shape, current_col.type(), data_ptr));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TFBuffer::LoadBytesList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list,
|
||||
std::string *element_str) {
|
||||
// kBytesList can map to the following DE types ONLY!
|
||||
// DE_UINT8, DE_INT8
|
||||
// Must be single byte type for each element!
|
||||
if (current_col.type() != DataType::DE_UINT8 && current_col.type() != DataType::DE_INT8) {
|
||||
std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name();
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
const dataengine::BytesList &bytes_list = column_values_list.bytes_list();
|
||||
|
||||
// A bytesList is a special case where the entire list of data can be
|
||||
// deserialized into a single string. For example, it is not a list
|
||||
// of bytes, it is a list of strings, where each string represents
|
||||
// a list of bytes (this is different from the other cases like IntList etc)
|
||||
// As such, if there is more than one string in this list, that is invalid.
|
||||
if (bytes_list.value_size() > 1) {
|
||||
std::string err_msg = "Bytes list contains more than one element for column: " + current_col.name();
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
// Extract the string that contains the bytes we need. Position 0 is the only
|
||||
// valid string here.
|
||||
*element_str = bytes_list.value(0);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TFBuffer::LoadFloatList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list,
|
||||
uint32_t *num_elements, std::unique_ptr<float[]> *float_array) {
|
||||
// KFloatList can only map to DE types:
|
||||
// DE_FLOAT32
|
||||
if (current_col.type() != DataType::DE_FLOAT32) {
|
||||
std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name();
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
const dataengine::FloatList &float_list = column_values_list.float_list();
|
||||
|
||||
// Identify how many values we have and then create a local array of these
|
||||
// to deserialize into
|
||||
*num_elements = float_list.value_size();
|
||||
*float_array = std::make_unique<float[]>(*num_elements);
|
||||
for (int i = 0; i < float_list.value_size(); i++) {
|
||||
(*float_array)[i] = float_list.value(i);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TFBuffer::LoadIntList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list,
|
||||
uint32_t *num_elements, std::unique_ptr<int64_t[]> *int_array) {
|
||||
// KInt64List can only map to DE types:
|
||||
// DE_UINT64, DE_INT64, DE_UINT32, DE_INT32, DE_UINT16, DE_INT16, DE_UINT8, DE_INT8
|
||||
if (!(current_col.type().IsInt())) {
|
||||
std::string err_msg = "Invalid datatype/rank for column label in TFBuffer.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
const dataengine::Int64List &int64_list = column_values_list.int64_list();
|
||||
|
||||
// Identify how many values we have and then create a local array of these
|
||||
// to deserialize into
|
||||
*num_elements = int64_list.value_size();
|
||||
*int_array = std::make_unique<int64_t[]>(*num_elements);
|
||||
for (int i = 0; i < int64_list.value_size(); i++) {
|
||||
(*int_array)[i] = int64_list.value(i);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TFBuffer::CreateTensorShapeForColumn(const ColDescriptor ¤t_col, uint32_t num_elements,
|
||||
TensorShape *current_shape) {
|
||||
// If the shape is assigned by user, we have an assumption that the data is
|
||||
// already in the appropriate format that we can copy into the Tensor as-is.
|
||||
if (current_col.hasShape()) {
|
||||
*current_shape = current_col.shape();
|
||||
} else if (current_col.rank() == 1) {
|
||||
// If shape was not given, then we support 2 possible shapes.
|
||||
// 1) It's a scalar (rank 0), in which case the shape is empty but we need to flag
|
||||
// it as a scalar value (empty shape but has a single value)
|
||||
// 2) It's a rank 1 shape, and the dimension value for that single dimension will
|
||||
// be comprised of the entire bytes-size of the input data.
|
||||
*current_shape = TensorShape({num_elements});
|
||||
} else if (current_col.rank() == 0) {
|
||||
// Make this shape into a single value scalar.
|
||||
*current_shape = TensorShape::CreateScalar();
|
||||
} else if (current_col.rank() > 1) {
|
||||
// All other ranks, except for 0, are invalid because we cannot guess
|
||||
// what the shape will be. For example, if we have rank 3 and 12 bytes
|
||||
// of data, is it shape {2,2,3} or is it {2,6,1}. We can't guess at
|
||||
// the shape dimensions.
|
||||
const std::string kErrMsg = "Invalid rank (rank>1) for dynamic shape construction. Specify shape in schema.";
|
||||
RETURN_STATUS_UNEXPECTED(kErrMsg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -1,91 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TF_BUFFER_H_
|
||||
#define DATASET_ENGINE_DATASETOPS_SOURCE_TF_BUFFER_H_
|
||||
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "dataset/engine/data_buffer.h"
|
||||
#include "proto/example.pb.h"
|
||||
#include "dataset/engine/datasetops/source/tf_client.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// This TFBuffer is the buffer type for dealing with tf record data.
|
||||
class TFBuffer : public DataBuffer {
|
||||
public:
|
||||
// constructor
|
||||
TFBuffer(uint32_t id, // In: The id for this buffer
|
||||
DataBuffer::BufferFlags flags, // In: The flags for this buffer
|
||||
const std::shared_ptr<StorageClient>
|
||||
&storage_client); // In: The storage client that is related to this buffer type
|
||||
|
||||
// destructor
|
||||
~TFBuffer() override;
|
||||
|
||||
// Name: print()
|
||||
// Description: A function that prints info
|
||||
void Print(std::ostream &out, // In: The output stream to print to
|
||||
bool show_all) const override; // In: T/F if it should print everything
|
||||
|
||||
// Provide stream operator for displaying it
|
||||
friend std::ostream &operator<<(std::ostream &out, const TFBuffer &tf_buffer) {
|
||||
tf_buffer.Print(out, false); // Show meta info only
|
||||
return out;
|
||||
}
|
||||
|
||||
// Name: load()
|
||||
// Description: populates the DataBuffer with data.
|
||||
// Overrides base-class method.
|
||||
Status Load() override;
|
||||
|
||||
private:
|
||||
std::ifstream cur_reader_;
|
||||
FileInfo cur_f_info_;
|
||||
|
||||
std::shared_ptr<StorageClient> storage_client_; // The storage client for populating the buffer initially.
|
||||
|
||||
// Name: ParseSingleExample()
|
||||
// Description: Drives the calls to TFClient for fetching the tf_file info from
|
||||
// the tf_file files. Returns a single row of data from the tf_file
|
||||
// files.
|
||||
Status ParseSingleExample(dataengine::Example *ptr);
|
||||
|
||||
// Name: LoadFeature()
|
||||
// Description: Given the column type of the tf record and the values list,
|
||||
// constructs the tensor and returns it.
|
||||
Status LoadFeature(const dataengine::Feature::KindCase &column_list_type,
|
||||
const dataengine::Feature &column_values_list, const ColDescriptor ¤t_col,
|
||||
std::shared_ptr<Tensor> *out_tensor);
|
||||
|
||||
Status LoadBytesList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list,
|
||||
std::string *element_str);
|
||||
|
||||
Status LoadFloatList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list,
|
||||
uint32_t *num_elements, std::unique_ptr<float[]> *float_array);
|
||||
|
||||
Status LoadIntList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list,
|
||||
uint32_t *num_elements, std::unique_ptr<int64_t[]> *int_array);
|
||||
|
||||
Status CreateTensorShapeForColumn(const ColDescriptor ¤t_col, uint32_t num_elements,
|
||||
TensorShape *current_shape);
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TF_BUFFER_H_
|
|
@ -1,376 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "dataset/engine/datasetops/source/tf_client.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
|
||||
#include "common/utils.h"
|
||||
#include "proto/example.pb.h"
|
||||
#include "dataset/engine/datasetops/source/storage_client.h"
|
||||
#include "dataset/util/path.h"
|
||||
#include "dataset/util/status.h"
|
||||
#include "dataset/engine/datasetops/source/storage_op.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Name: Constructor
|
||||
// Description: Creates the TFClient.
|
||||
TFClient::TFClient(std::unique_ptr<DataSchema> schema, // In: The schema for this storage client.
|
||||
StorageOp *so) // In: The StorageOp that's using this client
|
||||
: StorageClient(std::move(schema), so),
|
||||
rows_per_buffer_(so->rows_per_buffer()),
|
||||
random_seed_generator_(so->seed()),
|
||||
random_seed_distribution_(0, std::numeric_limits<uint32_t>::max()),
|
||||
rows_per_shard_(0) {}
|
||||
|
||||
Status TFClient::Init() {
|
||||
// Initialize queue to hold the tf file names
|
||||
const std::string kExtensionData = ".data";
|
||||
const std::string kExtensionTF = ".tfrecord";
|
||||
bool schema_init = false;
|
||||
if (!storage_op_->dataset_files_dir().empty()) {
|
||||
MS_LOG(DEBUG) << "Reading dataset using datasetPath.";
|
||||
Path data_set_directory(storage_op_->dataset_files_dir());
|
||||
auto dirIt = Path::DirIterator::OpenDirectory(&data_set_directory);
|
||||
if (dirIt) {
|
||||
while (dirIt->hasNext()) {
|
||||
Path file = dirIt->next();
|
||||
std::string filename = file.toString();
|
||||
if ((file.Extension() == kExtensionData) || (file.Extension() == kExtensionTF)) {
|
||||
const std::vector<uint64_t> recs_lengths = ParseTfFileLines(filename);
|
||||
v_total_file_rows_.emplace_back(
|
||||
std::pair<std::string, std::vector<uint64_t>>(filename, std::move(recs_lengths)));
|
||||
|
||||
// schema
|
||||
if (!schema_init) {
|
||||
RETURN_IF_NOT_OK(ParseTfFileSchema(filename));
|
||||
schema_init = true;
|
||||
}
|
||||
MS_LOG(INFO) << "found tf file: " << filename << ", num rows " << recs_lengths.size() << ".";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Unable to open directory " + data_set_directory.toString());
|
||||
}
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Reading dataset using dataset files list.";
|
||||
for (auto filename : storage_op_->dataset_file_list()) {
|
||||
const std::vector<uint64_t> recs_lengths = ParseTfFileLines(filename);
|
||||
v_total_file_rows_.emplace_back(std::pair<std::string, std::vector<uint64_t>>(filename, std::move(recs_lengths)));
|
||||
|
||||
// schema
|
||||
if (!schema_init) {
|
||||
RETURN_IF_NOT_OK(ParseTfFileSchema(filename));
|
||||
schema_init = true;
|
||||
}
|
||||
MS_LOG(INFO) << "Processed tf file: " << filename << ", num rows " << recs_lengths.size() << ".";
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(CalculateRowsPerDevice());
|
||||
std::sort(v_total_file_rows_.begin(), v_total_file_rows_.end());
|
||||
RETURN_IF_NOT_OK(ScatterFileRows(static_cast<uint32_t>(storage_op_->device_id()), storage_op_->shard_config(),
|
||||
storage_op_->seed(), storage_op_->shuffle_config()));
|
||||
|
||||
CalculateNumRows();
|
||||
InitStateInfo();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Sharding will reduce the number of rows. Doing this in constructor as we only want to do this once.
|
||||
void TFClient::CalculateNumRows() {
|
||||
num_rows_in_dataset_ = 0;
|
||||
for (auto rows : file_start_end_offset_) {
|
||||
num_rows_in_dataset_ += (rows.second - rows.first);
|
||||
}
|
||||
}
|
||||
|
||||
Status TFClient::CalculateRowsPerDevice() {
|
||||
uint64_t num = std::accumulate(
|
||||
v_total_file_rows_.begin(), v_total_file_rows_.end(), 0,
|
||||
[](uint64_t value, const std::pair<std::string, std::vector<uint64_t>> &a) { return value + a.second.size(); });
|
||||
if (static_cast<uint64_t>(std::floor(num * 1.0 / storage_op_->device_num())) == 0) {
|
||||
RETURN_STATUS_UNEXPECTED("Num rows of dataset is less than device number");
|
||||
}
|
||||
rows_per_shard_ = static_cast<uint64_t>(std::ceil(num * 1.0 / storage_op_->device_num()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool TFClient::ValidFileForShard(const uint64_t file_rows, uint64_t *start_offset, uint64_t *end_offset,
|
||||
const uint64_t &pre_count, uint32_t device_id) const {
|
||||
*start_offset = 0;
|
||||
*end_offset = 0;
|
||||
bool valid = false;
|
||||
uint64_t start_index = device_id * rows_per_shard_;
|
||||
uint64_t end_index = (device_id + 1) * rows_per_shard_;
|
||||
|
||||
// First valid file
|
||||
if (pre_count <= start_index && pre_count + file_rows > start_index) {
|
||||
*start_offset = start_index - pre_count;
|
||||
valid = true;
|
||||
if (pre_count < end_index && pre_count + file_rows >= end_index) {
|
||||
*end_offset = end_index - pre_count;
|
||||
} else {
|
||||
*end_offset = file_rows;
|
||||
}
|
||||
}
|
||||
|
||||
// Second and subsequent files
|
||||
if (pre_count > start_index && pre_count < end_index) {
|
||||
*start_offset = 0;
|
||||
valid = true;
|
||||
if (pre_count + file_rows >= end_index) {
|
||||
*end_offset = end_index - pre_count;
|
||||
} else {
|
||||
*end_offset = file_rows;
|
||||
}
|
||||
}
|
||||
|
||||
return valid;
|
||||
}
|
||||
|
||||
void TFClient::GetValidFileForShard(const std::vector<std::pair<std::string, std::vector<uint64_t>>> &v_files,
|
||||
uint32_t device_id) {
|
||||
uint64_t start_offset = 0;
|
||||
uint64_t end_offset = 0;
|
||||
uint64_t pre_count = 0;
|
||||
bool finish = false;
|
||||
while (!finish) {
|
||||
for (const auto &file : v_files) {
|
||||
if (ValidFileForShard(file.second.size(), &start_offset, &end_offset, pre_count, device_id)) {
|
||||
std::pair<uint32_t, uint32_t> offset(start_offset, end_offset);
|
||||
file_start_end_offset_.emplace_back(offset);
|
||||
v_file_rows_.emplace_back(file);
|
||||
}
|
||||
pre_count += file.second.size();
|
||||
}
|
||||
if (pre_count < (device_id + 1) * rows_per_shard_) {
|
||||
finish = false;
|
||||
} else {
|
||||
finish = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Description: Scatter file rows to local single-P according to config info.
|
||||
// There are 3 modes: ALL, UNIQUE, RANDOM. For UNIQUE and RANDOM mode, shuffleConfig controls
|
||||
// whether file row vector would be shuffled or not before a new mEopch.
|
||||
// For ALL mode, temporarily, we deal with epoch in python part.
|
||||
Status TFClient::ScatterFileRows(uint32_t device_id, const std::string &shard_config, uint32_t seed,
|
||||
bool shuffle_config) {
|
||||
if (shard_config == "UNIQUE" || shard_config == "RANDOM") {
|
||||
std::vector<std::pair<std::string, std::vector<uint64_t>>> v_shuffled_total_file_rows =
|
||||
ShuffleVector(v_total_file_rows_, seed);
|
||||
GetValidFileForShard(v_shuffled_total_file_rows, device_id);
|
||||
if (shuffle_config) {
|
||||
v_total_file_rows_ = v_shuffled_total_file_rows;
|
||||
}
|
||||
} else if (shard_config == "ALL") {
|
||||
v_file_rows_.insert(v_file_rows_.end(), v_total_file_rows_.begin(), v_total_file_rows_.end());
|
||||
if (shuffle_config) {
|
||||
v_total_file_rows_ = ShuffleVector(v_total_file_rows_, seed);
|
||||
}
|
||||
|
||||
for (const auto &file : v_file_rows_) {
|
||||
std::pair<uint32_t, uint32_t> offset(0, file.second.size());
|
||||
file_start_end_offset_.emplace_back(offset);
|
||||
}
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("In parallel config file, wrong shuffleConfig or shardConfig provided.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, std::vector<uint64_t>>> TFClient::ShuffleVector(
|
||||
std::vector<std::pair<std::string, std::vector<uint64_t>>> v, uint32_t seed = 1) {
|
||||
std::default_random_engine randomEngine(seed);
|
||||
std::shuffle(std::begin(v), std::end(v), randomEngine);
|
||||
return v;
|
||||
}
|
||||
|
||||
void TFClient::CalculateStartOffset(const uint64_t start_index, const uint64_t end_index,
|
||||
const std::vector<uint64_t> &vec_length, uint64_t *start_offset) const {
|
||||
for (size_t i = start_index; i < end_index; i++) {
|
||||
// Format of a single record:
|
||||
// uint64 length
|
||||
// uint32 masked crc of length
|
||||
// byte data[length]
|
||||
// uint32 masked crc of data
|
||||
*start_offset += sizeof(uint64_t) + 2 * sizeof(uint32_t) + vec_length[i];
|
||||
}
|
||||
}
|
||||
|
||||
void TFClient::InitStateInfo() {
|
||||
uint32_t start_idx = 0, record_num = 0, buffer_id = 0;
|
||||
uint64_t start_offset = 0;
|
||||
bool first_buffer = true;
|
||||
f_info_queue_.emplace_back(QFile());
|
||||
std::vector<std::pair<std::string, std::vector<uint64_t>>>::iterator itr = v_file_rows_.begin();
|
||||
uint32_t index = 0;
|
||||
while (itr != v_file_rows_.end()) {
|
||||
uint32_t file_start_index = file_start_end_offset_[index].first;
|
||||
uint32_t file_end_index = file_start_end_offset_[index].second;
|
||||
FileInfo f_info;
|
||||
f_info.fileName = itr->first;
|
||||
f_info.startRecordIdx = start_idx > file_start_index ? start_idx : file_start_index;
|
||||
if (first_buffer && f_info.startRecordIdx != 0) {
|
||||
CalculateStartOffset(0, f_info.startRecordIdx, itr->second, &start_offset);
|
||||
start_idx = static_cast<uint32_t>(f_info.startRecordIdx);
|
||||
}
|
||||
first_buffer = false;
|
||||
f_info.startOffset = start_offset;
|
||||
if (start_idx + rows_per_buffer_ - record_num < itr->second.size()) {
|
||||
uint64_t end_idx = start_idx + rows_per_buffer_ - record_num - 1;
|
||||
f_info.endRecordIdx = end_idx > (file_end_index - 1) ? (file_end_index - 1) : end_idx;
|
||||
f_info_queue_[buffer_id].push(f_info);
|
||||
CalculateStartOffset(start_idx, f_info.endRecordIdx + 1, itr->second, &start_offset);
|
||||
start_idx = start_idx + rows_per_buffer_ - record_num;
|
||||
record_num = 0;
|
||||
buffer_id++;
|
||||
f_info_queue_.emplace_back(QFile());
|
||||
if (end_idx >= file_end_index - 1) {
|
||||
start_idx = start_offset = 0;
|
||||
++itr;
|
||||
++index;
|
||||
}
|
||||
} else {
|
||||
f_info.endRecordIdx = itr->second.size() - 1 > file_end_index - 1 ? file_end_index - 1 : itr->second.size() - 1;
|
||||
f_info_queue_[buffer_id].push(f_info);
|
||||
if (start_idx + rows_per_buffer_ - record_num == itr->second.size()) {
|
||||
record_num = start_idx = start_offset = 0;
|
||||
buffer_id++;
|
||||
if (itr + 1 != v_file_rows_.end()) {
|
||||
f_info_queue_.emplace_back(QFile());
|
||||
}
|
||||
} else {
|
||||
record_num += static_cast<uint32_t>(itr->second.size()) - start_idx;
|
||||
start_idx = start_offset = 0;
|
||||
}
|
||||
++itr;
|
||||
++index;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Name: Print()
|
||||
// Description: A function that prints info about the TFClient
|
||||
void TFClient::Print(std::ostream &out) const { // In: The output stream to print to
|
||||
out << "TF client.";
|
||||
}
|
||||
|
||||
std::vector<uint64_t> TFClient::ParseTfFileLines(const std::string &filename) {
|
||||
std::vector<uint64_t> recs_lengths;
|
||||
std::ifstream reader;
|
||||
reader.open(filename);
|
||||
while (true) {
|
||||
if (reader.peek() == EOF) {
|
||||
reader.close();
|
||||
break;
|
||||
}
|
||||
|
||||
// read length
|
||||
uint64_t record_length = 0;
|
||||
(void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(uint64_t)));
|
||||
recs_lengths.push_back(record_length);
|
||||
|
||||
// ignore crc header
|
||||
(void)reader.ignore(static_cast<std::streamsize>(sizeof(uint32_t)));
|
||||
|
||||
// ignore data length
|
||||
(void)reader.ignore(static_cast<std::streamsize>(record_length));
|
||||
|
||||
// ignore crc footer
|
||||
(void)reader.ignore(static_cast<std::streamsize>(sizeof(uint32_t)));
|
||||
}
|
||||
return recs_lengths;
|
||||
}
|
||||
|
||||
Status TFClient::ParseTfFileSchema(const std::string &filename) {
|
||||
std::ifstream reader;
|
||||
reader.open(filename);
|
||||
std::string serialized_example;
|
||||
// read length
|
||||
uint64_t record_length = 0;
|
||||
(void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(uint64_t)));
|
||||
|
||||
// ignore crc header
|
||||
(void)reader.ignore(static_cast<std::streamsize>(sizeof(uint32_t)));
|
||||
|
||||
// read serialized Example
|
||||
serialized_example.resize(record_length);
|
||||
(void)reader.read(&serialized_example[0], static_cast<std::streamsize>(record_length));
|
||||
|
||||
// ignore crc footer
|
||||
(void)reader.ignore(static_cast<std::streamsize>(sizeof(uint32_t)));
|
||||
|
||||
reader.close();
|
||||
dataengine::Example tf_file;
|
||||
if (!tf_file.ParseFromString(serialized_example)) {
|
||||
std::string err_msg = "parse tf_file failed, file name is " + filename;
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
const dataengine::Features &example_features = tf_file.features();
|
||||
const google::protobuf::Map<std::string, dataengine::Feature> &feature_map = example_features.feature();
|
||||
for (auto it = feature_map.begin(); it != feature_map.end(); ++it) {
|
||||
col_names_.push_back(it->first);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Name: Reset()
|
||||
// Description: Resets any state info inside the client back to it's initialized
|
||||
// state.
|
||||
Status TFClient::Reset() {
|
||||
v_file_rows_.clear();
|
||||
file_start_end_offset_.clear();
|
||||
|
||||
uint32_t next_seed = random_seed_distribution_(random_seed_generator_);
|
||||
RETURN_IF_NOT_OK(ScatterFileRows(static_cast<uint32_t>(storage_op_->device_id()), storage_op_->shard_config(),
|
||||
next_seed, storage_op_->shuffle_config()));
|
||||
|
||||
CalculateNumRows();
|
||||
uint32_t num_rows_in_file = 0;
|
||||
RETURN_IF_NOT_OK(this->numRowsFromFile(num_rows_in_file));
|
||||
if (num_rows_in_file < num_rows_in_dataset_) {
|
||||
num_rows_in_dataset_ = num_rows_in_file;
|
||||
}
|
||||
|
||||
storage_op_->set_num_rows(static_cast<int32_t>(num_rows_in_dataset_));
|
||||
InitStateInfo();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TFClient::NextFileInfo(uint32_t id, FileInfo *ptr) {
|
||||
if (f_info_queue_.empty() || id >= f_info_queue_.size() || f_info_queue_[id].empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("cannot find next FileInfo in mFInfoQueue");
|
||||
}
|
||||
*ptr = f_info_queue_[id].front();
|
||||
f_info_queue_[id].pop();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool TFClient::IsMoreData(uint32_t id) { return (!f_info_queue_[id].empty()); }
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -1,111 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TF_CLIENT_H_
|
||||
#define DATASET_ENGINE_DATASETOPS_SOURCE_TF_CLIENT_H_
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "proto/example.pb.h"
|
||||
#include "dataset/engine/datasetops/source/storage_client.h"
|
||||
#include "dataset/util/status.h"
|
||||
|
||||
struct FileInfo {
|
||||
std::string fileName;
|
||||
uint64_t startRecordIdx;
|
||||
uint64_t endRecordIdx;
|
||||
uint64_t startOffset;
|
||||
};
|
||||
|
||||
using QFile = std::queue<FileInfo>;
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// forward declares
|
||||
class DataSchema;
|
||||
class ParallelOp;
|
||||
|
||||
class TFClient : public StorageClient {
|
||||
public:
|
||||
// Name: Constructor
|
||||
// Description: Creates the TFClient.
|
||||
TFClient(std::unique_ptr<DataSchema> schema, // In: The schema for this storage client.
|
||||
StorageOp *so); // In: The ParallelOp that's using this client
|
||||
|
||||
~TFClient() {}
|
||||
|
||||
Status Init() override;
|
||||
|
||||
// Name: Print()
|
||||
// Description: A function that prints info about the TFClient
|
||||
void Print(std::ostream &out) const override; // In: The output stream to print to
|
||||
|
||||
std::vector<uint64_t> ParseTfFileLines(const std::string &filename);
|
||||
|
||||
Status ParseTfFileSchema(const std::string &filename);
|
||||
|
||||
Status NextFileInfo(uint32_t id, FileInfo *);
|
||||
|
||||
bool IsMoreData(uint32_t id) override;
|
||||
|
||||
// Name: Reset()
|
||||
// Description: Resets any state info inside the client back to it's initialized
|
||||
// state.
|
||||
Status Reset() override;
|
||||
|
||||
Status ScatterFileRows(uint32_t device_id, const std::string &shard_config, uint32_t seed, bool shuffle_config);
|
||||
|
||||
private:
|
||||
// hardcoded, put this in json schema
|
||||
// const static int32_t BERT_DATASET_TOTAL_ROWS = 43900;
|
||||
uint32_t rows_per_buffer_;
|
||||
std::default_random_engine random_seed_generator_;
|
||||
std::uniform_int_distribution<uint32_t> random_seed_distribution_;
|
||||
|
||||
std::vector<std::pair<std::string, std::vector<uint64_t>>> v_file_rows_;
|
||||
std::vector<std::pair<std::string, std::vector<uint64_t>>> v_total_file_rows_;
|
||||
std::vector<QFile> f_info_queue_;
|
||||
uint64_t rows_per_shard_;
|
||||
std::vector<std::pair<uint32_t, uint32_t>> file_start_end_offset_;
|
||||
|
||||
void InitStateInfo();
|
||||
|
||||
std::vector<std::pair<std::string, std::vector<uint64_t>>> ShuffleVector(
|
||||
std::vector<std::pair<std::string, std::vector<uint64_t>>> v, uint32_t seed);
|
||||
|
||||
Status CalculateRowsPerDevice();
|
||||
|
||||
bool ValidFileForShard(const uint64_t file_rows, uint64_t *start_offset, uint64_t *end_offset,
|
||||
const uint64_t &pre_count, uint32_t device_id) const;
|
||||
|
||||
void CalculateNumRows();
|
||||
|
||||
void GetValidFileForShard(const std::vector<std::pair<std::string, std::vector<uint64_t>>> &v_files,
|
||||
uint32_t device_id);
|
||||
|
||||
void CalculateStartOffset(const uint64_t start_index, const uint64_t end_index,
|
||||
const std::vector<uint64_t> &vec_length, uint64_t *start_offset) const;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TF_CLIENT_H_
|
|
@ -16,6 +16,7 @@
|
|||
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <future>
|
||||
#include <iomanip>
|
||||
#include <memory>
|
||||
|
@ -32,8 +33,6 @@
|
|||
#include "dataset/engine/connector.h"
|
||||
#include "dataset/engine/data_schema.h"
|
||||
#include "dataset/engine/datasetops/source/io_block.h"
|
||||
#include "dataset/engine/datasetops/source/storage_client.h"
|
||||
#include "dataset/engine/datasetops/source/tf_client.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/engine/jagged_connector.h"
|
||||
|
|
|
@ -40,7 +40,7 @@ class TakeOp : public PipelineOp {
|
|||
~Builder() = default;
|
||||
|
||||
// The builder "build" method creates the final object.
|
||||
// @return shared_ptr to the new StorageOp object
|
||||
// @return shared_ptr to the new TakeOp object
|
||||
Status Build(std::shared_ptr<TakeOp> *);
|
||||
|
||||
private:
|
||||
|
|
|
@ -65,7 +65,7 @@ class ZipOp : public PipelineOp {
|
|||
}
|
||||
|
||||
// The builder "build" method creates the ZipOp dataset Operator.
|
||||
// @return shared_ptr to the new StorageOp object
|
||||
// @return shared_ptr to the new ZipOp object
|
||||
Status Build(std::shared_ptr<ZipOp> *);
|
||||
|
||||
private:
|
||||
|
|
|
@ -27,7 +27,6 @@
|
|||
#include "dataset/engine/datasetops/shuffle_op.h"
|
||||
#include "dataset/engine/datasetops/source/generator_op.h"
|
||||
#include "dataset/engine/datasetops/source/mindrecord_op.h"
|
||||
#include "dataset/engine/datasetops/source/storage_op.h"
|
||||
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "dataset/engine/datasetops/source/image_folder_op.h"
|
||||
#include "dataset/engine/datasetops/take_op.h"
|
||||
|
|
|
@ -33,169 +33,6 @@ valid_detype = [
|
|||
]
|
||||
|
||||
|
||||
def check(method):
|
||||
"""Check the function parameters and return the function ."""
|
||||
func_name = method.__name__
|
||||
# Required parameter
|
||||
req_param_int = []
|
||||
req_param_bool = []
|
||||
# Non-required parameter
|
||||
nreq_param_int = []
|
||||
nreq_param_bool = []
|
||||
|
||||
if func_name in 'repeat':
|
||||
nreq_param_int = ['count', 'prefetch_size']
|
||||
|
||||
if func_name in 'take':
|
||||
req_param_int = ['count']
|
||||
nreq_param_int = ['prefetch_size']
|
||||
|
||||
elif func_name in 'shuffle':
|
||||
req_param_int = ['buffer_size']
|
||||
nreq_param_bool = ['reshuffle_each_iteration']
|
||||
nreq_param_int = ['prefetch_size', 'seed']
|
||||
|
||||
elif func_name in 'batch':
|
||||
req_param_int = ['batch_size']
|
||||
nreq_param_int = ['num_parallel_workers', 'prefetch_size']
|
||||
nreq_param_bool = ['drop_remainder']
|
||||
|
||||
elif func_name in ('zip', 'filter', 'cache', 'rename', 'project'):
|
||||
nreq_param_int = ['prefetch_size']
|
||||
|
||||
elif func_name in ('map', '__init__'):
|
||||
nreq_param_int = ['num_parallel_workers', 'prefetch_size', 'seed']
|
||||
nreq_param_bool = ['block_reader']
|
||||
|
||||
@wraps(method)
|
||||
def wrapper(*args, **kwargs):
|
||||
|
||||
def _make_key():
|
||||
sig = ins.signature(method)
|
||||
params = sig.parameters
|
||||
keys = list(params.keys())
|
||||
param_dic = dict()
|
||||
for name, value in enumerate(args):
|
||||
param_dic[keys[name]] = value
|
||||
param_dic.update(zip(params.keys(), args))
|
||||
param_dic.update(kwargs)
|
||||
|
||||
for name, value in params.items():
|
||||
if name not in param_dic:
|
||||
param_dic[name] = value.default
|
||||
return param_dic
|
||||
|
||||
# check type
|
||||
def _check_param_type(arg, param_name, param_type=None):
|
||||
if param_type is not None and not isinstance(arg, param_type):
|
||||
raise ValueError(
|
||||
"The %s function %s type error!" % (func_name, param_name))
|
||||
|
||||
# check range
|
||||
def _check_param_range(arg, param_name):
|
||||
if isinstance(arg, int) and param_name == "seed" and (
|
||||
arg < 0 or arg > 2147483647):
|
||||
raise ValueError(
|
||||
"The %s function %s exceeds the boundary!" % (
|
||||
func_name, param_name))
|
||||
if isinstance(arg, int) and param_name == "count" and ((arg <= 0 and arg != -1) or arg > 2147483647):
|
||||
raise ValueError(
|
||||
"The %s function %s exceeds the boundary!" % (
|
||||
func_name, param_name))
|
||||
if isinstance(arg, int) and param_name == "prefetch_size" and (
|
||||
arg <= 0 or arg > 1024):
|
||||
raise ValueError(
|
||||
"The %s function %s exceeds the boundary!" % (
|
||||
func_name, param_name))
|
||||
if isinstance(arg, int) and param_name == "num_parallel_workers" and (
|
||||
arg < 1 or arg > cpu_count()):
|
||||
raise ValueError(
|
||||
"The %s function %s exceeds the boundary(%s)!" % (
|
||||
func_name, param_name, cpu_count()))
|
||||
if isinstance(arg, int) and param_name != "seed" \
|
||||
and param_name != "count" and param_name != "prefetch_size" \
|
||||
and param_name != "num_parallel_workers" and (arg < 1 or arg > 2147483647):
|
||||
raise ValueError(
|
||||
"The %s function %s exceeds the boundary!" % (
|
||||
func_name, param_name))
|
||||
|
||||
key = _make_key()
|
||||
# check integer
|
||||
for karg in req_param_int:
|
||||
_check_param_type(key[karg], karg, int)
|
||||
_check_param_range(key[karg], karg)
|
||||
for karg in nreq_param_int:
|
||||
if karg in key:
|
||||
if key[karg] is not None:
|
||||
_check_param_type(key[karg], karg, int)
|
||||
_check_param_range(key[karg], karg)
|
||||
# check bool
|
||||
for karg in req_param_bool:
|
||||
_check_param_type(key[karg], karg, bool)
|
||||
for karg in nreq_param_bool:
|
||||
if karg in key:
|
||||
if key[karg] is not None:
|
||||
_check_param_type(key[karg], karg, bool)
|
||||
|
||||
if func_name in '__init__':
|
||||
if 'columns_list' in key.keys():
|
||||
columns_list = key['columns_list']
|
||||
if columns_list is not None:
|
||||
_check_param_type(columns_list, 'columns_list', list)
|
||||
|
||||
if 'columns' in key.keys():
|
||||
columns = key['columns']
|
||||
if columns is not None:
|
||||
_check_param_type(columns, 'columns', list)
|
||||
|
||||
if 'partitions' in key.keys():
|
||||
partitions = key['partitions']
|
||||
if partitions is not None:
|
||||
_check_param_type(partitions, 'partitions', list)
|
||||
|
||||
if 'schema' in key.keys():
|
||||
schema = key['schema']
|
||||
if schema is not None:
|
||||
check_filename(schema)
|
||||
if not os.path.isfile(schema) or not os.access(schema, os.R_OK):
|
||||
raise ValueError(
|
||||
"The file %s does not exist or permission denied!" % schema)
|
||||
|
||||
if 'dataset_dir' in key.keys():
|
||||
dataset_dir = key['dataset_dir']
|
||||
if dataset_dir is not None:
|
||||
if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK):
|
||||
raise ValueError(
|
||||
"The folder %s does not exist or permission denied!" % dataset_dir)
|
||||
|
||||
if 'dataset_files' in key.keys():
|
||||
dataset_files = key['dataset_files']
|
||||
if not dataset_files:
|
||||
raise ValueError(
|
||||
"The dataset file does not exists!")
|
||||
if dataset_files is not None:
|
||||
_check_param_type(dataset_files, 'dataset_files', list)
|
||||
for file in dataset_files:
|
||||
if not os.path.isfile(file) or not os.access(file, os.R_OK):
|
||||
raise ValueError(
|
||||
"The file %s does not exist or permission denied!" % file)
|
||||
|
||||
if 'dataset_file' in key.keys():
|
||||
dataset_file = key['dataset_file']
|
||||
if not dataset_file:
|
||||
raise ValueError(
|
||||
"The dataset file does not exists!")
|
||||
check_filename(dataset_file)
|
||||
if dataset_file is not None:
|
||||
if not os.path.isfile(dataset_file) or not os.access(dataset_file, os.R_OK):
|
||||
raise ValueError(
|
||||
"The file %s does not exist or permission denied!" % dataset_file)
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def check_valid_detype(type_):
|
||||
if type_ not in valid_detype:
|
||||
raise ValueError("Unknown column type")
|
||||
|
|
|
@ -48,7 +48,6 @@ SET(DE_UT_SRCS
|
|||
shuffle_op_test.cc
|
||||
stand_alone_samplers_test.cc
|
||||
status_test.cc
|
||||
storage_op_test.cc
|
||||
task_manager_test.cc
|
||||
tensor_test.cc
|
||||
tensor_string_test.cc
|
||||
|
|
|
@ -54,10 +54,10 @@ std::shared_ptr<de::RepeatOp> Repeat(int repeat_cnt = 1) {
|
|||
return op;
|
||||
}
|
||||
|
||||
std::shared_ptr<de::StorageOp> Storage(std::string schema, int rows_per_buf = 2, int num_works = 8) {
|
||||
std::shared_ptr<de::StorageOp> so;
|
||||
de::StorageOp::Builder builder;
|
||||
builder.SetDatasetFilesDir(schema).SetRowsPerBuffer(rows_per_buf).SetNumWorkers(num_works);
|
||||
std::shared_ptr<de::TFReaderOp> TFReader(std::string schema, int rows_per_buf = 2, int num_works = 8) {
|
||||
std::shared_ptr<de::TFReaderOp> so;
|
||||
de::TFReaderOp::Builder builder;
|
||||
builder.SetDatasetFilesList({schema}).SetRowsPerBuffer(rows_per_buf).SetNumWorkers(num_works);
|
||||
Status rc = builder.Build(&so);
|
||||
return so;
|
||||
}
|
||||
|
@ -77,9 +77,9 @@ std::shared_ptr<de::ExecutionTree> Build(std::vector<std::shared_ptr<de::Dataset
|
|||
}
|
||||
|
||||
TEST_F(MindDataTestBatchOp, TestSimpleBatch) {
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset";
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
|
||||
bool success = false;
|
||||
auto tree = Build({Storage(schema_file), Batch(12)});
|
||||
auto tree = Build({TFReader(schema_file), Batch(12)});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
|
@ -108,9 +108,9 @@ TEST_F(MindDataTestBatchOp, TestSimpleBatch) {
|
|||
}
|
||||
|
||||
TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) {
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset";
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
|
||||
bool success = false;
|
||||
auto tree = Build({Storage(schema_file), Repeat(2), Batch(7, true, 99)});
|
||||
auto tree = Build({TFReader(schema_file), Repeat(2), Batch(7, true, 99)});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
|
@ -153,9 +153,9 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) {
|
|||
}
|
||||
|
||||
TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) {
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset";
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
|
||||
bool success = false;
|
||||
auto tree = Build({Storage(schema_file), Repeat(2), Batch(7, false, 99)});
|
||||
auto tree = Build({TFReader(schema_file), Repeat(2), Batch(7, false, 99)});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
|
@ -205,9 +205,9 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) {
|
|||
}
|
||||
|
||||
TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) {
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset";
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
|
||||
bool success = false;
|
||||
auto tree = Build({Storage(schema_file), Batch(7, false, 99), Repeat(2)});
|
||||
auto tree = Build({TFReader(schema_file), Batch(7, false, 99), Repeat(2)});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
|
@ -251,9 +251,9 @@ TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) {
|
|||
}
|
||||
|
||||
TEST_F(MindDataTestBatchOp, TestBatchDropTrueRepeat) {
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset";
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
|
||||
bool success = false;
|
||||
auto tree = Build({Storage(schema_file), Batch(5, true, 99), Repeat(2)});
|
||||
auto tree = Build({TFReader(schema_file), Batch(5, true, 99), Repeat(2)});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
|
@ -297,7 +297,7 @@ TEST_F(MindDataTestBatchOp, TestBatchDropTrueRepeat) {
|
|||
}
|
||||
|
||||
TEST_F(MindDataTestBatchOp, TestSimpleBatchPadding) {
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset";
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
|
||||
std::shared_ptr<BatchOp> op;
|
||||
PadInfo m;
|
||||
std::shared_ptr<Tensor> pad_value;
|
||||
|
@ -305,7 +305,7 @@ TEST_F(MindDataTestBatchOp, TestSimpleBatchPadding) {
|
|||
pad_value->SetItemAt<float>({}, -1);
|
||||
m.insert({"col_1d", std::make_pair(TensorShape({4}), pad_value)});
|
||||
de::BatchOp::Builder(12).SetDrop(false).SetPaddingMap(m, true).Build(&op);
|
||||
auto tree = Build({Storage(schema_file), op});
|
||||
auto tree = Build({TFReader(schema_file), op});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
|
|
|
@ -88,17 +88,17 @@ TEST_F(MindDataTestClientConfig, TestClientConfig2) {
|
|||
// Dataset from testDataset1 has 10 rows, 2 columns.
|
||||
// RowsPerBuffer buffer setting of 2 divides evenly into total rows.
|
||||
std::string dataset_path;
|
||||
dataset_path = datasets_root_path_ + "/testDataset1";
|
||||
std::shared_ptr<StorageOp> my_storage_op;
|
||||
StorageOp::Builder builder;
|
||||
builder.SetDatasetFilesDir(dataset_path);
|
||||
rc = builder.Build(&my_storage_op);
|
||||
dataset_path = datasets_root_path_ + "/testDataset1/testDataset1.data";
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
TFReaderOp::Builder builder;
|
||||
builder.SetDatasetFilesList({dataset_path});
|
||||
rc = builder.Build(&my_tfreader_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
ASSERT_EQ(my_storage_op->num_workers(),16);
|
||||
my_tree->AssociateNode(my_storage_op);
|
||||
ASSERT_EQ(my_tfreader_op->num_workers(),1);
|
||||
my_tree->AssociateNode(my_tfreader_op);
|
||||
|
||||
// Set children/root layout.
|
||||
my_tree->AssignRoot(my_storage_op);
|
||||
my_tree->AssignRoot(my_tfreader_op);
|
||||
|
||||
my_tree->Prepare();
|
||||
my_tree->Launch();
|
||||
|
@ -116,5 +116,5 @@ TEST_F(MindDataTestClientConfig, TestClientConfig2) {
|
|||
row_count++;
|
||||
}
|
||||
ASSERT_EQ(row_count, 10); // Should be 10 rows fetched
|
||||
ASSERT_EQ(my_storage_op->num_workers(),16);
|
||||
ASSERT_EQ(my_tfreader_op->num_workers(),1);
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include "dataset/core/client.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/engine/datasetops/shuffle_op.h"
|
||||
#include "dataset/engine/datasetops/source/storage_op.h"
|
||||
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "common/common.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "dataset/util/de_error.h"
|
||||
|
@ -103,17 +103,17 @@ TEST_F(MindDataTestExecutionTree, TestExecutionTree2) {
|
|||
Status rc;
|
||||
auto my_tree = std::make_shared<ExecutionTree>();
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/testDataset1";
|
||||
std::shared_ptr<StorageOp> my_storage_op;
|
||||
StorageOp::Builder()
|
||||
.SetDatasetFilesDir(dataset_path)
|
||||
std::string dataset_path = datasets_root_path_ + "/testDataset1/testDataset1.data";
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
TFReaderOp::Builder()
|
||||
.SetDatasetFilesList({dataset_path})
|
||||
.SetRowsPerBuffer(2)
|
||||
.SetWorkerConnectorSize(2)
|
||||
.SetNumWorkers(2)
|
||||
.Build(&my_storage_op);
|
||||
.Build(&my_tfreader_op);
|
||||
|
||||
my_tree->AssociateNode(my_storage_op);
|
||||
my_tree->AssignRoot(my_storage_op);
|
||||
my_tree->AssociateNode(my_tfreader_op);
|
||||
my_tree->AssignRoot(my_tfreader_op);
|
||||
|
||||
// prepare the tree
|
||||
my_tree->Prepare();
|
||||
|
|
|
@ -91,7 +91,8 @@ class MindDataTestMapOp : public UT::DatasetOpTesting {
|
|||
public:
|
||||
void SetUp() override {
|
||||
DatasetOpTesting::SetUp();
|
||||
dataset_path_ = datasets_root_path_ + "" + "/testDataset2";
|
||||
dataset_path_ = datasets_root_path_ + "" + "/testDataset2/testDataset2.data";
|
||||
schema_path_ = datasets_root_path_ + "" + "/testDataset2/datasetSchema.json";
|
||||
|
||||
GlobalInit();
|
||||
|
||||
|
@ -99,22 +100,28 @@ class MindDataTestMapOp : public UT::DatasetOpTesting {
|
|||
my_tree_ = std::make_shared<ExecutionTree>();
|
||||
}
|
||||
|
||||
std::shared_ptr<StorageOp> CreateStorageOp() {
|
||||
std::shared_ptr<StorageOp> my_storage_op;
|
||||
StorageOp::Builder builder;
|
||||
builder.SetDatasetFilesDir(dataset_path_)
|
||||
std::shared_ptr<TFReaderOp> CreateTFReaderOp() {
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
TFReaderOp::Builder builder;
|
||||
builder.SetDatasetFilesList({dataset_path_})
|
||||
.SetColumnsToLoad({"image", "label", "A", "B"})
|
||||
.SetRowsPerBuffer(2)
|
||||
.SetWorkerConnectorSize(2)
|
||||
.SetNumWorkers(2);
|
||||
Status rc = builder.Build(&my_storage_op);
|
||||
|
||||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
schema->LoadSchemaFile(schema_path_, {});
|
||||
builder.SetDataSchema(std::move(schema));
|
||||
|
||||
Status rc = builder.Build(&my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
return my_storage_op;
|
||||
return my_tfreader_op;
|
||||
}
|
||||
|
||||
std::shared_ptr<ExecutionTree> my_tree_;
|
||||
private:
|
||||
std::string dataset_path_;
|
||||
std::string schema_path_;
|
||||
};
|
||||
|
||||
std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path,
|
||||
|
@ -124,7 +131,7 @@ std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int6
|
|||
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
|
||||
|
||||
// TestByPosition scenario:
|
||||
// StorageOp reads a dataset that have column ordering |image|label|A|B|.
|
||||
// TFReaderOp reads a dataset that have column ordering |image|label|A|B|.
|
||||
// A TensorOp that does nothing picks the label column and output a column also named label.
|
||||
// Thus, based on the new MapOp behaviour, the column ordering will be |image|label|A|B|.
|
||||
// Verify the column ordering based on the Tensor properties matching to that of in the schema file.
|
||||
|
@ -132,10 +139,10 @@ TEST_F(MindDataTestMapOp, TestByPosition) {
|
|||
Status rc;
|
||||
MS_LOG(INFO) << "Doing TestByPosition.";
|
||||
|
||||
// Note: The above storage config yields 5 buffers, each with 2 rows, for a total
|
||||
// Note: The above TFReader config yields 5 buffers, each with 2 rows, for a total
|
||||
// of 10 rows.
|
||||
auto my_storage_op = this->CreateStorageOp();
|
||||
rc = my_tree_->AssociateNode(my_storage_op);
|
||||
auto my_tfreader_op = this->CreateTFReaderOp();
|
||||
rc = my_tree_->AssociateNode(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
auto my_no_op = std::make_shared<mindspore::dataset::test::NoOp>();
|
||||
std::vector<std::shared_ptr<TensorOp>> my_func_list;
|
||||
|
@ -144,13 +151,14 @@ TEST_F(MindDataTestMapOp, TestByPosition) {
|
|||
MapOp::Builder builder;
|
||||
builder.SetInColNames({"label"})
|
||||
.SetOutColNames({})
|
||||
.SetColOrder({"image", "label", "A", "B"})
|
||||
.SetTensorFuncs(std::move(my_func_list))
|
||||
.SetNumWorkers(100);
|
||||
rc = builder.Build(&my_map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->AssociateNode(my_map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_map_op->AddChild(my_storage_op);
|
||||
rc = my_map_op->AddChild(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->AssignRoot(my_map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
@ -192,7 +200,7 @@ TEST_F(MindDataTestMapOp, TestByPosition) {
|
|||
}
|
||||
|
||||
// TestAsMap scenario:
|
||||
// StorageOp reads a dataset that have column ordering |image|label|A|B|.
|
||||
// TFReaderOp reads a dataset that have column ordering |image|label|A|B|.
|
||||
// A TensorOp that does nothing picks the "image" column and produces a column named "X".
|
||||
// Thus, based on the new MapOp behaviour, the column ordering will be |X|label|A|B|.
|
||||
// Verify that the "image" column is removed and "X" column is added.
|
||||
|
@ -200,9 +208,9 @@ TEST_F(MindDataTestMapOp, TestAsMap) {
|
|||
Status rc;
|
||||
MS_LOG(INFO) << "Doing TestAsMap.";
|
||||
|
||||
// Note: The above storage config yields 5 buffers, each with 2 rows, for a total of 10 rows.
|
||||
auto my_storage_op = this->CreateStorageOp();
|
||||
rc = my_tree_->AssociateNode(my_storage_op);
|
||||
// Note: The above TFReader config yields 5 buffers, each with 2 rows, for a total of 10 rows.
|
||||
auto my_tfreader_op = this->CreateTFReaderOp();
|
||||
rc = my_tree_->AssociateNode(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
auto my_no_op = std::make_shared<mindspore::dataset::test::NoOp>();
|
||||
std::vector<std::shared_ptr<TensorOp>> my_func_list;
|
||||
|
@ -216,7 +224,7 @@ TEST_F(MindDataTestMapOp, TestAsMap) {
|
|||
rc = builder.Build(&my_map_op);
|
||||
rc = my_tree_->AssociateNode(my_map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_map_op->AddChild(my_storage_op);
|
||||
rc = my_map_op->AddChild(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
// Assign the tree root
|
||||
|
@ -243,7 +251,7 @@ TEST_F(MindDataTestMapOp, TestAsMap) {
|
|||
}
|
||||
|
||||
// Test3to1 scenario:
|
||||
// StorageOp reads a dataset that have column ordering |image|label|A|B|.
|
||||
// TFReaderOp reads a dataset that have column ordering |image|label|A|B|.
|
||||
// A 3-to-1 TensorOp picks the columns [image, A, B] and produce a column named "X".
|
||||
// Thus, based on the new MapOp behaviour, the column ordering will be |X|label|.
|
||||
// Verify that the only columns "X" and "label" exist.
|
||||
|
@ -251,9 +259,9 @@ TEST_F(MindDataTestMapOp, Test3to1) {
|
|||
Status rc;
|
||||
MS_LOG(INFO) << "Doing Test3to1.";
|
||||
|
||||
// Note: The above storage config yields 5 buffers, each with 2 rows, for a total of 10 rows.
|
||||
auto my_storage_op = this->CreateStorageOp();
|
||||
rc = my_tree_->AssociateNode(my_storage_op);
|
||||
// Note: The above TFReader config yields 5 buffers, each with 2 rows, for a total of 10 rows.
|
||||
auto my_tfreader_op = this->CreateTFReaderOp();
|
||||
rc = my_tree_->AssociateNode(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
auto my_op = std::make_shared<mindspore::dataset::test::ThreeToOneOp>();
|
||||
std::vector<std::shared_ptr<TensorOp>> my_func_list;
|
||||
|
@ -268,7 +276,7 @@ TEST_F(MindDataTestMapOp, Test3to1) {
|
|||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->AssociateNode(my_map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_map_op->AddChild(my_storage_op);
|
||||
rc = my_map_op->AddChild(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->AssignRoot(my_map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
@ -295,7 +303,7 @@ TEST_F(MindDataTestMapOp, Test3to1) {
|
|||
}
|
||||
|
||||
// Test1to3 scenario:
|
||||
// StorageOp reads a dataset that have column ordering |image|label|A|B|.
|
||||
// TFReaderOp reads a dataset that have column ordering |image|label|A|B|.
|
||||
// A 1-to-3 TensorOp picks the columns [image] and produce a column named [X, Y, Z].
|
||||
// Thus, based on the new MapOp behaviour, the column ordering will be |X|Y|Z|label|A|B|.
|
||||
// Verify that the only columns X, Y, Z are added (to the front) and followed by columns label, A, B..
|
||||
|
@ -303,9 +311,9 @@ TEST_F(MindDataTestMapOp, Test1to3) {
|
|||
Status rc;
|
||||
MS_LOG(INFO) << "Doing Test1to3.";
|
||||
|
||||
// Note: The above storage config yields 5 buffers, each with 2 rows, for a total of 10 rows.
|
||||
auto my_storage_op = this->CreateStorageOp();
|
||||
rc = my_tree_->AssociateNode(my_storage_op);
|
||||
// Note: The above TFReader config yields 5 buffers, each with 2 rows, for a total of 10 rows.
|
||||
auto my_tfreader_op = this->CreateTFReaderOp();
|
||||
rc = my_tree_->AssociateNode(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
auto my_op = std::make_shared<mindspore::dataset::test::OneToThreeOp>();
|
||||
std::vector<std::shared_ptr<TensorOp>> my_func_list;
|
||||
|
@ -316,12 +324,25 @@ TEST_F(MindDataTestMapOp, Test1to3) {
|
|||
.SetOutColNames({"X", "Y", "Z"})
|
||||
.SetTensorFuncs(std::move(my_func_list))
|
||||
.SetNumWorkers(1);
|
||||
|
||||
|
||||
// ProjectOp
|
||||
std::vector<std::string> columns_to_project = {"X", "Y", "Z", "label", "A", "B"};
|
||||
std::shared_ptr<ProjectOp> my_project_op = std::make_shared<ProjectOp>(columns_to_project);
|
||||
rc = my_tree_->AssociateNode(my_project_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
rc = my_tree_->AssignRoot(my_project_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
rc = builder.Build(&my_map_op);
|
||||
rc = my_tree_->AssociateNode(my_map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_map_op->AddChild(my_storage_op);
|
||||
|
||||
rc = my_project_op->AddChild(my_map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->AssignRoot(my_map_op);
|
||||
|
||||
rc = my_map_op->AddChild(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->Prepare();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
@ -371,7 +392,7 @@ TEST_F(MindDataTestMapOp, Test1to3) {
|
|||
}
|
||||
|
||||
// TestMultiTensorOp scenario:
|
||||
// StorageOp reads a dataset that have column ordering |image|label|A|B|.
|
||||
// TFReaderOp reads a dataset that have column ordering |image|label|A|B|.
|
||||
// A series of 3-to-1 and 1-to-3 TensorOps are applied to [image, A, B] and
|
||||
// produce final output columns [X, Y, Z].
|
||||
// Based on the new MapOp behaviour, the column ordering will be |X|Y|Z|label|.
|
||||
|
@ -379,9 +400,9 @@ TEST_F(MindDataTestMapOp, TestMultiTensorOp) {
|
|||
Status rc;
|
||||
MS_LOG(INFO) << "Doing TestMultiTensorOp.";
|
||||
|
||||
// Note: The above storage config yields 5 buffers, each with 2 rows, for a total of 10 rows.
|
||||
auto my_storage_op = this->CreateStorageOp();
|
||||
rc = my_tree_->AssociateNode(my_storage_op);
|
||||
// Note: The above TFReader config yields 5 buffers, each with 2 rows, for a total of 10 rows.
|
||||
auto my_tfreader_op = this->CreateTFReaderOp();
|
||||
rc = my_tree_->AssociateNode(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
auto my_op1 = std::make_shared<mindspore::dataset::test::ThreeToOneOp>();
|
||||
auto my_op2 = std::make_shared<mindspore::dataset::test::OneToThreeOp>();
|
||||
|
@ -398,7 +419,7 @@ TEST_F(MindDataTestMapOp, TestMultiTensorOp) {
|
|||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->AssociateNode(my_map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_map_op->AddChild(my_storage_op);
|
||||
rc = my_map_op->AddChild(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->AssignRoot(my_map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
@ -431,15 +452,15 @@ TEST_F(MindDataTestMapOp, TestMultiTensorOp) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestMapOp, TestStorageRepeatMap) {
|
||||
TEST_F(MindDataTestMapOp, TestTFReaderRepeatMap) {
|
||||
Status rc;
|
||||
MS_LOG(INFO) << "Doing TestStorageRepeatMap.";
|
||||
MS_LOG(INFO) << "Doing TestTFReaderRepeatMap.";
|
||||
uint32_t num_repeats = 3;
|
||||
|
||||
// Note: The above storage config yields 5 buffers, each with 2 rows, for a total
|
||||
// Note: The above TFReader config yields 5 buffers, each with 2 rows, for a total
|
||||
// of 10 rows.
|
||||
auto my_storage_op = this->CreateStorageOp();
|
||||
rc = my_tree_->AssociateNode(my_storage_op);
|
||||
auto my_tfreader_op = this->CreateTFReaderOp();
|
||||
rc = my_tree_->AssociateNode(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
auto my_no_op = std::make_shared<mindspore::dataset::test::NoOp>();
|
||||
std::vector<std::shared_ptr<TensorOp>> my_func_list;
|
||||
|
@ -465,7 +486,7 @@ TEST_F(MindDataTestMapOp, TestStorageRepeatMap) {
|
|||
rc = my_map_op->AddChild(my_repeat_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
rc = my_repeat_op->AddChild(my_storage_op);
|
||||
rc = my_repeat_op->AddChild(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
rc = my_tree_->AssignRoot(my_map_op);
|
||||
|
@ -493,15 +514,15 @@ TEST_F(MindDataTestMapOp, TestStorageRepeatMap) {
|
|||
ASSERT_EQ(row_count, 10 * num_repeats);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestMapOp, TestStorageMapRepeat) {
|
||||
TEST_F(MindDataTestMapOp, TestTFReaderMapRepeat) {
|
||||
Status rc;
|
||||
MS_LOG(INFO) << "Doing TestStorageMapRepeat.";
|
||||
MS_LOG(INFO) << "Doing TestTFReaderMapRepeat.";
|
||||
uint32_t num_repeats = 3;
|
||||
|
||||
// Note: The above storage config yields 5 buffers, each with 2 rows, for a total
|
||||
// Note: The above TFReader config yields 5 buffers, each with 2 rows, for a total
|
||||
// of 10 rows.
|
||||
auto my_storage_op = this->CreateStorageOp();
|
||||
rc = my_tree_->AssociateNode(my_storage_op);
|
||||
auto my_tfreader_op = this->CreateTFReaderOp();
|
||||
rc = my_tree_->AssociateNode(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
auto my_no_op = std::make_shared<mindspore::dataset::test::NoOp>();
|
||||
std::vector<std::shared_ptr<TensorOp>> my_func_list;
|
||||
|
@ -527,7 +548,7 @@ TEST_F(MindDataTestMapOp, TestStorageMapRepeat) {
|
|||
rc = my_repeat_op->AddChild(my_map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
rc = my_map_op->AddChild(my_storage_op);
|
||||
rc = my_map_op->AddChild(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
rc = my_tree_->AssignRoot(my_repeat_op);
|
||||
|
@ -554,23 +575,23 @@ TEST_F(MindDataTestMapOp, TestStorageMapRepeat) {
|
|||
ASSERT_EQ(row_count, 10 * num_repeats);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestMapOp, Storage_Decode_Repeat_Resize) {
|
||||
TEST_F(MindDataTestMapOp, TFReader_Decode_Repeat_Resize) {
|
||||
Status rc;
|
||||
MS_LOG(INFO) << "Doing Storage_Decode_Repeat_Resize.";
|
||||
MS_LOG(INFO) << "Doing TFReader_Decode_Repeat_Resize.";
|
||||
uint32_t num_repeats = 2;
|
||||
|
||||
std::string dataset_path_ = datasets_root_path_ + "/" + "test_tf_file_3_images";
|
||||
std::shared_ptr<StorageOp> my_storage_op;
|
||||
StorageOp::Builder sobuilder;
|
||||
sobuilder.SetDatasetFilesDir(dataset_path_)
|
||||
std::string dataset_path_ = datasets_root_path_ + "/" + "test_tf_file_3_images/train-0000-of-0001.data";
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
TFReaderOp::Builder sobuilder;
|
||||
sobuilder.SetDatasetFilesList({dataset_path_})
|
||||
.SetColumnsToLoad({"image", "label"})
|
||||
.SetRowsPerBuffer(2)
|
||||
.SetWorkerConnectorSize(2)
|
||||
.SetNumWorkers(2);
|
||||
rc = sobuilder.Build(&my_storage_op);
|
||||
rc = sobuilder.Build(&my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
rc = my_tree_->AssociateNode(my_storage_op);
|
||||
rc = my_tree_->AssociateNode(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
auto decode_op = std::make_shared<DecodeOp>();
|
||||
std::vector<std::shared_ptr<TensorOp>> my_func_list;
|
||||
|
@ -608,7 +629,7 @@ TEST_F(MindDataTestMapOp, Storage_Decode_Repeat_Resize) {
|
|||
rc = my_tree_->AssociateNode(my_map_resize_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
rc = my_map_decode_op->AddChild(my_storage_op);
|
||||
rc = my_map_decode_op->AddChild(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
rc = my_repeat_op->AddChild(my_map_decode_op);
|
||||
|
|
|
@ -44,23 +44,23 @@ TEST_F(MindDataTestRenameOp, TestRenameOpDefault) {
|
|||
//
|
||||
// OpId(2) RenameOp
|
||||
// |
|
||||
// OpId(0) StorageOp
|
||||
// OpId(0) TFReaderOp
|
||||
// Start with an empty execution tree
|
||||
Status rc;
|
||||
MS_LOG(INFO) << "UT test TestRenameBasic.";
|
||||
auto my_tree = std::make_shared<ExecutionTree>();
|
||||
// Creating StorageOp
|
||||
// Creating TFReaderOp
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1";
|
||||
std::shared_ptr<StorageOp> my_storage_op;
|
||||
rc = StorageOp::Builder()
|
||||
.SetDatasetFilesDir(dataset_path)
|
||||
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data";
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
rc = TFReaderOp::Builder()
|
||||
.SetDatasetFilesList({dataset_path})
|
||||
.SetRowsPerBuffer(2)
|
||||
.SetWorkerConnectorSize(16)
|
||||
.SetNumWorkers(1)
|
||||
.Build(&my_storage_op);
|
||||
.Build(&my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssociateNode(my_storage_op);
|
||||
rc = my_tree->AssociateNode(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
// Creating DatasetOp
|
||||
|
@ -76,7 +76,7 @@ TEST_F(MindDataTestRenameOp, TestRenameOpDefault) {
|
|||
|
||||
rc = my_tree->AssociateNode(rename_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = rename_op->AddChild(std::move(my_storage_op));
|
||||
rc = rename_op->AddChild(std::move(my_tfreader_op));
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssignRoot(rename_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
|
|
@ -39,11 +39,11 @@ class MindDataTestShuffleOp : public UT::DatasetOpTesting {
|
|||
// - RowsPerBuffer buffer setting of 2 divides evenly into total rows.
|
||||
// - Shuffle size is multiple of rows per buffer.
|
||||
//
|
||||
// Tree: shuffle over storage
|
||||
// Tree: shuffle over TFReader
|
||||
//
|
||||
// ShuffleOp
|
||||
// |
|
||||
// StorageOp
|
||||
// TFReaderOp
|
||||
//
|
||||
TEST_F(MindDataTestShuffleOp, TestShuffleBasic1) {
|
||||
Status rc;
|
||||
|
@ -53,16 +53,16 @@ TEST_F(MindDataTestShuffleOp, TestShuffleBasic1) {
|
|||
auto my_tree = std::make_shared<ExecutionTree>();
|
||||
|
||||
std::string dataset_path;
|
||||
dataset_path = datasets_root_path_ + "/testDataset1";
|
||||
std::shared_ptr<StorageOp> my_storage_op;
|
||||
rc = StorageOp::Builder()
|
||||
.SetDatasetFilesDir(dataset_path)
|
||||
dataset_path = datasets_root_path_ + "/testDataset1/testDataset1.data";
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
rc = TFReaderOp::Builder()
|
||||
.SetDatasetFilesList({dataset_path})
|
||||
.SetRowsPerBuffer(2)
|
||||
.SetWorkerConnectorSize(16)
|
||||
.SetNumWorkers(1)
|
||||
.Build(&my_storage_op);
|
||||
.Build(&my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssociateNode(my_storage_op);
|
||||
rc = my_tree->AssociateNode(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
std::shared_ptr<ShuffleOp> my_shuffle_op;
|
||||
rc = ShuffleOp::Builder().SetRowsPerBuffer(2).SetShuffleSize(4).Build(&my_shuffle_op);
|
||||
|
@ -71,7 +71,7 @@ TEST_F(MindDataTestShuffleOp, TestShuffleBasic1) {
|
|||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
// Set children/root layout.
|
||||
rc = my_shuffle_op->AddChild(my_storage_op);
|
||||
rc = my_shuffle_op->AddChild(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssignRoot(my_shuffle_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
@ -112,11 +112,11 @@ TEST_F(MindDataTestShuffleOp, TestShuffleBasic1) {
|
|||
// - Shuffle size is not a multiple of rows per buffer.
|
||||
// - User has provided a non-default seed value.
|
||||
//
|
||||
// Tree: shuffle over storage
|
||||
// Tree: shuffle over TFReader
|
||||
//
|
||||
// ShuffleOp
|
||||
// |
|
||||
// StorageOp
|
||||
// TFReaderOp
|
||||
//
|
||||
TEST_F(MindDataTestShuffleOp, TestShuffleBasic2) {
|
||||
Status rc;
|
||||
|
@ -126,16 +126,16 @@ TEST_F(MindDataTestShuffleOp, TestShuffleBasic2) {
|
|||
auto my_tree = std::make_shared<ExecutionTree>();
|
||||
|
||||
std::string dataset_path;
|
||||
dataset_path = datasets_root_path_ + "/testDataset1";
|
||||
std::shared_ptr<StorageOp> my_storage_op;
|
||||
rc = StorageOp::Builder()
|
||||
.SetDatasetFilesDir(dataset_path)
|
||||
dataset_path = datasets_root_path_ + "/testDataset1/testDataset1.data";
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
rc = TFReaderOp::Builder()
|
||||
.SetDatasetFilesList({dataset_path})
|
||||
.SetRowsPerBuffer(3)
|
||||
.SetWorkerConnectorSize(16)
|
||||
.SetNumWorkers(2)
|
||||
.Build(&my_storage_op);
|
||||
.Build(&my_tfreader_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssociateNode(my_storage_op);
|
||||
rc = my_tree->AssociateNode(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
std::shared_ptr<ShuffleOp> my_shuffle_op;
|
||||
rc = ShuffleOp::Builder().SetShuffleSize(4).SetShuffleSeed(100).SetRowsPerBuffer(3).Build(&my_shuffle_op);
|
||||
|
@ -144,7 +144,7 @@ TEST_F(MindDataTestShuffleOp, TestShuffleBasic2) {
|
|||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
// Set children/root layout.
|
||||
rc = my_shuffle_op->AddChild(my_storage_op);
|
||||
rc = my_shuffle_op->AddChild(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssignRoot(my_shuffle_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
@ -183,11 +183,11 @@ TEST_F(MindDataTestShuffleOp, TestShuffleBasic2) {
|
|||
// - Shuffle size captures the entire dataset size (actually sets a value that is larger than the
|
||||
// amount of rows in the dataset.
|
||||
//
|
||||
// Tree: shuffle over storage
|
||||
// Tree: shuffle over TFReader
|
||||
//
|
||||
// ShuffleOp
|
||||
// |
|
||||
// StorageOp
|
||||
// TFReaderOp
|
||||
//
|
||||
TEST_F(MindDataTestShuffleOp, TestShuffleBasic3) {
|
||||
Status rc;
|
||||
|
@ -197,16 +197,16 @@ TEST_F(MindDataTestShuffleOp, TestShuffleBasic3) {
|
|||
auto my_tree = std::make_shared<ExecutionTree>();
|
||||
|
||||
std::string dataset_path;
|
||||
dataset_path = datasets_root_path_ + "/testDataset1";
|
||||
std::shared_ptr<StorageOp> my_storage_op;
|
||||
rc = StorageOp::Builder()
|
||||
.SetDatasetFilesDir(dataset_path)
|
||||
dataset_path = datasets_root_path_ + "/testDataset1/testDataset1.data";
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
rc = TFReaderOp::Builder()
|
||||
.SetDatasetFilesList({dataset_path})
|
||||
.SetRowsPerBuffer(3)
|
||||
.SetWorkerConnectorSize(16)
|
||||
.SetNumWorkers(2)
|
||||
.Build(&my_storage_op);
|
||||
.Build(&my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
my_tree->AssociateNode(my_storage_op);
|
||||
my_tree->AssociateNode(my_tfreader_op);
|
||||
std::shared_ptr<ShuffleOp> my_shuffle_op;
|
||||
rc = ShuffleOp::Builder().SetShuffleSize(100).SetRowsPerBuffer(3).Build(&my_shuffle_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
@ -214,7 +214,7 @@ TEST_F(MindDataTestShuffleOp, TestShuffleBasic3) {
|
|||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
// Set children/root layout.
|
||||
rc = my_shuffle_op->AddChild(my_storage_op);
|
||||
rc = my_shuffle_op->AddChild(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssignRoot(my_shuffle_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
@ -255,13 +255,13 @@ TEST_F(MindDataTestShuffleOp, TestShuffleBasic3) {
|
|||
// - shuffle seed is given, and subsequent epochs will change the seed each time.
|
||||
// - Repeat count of 2
|
||||
//
|
||||
// Tree: Repeat over shuffle over storage
|
||||
// Tree: Repeat over shuffle over TFReader
|
||||
//
|
||||
// Repeat
|
||||
// |
|
||||
// shuffle
|
||||
// |
|
||||
// StorageOp
|
||||
// TFReaderOp
|
||||
//
|
||||
TEST_F(MindDataTestShuffleOp, TestRepeatShuffle) {
|
||||
Status rc;
|
||||
|
@ -271,16 +271,16 @@ TEST_F(MindDataTestShuffleOp, TestRepeatShuffle) {
|
|||
auto my_tree = std::make_shared<ExecutionTree>();
|
||||
|
||||
std::string dataset_path;
|
||||
dataset_path = datasets_root_path_ + "/testDataset1";
|
||||
std::shared_ptr<StorageOp> my_storage_op;
|
||||
rc = StorageOp::Builder()
|
||||
.SetDatasetFilesDir(dataset_path)
|
||||
dataset_path = datasets_root_path_ + "/testDataset1/testDataset1.data";
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
rc = TFReaderOp::Builder()
|
||||
.SetDatasetFilesList({dataset_path})
|
||||
.SetRowsPerBuffer(3)
|
||||
.SetWorkerConnectorSize(16)
|
||||
.SetNumWorkers(2)
|
||||
.Build(&my_storage_op);
|
||||
.Build(&my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssociateNode(my_storage_op);
|
||||
rc = my_tree->AssociateNode(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
std::shared_ptr<ShuffleOp> my_shuffle_op;
|
||||
rc = ShuffleOp::Builder()
|
||||
|
@ -302,7 +302,7 @@ TEST_F(MindDataTestShuffleOp, TestRepeatShuffle) {
|
|||
// Set children/root layout.
|
||||
rc = my_repeat_op->AddChild(my_shuffle_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_shuffle_op->AddChild(my_storage_op);
|
||||
rc = my_shuffle_op->AddChild(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssignRoot(my_repeat_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
|
|
@ -1,165 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/core/client.h"
|
||||
#include "common/common.h"
|
||||
#include "common/utils.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
namespace common = mindspore::common;
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::LogStream;
|
||||
|
||||
class MindDataTestStorageOp : public UT::DatasetOpTesting {
|
||||
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestStorageOp, TestStorageBasic1) {
|
||||
|
||||
// single storage op and nothing else
|
||||
//
|
||||
// StorageOp
|
||||
|
||||
MS_LOG(INFO) << "UT test TestStorageBasic1.";
|
||||
|
||||
Status rc;
|
||||
|
||||
// Start with an empty execution tree
|
||||
auto my_tree = std::make_shared<ExecutionTree>();
|
||||
|
||||
// Test info:
|
||||
// Dataset from testDataset1 has 10 rows, 2 columns.
|
||||
// RowsPerBuffer buffer setting of 2 divides evenly into total rows.
|
||||
std::string dataset_path;
|
||||
dataset_path = datasets_root_path_ + "/testDataset1";
|
||||
std::shared_ptr<StorageOp> my_storage_op;
|
||||
StorageOp::Builder builder;
|
||||
builder.SetDatasetFilesDir(dataset_path)
|
||||
.SetRowsPerBuffer(2)
|
||||
.SetWorkerConnectorSize(16)
|
||||
.SetNumWorkers(1);
|
||||
rc = builder.Build(&my_storage_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
my_tree->AssociateNode(my_storage_op);
|
||||
|
||||
// Set children/root layout.
|
||||
my_tree->AssignRoot(my_storage_op);
|
||||
|
||||
MS_LOG(INFO) << "Launching tree and begin iteration.";
|
||||
my_tree->Prepare();
|
||||
my_tree->Launch();
|
||||
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
DatasetIterator di(my_tree);
|
||||
TensorRow tensor_list;
|
||||
rc = di.FetchNextTensorRow(&tensor_list);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
int row_count = 0;
|
||||
while (!tensor_list.empty()) {
|
||||
MS_LOG(INFO) << "Row display for row #: " << row_count << ".";
|
||||
|
||||
// Display the tensor by calling the printer on it
|
||||
for (int i = 0; i < tensor_list.size(); i++) {
|
||||
std::ostringstream ss;
|
||||
ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl;
|
||||
MS_LOG(INFO) << "Tensor print: " << common::SafeCStr(ss.str()) << ".";
|
||||
}
|
||||
|
||||
rc = di.FetchNextTensorRow(&tensor_list);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
row_count++;
|
||||
}
|
||||
ASSERT_EQ(row_count, 10); // Should be 10 rows fetched
|
||||
|
||||
// debugging temp. what happens if we keep fetching..
|
||||
rc = di.FetchNextTensorRow(&tensor_list);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
rc = di.FetchNextTensorRow(&tensor_list);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestStorageOp, TestStorageBasic2) {
|
||||
|
||||
// single storage op and nothing else
|
||||
//
|
||||
// StorageOp
|
||||
|
||||
MS_LOG(INFO) << "UT test TestStorageBasic1.";
|
||||
|
||||
Status rc;
|
||||
|
||||
// Start with an empty execution tree
|
||||
auto my_tree = std::make_shared<ExecutionTree>();
|
||||
|
||||
// Test info:
|
||||
// Dataset from testDataset1 has 10 rows, 2 columns.
|
||||
// RowsPerBuffer buffer setting of 3 yields 4 buffers with the last buffer having single row
|
||||
// only. 2 workers.
|
||||
// Test a column selection instead of all columns as well.
|
||||
std::string dataset_path;
|
||||
dataset_path = datasets_root_path_ + "/testDataset1";
|
||||
std::vector<std::string> column_list;
|
||||
std::string label_colname("label");
|
||||
column_list.push_back(label_colname);
|
||||
std::shared_ptr<StorageOp> my_storage_op;
|
||||
StorageOp::Builder builder;
|
||||
builder.SetDatasetFilesDir(dataset_path)
|
||||
.SetRowsPerBuffer(3)
|
||||
.SetWorkerConnectorSize(16)
|
||||
.SetNumWorkers(2)
|
||||
.SetColumnsToLoad(column_list);
|
||||
rc = builder.Build(&my_storage_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
my_tree->AssociateNode(my_storage_op);
|
||||
|
||||
// Set children/root layout.
|
||||
my_tree->AssignRoot(my_storage_op);
|
||||
|
||||
MS_LOG(INFO) << "Launching tree and begin iteration.";
|
||||
my_tree->Prepare();
|
||||
my_tree->Launch();
|
||||
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
DatasetIterator di(my_tree);
|
||||
TensorRow tensor_list;
|
||||
rc = di.FetchNextTensorRow(&tensor_list);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
int row_count = 0;
|
||||
while (!tensor_list.empty()) {
|
||||
MS_LOG(INFO) << "Row display for row #: " << row_count << ".";
|
||||
|
||||
// Display the tensor by calling the printer on it
|
||||
for (int i = 0; i < tensor_list.size(); i++) {
|
||||
std::ostringstream ss;
|
||||
ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl;
|
||||
MS_LOG(INFO) << "Tensor print: " << common::SafeCStr(ss.str()) << ".";
|
||||
}
|
||||
|
||||
rc = di.FetchNextTensorRow(&tensor_list);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
row_count++;
|
||||
}
|
||||
ASSERT_EQ(row_count, 10); // Should be 10 rows fetched
|
||||
}
|
|
@ -51,35 +51,35 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpDefault) {
|
|||
*
|
||||
* OpId(2) ZipOp
|
||||
* / \
|
||||
* OpId(0) StorageOp OpId(1) StorageOp
|
||||
* OpId(0) TFReaderOp OpId(1) TFReaderOp
|
||||
* Start with an empty execution tree
|
||||
*/
|
||||
Status rc;
|
||||
MS_LOG(INFO) << "UT test TestZipBasic.";
|
||||
auto my_tree = std::make_shared<ExecutionTree>();
|
||||
// Creating StorageOp
|
||||
// Creating TFReaderOp
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1";
|
||||
std::string dataset_path2 = datasets_root_path_ + "/test_tf_file_3_images_2";
|
||||
std::shared_ptr<StorageOp> my_storage_op;
|
||||
rc = StorageOp::Builder()
|
||||
.SetDatasetFilesDir(dataset_path)
|
||||
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data";
|
||||
std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data";
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
rc = TFReaderOp::Builder()
|
||||
.SetDatasetFilesList({dataset_path})
|
||||
.SetRowsPerBuffer(2)
|
||||
.SetWorkerConnectorSize(16)
|
||||
.SetNumWorkers(1)
|
||||
.Build(&my_storage_op);
|
||||
.Build(&my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssociateNode(my_storage_op);
|
||||
rc = my_tree->AssociateNode(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
std::shared_ptr<StorageOp> my_storage_op2;
|
||||
rc = StorageOp::Builder()
|
||||
.SetDatasetFilesDir(dataset_path2)
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op2;
|
||||
rc = TFReaderOp::Builder()
|
||||
.SetDatasetFilesList({dataset_path2})
|
||||
.SetRowsPerBuffer(2)
|
||||
.SetWorkerConnectorSize(1)
|
||||
.SetNumWorkers(1)
|
||||
.Build(&my_storage_op2);
|
||||
.Build(&my_tfreader_op2);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssociateNode(my_storage_op2);
|
||||
rc = my_tree->AssociateNode(my_tfreader_op2);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
// Creating DatasetOp
|
||||
|
@ -89,9 +89,9 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpDefault) {
|
|||
|
||||
rc = my_tree->AssociateNode(zip_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = zip_op->AddChild(std::move(my_storage_op));
|
||||
rc = zip_op->AddChild(std::move(my_tfreader_op));
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = zip_op->AddChild(std::move(my_storage_op2));
|
||||
rc = zip_op->AddChild(std::move(my_tfreader_op2));
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssignRoot(zip_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
@ -125,6 +125,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpDefault) {
|
|||
EXPECT_TRUE(rc.IsOk());
|
||||
row_count++;
|
||||
}
|
||||
MS_LOG(WARNING) <<"row count is: " << row_count;
|
||||
ASSERT_EQ(row_count, 3); // Should be 3 rows fetched
|
||||
}
|
||||
|
||||
|
@ -135,7 +136,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) {
|
|||
*
|
||||
* OpId(2) ZipOp
|
||||
* / \
|
||||
* OpId(0) StorageOp OpId(1) StorageOp
|
||||
* OpId(0) TFReaderOp OpId(1) TFReaderOp
|
||||
*
|
||||
* Start with an empty execution tree
|
||||
*/
|
||||
|
@ -143,27 +144,27 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) {
|
|||
MS_LOG(INFO) << "UT test TestZipRepeat.";
|
||||
auto my_tree = std::make_shared<ExecutionTree>();
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1";
|
||||
std::string dataset_path2 = datasets_root_path_ + "/test_tf_file_3_images_2";
|
||||
std::shared_ptr<StorageOp> my_storage_op;
|
||||
rc = StorageOp::Builder()
|
||||
.SetDatasetFilesDir(dataset_path)
|
||||
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data";
|
||||
std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data";
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
rc = TFReaderOp::Builder()
|
||||
.SetDatasetFilesList({dataset_path})
|
||||
.SetRowsPerBuffer(2)
|
||||
.SetWorkerConnectorSize(16)
|
||||
.SetNumWorkers(1)
|
||||
.Build(&my_storage_op);
|
||||
.Build(&my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssociateNode(my_storage_op);
|
||||
rc = my_tree->AssociateNode(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
std::shared_ptr<StorageOp> my_storage_op2;
|
||||
rc = StorageOp::Builder()
|
||||
.SetDatasetFilesDir(dataset_path2)
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op2;
|
||||
rc = TFReaderOp::Builder()
|
||||
.SetDatasetFilesList({dataset_path2})
|
||||
.SetRowsPerBuffer(2)
|
||||
.SetWorkerConnectorSize(1)
|
||||
.SetNumWorkers(1)
|
||||
.Build(&my_storage_op2);
|
||||
.Build(&my_tfreader_op2);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssociateNode(my_storage_op2);
|
||||
rc = my_tree->AssociateNode(my_tfreader_op2);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
// Creating DatasetOp
|
||||
std::shared_ptr<ZipOp> zip_op;
|
||||
|
@ -171,9 +172,9 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) {
|
|||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssociateNode(zip_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = zip_op->AddChild(std::move(my_storage_op));
|
||||
rc = zip_op->AddChild(std::move(my_tfreader_op));
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = zip_op->AddChild(std::move(my_storage_op2));
|
||||
rc = zip_op->AddChild(std::move(my_tfreader_op2));
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
// Builder(num_of_repeats)
|
||||
|
|
Binary file not shown.
|
@ -1,8 +0,0 @@
|
|||
{
|
||||
"deviceNum":3,
|
||||
"deviceId":1,
|
||||
"shardConfig":"ALL",
|
||||
"shuffle":"ON",
|
||||
"seed": 0,
|
||||
"epoch": 2
|
||||
}
|
|
@ -1,8 +0,0 @@
|
|||
{
|
||||
"deviceNum":7,
|
||||
"deviceId":6,
|
||||
"shardConfig":"RANDOM",
|
||||
"shuffle":"ON",
|
||||
"seed": 0,
|
||||
"epoch": 1
|
||||
}
|
|
@ -1,8 +0,0 @@
|
|||
{
|
||||
"deviceNum":3,
|
||||
"deviceId":1,
|
||||
"shardConfig":"RANDOM",
|
||||
"shuffle":"ON",
|
||||
"seed": 0,
|
||||
"epoch": 1
|
||||
}
|
|
@ -1,8 +0,0 @@
|
|||
{
|
||||
"deviceNum":3,
|
||||
"deviceId":1,
|
||||
"shardConfig":"UNIQUE",
|
||||
"shuffle":"ON",
|
||||
"seed": 0,
|
||||
"epoch": 3
|
||||
}
|
Loading…
Reference in New Issue