!7734 Redesigned GetNumClasses

Merge pull request !7734 from Alex Yuyue/IR_dataset_input
This commit is contained in:
mindspore-ci-bot 2020-10-29 01:44:36 +08:00 committed by Gitee
commit c39c3ccfbe
15 changed files with 200 additions and 30 deletions

View File

@ -257,32 +257,79 @@ int64_t Dataset::GetDatasetSize() {
std::vector<DataType> Dataset::GetOutputTypes() {
std::vector<DataType> types;
Status s;
Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed.";
types.clear();
return types;
}
if (!tree_getters_->isInitialized()) {
s = tree_getters_->Init(shared_from_this());
if (s.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
rc = tree_getters_->Init(shared_from_this());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputTypes: Initializing TreeGetters failed.";
types.clear();
return types;
}
}
tree_getters_->GetOutputTypes(&types);
rc = tree_getters_->GetOutputTypes(&types);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputTypes: Get Output Types failed.";
types.clear();
return types;
}
return types;
}
std::vector<TensorShape> Dataset::GetOutputShapes() {
std::vector<TensorShape> shapes;
Status s;
Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed.";
shapes.clear();
return shapes;
}
if (!tree_getters_->isInitialized()) {
s = tree_getters_->Init(shared_from_this());
if (s.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
rc = tree_getters_->Init(shared_from_this());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputShapes: Initializing TreeGetters failed.";
shapes.clear();
return shapes;
}
}
tree_getters_->GetOutputShapes(&shapes);
rc = tree_getters_->GetOutputShapes(&shapes);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputShapes: Get Output Shapes failed.";
shapes.clear();
return shapes;
}
return shapes;
}
int64_t Dataset::GetNumClasses() {
int64_t num_classes;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed.";
return -1;
}
if (!tree_getters_->isInitialized()) {
rc = tree_getters_->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetNumClasses: Initializing TreeGetters failed.";
return -1;
}
}
rc = tree_getters_->GetNumClasses(&num_classes);
return rc.IsError() ? -1 : num_classes;
}
// Constructor to initialize the cache
Dataset::Dataset(const std::shared_ptr<DatasetCache> &dataset_cache) : Dataset() { cache_ = dataset_cache; }
@ -656,6 +703,7 @@ Status Dataset::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
}
return Status::OK();
}
int64_t Dataset::GetBatchSize() {
int64_t batch_size;
auto ds = shared_from_this();
@ -666,14 +714,17 @@ int64_t Dataset::GetBatchSize() {
MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed.";
return -1;
}
rc = tree_getters_->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed.";
return -1;
if (!tree_getters_->isInitialized()) {
rc = tree_getters_->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed.";
return -1;
}
}
rc = tree_getters_->GetBatchSize(&batch_size);
return rc.IsError() ? -1 : batch_size;
}
int64_t Dataset::GetRepeatCount() {
int64_t repeat_count;
auto ds = shared_from_this();
@ -684,10 +735,12 @@ int64_t Dataset::GetRepeatCount() {
MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed.";
return -1;
}
rc = tree_getters_->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed.";
return -1;
if (!tree_getters_->isInitialized()) {
rc = tree_getters_->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed.";
return -1;
}
}
rc = tree_getters_->GetRepeatCount(&repeat_count);
return rc.IsError() ? 0 : repeat_count;

View File

@ -444,10 +444,18 @@ Status TreeGetters::GetBatchSize(int64_t *batch_size) {
CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != -1, "Error in finding the batch size.");
return Status::OK();
}
Status TreeGetters::GetRepeatCount(int64_t *repeat_count) {
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
*repeat_count = root->GetTreeRepeatCount();
return Status::OK();
}
Status TreeGetters::GetNumClasses(int64_t *num_classes) {
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
RETURN_IF_NOT_OK(root->GetNumClasses(num_classes));
return Status::OK();
}
} // namespace mindspore::dataset

View File

@ -164,6 +164,7 @@ class TreeGetters : public TreeConsumer {
Status GetOutputShapes(std::vector<TensorShape> *shapes);
Status GetBatchSize(int64_t *batch_size);
Status GetRepeatCount(int64_t *repeat_count);
Status GetNumClasses(int64_t *num_classes);
bool isInitialized();
std::string Name() override { return "TreeGetters"; }
Status GetRow(TensorRow *r);

View File

@ -51,7 +51,8 @@ DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler
op_current_repeats_(0),
op_current_epochs_(0),
out_connector_(nullptr),
dataset_size_(-1) {
dataset_size_(-1),
num_classes_(-1) {
// The operator starts out with an invalid operator id. The only way to
// get it out of invalid state is to assign the operator to an execution tree.
}
@ -302,6 +303,19 @@ Status DatasetOp::GetDatasetSize(int64_t *dataset_size) {
return child_[0]->GetDatasetSize(dataset_size);
}
// Gets the number of classes
Status DatasetOp::GetNumClasses(int64_t *num_classes) {
if (num_classes_ > 0) {
*num_classes = num_classes_;
return Status::OK();
}
if (!child_.empty()) {
return child_[0]->GetNumClasses(num_classes);
} else {
RETURN_STATUS_UNEXPECTED("Can't get the dataset size for the current tree.");
}
}
// Performs handling for when an eoe message is received.
// The base class implementation simply flows the eoe message to output. Derived classes
// may override if they need to perform special eoe handling.

View File

@ -191,6 +191,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return Status - The status code return
virtual int64_t GetTreeRepeatCount();
/// \brief Gets the number of classes
/// \return Status - The status code return
virtual Status GetNumClasses(int64_t *num_classes);
/// \brief Performs handling for when an eoe message is received.
/// The base class implementation simply flows the eoe message to output. Derived classes
/// may override if they need to perform special eoe handling.
@ -419,6 +423,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
std::mutex column_name_map_mutex_; // For protecting shared access to the column map
CallbackManager callback_manager_; // Manages callbacks associated with a DatasetOp
int64_t dataset_size_; // Size of the dataset
int64_t num_classes_; // Number of classes
private:
/// Sets the operator id.

View File

@ -468,5 +468,17 @@ Status ImageFolderOp::GetDatasetSize(int64_t *dataset_size) {
dataset_size_ = *dataset_size;
return Status::OK();
}
// Get number of classes
Status ImageFolderOp::GetNumClasses(int64_t *num_classes) {
if (num_classes_ > 0) {
*num_classes = num_classes_;
return Status::OK();
}
int64_t num_rows = num_rows_;
RETURN_IF_NOT_OK(CountRowsAndClasses(folder_path_, extensions_, &num_rows, num_classes));
num_classes_ = *num_classes;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -222,6 +222,11 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
/// \brief Base-class override for GetNumClasses
/// \param[out] num_classes the number of classes
/// \return Status of the function
Status GetNumClasses(int64_t *num_classes) override;
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return

View File

@ -18,6 +18,7 @@
#include <algorithm>
#include <fstream>
#include <iomanip>
#include <set>
#include <nlohmann/json.hpp>
#include "utils/ms_utils.h"
@ -297,6 +298,7 @@ Status ManifestOp::ParseManifestFile() {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Manifest file: " + file_);
}
std::string line;
std::set<std::string> classes;
while (getline(file_handle, line)) {
try {
nlohmann::json js = nlohmann::json::parse(line);
@ -317,6 +319,7 @@ Status ManifestOp::ParseManifestFile() {
for (nlohmann::json::iterator it = annotations.begin(); it != annotations.end(); ++it) {
nlohmann::json annotation = it.value();
std::string label_name = annotation.value("name", "");
classes.insert(label_name);
if (label_name == "") {
file_handle.close();
RETURN_STATUS_UNEXPECTED("Invalid data, label name is not found in Manifest file: " + image_file_path);
@ -336,6 +339,7 @@ Status ManifestOp::ParseManifestFile() {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to parse manifest file: " + line);
}
}
num_classes_ = classes.size();
file_handle.close();
return Status::OK();
@ -471,5 +475,18 @@ Status ManifestOp::GetDatasetSize(int64_t *dataset_size) {
return Status::OK();
}
// Get number of classes
Status ManifestOp::GetNumClasses(int64_t *num_classes) {
if (num_classes_ > 0) {
*num_classes = num_classes_;
return Status::OK();
}
std::shared_ptr<ManifestOp> op;
RETURN_IF_NOT_OK(Builder().SetManifestFile(file_).SetClassIndex(class_index_).SetUsage(usage_).Build(&op));
RETURN_IF_NOT_OK(op->ParseManifestFile());
*num_classes = num_classes_;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -188,6 +188,11 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
/// \brief Base-class override for GetNumClasses
/// \param[out] num_classes the number of classes
/// \return Status of the function
Status GetNumClasses(int64_t *num_classes) override;
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return

View File

@ -589,15 +589,15 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
}
/// \brief Gets the dataset size
/// \return int64_t
/// \return dataset size. If failed, return -1
int64_t GetDatasetSize();
/// \brief Gets the output type
/// \return vector of DataType
/// \return a vector of DataType. If failed, return an empty vector
std::vector<DataType> GetOutputTypes();
/// \brief Gets the output shape
/// \return vector of TensorShapes
/// \return a vector of TensorShape. If failed, return am empty vector
std::vector<TensorShape> GetOutputShapes();
/// \brief Gets the batch size
@ -608,6 +608,10 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return int64_t
int64_t GetRepeatCount();
/// \brief Gets the number of classes
/// \return number of classes. If failed, return -1
int64_t GetNumClasses();
/// \brief Setter function for runtime number of workers
/// \param[in] num_workers The number of threads in this operator
/// \return Shared pointer to the original object

View File

@ -69,6 +69,26 @@ TEST_F(MindDataTestPipeline, TestAlbumBasic) {
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestAlbumgetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAlbumgetters.";
std::string folder_path = datasets_root_path_ + "/testAlbum/images";
std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json";
std::vector<std::string> column_names = {"image", "label", "id"};
// Create a Album Dataset
std::shared_ptr<Dataset> ds = Album(folder_path, schema_file, column_names);
EXPECT_NE(ds, nullptr);
int64_t dataset_size = ds->GetDatasetSize();
EXPECT_EQ(dataset_size, 7);
int64_t num_classes = ds->GetNumClasses();
EXPECT_EQ(num_classes, -1);
int64_t batch_size = ds->GetBatchSize();
EXPECT_EQ(batch_size, 1);
int64_t repeat_count = ds->GetRepeatCount();
EXPECT_EQ(repeat_count, 1);
}
TEST_F(MindDataTestPipeline, TestAlbumDecode) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAlbumDecode.";
std::string folder_path = datasets_root_path_ + "/testAlbum/images";

View File

@ -86,7 +86,7 @@ TEST_F(MindDataTestPipeline, TestCifar10GetDatasetSize) {
EXPECT_EQ(ds->GetDatasetSize(), 10000);
}
TEST_F(MindDataTestPipeline, TestCifar10MixGetter) {
TEST_F(MindDataTestPipeline, TestCifar10Getters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10MixGetter.";
// Create a Cifar10 Dataset
@ -97,19 +97,28 @@ TEST_F(MindDataTestPipeline, TestCifar10MixGetter) {
EXPECT_EQ(ds->GetDatasetSize(), 10000);
std::vector<DataType> types = ds->GetOutputTypes();
std::vector<TensorShape> shapes = ds->GetOutputShapes();
int64_t num_classes = ds->GetNumClasses();
EXPECT_EQ(types.size(), 2);
EXPECT_EQ(types[0].ToString(), "uint8");
EXPECT_EQ(types[1].ToString(), "uint32");
EXPECT_EQ(shapes.size(), 2);
EXPECT_EQ(shapes[0].ToString(), "<32,32,3>");
EXPECT_EQ(shapes[1].ToString(), "<>");
EXPECT_EQ(num_classes, -1);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetDatasetSize(), 10000);
EXPECT_EQ(ds->GetOutputTypes(), types);
EXPECT_EQ(ds->GetOutputShapes(), shapes);
EXPECT_EQ(ds->GetNumClasses(), -1);
EXPECT_EQ(ds->GetDatasetSize(), 10000);
EXPECT_EQ(ds->GetOutputTypes(), types);
EXPECT_EQ(ds->GetOutputShapes(), shapes);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetNumClasses(), -1);
EXPECT_EQ(ds->GetDatasetSize(), 10000);
}

View File

@ -67,15 +67,22 @@ TEST_F(MindDataTestPipeline, TestManifestBasic) {
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestManifestGetDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestManifestGetDatasetSize.";
TEST_F(MindDataTestPipeline, TestManifestGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestManifestGetters.";
std::string file_path = datasets_root_path_ + "/testManifestData/cpp.json";
std::string file_path1 = datasets_root_path_ + "/testManifestData/cpp.json";
std::string file_path2 = datasets_root_path_ + "/testManifestData/cpp2.json";
// Create a Manifest Dataset
std::shared_ptr<Dataset> ds = Manifest(file_path);
EXPECT_NE(ds, nullptr);
std::shared_ptr<Dataset> ds1 = Manifest(file_path1);
std::shared_ptr<Dataset> ds2 = Manifest(file_path2);
EXPECT_EQ(ds->GetDatasetSize(), 2);
EXPECT_NE(ds1, nullptr);
EXPECT_EQ(ds1->GetDatasetSize(), 2);
EXPECT_EQ(ds1->GetNumClasses(), 2);
EXPECT_NE(ds2, nullptr);
EXPECT_EQ(ds2->GetDatasetSize(), 4);
EXPECT_EQ(ds2->GetNumClasses(), 3);
}
TEST_F(MindDataTestPipeline, TestManifestDecode) {

View File

@ -221,7 +221,7 @@ TEST_F(MindDataTestPipeline, TestImageFolderFailWithWrongExtension) {
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestImageFolderGetDatasetSize) {
TEST_F(MindDataTestPipeline, TestImageFolderGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestImageFolderGetDatasetSize.";
// Create an ImageFolder Dataset
@ -230,6 +230,10 @@ TEST_F(MindDataTestPipeline, TestImageFolderGetDatasetSize) {
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 44);
EXPECT_EQ(ds->GetNumClasses(), 4);
EXPECT_EQ(ds->GetNumClasses(), 4);
EXPECT_EQ(ds->GetDatasetSize(), 44);
EXPECT_EQ(ds->GetDatasetSize(), 44);
}
TEST_F(MindDataTestPipeline, TestImageFolderFailWithNullSampler) {

View File

@ -0,0 +1,6 @@
{"source":"./data/dataset/testManifestData/train/1.JPEG", "usage":"TRAIN","id":"0162005993f8065ef47eefb59d1e4970","annotation": [{"type": "modelarts/image_classification","name": "dog","property": {"color":"white","kind":"Persian cat"},"hard":"true","hard-coefficient":0.8,"annotated-by":"human","creation-time":"2019-01-23 11:30:30"}],"inference-loc":"/path/to/inference-output"}
{"source":"./data/dataset/testManifestData/train/1.JPEG", "usage":"TRAIN","id":"0162005993f8065ef47eefb59d1e4970","annotation": [{"type": "modelarts/image_classification","name": "cat","property": {"color":"white","kind":"Persian cat"},"hard":"true","hard-coefficient":0.8,"annotated-by":"human","creation-time":"2019-01-23 11:30:30"}],"inference-loc":"/path/to/inference-output"}
{"source":"./data/dataset/testManifestData/train/1.JPEG", "usage":"TRAIN","id":"0162005993f8065ef47eefb59d1e4970","annotation": [{"type": "modelarts/image_classification","name": "cat","property": {"color":"white","kind":"Persian cat"},"hard":"true","hard-coefficient":0.8,"annotated-by":"human","creation-time":"2019-01-23 11:30:30"}],"inference-loc":"/path/to/inference-output"}
{"source":"./data/dataset/testManifestData/train/1.JPEG", "usage":"TRAIN","id":"0162005993f8065ef47eefb59d1e4970","annotation": [{"type": "modelarts/image_classification","name": "cat","property": {"color":"white","kind":"Persian cat"},"hard":"true","hard-coefficient":0.8,"annotated-by":"human","creation-time":"2019-01-23 11:30:30"},{"type": "modelarts/image_classification","name": "flower","property": {"color":"white","kind":"Persian cat"},"hard":"true","hard-coefficient":0.8,"annotated-by":"human","creation-time":"2019-01-23 11:30:30"}],"inference-loc":"/path/to/inference-output"}
{"source":"./data/dataset/testManifestData/eval/1.JPEG", "usage":"EVAL","id":"0162005993f8065ef47eefb59d1e4970","annotation": [{"type": "modelarts/image_classification","name": "cat","property": {"color":"white","kind":"Persian cat"},"hard":"true","hard-coefficient":0.8,"annotated-by":"human","creation-time":"2019-01-23 11:30:30"}],"inference-loc":"/path/to/inference-output"}
{"source":"./data/dataset/testManifestData/eval/2.JPEG", "usage":"EVAL","id":"0162005993f8065ef47eefb59d1e4970","annotation": [{"type": "modelarts/image_classification","name": "dog","property": {"color":"white","kind":"Persian cat"},"hard":"true","hard-coefficient":0.8,"annotated-by":"human","creation-time":"2019-01-23 11:30:30"}],"inference-loc":"/path/to/inference-output"}