forked from mindspore-Ecosystem/mindspore
Added IR breakdown for remaining ops after PR7596
fix android build failure fix include in datasets.cc attempting fix android build fix complie err x86
This commit is contained in:
parent
d50736df2c
commit
cfc50e231a
File diff suppressed because it is too large
Load Diff
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -4,19 +4,17 @@ add_subdirectory(source)
|
|||
|
||||
set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
|
||||
batch_node.cc
|
||||
bucket_batch_by_length_node.cc
|
||||
build_vocab_node.cc
|
||||
concat_node.cc
|
||||
map_node.cc
|
||||
project_node.cc
|
||||
rename_node.cc
|
||||
repeat_node.cc
|
||||
shuffle_node.cc
|
||||
skip_node.cc
|
||||
take_node.cc
|
||||
zip_node.cc
|
||||
)
|
||||
|
||||
if (NOT ENABLE_ANDROID)
|
||||
set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
|
||||
${DATASET_ENGINE_IR_DATASETOPS_SRC_FILES}
|
||||
bucket_batch_by_length_node.cc
|
||||
build_vocab_node.cc)
|
||||
endif ()
|
||||
|
||||
add_library(engine-ir-datasetops OBJECT ${DATASET_ENGINE_IR_DATASETOPS_SRC_FILES})
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUCKET_BATCH_BY_LENGTH_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUCKET_BATCH_BY_LENGTH_NODE_H_
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUCKET_BATCH_BY_LENGTH_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUCKET_BATCH_BY_LENGTH_NODE_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
@ -61,4 +61,4 @@ class BucketBatchByLengthNode : public Dataset {
|
|||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUCKET_BATCH_BY_LENGTH_NODE_H_
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUCKET_BATCH_BY_LENGTH_NODE_H_
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUILD_VOCAB_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUILD_VOCAB_NODE_H_
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_VOCAB_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_VOCAB_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -58,4 +58,4 @@ class BuildVocabNode : public Dataset {
|
|||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUILD_VOCAB_NODE_H_
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_VOCAB_NODE_H_
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONCAT_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONCAT_NODE_H_
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CONCAT_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CONCAT_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -50,4 +50,4 @@ class ConcatNode : public Dataset {
|
|||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONCAT_NODE_H_
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CONCAT_NODE_H_
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
MapNode::MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations,
|
||||
std::vector<std::string> input_columns, std::vector<std::string> output_columns,
|
||||
const std::vector<std::string> &project_columns, std::shared_ptr<DatasetCache> cache)
|
||||
: operations_(operations),
|
||||
input_columns_(input_columns),
|
||||
output_columns_(output_columns),
|
||||
project_columns_(project_columns),
|
||||
Dataset(std::move(cache)) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
std::vector<std::shared_ptr<TensorOp>> tensor_ops;
|
||||
|
||||
// Build tensorOp from tensorOperation vector
|
||||
// This is to ensure each iterator hold its own copy of the tensorOp objects.
|
||||
(void)std::transform(
|
||||
operations_.begin(), operations_.end(), std::back_inserter(tensor_ops),
|
||||
[](std::shared_ptr<TensorOperation> operation) -> std::shared_ptr<TensorOp> { return operation->Build(); });
|
||||
|
||||
// This parameter will be removed with next rebase
|
||||
std::vector<std::string> col_orders;
|
||||
auto map_op = std::make_shared<MapOp>(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_);
|
||||
if (!project_columns_.empty()) {
|
||||
auto project_op = std::make_shared<ProjectOp>(project_columns_);
|
||||
node_ops.push_back(project_op);
|
||||
}
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
|
||||
node_ops.push_back(map_op);
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
Status MapNode::ValidateParams() {
|
||||
if (operations_.empty()) {
|
||||
std::string err_msg = "MapNode: No operation is specified.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (!input_columns_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapNode", "input_columns", input_columns_));
|
||||
}
|
||||
|
||||
if (!output_columns_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapNode", "output_columns", output_columns_));
|
||||
}
|
||||
|
||||
if (!project_columns_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapNode", "project_columns", project_columns_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_MAP_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_MAP_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
class MapNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations,
|
||||
std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {},
|
||||
const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr);
|
||||
|
||||
/// \brief Destructor
|
||||
~MapNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<TensorOperation>> operations_;
|
||||
std::vector<std::string> input_columns_;
|
||||
std::vector<std::string> output_columns_;
|
||||
std::vector<std::string> project_columns_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_MAP_NODE_H_
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PROJECT_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PROJECT_NODE_H_
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_PROJECT_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_PROJECT_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -51,4 +51,4 @@ class ProjectNode : public Dataset {
|
|||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PROJECT_NODE_H_
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_PROJECT_NODE_H_
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENAME_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENAME_NODE_H_
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_RENAME_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_RENAME_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -53,4 +53,4 @@ class RenameNode : public Dataset {
|
|||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENAME_NODE_H_
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_RENAME_NODE_H_
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_REPEAT_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_REPEAT_NODE_H_
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_REPEAT_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_REPEAT_NODE_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
@ -53,4 +53,4 @@ class RepeatNode : public Dataset {
|
|||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_REPEAT_NODE_H_
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_REPEAT_NODE_H_
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SHUFFLE_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SHUFFLE_NODE_H_
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SHUFFLE_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SHUFFLE_NODE_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
@ -49,4 +49,4 @@ class ShuffleNode : public Dataset {
|
|||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SHUFFLE_NODE_H_
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SHUFFLE_NODE_H_
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/skip_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
// Constructor for SkipNode
|
||||
SkipNode::SkipNode(std::shared_ptr<Dataset> child, int32_t count) : skip_count_(count) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
// Function to build the SkipOp
|
||||
std::vector<std::shared_ptr<DatasetOp>> SkipNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<SkipOp>(skip_count_, connector_que_size_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Function to validate the parameters for SkipNode
|
||||
Status SkipNode::ValidateParams() {
|
||||
if (skip_count_ <= -1) {
|
||||
std::string err_msg = "SkipNode: skip_count should not be negative, skip_count: " + std::to_string(skip_count_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SKIP_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SKIP_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
namespace api {
|
||||
class SkipNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit SkipNode(std::shared_ptr<Dataset> child, int32_t count);
|
||||
|
||||
/// \brief Destructor
|
||||
~SkipNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
int32_t skip_count_;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SKIP_NODE_H_
|
|
@ -2,7 +2,21 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
|
|||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
|
||||
set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
||||
album_node.cc
|
||||
celeba_node.cc
|
||||
cifar100_node.cc
|
||||
cifar10_node.cc
|
||||
clue_node.cc
|
||||
coco_node.cc
|
||||
csv_node.cc
|
||||
image_folder_node.cc
|
||||
manifest_node.cc
|
||||
minddata_node.cc
|
||||
mnist_node.cc
|
||||
random_node.cc
|
||||
text_file_node.cc
|
||||
tf_record_node.cc
|
||||
voc_node.cc
|
||||
)
|
||||
|
||||
if (ENABLE_PYTHON)
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/album_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
// Constructor for AlbumNode
|
||||
AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema,
|
||||
const std::vector<std::string> &column_names, bool decode,
|
||||
const std::shared_ptr<SamplerObj> &sampler)
|
||||
: dataset_dir_(dataset_dir),
|
||||
schema_path_(data_schema),
|
||||
column_names_(column_names),
|
||||
decode_(decode),
|
||||
sampler_(sampler) {}
|
||||
|
||||
Status AlbumNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumNode", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("AlbumNode", {schema_path_}));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("AlbumNode", sampler_));
|
||||
|
||||
if (!column_names_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("AlbumNode", "column_names", column_names_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build AlbumNode
|
||||
std::vector<std::shared_ptr<DatasetOp>> AlbumNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_EMPTY_IF_ERROR(schema->LoadSchemaFile(schema_path_, column_names_));
|
||||
|
||||
// Argument that is not exposed to user in the API.
|
||||
std::set<std::string> extensions = {};
|
||||
|
||||
node_ops.push_back(std::make_shared<AlbumOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
decode_, extensions, std::move(schema), std::move(sampler_->Build())));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_ALBUM_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_ALBUM_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
class AlbumNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
AlbumNode(const std::string &dataset_dir, const std::string &data_schema,
|
||||
const std::vector<std::string> &column_names, bool decode, const std::shared_ptr<SamplerObj> &sampler);
|
||||
|
||||
/// \brief Destructor
|
||||
~AlbumNode() = default;
|
||||
|
||||
/// \brief a base class override function to create a runtime dataset op object from this class
|
||||
/// \return shared pointer to the newly created DatasetOp
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string schema_path_;
|
||||
std::vector<std::string> column_names_;
|
||||
bool decode_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_ALBUM_NODE_H_
|
|
@ -0,0 +1,72 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
// Constructor for CelebANode
|
||||
CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage,
|
||||
const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
|
||||
const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache)
|
||||
: Dataset(cache),
|
||||
dataset_dir_(dataset_dir),
|
||||
usage_(usage),
|
||||
sampler_(sampler),
|
||||
decode_(decode),
|
||||
extensions_(extensions) {}
|
||||
|
||||
Status CelebANode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("CelebANode", sampler_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("CelebANode", usage_, {"all", "train", "valid", "test"}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build CelebANode
|
||||
std::vector<std::shared_ptr<DatasetOp>> CelebANode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
// label is like this:0 1 0 0 1......
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
|
||||
node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
decode_, usage_, extensions_, std::move(schema),
|
||||
std::move(sampler_->Build())));
|
||||
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CELEBA_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CELEBA_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
class CelebANode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler,
|
||||
const bool &decode, const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~CelebANode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
bool decode_;
|
||||
std::set<std::string> extensions_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CELEBA_NODE_H_
|
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
// Constructor for Cifar100Node
|
||||
Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage,
|
||||
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache)
|
||||
: Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||
|
||||
Status Cifar100Node::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar100Node", sampler_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("Cifar100Node", usage_, {"train", "test", "all"}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build CifarOp for Cifar100
|
||||
std::vector<std::shared_ptr<DatasetOp>> Cifar100Node::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("coarse_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
|
||||
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_,
|
||||
dataset_dir_, connector_que_size_, std::move(schema),
|
||||
std::move(sampler_->Build())));
|
||||
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR100_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR100_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
class Cifar100Node : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~Cifar100Node() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR100_NODE_H_
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
// Constructor for Cifar10Node
|
||||
Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache)
|
||||
: Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||
|
||||
Status Cifar10Node::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar10Node", sampler_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("Cifar10Node", usage_, {"train", "test", "all"}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build CifarOp for Cifar10
|
||||
std::vector<std::shared_ptr<DatasetOp>> Cifar10Node::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
|
||||
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_,
|
||||
dataset_dir_, connector_que_size_, std::move(schema),
|
||||
std::move(sampler_->Build())));
|
||||
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR10_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR10_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
class Cifar10Node : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~Cifar10Node() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR10_NODE_H_
|
|
@ -0,0 +1,218 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
// Constructor for CLUENode
|
||||
CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, std::string usage, int64_t num_samples,
|
||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
|
||||
: Dataset(std::move(cache)),
|
||||
dataset_files_(clue_files),
|
||||
task_(task),
|
||||
usage_(usage),
|
||||
num_samples_(num_samples),
|
||||
shuffle_(shuffle),
|
||||
num_shards_(num_shards),
|
||||
shard_id_(shard_id) {}
|
||||
|
||||
Status CLUENode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("CLUENode", task_, {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"}));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("CLUENode", usage_, {"train", "test", "eval"}));
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
std::string err_msg = "CLUENode: Invalid number of samples: " + std::to_string(num_samples_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetShardParams("CLUENode", num_shards_, shard_id_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to split string based on a character delimiter
|
||||
std::vector<std::string> CLUENode::split(const std::string &s, char delim) {
|
||||
std::vector<std::string> res;
|
||||
std::stringstream ss(s);
|
||||
std::string item;
|
||||
|
||||
while (getline(ss, item, delim)) {
|
||||
res.push_back(item);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
// Function to build CLUENode
|
||||
std::vector<std::shared_ptr<DatasetOp>> CLUENode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
std::map<std::string, std::string> key_map;
|
||||
if (task_ == "AFQMC") {
|
||||
if (usage_ == "train") {
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
key_map["label"] = "label";
|
||||
} else if (usage_ == "test") {
|
||||
key_map["id"] = "id";
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
} else if (usage_ == "eval") {
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
key_map["label"] = "label";
|
||||
}
|
||||
} else if (task_ == "CMNLI") {
|
||||
if (usage_ == "train") {
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
key_map["label"] = "label";
|
||||
} else if (usage_ == "test") {
|
||||
key_map["id"] = "id";
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
} else if (usage_ == "eval") {
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
key_map["label"] = "label";
|
||||
}
|
||||
} else if (task_ == "CSL") {
|
||||
if (usage_ == "train") {
|
||||
key_map["id"] = "id";
|
||||
key_map["abst"] = "abst";
|
||||
key_map["keyword"] = "keyword";
|
||||
key_map["label"] = "label";
|
||||
} else if (usage_ == "test") {
|
||||
key_map["id"] = "id";
|
||||
key_map["abst"] = "abst";
|
||||
key_map["keyword"] = "keyword";
|
||||
} else if (usage_ == "eval") {
|
||||
key_map["id"] = "id";
|
||||
key_map["abst"] = "abst";
|
||||
key_map["keyword"] = "keyword";
|
||||
key_map["label"] = "label";
|
||||
}
|
||||
} else if (task_ == "IFLYTEK") {
|
||||
if (usage_ == "train") {
|
||||
key_map["label"] = "label";
|
||||
key_map["label_des"] = "label_des";
|
||||
key_map["sentence"] = "sentence";
|
||||
} else if (usage_ == "test") {
|
||||
key_map["id"] = "id";
|
||||
key_map["sentence"] = "sentence";
|
||||
} else if (usage_ == "eval") {
|
||||
key_map["label"] = "label";
|
||||
key_map["label_des"] = "label_des";
|
||||
key_map["sentence"] = "sentence";
|
||||
}
|
||||
} else if (task_ == "TNEWS") {
|
||||
if (usage_ == "train") {
|
||||
key_map["label"] = "label";
|
||||
key_map["label_desc"] = "label_desc";
|
||||
key_map["sentence"] = "sentence";
|
||||
key_map["keywords"] = "keywords";
|
||||
} else if (usage_ == "test") {
|
||||
key_map["id"] = "id";
|
||||
key_map["sentence"] = "sentence";
|
||||
key_map["keywords"] = "keywords";
|
||||
} else if (usage_ == "eval") {
|
||||
key_map["label"] = "label";
|
||||
key_map["label_desc"] = "label_desc";
|
||||
key_map["sentence"] = "sentence";
|
||||
key_map["keywords"] = "keywords";
|
||||
}
|
||||
} else if (task_ == "WSC") {
|
||||
if (usage_ == "train") {
|
||||
key_map["span1_index"] = "target/span1_index";
|
||||
key_map["span2_index"] = "target/span2_index";
|
||||
key_map["span1_text"] = "target/span1_text";
|
||||
key_map["span2_text"] = "target/span2_text";
|
||||
key_map["idx"] = "idx";
|
||||
key_map["label"] = "label";
|
||||
key_map["text"] = "text";
|
||||
} else if (usage_ == "test") {
|
||||
key_map["span1_index"] = "target/span1_index";
|
||||
key_map["span2_index"] = "target/span2_index";
|
||||
key_map["span1_text"] = "target/span1_text";
|
||||
key_map["span2_text"] = "target/span2_text";
|
||||
key_map["idx"] = "idx";
|
||||
key_map["text"] = "text";
|
||||
} else if (usage_ == "eval") {
|
||||
key_map["span1_index"] = "target/span1_index";
|
||||
key_map["span2_index"] = "target/span2_index";
|
||||
key_map["span1_text"] = "target/span1_text";
|
||||
key_map["span2_text"] = "target/span2_text";
|
||||
key_map["idx"] = "idx";
|
||||
key_map["label"] = "label";
|
||||
key_map["text"] = "text";
|
||||
}
|
||||
}
|
||||
|
||||
ColKeyMap ck_map;
|
||||
for (auto &p : key_map) {
|
||||
ck_map.insert({p.first, split(p.second, '/')});
|
||||
}
|
||||
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
|
||||
// Sort the dataset files in a lexicographical order
|
||||
std::vector<std::string> sorted_dataset_files = dataset_files_;
|
||||
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
|
||||
|
||||
std::shared_ptr<ClueOp> clue_op =
|
||||
std::make_shared<ClueOp>(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map,
|
||||
sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr);
|
||||
RETURN_EMPTY_IF_ERROR(clue_op->Init());
|
||||
if (shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp
|
||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset
|
||||
RETURN_EMPTY_IF_ERROR(ClueOp::CountAllFileRows(sorted_dataset_files, &num_rows));
|
||||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op));
|
||||
node_ops.push_back(shuffle_op);
|
||||
}
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
|
||||
node_ops.push_back(clue_op);
|
||||
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CLUE_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CLUE_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
/// \class CLUENode
|
||||
/// \brief A Dataset derived class to represent CLUE dataset
|
||||
class CLUENode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CLUENode(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples,
|
||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~CLUENode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
/// \brief Split string based on a character delimiter
|
||||
/// \return A string vector
|
||||
std::vector<std::string> split(const std::string &s, char delim);
|
||||
|
||||
std::vector<std::string> dataset_files_;
|
||||
std::string task_;
|
||||
std::string usage_;
|
||||
int64_t num_samples_;
|
||||
ShuffleMode shuffle_;
|
||||
int32_t num_shards_;
|
||||
int32_t shard_id_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CLUE_NODE_H_
|
|
@ -0,0 +1,122 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
// Constructor for CocoNode
|
||||
CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
|
||||
const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
|
||||
: Dataset(std::move(cache)),
|
||||
dataset_dir_(dataset_dir),
|
||||
annotation_file_(annotation_file),
|
||||
task_(task),
|
||||
decode_(decode),
|
||||
sampler_(sampler) {}
|
||||
|
||||
Status CocoNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("CocoNode", sampler_));
|
||||
|
||||
Path annotation_file(annotation_file_);
|
||||
if (!annotation_file.Exists()) {
|
||||
std::string err_msg = "CocoNode: annotation_file is invalid or does not exist.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("CocoNode", task_, {"Detection", "Stuff", "Panoptic", "Keypoint"}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build CocoNode
|
||||
std::vector<std::shared_ptr<DatasetOp>> CocoNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
CocoOp::TaskType task_type;
|
||||
if (task_ == "Detection") {
|
||||
task_type = CocoOp::TaskType::Detection;
|
||||
} else if (task_ == "Stuff") {
|
||||
task_type = CocoOp::TaskType::Stuff;
|
||||
} else if (task_ == "Keypoint") {
|
||||
task_type = CocoOp::TaskType::Keypoint;
|
||||
} else if (task_ == "Panoptic") {
|
||||
task_type = CocoOp::TaskType::Panoptic;
|
||||
}
|
||||
|
||||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor(std::string("image"), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
switch (task_type) {
|
||||
case CocoOp::TaskType::Detection:
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string("bbox"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string("category_id"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string("iscrowd"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
break;
|
||||
case CocoOp::TaskType::Stuff:
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string("segmentation"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string("iscrowd"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
break;
|
||||
case CocoOp::TaskType::Keypoint:
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string("keypoints"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string("num_keypoints"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
break;
|
||||
case CocoOp::TaskType::Panoptic:
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string("bbox"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string("category_id"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string("iscrowd"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor(std::string("area"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "CocoNode::Build : Invalid task type";
|
||||
return {};
|
||||
}
|
||||
std::shared_ptr<CocoOp> op =
|
||||
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_,
|
||||
connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
|
||||
node_ops.push_back(op);
|
||||
|
||||
return node_ops;
|
||||
}
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_COCO_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_COCO_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
class CocoNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
|
||||
const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~CocoNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string annotation_file_;
|
||||
std::string task_;
|
||||
bool decode_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_COCO_NODE_H_
|
|
@ -0,0 +1,127 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
// Constructor for CSVNode
|
||||
CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim,
|
||||
const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
|
||||
const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle,
|
||||
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
|
||||
: Dataset(std::move(cache)),
|
||||
dataset_files_(csv_files),
|
||||
field_delim_(field_delim),
|
||||
column_defaults_(column_defaults),
|
||||
column_names_(column_names),
|
||||
num_samples_(num_samples),
|
||||
shuffle_(shuffle),
|
||||
num_shards_(num_shards),
|
||||
shard_id_(shard_id) {}
|
||||
|
||||
Status CSVNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CSVNode", dataset_files_));
|
||||
|
||||
if (field_delim_ == '"' || field_delim_ == '\r' || field_delim_ == '\n') {
|
||||
std::string err_msg = "CSVNode: The field delimiter should not be \", \\r, \\n";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
std::string err_msg = "CSVNode: Invalid number of samples: " + std::to_string(num_samples_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetShardParams("CSVNode", num_shards_, shard_id_));
|
||||
|
||||
if (find(column_defaults_.begin(), column_defaults_.end(), nullptr) != column_defaults_.end()) {
|
||||
std::string err_msg = "CSVNode: column_default should not be null.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (!column_names_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("CSVNode", "column_names", column_names_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build CSVNode
|
||||
std::vector<std::shared_ptr<DatasetOp>> CSVNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
|
||||
// Sort the dataset files in a lexicographical order
|
||||
std::vector<std::string> sorted_dataset_files = dataset_files_;
|
||||
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
|
||||
|
||||
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list;
|
||||
for (auto v : column_defaults_) {
|
||||
if (v->type == CsvType::INT) {
|
||||
column_default_list.push_back(
|
||||
std::make_shared<CsvOp::Record<int>>(CsvOp::INT, std::dynamic_pointer_cast<CsvRecord<int>>(v)->value));
|
||||
} else if (v->type == CsvType::FLOAT) {
|
||||
column_default_list.push_back(
|
||||
std::make_shared<CsvOp::Record<float>>(CsvOp::FLOAT, std::dynamic_pointer_cast<CsvRecord<float>>(v)->value));
|
||||
} else if (v->type == CsvType::STRING) {
|
||||
column_default_list.push_back(std::make_shared<CsvOp::Record<std::string>>(
|
||||
CsvOp::STRING, std::dynamic_pointer_cast<CsvRecord<std::string>>(v)->value));
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>(
|
||||
sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_,
|
||||
num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr);
|
||||
RETURN_EMPTY_IF_ERROR(csv_op->Init());
|
||||
if (shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp
|
||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset
|
||||
RETURN_EMPTY_IF_ERROR(CsvOp::CountAllFileRows(sorted_dataset_files, column_names_.empty(), &num_rows));
|
||||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op));
|
||||
|
||||
node_ops.push_back(shuffle_op);
|
||||
}
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
|
||||
node_ops.push_back(csv_op);
|
||||
|
||||
return node_ops;
|
||||
}
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,82 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CSV_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CSV_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api { /// \brief Base class of CSV Record
|
||||
/// \brief Record type for CSV
|
||||
enum CsvType : uint8_t { INT = 0, FLOAT, STRING };
|
||||
|
||||
class CsvBase {
|
||||
public:
|
||||
CsvBase() = default;
|
||||
explicit CsvBase(CsvType t) : type(t) {}
|
||||
virtual ~CsvBase() {}
|
||||
CsvType type;
|
||||
};
|
||||
|
||||
/// \brief CSV Record that can represent integer, float and string.
|
||||
template <typename T>
|
||||
class CsvRecord : public CsvBase {
|
||||
public:
|
||||
CsvRecord() = default;
|
||||
CsvRecord(CsvType t, T v) : CsvBase(t), value(v) {}
|
||||
~CsvRecord() {}
|
||||
T value;
|
||||
};
|
||||
|
||||
class CSVNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CSVNode(const std::vector<std::string> &dataset_files, char field_delim,
|
||||
const std::vector<std::shared_ptr<CsvBase>> &column_defaults, const std::vector<std::string> &column_names,
|
||||
int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~CSVNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> dataset_files_;
|
||||
char field_delim_;
|
||||
std::vector<std::shared_ptr<CsvBase>> column_defaults_;
|
||||
std::vector<std::string> column_names_;
|
||||
int64_t num_samples_;
|
||||
ShuffleMode shuffle_;
|
||||
int32_t num_shards_;
|
||||
int32_t shard_id_;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CSV_NODE_H_
|
|
@ -0,0 +1,90 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &usage,
|
||||
const std::shared_ptr<SamplerObj> &sampler,
|
||||
const std::map<std::string, int32_t> &class_indexing, bool decode,
|
||||
std::shared_ptr<DatasetCache> cache)
|
||||
: Dataset(std::move(cache)),
|
||||
dataset_file_(dataset_file),
|
||||
usage_(usage),
|
||||
decode_(decode),
|
||||
class_index_(class_indexing),
|
||||
sampler_(sampler) {}
|
||||
|
||||
Status ManifestNode::ValidateParams() {
|
||||
std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'};
|
||||
for (char c : dataset_file_) {
|
||||
auto p = std::find(forbidden_symbols.begin(), forbidden_symbols.end(), c);
|
||||
if (p != forbidden_symbols.end()) {
|
||||
std::string err_msg = "ManifestNode: filename should not contain :*?\"<>|`&;\'";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
Path manifest_file(dataset_file_);
|
||||
if (!manifest_file.Exists()) {
|
||||
std::string err_msg = "ManifestNode: dataset file: [" + dataset_file_ + "] is invalid or not exist";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("ManifestNode", sampler_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("ManifestNode", usage_, {"train", "eval", "inference"}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> ManifestNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
std::shared_ptr<ManifestOp> manifest_op;
|
||||
manifest_op =
|
||||
std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_,
|
||||
class_index_, std::move(schema), std::move(sampler_->Build()), usage_);
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
|
||||
node_ops.push_back(manifest_op);
|
||||
|
||||
return node_ops;
|
||||
}
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MANIFEST_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MANIFEST_NODE_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
class ManifestNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler,
|
||||
const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~ManifestNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_file_;
|
||||
std::string usage_;
|
||||
bool decode_;
|
||||
std::map<std::string, int32_t> class_index_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MANIFEST_NODE_H_
|
|
@ -0,0 +1,165 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
MindDataNode::MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list,
|
||||
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded)
|
||||
: dataset_file_(std::string()),
|
||||
dataset_files_(dataset_files),
|
||||
search_for_pattern_(false),
|
||||
columns_list_(columns_list),
|
||||
sampler_(sampler),
|
||||
padded_sample_(padded_sample),
|
||||
sample_bytes_({}),
|
||||
num_padded_(num_padded) {}
|
||||
|
||||
MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<std::string> &columns_list,
|
||||
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded)
|
||||
: dataset_file_(dataset_file),
|
||||
dataset_files_({}),
|
||||
search_for_pattern_(true),
|
||||
columns_list_(columns_list),
|
||||
sampler_(sampler),
|
||||
padded_sample_(padded_sample),
|
||||
sample_bytes_({}),
|
||||
num_padded_(num_padded) {}
|
||||
|
||||
Status MindDataNode::ValidateParams() {
|
||||
if (!search_for_pattern_ && dataset_files_.size() > 4096) {
|
||||
std::string err_msg =
|
||||
"MindDataNode: length of dataset_file must be less than or equal to 4096, dataset_file length: " +
|
||||
std::to_string(dataset_file_.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
std::vector<std::string> dataset_file_vec =
|
||||
search_for_pattern_ ? std::vector<std::string>{dataset_file_} : dataset_files_;
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("MindDataNode", dataset_file_vec));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("MindDataNode", sampler_));
|
||||
|
||||
if (!columns_list_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MindDataNode", "columns_list", columns_list_));
|
||||
}
|
||||
|
||||
if (padded_sample_ != nullptr) {
|
||||
if (num_padded_ < 0) {
|
||||
std::string err_msg =
|
||||
"MindDataNode: num_padded must be greater than or equal to zero, num_padded: " + std::to_string(num_padded_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (columns_list_.empty()) {
|
||||
std::string err_msg = "MindDataNode: padded_sample is specified and requires columns_list as well";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
for (std::string &column : columns_list_) {
|
||||
if (padded_sample_.find(column) == padded_sample_.end()) {
|
||||
std::string err_msg = "MindDataNode: " + column + " in columns_list does not match any column in padded_sample";
|
||||
MS_LOG(ERROR) << err_msg << ", padded_sample: " << padded_sample_;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (num_padded_ > 0) {
|
||||
if (padded_sample_ == nullptr) {
|
||||
std::string err_msg = "MindDataNode: num_padded is specified but padded_sample is not";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Helper function to create runtime sampler for minddata dataset
|
||||
Status MindDataNode::BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> &sampler,
|
||||
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators_,
|
||||
int64_t num_padded) {
|
||||
std::shared_ptr<mindrecord::ShardOperator> op = sampler->BuildForMindDataset();
|
||||
if (op == nullptr) {
|
||||
std::string err_msg =
|
||||
"MindDataNode: Unsupported sampler is supplied for MindDataset. Supported sampler list: "
|
||||
"SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler and DistributedSampler";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
std::stack<std::shared_ptr<mindrecord::ShardOperator>> stack_ops;
|
||||
while (op != nullptr) {
|
||||
auto sampler_op = std::dynamic_pointer_cast<mindrecord::ShardDistributedSample>(op);
|
||||
if (sampler_op && num_padded > 0) {
|
||||
sampler_op->SetNumPaddedSamples(num_padded);
|
||||
stack_ops.push(sampler_op);
|
||||
} else {
|
||||
stack_ops.push(op);
|
||||
}
|
||||
op = op->GetChildOp();
|
||||
}
|
||||
while (!stack_ops.empty()) {
|
||||
operators_->push_back(stack_ops.top());
|
||||
stack_ops.pop();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Helper function to set sample_bytes from py::byte type
|
||||
void MindDataNode::SetSampleBytes(std::map<std::string, std::string> *sample_bytes) { sample_bytes_ = *sample_bytes; }
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
std::vector<std::shared_ptr<ShardOperator>> operators_;
|
||||
RETURN_EMPTY_IF_ERROR(BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_));
|
||||
|
||||
std::shared_ptr<MindRecordOp> mindrecord_op;
|
||||
// If pass a string to MindData(), it will be treated as a pattern to search for matched files,
|
||||
// else if pass a vector to MindData(), it will be treated as specified files to be read
|
||||
if (search_for_pattern_) {
|
||||
std::vector<std::string> dataset_file_vec_ = {dataset_file_};
|
||||
mindrecord_op = std::make_shared<MindRecordOp>(num_workers_, rows_per_buffer_, dataset_file_vec_,
|
||||
search_for_pattern_, connector_que_size_, columns_list_, operators_,
|
||||
num_padded_, padded_sample_, sample_bytes_);
|
||||
} else {
|
||||
mindrecord_op = std::make_shared<MindRecordOp>(num_workers_, rows_per_buffer_, dataset_files_, search_for_pattern_,
|
||||
connector_que_size_, columns_list_, operators_, num_padded_,
|
||||
padded_sample_, sample_bytes_);
|
||||
}
|
||||
|
||||
RETURN_EMPTY_IF_ERROR(mindrecord_op->Init());
|
||||
node_ops.push_back(mindrecord_op);
|
||||
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,75 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MINDDATA_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MINDDATA_NODE_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
class MindDataNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list,
|
||||
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded);
|
||||
|
||||
/// \brief Constructor
|
||||
MindDataNode(const std::string &dataset_file, const std::vector<std::string> &columns_list,
|
||||
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded);
|
||||
|
||||
/// \brief Destructor
|
||||
~MindDataNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Build sampler chain for minddata dataset
|
||||
/// \return Status Status::OK() if input sampler is valid
|
||||
Status BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> &sampler,
|
||||
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators_,
|
||||
int64_t num_padded);
|
||||
|
||||
/// \brief Set sample_bytes when padded_sample has py::byte value
|
||||
/// \note Pybind will use this function to set sample_bytes into MindDataNode
|
||||
void SetSampleBytes(std::map<std::string, std::string> *sample_bytes);
|
||||
|
||||
private:
|
||||
std::string dataset_file_; // search_for_pattern_ will be true in this mode
|
||||
std::vector<std::string> dataset_files_; // search_for_pattern_ will be false in this mode
|
||||
bool search_for_pattern_;
|
||||
std::vector<std::string> columns_list_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
nlohmann::json padded_sample_;
|
||||
std::map<std::string, std::string> sample_bytes_; // enable in python
|
||||
int64_t num_padded_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MINDDATA_NODE_H_
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache)
|
||||
: Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||
|
||||
Status MnistNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("MnistNode", sampler_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("MnistNode", usage_, {"train", "test", "all"}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> MnistNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
|
||||
node_ops.push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_,
|
||||
connector_que_size_, std::move(schema), std::move(sampler_->Build())));
|
||||
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MNIST_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MNIST_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
class MnistNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~MnistNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MNIST_NODE_H_
|
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
// ValidateParams for RandomNode
|
||||
Status RandomNode::ValidateParams() {
|
||||
if (total_rows_ < 0) {
|
||||
std::string err_msg =
|
||||
"RandomNode: total_rows must be greater than or equal 0, now get " + std::to_string(total_rows_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("RandomNode", sampler_));
|
||||
|
||||
if (!columns_list_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RandomNode", "columns_list", columns_list_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int32_t RandomNode::GenRandomInt(int32_t min, int32_t max) {
|
||||
std::uniform_int_distribution<int32_t> uniDist(min, max);
|
||||
return uniDist(rand_gen_);
|
||||
}
|
||||
|
||||
// Build for RandomNode
|
||||
std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
rand_gen_.seed(GetSeed()); // seed the random generator
|
||||
// If total rows was not given, then randomly pick a number
|
||||
std::shared_ptr<SchemaObj> schema_obj;
|
||||
if (!schema_path_.empty()) {
|
||||
schema_obj = Schema(schema_path_);
|
||||
if (schema_obj == nullptr) {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
std::string schema_json_string, schema_file_path;
|
||||
if (schema_ != nullptr) {
|
||||
schema_->set_dataset_type("Random");
|
||||
if (total_rows_ != 0) {
|
||||
schema_->set_num_rows(total_rows_);
|
||||
}
|
||||
schema_json_string = schema_->to_json();
|
||||
} else {
|
||||
schema_file_path = schema_path_;
|
||||
}
|
||||
|
||||
std::unique_ptr<DataSchema> data_schema;
|
||||
std::vector<std::string> columns_to_load;
|
||||
if (columns_list_.size() > 0) {
|
||||
columns_to_load = columns_list_;
|
||||
}
|
||||
if (!schema_file_path.empty() || !schema_json_string.empty()) {
|
||||
data_schema = std::make_unique<DataSchema>();
|
||||
if (!schema_file_path.empty()) {
|
||||
data_schema->LoadSchemaFile(schema_file_path, columns_to_load);
|
||||
} else if (!schema_json_string.empty()) {
|
||||
data_schema->LoadSchemaString(schema_json_string, columns_to_load);
|
||||
}
|
||||
}
|
||||
std::shared_ptr<RandomDataOp> op;
|
||||
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_,
|
||||
std::move(data_schema), std::move(sampler_->Build()));
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
|
||||
node_ops.push_back(op);
|
||||
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,86 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RANDOM_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RANDOM_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
class RandomNode : public Dataset {
|
||||
public:
|
||||
// Some constants to provide limits to random generation.
|
||||
static constexpr int32_t kMaxNumColumns = 4;
|
||||
static constexpr int32_t kMaxRank = 4;
|
||||
static constexpr int32_t kMaxDimValue = 32;
|
||||
|
||||
/// \brief Constructor
|
||||
RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list,
|
||||
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
|
||||
: Dataset(std::move(cache)),
|
||||
total_rows_(total_rows),
|
||||
schema_path_(""),
|
||||
schema_(std::move(schema)),
|
||||
columns_list_(columns_list),
|
||||
sampler_(std::move(sampler)) {}
|
||||
|
||||
/// \brief Constructor
|
||||
RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list,
|
||||
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
|
||||
: Dataset(std::move(cache)),
|
||||
total_rows_(total_rows),
|
||||
schema_path_(schema_path),
|
||||
columns_list_(columns_list),
|
||||
sampler_(std::move(sampler)) {}
|
||||
|
||||
/// \brief Destructor
|
||||
~RandomNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
/// \brief A quick inline for producing a random number between (and including) min/max
|
||||
/// \param[in] min minimum number that can be generated.
|
||||
/// \param[in] max maximum number that can be generated.
|
||||
/// \return The generated random number
|
||||
int32_t GenRandomInt(int32_t min, int32_t max);
|
||||
|
||||
int32_t total_rows_;
|
||||
std::string schema_path_;
|
||||
std::shared_ptr<SchemaObj> schema_;
|
||||
std::vector<std::string> columns_list_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
std::mt19937 rand_gen_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RANDOM_NODE_H_
|
|
@ -0,0 +1,100 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
// Constructor for TextFileNode
|
||||
TextFileNode::TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle,
|
||||
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
|
||||
: Dataset(std::move(cache)),
|
||||
dataset_files_(dataset_files),
|
||||
num_samples_(num_samples),
|
||||
shuffle_(shuffle),
|
||||
num_shards_(num_shards),
|
||||
shard_id_(shard_id) {}
|
||||
|
||||
Status TextFileNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("TextFileNode", dataset_files_));
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
std::string err_msg = "TextFileNode: Invalid number of samples: " + std::to_string(num_samples_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetShardParams("TextFileNode", num_shards_, shard_id_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build TextFileNode
|
||||
std::vector<std::shared_ptr<DatasetOp>> TextFileNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
|
||||
// Sort the dataset files in a lexicographical order
|
||||
std::vector<std::string> sorted_dataset_files = dataset_files_;
|
||||
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
|
||||
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
|
||||
// Create and initalize TextFileOp
|
||||
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
|
||||
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files,
|
||||
connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr);
|
||||
RETURN_EMPTY_IF_ERROR(text_file_op->Init());
|
||||
|
||||
if (shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp
|
||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset
|
||||
RETURN_EMPTY_IF_ERROR(TextFileOp::CountAllFileRows(sorted_dataset_files, &num_rows));
|
||||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op));
|
||||
node_ops.push_back(shuffle_op);
|
||||
}
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
|
||||
// Add TextFileOp
|
||||
node_ops.push_back(text_file_op);
|
||||
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TEXT_FILE_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TEXT_FILE_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
/// \class TextFileNode
|
||||
/// \brief A Dataset derived class to represent TextFile dataset
|
||||
class TextFileNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
|
||||
int32_t shard_id, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~TextFileNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> dataset_files_;
|
||||
int32_t num_samples_;
|
||||
int32_t num_shards_;
|
||||
int32_t shard_id_;
|
||||
ShuffleMode shuffle_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TEXT_FILE_NODE_H_
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/jagged_connector.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
// Validator for TFRecordNode
|
||||
Status TFRecordNode::ValidateParams() { return Status::OK(); }
|
||||
|
||||
// Function to build TFRecordNode
|
||||
std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
// Sort the datasets file in a lexicographical order
|
||||
std::vector<std::string> sorted_dir_files = dataset_files_;
|
||||
std::sort(sorted_dir_files.begin(), sorted_dir_files.end());
|
||||
|
||||
// Create Schema Object
|
||||
std::unique_ptr<DataSchema> data_schema = std::make_unique<DataSchema>();
|
||||
if (!schema_path_.empty()) {
|
||||
RETURN_EMPTY_IF_ERROR(data_schema->LoadSchemaFile(schema_path_, columns_list_));
|
||||
} else if (schema_obj_ != nullptr) {
|
||||
std::string schema_json_string = schema_obj_->to_json();
|
||||
RETURN_EMPTY_IF_ERROR(data_schema->LoadSchemaString(schema_json_string, columns_list_));
|
||||
}
|
||||
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
|
||||
// Create and initialize TFReaderOp
|
||||
std::shared_ptr<TFReaderOp> tf_reader_op = std::make_shared<TFReaderOp>(
|
||||
num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files, std::move(data_schema),
|
||||
connector_que_size_, columns_list_, shuffle_files, num_shards_, shard_id_, shard_equal_rows_, nullptr);
|
||||
|
||||
RETURN_EMPTY_IF_ERROR(tf_reader_op->Init());
|
||||
|
||||
if (shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp
|
||||
|
||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset
|
||||
RETURN_EMPTY_IF_ERROR(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files));
|
||||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op));
|
||||
node_ops.push_back(shuffle_op);
|
||||
}
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
|
||||
// Add TFReaderOp
|
||||
node_ops.push_back(tf_reader_op);
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,90 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TF_RECORD_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TF_RECORD_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
/// \class TFRecordNode
|
||||
/// \brief A Dataset derived class to represent TFRecord dataset
|
||||
class TFRecordNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
/// \note Parameter 'schema' is the path to the schema file
|
||||
TFRecordNode(const std::vector<std::string> &dataset_files, std::string schema,
|
||||
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
|
||||
int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache)
|
||||
: Dataset(std::move(cache)),
|
||||
dataset_files_(dataset_files),
|
||||
schema_path_(schema),
|
||||
columns_list_(columns_list),
|
||||
num_samples_(num_samples),
|
||||
shuffle_(shuffle),
|
||||
num_shards_(num_shards),
|
||||
shard_id_(shard_id),
|
||||
shard_equal_rows_(shard_equal_rows) {}
|
||||
|
||||
/// \brief Constructor
|
||||
/// \note Parameter 'schema' is shared pointer to Schema object
|
||||
TFRecordNode(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema,
|
||||
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
|
||||
int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache)
|
||||
: Dataset(std::move(cache)),
|
||||
dataset_files_(dataset_files),
|
||||
schema_obj_(schema),
|
||||
columns_list_(columns_list),
|
||||
num_samples_(num_samples),
|
||||
shuffle_(shuffle),
|
||||
num_shards_(num_shards),
|
||||
shard_id_(shard_id),
|
||||
shard_equal_rows_(shard_equal_rows) {}
|
||||
|
||||
/// \brief Destructor
|
||||
~TFRecordNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> dataset_files_;
|
||||
std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string
|
||||
std::shared_ptr<SchemaObj> schema_obj_; // schema_obj_ schema object.
|
||||
std::vector<std::string> columns_list_;
|
||||
int64_t num_samples_;
|
||||
ShuffleMode shuffle_;
|
||||
int32_t num_shards_;
|
||||
int32_t shard_id_;
|
||||
bool shard_equal_rows_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TF_RECORD_NODE_H_
|
|
@ -0,0 +1,117 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
// Constructor for VOCNode
|
||||
VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage,
|
||||
const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache)
|
||||
: Dataset(std::move(cache)),
|
||||
dataset_dir_(dataset_dir),
|
||||
task_(task),
|
||||
usage_(usage),
|
||||
class_index_(class_indexing),
|
||||
decode_(decode),
|
||||
sampler_(sampler) {}
|
||||
|
||||
Status VOCNode::ValidateParams() {
|
||||
Path dir(dataset_dir_);
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("VOCNode", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("VOCNode", sampler_));
|
||||
|
||||
if (task_ == "Segmentation") {
|
||||
if (!class_index_.empty()) {
|
||||
std::string err_msg = "VOCNode: class_indexing is invalid in Segmentation task.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
Path imagesets_file = dir / "ImageSets" / "Segmentation" / usage_ + ".txt";
|
||||
if (!imagesets_file.Exists()) {
|
||||
std::string err_msg = "VOCNode: Invalid usage: " + usage_ + ", file does not exist";
|
||||
MS_LOG(ERROR) << "VOCNode: Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!";
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
} else if (task_ == "Detection") {
|
||||
Path imagesets_file = dir / "ImageSets" / "Main" / usage_ + ".txt";
|
||||
if (!imagesets_file.Exists()) {
|
||||
std::string err_msg = "VOCNode: Invalid usage: " + usage_ + ", file does not exist";
|
||||
MS_LOG(ERROR) << "VOCNode: Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!";
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
} else {
|
||||
std::string err_msg = "VOCNode: Invalid task: " + task_;
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build VOCNode
|
||||
std::vector<std::shared_ptr<DatasetOp>> VOCNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
VOCOp::TaskType task_type_;
|
||||
|
||||
if (task_ == "Segmentation") {
|
||||
task_type_ = VOCOp::TaskType::Segmentation;
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string(kColumnTarget), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
} else if (task_ == "Detection") {
|
||||
task_type_ = VOCOp::TaskType::Detection;
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string(kColumnBbox), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string(kColumnLabel), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string(kColumnDifficult), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string(kColumnTruncate), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
}
|
||||
|
||||
std::shared_ptr<VOCOp> voc_op;
|
||||
voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
|
||||
connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
|
||||
node_ops.push_back(voc_op);
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_VOC_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_VOC_NODE_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
class VOCNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage,
|
||||
const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~VOCNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
const std::string kColumnImage = "image";
|
||||
const std::string kColumnTarget = "target";
|
||||
const std::string kColumnBbox = "bbox";
|
||||
const std::string kColumnLabel = "label";
|
||||
const std::string kColumnDifficult = "difficult";
|
||||
const std::string kColumnTruncate = "truncate";
|
||||
std::string dataset_dir_;
|
||||
std::string task_;
|
||||
std::string usage_;
|
||||
std::map<std::string, int32_t> class_index_;
|
||||
bool decode_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_VOC_NODE_H_
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TAKE_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TAKE_NODE_H_
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_TAKE_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_TAKE_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -51,4 +51,4 @@ class TakeNode : public Dataset {
|
|||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TAKE_NODE_H_
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_TAKE_NODE_H_
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/zip_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
ZipNode::ZipNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) {
|
||||
for (auto dataset : datasets_) {
|
||||
this->children.push_back(dataset);
|
||||
}
|
||||
}
|
||||
|
||||
Status ZipNode::ValidateParams() {
|
||||
if (datasets_.empty()) {
|
||||
std::string err_msg = "ZipNode: datasets to zip are not specified.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) {
|
||||
std::string err_msg = "ZipNode: zip datasets should not be null.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> ZipNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<ZipOp>(rows_per_buffer_, connector_que_size_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ZIP_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ZIP_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
class ZipNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit ZipNode(const std::vector<std::shared_ptr<Dataset>> &datasets);
|
||||
|
||||
/// \brief Destructor
|
||||
~ZipNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<Dataset>> datasets_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ZIP_NODE_H_
|
|
@ -106,13 +106,20 @@ class ZipNode;
|
|||
} \
|
||||
} while (false)
|
||||
|
||||
Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
|
||||
int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op);
|
||||
|
||||
// Helper function to validate dataset files parameter
|
||||
Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files);
|
||||
|
||||
// Helper function to validate dataset num_shards and shard_id parameters
|
||||
Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id);
|
||||
|
||||
// Helper function to validate dataset sampler parameter
|
||||
Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler);
|
||||
|
||||
Status ValidateStringValue(const std::string &str, const std::unordered_set<std::string> &valid_strings);
|
||||
Status ValidateStringValue(const std::string &dataset_name, const std::string &str,
|
||||
const std::unordered_set<std::string> &valid_strings);
|
||||
|
||||
// Helper function to validate dataset input/output column parameterCD -
|
||||
Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param,
|
||||
|
@ -799,551 +806,8 @@ class SchemaObj {
|
|||
|
||||
/* ####################################### Derived Dataset classes ################################# */
|
||||
|
||||
// DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS
|
||||
// (In alphabetical order)
|
||||
|
||||
class AlbumNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
AlbumNode(const std::string &dataset_dir, const std::string &data_schema,
|
||||
const std::vector<std::string> &column_names, bool decode, const std::shared_ptr<SamplerObj> &sampler);
|
||||
|
||||
/// \brief Destructor
|
||||
~AlbumNode() = default;
|
||||
|
||||
/// \brief a base class override function to create a runtime dataset op object from this class
|
||||
/// \return shared pointer to the newly created DatasetOp
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string schema_path_;
|
||||
std::vector<std::string> column_names_;
|
||||
bool decode_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
class CelebANode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler,
|
||||
const bool &decode, const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~CelebANode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
bool decode_;
|
||||
std::set<std::string> extensions_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
// DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS
|
||||
// (In alphabetical order)
|
||||
|
||||
class Cifar10Node : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~Cifar10Node() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
class Cifar100Node : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~Cifar100Node() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
/// \class CLUENode
|
||||
/// \brief A Dataset derived class to represent CLUE dataset
|
||||
class CLUENode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CLUENode(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples,
|
||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~CLUENode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
/// \brief Split string based on a character delimiter
|
||||
/// \return A string vector
|
||||
std::vector<std::string> split(const std::string &s, char delim);
|
||||
|
||||
std::vector<std::string> dataset_files_;
|
||||
std::string task_;
|
||||
std::string usage_;
|
||||
int64_t num_samples_;
|
||||
ShuffleMode shuffle_;
|
||||
int32_t num_shards_;
|
||||
int32_t shard_id_;
|
||||
};
|
||||
|
||||
class CocoNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
|
||||
const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~CocoNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string annotation_file_;
|
||||
std::string task_;
|
||||
bool decode_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
/// \brief Record type for CSV
|
||||
enum CsvType : uint8_t { INT = 0, FLOAT, STRING };
|
||||
|
||||
/// \brief Base class of CSV Record
|
||||
class CsvBase {
|
||||
public:
|
||||
CsvBase() = default;
|
||||
explicit CsvBase(CsvType t) : type(t) {}
|
||||
virtual ~CsvBase() {}
|
||||
CsvType type;
|
||||
};
|
||||
|
||||
/// \brief CSV Record that can represent integer, float and string.
|
||||
template <typename T>
|
||||
class CsvRecord : public CsvBase {
|
||||
public:
|
||||
CsvRecord() = default;
|
||||
CsvRecord(CsvType t, T v) : CsvBase(t), value(v) {}
|
||||
~CsvRecord() {}
|
||||
T value;
|
||||
};
|
||||
|
||||
class CSVNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CSVNode(const std::vector<std::string> &dataset_files, char field_delim,
|
||||
const std::vector<std::shared_ptr<CsvBase>> &column_defaults, const std::vector<std::string> &column_names,
|
||||
int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~CSVNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> dataset_files_;
|
||||
char field_delim_;
|
||||
std::vector<std::shared_ptr<CsvBase>> column_defaults_;
|
||||
std::vector<std::string> column_names_;
|
||||
int64_t num_samples_;
|
||||
ShuffleMode shuffle_;
|
||||
int32_t num_shards_;
|
||||
int32_t shard_id_;
|
||||
};
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
class ManifestNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler,
|
||||
const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~ManifestNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_file_;
|
||||
std::string usage_;
|
||||
bool decode_;
|
||||
std::map<std::string, int32_t> class_index_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
class MindDataNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list,
|
||||
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded);
|
||||
|
||||
/// \brief Constructor
|
||||
MindDataNode(const std::string &dataset_file, const std::vector<std::string> &columns_list,
|
||||
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded);
|
||||
|
||||
/// \brief Destructor
|
||||
~MindDataNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Build sampler chain for minddata dataset
|
||||
/// \return Status Status::OK() if input sampler is valid
|
||||
Status BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> &sampler,
|
||||
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators_,
|
||||
int64_t num_padded);
|
||||
|
||||
/// \brief Set sample_bytes when padded_sample has py::byte value
|
||||
/// \note Pybind will use this function to set sample_bytes into MindDataNode
|
||||
void SetSampleBytes(std::map<std::string, std::string> *sample_bytes);
|
||||
|
||||
private:
|
||||
std::string dataset_file_; // search_for_pattern_ will be true in this mode
|
||||
std::vector<std::string> dataset_files_; // search_for_pattern_ will be false in this mode
|
||||
bool search_for_pattern_;
|
||||
std::vector<std::string> columns_list_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
nlohmann::json padded_sample_;
|
||||
std::map<std::string, std::string> sample_bytes_; // enable in python
|
||||
int64_t num_padded_;
|
||||
};
|
||||
#endif
|
||||
|
||||
class MnistNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~MnistNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
class RandomNode : public Dataset {
|
||||
public:
|
||||
// Some constants to provide limits to random generation.
|
||||
static constexpr int32_t kMaxNumColumns = 4;
|
||||
static constexpr int32_t kMaxRank = 4;
|
||||
static constexpr int32_t kMaxDimValue = 32;
|
||||
|
||||
/// \brief Constructor
|
||||
RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list,
|
||||
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
|
||||
: Dataset(std::move(cache)),
|
||||
total_rows_(total_rows),
|
||||
schema_path_(""),
|
||||
schema_(std::move(schema)),
|
||||
columns_list_(columns_list),
|
||||
sampler_(std::move(sampler)) {}
|
||||
|
||||
/// \brief Constructor
|
||||
RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list,
|
||||
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
|
||||
: Dataset(std::move(cache)),
|
||||
total_rows_(total_rows),
|
||||
schema_path_(schema_path),
|
||||
columns_list_(columns_list),
|
||||
sampler_(std::move(sampler)) {}
|
||||
|
||||
/// \brief Destructor
|
||||
~RandomNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
/// \brief A quick inline for producing a random number between (and including) min/max
|
||||
/// \param[in] min minimum number that can be generated.
|
||||
/// \param[in] max maximum number that can be generated.
|
||||
/// \return The generated random number
|
||||
int32_t GenRandomInt(int32_t min, int32_t max);
|
||||
|
||||
int32_t total_rows_;
|
||||
std::string schema_path_;
|
||||
std::shared_ptr<SchemaObj> schema_;
|
||||
std::vector<std::string> columns_list_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
std::mt19937 rand_gen_;
|
||||
};
|
||||
|
||||
/// \class TextFileNode
|
||||
/// \brief A Dataset derived class to represent TextFile dataset
|
||||
class TextFileNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
|
||||
int32_t shard_id, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~TextFileNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> dataset_files_;
|
||||
int32_t num_samples_;
|
||||
int32_t num_shards_;
|
||||
int32_t shard_id_;
|
||||
ShuffleMode shuffle_;
|
||||
};
|
||||
|
||||
/// \class TFRecordNode
|
||||
/// \brief A Dataset derived class to represent TFRecord dataset
|
||||
class TFRecordNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
/// \note Parameter 'schema' is the path to the schema file
|
||||
TFRecordNode(const std::vector<std::string> &dataset_files, std::string schema,
|
||||
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
|
||||
int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache)
|
||||
: Dataset(std::move(cache)),
|
||||
dataset_files_(dataset_files),
|
||||
schema_path_(schema),
|
||||
columns_list_(columns_list),
|
||||
num_samples_(num_samples),
|
||||
shuffle_(shuffle),
|
||||
num_shards_(num_shards),
|
||||
shard_id_(shard_id),
|
||||
shard_equal_rows_(shard_equal_rows) {}
|
||||
|
||||
/// \brief Constructor
|
||||
/// \note Parameter 'schema' is shared pointer to Schema object
|
||||
TFRecordNode(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema,
|
||||
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
|
||||
int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache)
|
||||
: Dataset(std::move(cache)),
|
||||
dataset_files_(dataset_files),
|
||||
schema_obj_(schema),
|
||||
columns_list_(columns_list),
|
||||
num_samples_(num_samples),
|
||||
shuffle_(shuffle),
|
||||
num_shards_(num_shards),
|
||||
shard_id_(shard_id),
|
||||
shard_equal_rows_(shard_equal_rows) {}
|
||||
|
||||
/// \brief Destructor
|
||||
~TFRecordNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> dataset_files_;
|
||||
std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string
|
||||
std::shared_ptr<SchemaObj> schema_obj_; // schema_obj_ schema object.
|
||||
std::vector<std::string> columns_list_;
|
||||
int64_t num_samples_;
|
||||
ShuffleMode shuffle_;
|
||||
int32_t num_shards_;
|
||||
int32_t shard_id_;
|
||||
bool shard_equal_rows_;
|
||||
};
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
class VOCNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage,
|
||||
const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~VOCNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
const std::string kColumnImage = "image";
|
||||
const std::string kColumnTarget = "target";
|
||||
const std::string kColumnBbox = "bbox";
|
||||
const std::string kColumnLabel = "label";
|
||||
const std::string kColumnDifficult = "difficult";
|
||||
const std::string kColumnTruncate = "truncate";
|
||||
std::string dataset_dir_;
|
||||
std::string task_;
|
||||
std::string usage_;
|
||||
std::map<std::string, int32_t> class_index_;
|
||||
bool decode_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
#endif
|
||||
|
||||
// DERIVED DATASET CLASSES FOR DATASET OPS
|
||||
// (In alphabetical order)
|
||||
|
||||
class MapNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations,
|
||||
std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {},
|
||||
const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr);
|
||||
|
||||
/// \brief Destructor
|
||||
~MapNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<TensorOperation>> operations_;
|
||||
std::vector<std::string> input_columns_;
|
||||
std::vector<std::string> output_columns_;
|
||||
std::vector<std::string> project_columns_;
|
||||
};
|
||||
|
||||
class SkipNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit SkipNode(std::shared_ptr<Dataset> child, int32_t count);
|
||||
|
||||
/// \brief Destructor
|
||||
~SkipNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
int32_t skip_count_;
|
||||
};
|
||||
|
||||
class ZipNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit ZipNode(const std::vector<std::shared_ptr<Dataset>> &datasets);
|
||||
|
||||
/// \brief Destructor
|
||||
~ZipNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<Dataset>> datasets_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_
|
||||
|
|
|
@ -131,6 +131,10 @@ if (BUILD_MINDDATA STREQUAL "full")
|
|||
|
||||
list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
||||
"${MINDDATA_DIR}/engine/ir/datasetops/source/generator_node.cc"
|
||||
"${MINDDATA_DIR}/engine/ir/datasetops/source/manifest_node.cc"
|
||||
"${MINDDATA_DIR}/engine/ir/datasetops/source/minddata_node.cc"
|
||||
"${MINDDATA_DIR}/engine/ir/datasetops/source/tf_record_node.cc"
|
||||
"${MINDDATA_DIR}/engine/ir/datasetops/source/voc_node.cc"
|
||||
)
|
||||
|
||||
list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SRC_FILES
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <iostream>
|
||||
#include <memory>
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#if defined(__ANDROID__) || defined(ANDROID)
|
||||
#include <android/log.h>
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <iostream>
|
||||
#include <memory>
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
|
||||
using Dataset = mindspore::dataset::api::Dataset;
|
||||
|
|
|
@ -16,6 +16,19 @@
|
|||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
||||
|
|
|
@ -16,6 +16,22 @@
|
|||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
||||
|
|
|
@ -18,10 +18,24 @@
|
|||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::GlobalContext;
|
||||
using mindspore::dataset::ShuffleMode;
|
||||
using mindspore::dataset::Tensor;
|
||||
using mindspore::dataset::GlobalContext;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
|
@ -49,11 +63,8 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetAFQMC) {
|
|||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("sentence1"), row.end());
|
||||
std::vector<std::string> expected_result = {
|
||||
"蚂蚁借呗等额还款能否换成先息后本",
|
||||
"蚂蚁花呗说我违约了",
|
||||
"帮我看看本月花呗账单结清了没"
|
||||
};
|
||||
std::vector<std::string> expected_result = {"蚂蚁借呗等额还款能否换成先息后本", "蚂蚁花呗说我违约了",
|
||||
"帮我看看本月花呗账单结清了没"};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
|
@ -75,11 +86,7 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetAFQMC) {
|
|||
|
||||
// test
|
||||
usage = "test";
|
||||
expected_result = {
|
||||
"借呗取消的时间",
|
||||
"网商贷用什么方法转变成借呗",
|
||||
"我的借呗为什么开通不了"
|
||||
};
|
||||
expected_result = {"借呗取消的时间", "网商贷用什么方法转变成借呗", "我的借呗为什么开通不了"};
|
||||
ds = CLUE({test_file}, task, usage, 0, ShuffleMode::kFalse);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
iter = ds->CreateIterator();
|
||||
|
@ -100,11 +107,7 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetAFQMC) {
|
|||
|
||||
// eval
|
||||
usage = "eval";
|
||||
expected_result = {
|
||||
"你有花呗吗",
|
||||
"吃饭能用花呗吗",
|
||||
"蚂蚁花呗支付金额有什么限制"
|
||||
};
|
||||
expected_result = {"你有花呗吗", "吃饭能用花呗吗", "蚂蚁花呗支付金额有什么限制"};
|
||||
ds = CLUE({eval_file}, task, usage, 0, ShuffleMode::kFalse);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
iter = ds->CreateIterator();
|
||||
|
@ -179,11 +182,7 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetCMNLI) {
|
|||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("sentence1"), row.end());
|
||||
std::vector<std::string> expected_result = {
|
||||
"你应该给这件衣服定一个价格。",
|
||||
"我怎么知道他要说什么",
|
||||
"向左。"
|
||||
};
|
||||
std::vector<std::string> expected_result = {"你应该给这件衣服定一个价格。", "我怎么知道他要说什么", "向左。"};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
|
@ -224,11 +223,7 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetCSL) {
|
|||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("abst"), row.end());
|
||||
std::vector<std::string> expected_result = {
|
||||
"这是一段长文本",
|
||||
"这是一段长文本",
|
||||
"这是一段长文本"
|
||||
};
|
||||
std::vector<std::string> expected_result = {"这是一段长文本", "这是一段长文本", "这是一段长文本"};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
|
@ -337,11 +332,7 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetIFLYTEK) {
|
|||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("sentence"), row.end());
|
||||
std::vector<std::string> expected_result = {
|
||||
"第一个文本",
|
||||
"第二个文本",
|
||||
"第三个文本"
|
||||
};
|
||||
std::vector<std::string> expected_result = {"第一个文本", "第二个文本", "第三个文本"};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
|
@ -396,14 +387,12 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetShuffleFilesA) {
|
|||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("sentence1"), row.end());
|
||||
std::vector<std::string> expected_result = {
|
||||
"你有花呗吗",
|
||||
"吃饭能用花呗吗",
|
||||
"蚂蚁花呗支付金额有什么限制",
|
||||
"蚂蚁借呗等额还款能否换成先息后本",
|
||||
"蚂蚁花呗说我违约了",
|
||||
"帮我看看本月花呗账单结清了没"
|
||||
};
|
||||
std::vector<std::string> expected_result = {"你有花呗吗",
|
||||
"吃饭能用花呗吗",
|
||||
"蚂蚁花呗支付金额有什么限制",
|
||||
"蚂蚁借呗等额还款能否换成先息后本",
|
||||
"蚂蚁花呗说我违约了",
|
||||
"帮我看看本月花呗账单结清了没"};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
|
@ -463,14 +452,12 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetShuffleFilesB) {
|
|||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("sentence1"), row.end());
|
||||
std::vector<std::string> expected_result = {
|
||||
"你有花呗吗",
|
||||
"吃饭能用花呗吗",
|
||||
"蚂蚁花呗支付金额有什么限制",
|
||||
"蚂蚁借呗等额还款能否换成先息后本",
|
||||
"蚂蚁花呗说我违约了",
|
||||
"帮我看看本月花呗账单结清了没"
|
||||
};
|
||||
std::vector<std::string> expected_result = {"你有花呗吗",
|
||||
"吃饭能用花呗吗",
|
||||
"蚂蚁花呗支付金额有什么限制",
|
||||
"蚂蚁借呗等额还款能否换成先息后本",
|
||||
"蚂蚁花呗说我违约了",
|
||||
"帮我看看本月花呗账单结清了没"};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
|
@ -523,11 +510,8 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetShuffleGlobal) {
|
|||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("sentence1"), row.end());
|
||||
std::vector<std::string> expected_result = {
|
||||
"蚂蚁花呗说我违约了",
|
||||
"帮我看看本月花呗账单结清了没",
|
||||
"蚂蚁借呗等额还款能否换成先息后本"
|
||||
};
|
||||
std::vector<std::string> expected_result = {"蚂蚁花呗说我违约了", "帮我看看本月花呗账单结清了没",
|
||||
"蚂蚁借呗等额还款能否换成先息后本"};
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto text = row["sentence1"];
|
||||
|
@ -572,11 +556,7 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetTNEWS) {
|
|||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("sentence"), row.end());
|
||||
std::vector<std::string> expected_result = {
|
||||
"新闻1",
|
||||
"新闻2",
|
||||
"新闻3"
|
||||
};
|
||||
std::vector<std::string> expected_result = {"新闻1", "新闻2", "新闻3"};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
|
@ -617,11 +597,8 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetWSC) {
|
|||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("text"), row.end());
|
||||
std::vector<std::string> expected_result = {
|
||||
"小明呢,他在哪?",
|
||||
"小红刚刚看到小明,他在操场",
|
||||
"等小明回来,小张你叫他交作业"
|
||||
};
|
||||
std::vector<std::string> expected_result = {"小明呢,他在哪?", "小红刚刚看到小明,他在操场",
|
||||
"等小明回来,小张你叫他交作业"};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
|
|
|
@ -16,10 +16,40 @@
|
|||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::dsize_t;
|
||||
using mindspore::dataset::Tensor;
|
||||
using mindspore::dataset::TensorShape;
|
||||
using mindspore::dataset::dsize_t;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
|
@ -79,12 +109,14 @@ TEST_F(MindDataTestPipeline, TestCocoDetection) {
|
|||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
std::string expect_file[] = {"000000391895", "000000318219", "000000554625", "000000574769", "000000060623",
|
||||
"000000309022"};
|
||||
std::string expect_file[] = {"000000391895", "000000318219", "000000554625",
|
||||
"000000574769", "000000060623", "000000309022"};
|
||||
std::vector<std::vector<float>> expect_bbox_vector = {{10.0, 10.0, 10.0, 10.0, 70.0, 70.0, 70.0, 70.0},
|
||||
{20.0, 20.0, 20.0, 20.0, 80.0, 80.0, 80.0, 80.0},
|
||||
{30.0, 30.0, 30.0, 30.0}, {40.0, 40.0, 40.0, 40.0},
|
||||
{50.0, 50.0, 50.0, 50.0}, {60.0, 60.0, 60.0, 60.0}};
|
||||
{30.0, 30.0, 30.0, 30.0},
|
||||
{40.0, 40.0, 40.0, 40.0},
|
||||
{50.0, 50.0, 50.0, 50.0},
|
||||
{60.0, 60.0, 60.0, 60.0}};
|
||||
std::vector<std::vector<uint32_t>> expect_catagoryid_list = {{1, 7}, {2, 8}, {3}, {4}, {5}, {6}};
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
|
@ -148,13 +180,13 @@ TEST_F(MindDataTestPipeline, TestCocoKeypoint) {
|
|||
iter->GetNextRow(&row);
|
||||
|
||||
std::string expect_file[] = {"000000391895", "000000318219"};
|
||||
std::vector<std::vector<float>> expect_keypoint_vector =
|
||||
{{368.0, 61.0, 1.0, 369.0, 52.0, 2.0, 0.0, 0.0, 0.0, 382.0, 48.0, 2.0, 0.0, 0.0, 0.0, 368.0, 84.0, 2.0, 435.0,
|
||||
81.0, 2.0, 362.0, 125.0, 2.0, 446.0, 125.0, 2.0, 360.0, 153.0, 2.0, 0.0, 0.0, 0.0, 397.0, 167.0, 1.0, 439.0,
|
||||
166.0, 1.0, 369.0, 193.0, 2.0, 461.0, 234.0, 2.0, 361.0, 246.0, 2.0, 474.0, 287.0, 2.0},
|
||||
{244.0, 139.0, 2.0, 0.0, 0.0, 0.0, 226.0, 118.0, 2.0, 0.0, 0.0, 0.0, 154.0, 159.0, 2.0, 143.0, 261.0, 2.0, 135.0,
|
||||
312.0, 2.0, 271.0, 423.0, 2.0, 184.0, 530.0, 2.0, 261.0, 280.0, 2.0, 347.0, 592.0, 2.0, 0.0, 0.0, 0.0, 123.0,
|
||||
596.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}};
|
||||
std::vector<std::vector<float>> expect_keypoint_vector = {
|
||||
{368.0, 61.0, 1.0, 369.0, 52.0, 2.0, 0.0, 0.0, 0.0, 382.0, 48.0, 2.0, 0.0, 0.0, 0.0, 368.0, 84.0, 2.0,
|
||||
435.0, 81.0, 2.0, 362.0, 125.0, 2.0, 446.0, 125.0, 2.0, 360.0, 153.0, 2.0, 0.0, 0.0, 0.0, 397.0, 167.0, 1.0,
|
||||
439.0, 166.0, 1.0, 369.0, 193.0, 2.0, 461.0, 234.0, 2.0, 361.0, 246.0, 2.0, 474.0, 287.0, 2.0},
|
||||
{244.0, 139.0, 2.0, 0.0, 0.0, 0.0, 226.0, 118.0, 2.0, 0.0, 0.0, 0.0, 154.0, 159.0, 2.0, 143.0, 261.0, 2.0,
|
||||
135.0, 312.0, 2.0, 271.0, 423.0, 2.0, 184.0, 530.0, 2.0, 261.0, 280.0, 2.0, 347.0, 592.0, 2.0, 0.0, 0.0, 0.0,
|
||||
123.0, 596.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}};
|
||||
std::vector<std::vector<dsize_t>> expect_size = {{1, 51}, {1, 51}};
|
||||
std::vector<std::vector<uint32_t>> expect_num_keypoints_list = {{14}, {10}};
|
||||
uint64_t i = 0;
|
||||
|
@ -258,17 +290,17 @@ TEST_F(MindDataTestPipeline, TestCocoStuff) {
|
|||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
std::string expect_file[] = {"000000391895", "000000318219", "000000554625", "000000574769", "000000060623",
|
||||
"000000309022"};
|
||||
std::vector<std::vector<float>> expect_segmentation_vector =
|
||||
{{10.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0,
|
||||
70.0, 72.0, 73.0, 74.0, 75.0, -1.0, -1.0, -1.0, -1.0, -1.0},
|
||||
{20.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0,
|
||||
10.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, -1.0},
|
||||
{40.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 40.0, 41.0, 42.0},
|
||||
{50.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0},
|
||||
{60.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0},
|
||||
{60.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0}};
|
||||
std::string expect_file[] = {"000000391895", "000000318219", "000000554625",
|
||||
"000000574769", "000000060623", "000000309022"};
|
||||
std::vector<std::vector<float>> expect_segmentation_vector = {
|
||||
{10.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0,
|
||||
70.0, 72.0, 73.0, 74.0, 75.0, -1.0, -1.0, -1.0, -1.0, -1.0},
|
||||
{20.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0,
|
||||
10.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, -1.0},
|
||||
{40.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 40.0, 41.0, 42.0},
|
||||
{50.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0},
|
||||
{60.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0},
|
||||
{60.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0}};
|
||||
std::vector<std::vector<dsize_t>> expect_size = {{2, 10}, {2, 11}, {1, 12}, {1, 13}, {1, 14}, {2, 7}};
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
|
|
|
@ -18,13 +18,17 @@
|
|||
#include "minddata/dataset/include/config.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::ShuffleMode;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
|
|
@ -18,10 +18,40 @@
|
|||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::GlobalContext;
|
||||
using mindspore::dataset::ShuffleMode;
|
||||
using mindspore::dataset::Tensor;
|
||||
using mindspore::dataset::GlobalContext;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
|
@ -98,12 +128,8 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetMultiFiles) {
|
|||
iter->GetNextRow(&row);
|
||||
EXPECT_NE(row.find("col1"), row.end());
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"17", "18", "19", "20"},
|
||||
{"1", "2", "3", "4"},
|
||||
{"5", "6", "7", "8"},
|
||||
{"13", "14", "15", "16"},
|
||||
{"21", "22", "23", "24"},
|
||||
{"9", "10", "11", "12"},
|
||||
{"17", "18", "19", "20"}, {"1", "2", "3", "4"}, {"5", "6", "7", "8"},
|
||||
{"13", "14", "15", "16"}, {"21", "22", "23", "24"}, {"9", "10", "11", "12"},
|
||||
};
|
||||
|
||||
uint64_t i = 0;
|
||||
|
@ -148,10 +174,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetNumSamples) {
|
|||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
EXPECT_NE(row.find("col1"), row.end());
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"1", "2", "3", "4"},
|
||||
{"5", "6", "7", "8"}
|
||||
};
|
||||
std::vector<std::vector<std::string>> expected_result = {{"1", "2", "3", "4"}, {"5", "6", "7", "8"}};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
|
@ -191,10 +214,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetDistribution) {
|
|||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
EXPECT_NE(row.find("col1"), row.end());
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"1", "2", "3", "4"},
|
||||
{"5", "6", "7", "8"}
|
||||
};
|
||||
std::vector<std::vector<std::string>> expected_result = {{"1", "2", "3", "4"}, {"5", "6", "7", "8"}};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
|
@ -386,12 +406,8 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFilesA) {
|
|||
iter->GetNextRow(&row);
|
||||
EXPECT_NE(row.find("col1"), row.end());
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"13", "14", "15", "16"},
|
||||
{"1", "2", "3", "4"},
|
||||
{"17", "18", "19", "20"},
|
||||
{"5", "6", "7", "8"},
|
||||
{"21", "22", "23", "24"},
|
||||
{"9", "10", "11", "12"},
|
||||
{"13", "14", "15", "16"}, {"1", "2", "3", "4"}, {"17", "18", "19", "20"},
|
||||
{"5", "6", "7", "8"}, {"21", "22", "23", "24"}, {"9", "10", "11", "12"},
|
||||
};
|
||||
|
||||
uint64_t i = 0;
|
||||
|
@ -445,12 +461,8 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFilesB) {
|
|||
iter->GetNextRow(&row);
|
||||
EXPECT_NE(row.find("col1"), row.end());
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"13", "14", "15", "16"},
|
||||
{"1", "2", "3", "4"},
|
||||
{"17", "18", "19", "20"},
|
||||
{"5", "6", "7", "8"},
|
||||
{"21", "22", "23", "24"},
|
||||
{"9", "10", "11", "12"},
|
||||
{"13", "14", "15", "16"}, {"1", "2", "3", "4"}, {"17", "18", "19", "20"},
|
||||
{"5", "6", "7", "8"}, {"21", "22", "23", "24"}, {"9", "10", "11", "12"},
|
||||
};
|
||||
|
||||
uint64_t i = 0;
|
||||
|
@ -505,10 +517,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleGlobal) {
|
|||
iter->GetNextRow(&row);
|
||||
EXPECT_NE(row.find("col1"), row.end());
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"5", "6", "7", "8"},
|
||||
{"9", "10", "11", "12"},
|
||||
{"1", "2", "3", "4"}
|
||||
};
|
||||
{"5", "6", "7", "8"}, {"9", "10", "11", "12"}, {"1", "2", "3", "4"}};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
|
@ -16,12 +17,21 @@
|
|||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
|
|
@ -16,6 +16,19 @@
|
|||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
||||
|
|
|
@ -16,6 +16,20 @@
|
|||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
||||
|
@ -57,7 +71,6 @@ TEST_F(MindDataTestPipeline, TestMindDataSuccess1) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestMindDataSuccess2) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataSuccess2 with a vector of single mindrecord file.";
|
||||
|
||||
|
|
|
@ -18,16 +18,22 @@
|
|||
#include "minddata/dataset/include/datasets.h"
|
||||
#include "minddata/dataset/include/vision.h"
|
||||
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
// IR leaf nodes
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
|
|
@ -20,11 +20,41 @@
|
|||
|
||||
#include "mindspore/core/ir/dtype/type_id.h"
|
||||
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::DataType;
|
||||
using mindspore::dataset::Tensor;
|
||||
using mindspore::dataset::TensorShape;
|
||||
using mindspore::dataset::DataType;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
|
|
|
@ -14,10 +14,23 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::ShuffleMode;
|
||||
|
|
|
@ -18,12 +18,19 @@
|
|||
#include "minddata/dataset/include/vision.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using namespace mindspore::dataset::api;
|
||||
|
|
|
@ -16,10 +16,24 @@
|
|||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::DataType;
|
||||
using mindspore::dataset::Tensor;
|
||||
using mindspore::dataset::TensorShape;
|
||||
using mindspore::dataset::DataType;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
|
|
|
@ -16,13 +16,22 @@
|
|||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
// IR leaf nodes
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
|
|
@ -25,6 +25,36 @@
|
|||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/include/text.h"
|
||||
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::DataType;
|
||||
using mindspore::dataset::ShuffleMode;
|
||||
|
@ -304,8 +334,8 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail4) {
|
|||
|
||||
// Create vocab from dataset
|
||||
// Expected failure: special tokens are already in the dataset
|
||||
std::shared_ptr<Vocab> vocab = ds->BuildVocab({"text"}, {0, std::numeric_limits<int64_t>::max()},
|
||||
std::numeric_limits<int64_t>::max(), {"world"});
|
||||
std::shared_ptr<Vocab> vocab =
|
||||
ds->BuildVocab({"text"}, {0, std::numeric_limits<int64_t>::max()}, std::numeric_limits<int64_t>::max(), {"world"});
|
||||
EXPECT_EQ(vocab, nullptr);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,12 +18,35 @@
|
|||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/include/vision.h"
|
||||
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::BorderType;
|
||||
|
|
|
@ -18,13 +18,21 @@
|
|||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/include/vision.h"
|
||||
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::BorderType;
|
||||
|
|
|
@ -20,8 +20,22 @@
|
|||
#include "minddata/dataset/include/datasets.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
// IR leaf nodes
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
||||
|
|
Loading…
Reference in New Issue