forked from mindspore-Ecosystem/mindspore
parent
f1ef84e1a6
commit
45292abdfd
|
@ -25,6 +25,7 @@ endif ()
|
|||
add_library(cpp-API OBJECT
|
||||
config.cc
|
||||
datasets.cc
|
||||
execute.cc
|
||||
iterator.cc
|
||||
transforms.cc
|
||||
samplers.cc
|
||||
|
|
|
@ -15,7 +15,9 @@
|
|||
*/
|
||||
|
||||
#include "minddata/dataset/include/execute.h"
|
||||
#ifdef ENABLE_ANDROID
|
||||
#include "minddata/dataset/include/de_tensor.h"
|
||||
#endif
|
||||
#include "minddata/dataset/include/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
|
@ -29,6 +31,7 @@ namespace dataset {
|
|||
|
||||
Execute::Execute(std::shared_ptr<TensorOperation> op) : op_(std::move(op)) {}
|
||||
|
||||
#ifdef ENABLE_ANDROID
|
||||
std::shared_ptr<tensor::MSTensor> Execute::operator()(std::shared_ptr<tensor::MSTensor> input) {
|
||||
// Build the op
|
||||
if (op_ == nullptr) {
|
||||
|
@ -52,6 +55,7 @@ std::shared_ptr<tensor::MSTensor> Execute::operator()(std::shared_ptr<tensor::MS
|
|||
}
|
||||
return std::make_shared<tensor::DETensor>(std::move(de_output));
|
||||
}
|
||||
#endif
|
||||
|
||||
std::shared_ptr<dataset::Tensor> Execute::operator()(std::shared_ptr<dataset::Tensor> input) {
|
||||
// Build the op
|
||||
|
|
|
@ -298,30 +298,45 @@ Status DatasetOp::GetDatasetSize(int64_t *dataset_size) {
|
|||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 1, "Can't get the dataset size for the current tree.");
|
||||
|
||||
return child_[0]->GetDatasetSize(dataset_size);
|
||||
if (child_.size() == 1) {
|
||||
return child_[0]->GetDatasetSize(dataset_size);
|
||||
} else if (child_.size() > 1) {
|
||||
// It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case.
|
||||
// This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
|
||||
// always be in front of the child_ structure, so we get the dataset size from the last child.
|
||||
return child_[child_.size() - 1]->GetDatasetSize(dataset_size);
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override");
|
||||
}
|
||||
}
|
||||
|
||||
// Gets the number of classes
|
||||
Status DatasetOp::GetNumClasses(int64_t *num_classes) {
|
||||
if (num_classes_ > 0) {
|
||||
*num_classes = num_classes_;
|
||||
return Status::OK();
|
||||
}
|
||||
if (!child_.empty()) {
|
||||
if (child_.size() == 1) {
|
||||
return child_[0]->GetNumClasses(num_classes);
|
||||
} else if (child_.size() > 1) {
|
||||
// It is okay for dataset to have more than 1 child, GetNumClasses shouldn't fail in this case.
|
||||
// This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
|
||||
// always be in front of the child_ structure, so we get num classes from the last child.
|
||||
return child_[child_.size() - 1]->GetNumClasses(num_classes);
|
||||
} else {
|
||||
// when num classes isn't found, the default behavior is to return -1
|
||||
*num_classes = -1;
|
||||
RETURN_STATUS_UNEXPECTED("Can't get the number of classes for the current tree.");
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
Status DatasetOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
|
||||
if (!child_.empty()) {
|
||||
if (child_.size() == 1) {
|
||||
return child_[0]->GetClassIndexing(output_class_indexing);
|
||||
} else if (child_.size() > 1) {
|
||||
// It is okay for dataset to have more than 1 child, GetClassIndexing shouldn't fail in this case.
|
||||
// This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
|
||||
// always be in the front of the child_ structure, so we get data from the last child.
|
||||
return child_[child_.size() - 1]->GetClassIndexing(output_class_indexing);
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Can't get the class index for the current tree.");
|
||||
*output_class_indexing = {};
|
||||
RETURN_STATUS_UNEXPECTED("Trying to get class index from leaf node, missing override");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -478,17 +493,31 @@ void DatasetOp::UpdateRepeatAndEpochCounter() {
|
|||
if (op_current_repeats_ % op_num_repeats_per_epoch_ == 0) op_current_epochs_++;
|
||||
MS_LOG(DEBUG) << Name() << " current repeats: " << op_current_repeats_ << ", current epochs: " << op_current_epochs_;
|
||||
}
|
||||
|
||||
int64_t DatasetOp::GetTreeBatchSize() {
|
||||
if (!child_.empty()) {
|
||||
if (child_.size() == 1) {
|
||||
return child_[0]->GetTreeBatchSize();
|
||||
} else if (child_.size() > 1) {
|
||||
// It is okay for dataset to have more than 1 child, GetBatchSize shouldn't fail in this case.
|
||||
// This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
|
||||
// always be in front of the child_ structure, so we get data from the last child.
|
||||
return child_[child_.size() - 1]->GetTreeBatchSize();
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
int64_t DatasetOp::GetTreeRepeatCount() {
|
||||
if (!child_.empty()) {
|
||||
if (child_.size() == 1) {
|
||||
return child_[0]->GetTreeRepeatCount();
|
||||
} else if (child_.size() > 1) {
|
||||
// It is okay for dataset to have more than 1 child, GetRepeatCount shouldn't fail in this case.
|
||||
// This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
|
||||
// always be in front of the child_ structure, so we get data from the last child.
|
||||
return child_[child_.size() - 1]->GetTreeRepeatCount();
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -70,15 +70,6 @@ Status TFRecordNode::ValidateParams() {
|
|||
return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
|
||||
if (cache_ == nullptr && !shard_equal_rows_ && dataset_files_.size() < num_shards_) {
|
||||
// This check only makes sense in a non-cache path. We should make sure there is at least one file per
|
||||
// shard in file-based sharding
|
||||
std::string err_msg =
|
||||
"TFRecordNode: Invalid number of dataset files, should at least be " + std::to_string(num_shards_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
|
||||
std::vector<std::string> invalid_files(dataset_files_.size());
|
||||
auto it = std::copy_if(dataset_files_.begin(), dataset_files_.end(), invalid_files.begin(),
|
||||
[](const std::string &filename) { return !TFReaderOp::ValidateFirstRowCrc(filename); });
|
||||
|
|
|
@ -20,7 +20,9 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#ifdef ENABLE_ANDROID
|
||||
#include "minddata/dataset/include/de_tensor.h"
|
||||
#endif
|
||||
#include "minddata/dataset/include/tensor.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
|
||||
|
@ -35,10 +37,16 @@ class Execute {
|
|||
/// \brief Constructor
|
||||
explicit Execute(std::shared_ptr<TensorOperation> op);
|
||||
|
||||
#ifdef ENABLE_ANDROID
|
||||
/// \brief callable function to execute the TensorOperation in eager mode
|
||||
/// \param[inout] input - the tensor to be transformed
|
||||
/// \return - the output tensor, nullptr if Compute fails
|
||||
std::shared_ptr<tensor::MSTensor> operator()(std::shared_ptr<tensor::MSTensor> input);
|
||||
#endif
|
||||
|
||||
/// \brief callable function to execute the TensorOperation in eager mode
|
||||
/// \param[inout] input - the tensor to be transformed
|
||||
/// \return - the output tensor, nullptr if Compute fails
|
||||
std::shared_ptr<dataset::Tensor> operator()(std::shared_ptr<dataset::Tensor> input);
|
||||
|
||||
private:
|
||||
|
|
|
@ -57,6 +57,7 @@ SET(DE_UT_SRCS
|
|||
distributed_sampler_test.cc
|
||||
epoch_ctrl_op_test.cc
|
||||
equalize_op_test.cc
|
||||
execute_test.cc
|
||||
execution_tree_test.cc
|
||||
fill_op_test.cc
|
||||
global_context_test.cc
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "common/cvop_common.h"
|
||||
#include "minddata/dataset/include/execute.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/include/vision.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::LogStream;
|
||||
|
||||
class MindDataTestExecute : public UT::CVOP::CVOpCommon {
|
||||
protected:
|
||||
MindDataTestExecute() : CVOpCommon() {}
|
||||
|
||||
std::shared_ptr<Tensor> output_tensor_;
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestExecute, TestOp1) {
|
||||
MS_LOG(INFO) << "Doing testCrop.";
|
||||
// Crop params
|
||||
std::shared_ptr<TensorOperation> center_crop = vision::CenterCrop({30});
|
||||
std::shared_ptr<Tensor> out_image = Execute(std::move(center_crop))(input_tensor_);
|
||||
EXPECT_NE(out_image, nullptr);
|
||||
EXPECT_EQ(30, out_image->shape()[0]);
|
||||
EXPECT_EQ(30, out_image->shape()[1]);
|
||||
}
|
Loading…
Reference in New Issue