!8235 Some enhancements to the RuntimeContext and TreeConsumer

Merge pull request !8235 from h.farahat/consumer_changes
This commit is contained in:
mindspore-ci-bot 2020-11-05 06:29:31 +08:00 committed by Gitee
commit 27218ad9a3
7 changed files with 75 additions and 34 deletions

View File

@ -111,7 +111,7 @@ bool Dataset::DeviceQueue(bool send_epoch_end) {
Status rc; Status rc;
// Build and launch tree // Build and launch tree
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc; MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
@ -147,7 +147,7 @@ bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string data
Status rc; Status rc;
// Build and launch tree // Build and launch tree
auto ds = shared_from_this(); auto ds = shared_from_this();
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "CreateSaver failed." << rc; MS_LOG(ERROR) << "CreateSaver failed." << rc;
@ -193,7 +193,7 @@ Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
int64_t Dataset::GetDatasetSize() { int64_t Dataset::GetDatasetSize() {
int64_t dataset_size; int64_t dataset_size;
Status rc; Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
@ -213,7 +213,7 @@ int64_t Dataset::GetDatasetSize() {
std::vector<DataType> Dataset::GetOutputTypes() { std::vector<DataType> Dataset::GetOutputTypes() {
std::vector<DataType> types; std::vector<DataType> types;
Status rc; Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed."; MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed.";
@ -240,7 +240,7 @@ std::vector<DataType> Dataset::GetOutputTypes() {
std::vector<TensorShape> Dataset::GetOutputShapes() { std::vector<TensorShape> Dataset::GetOutputShapes() {
std::vector<TensorShape> shapes; std::vector<TensorShape> shapes;
Status rc; Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed."; MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed.";
@ -268,7 +268,7 @@ int64_t Dataset::GetNumClasses() {
int64_t num_classes; int64_t num_classes;
auto ds = shared_from_this(); auto ds = shared_from_this();
Status rc; Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed."; MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed.";
@ -562,7 +562,7 @@ int64_t Dataset::GetBatchSize() {
int64_t batch_size; int64_t batch_size;
auto ds = shared_from_this(); auto ds = shared_from_this();
Status rc; Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed."; MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed.";
@ -583,7 +583,7 @@ int64_t Dataset::GetRepeatCount() {
int64_t repeat_count; int64_t repeat_count;
auto ds = shared_from_this(); auto ds = shared_from_this();
Status rc; Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed."; MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed.";
@ -613,7 +613,7 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage, auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage,
model_type, params); model_type, params);
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
Status rc = runtime_context->Init(); Status rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init runtime context. Error status: " << rc; MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init runtime context. Error status: " << rc;
@ -645,7 +645,7 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
auto ds = auto ds =
std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first); std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first);
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
Status rc = runtime_context->Init(); Status rc = runtime_context->Init();
if (rc.IsError()) { if (rc.IsError()) {
MS_LOG(ERROR) << "BuildVocab: Failed to init runtime context. Error status: " << rc; MS_LOG(ERROR) << "BuildVocab: Failed to init runtime context. Error status: " << rc;

View File

@ -48,7 +48,7 @@ void Iterator::Stop() { runtime_context_->Terminate(); }
// Function to build and launch the execution tree. // Function to build and launch the execution tree.
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) { Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
runtime_context_ = std::make_unique<RuntimeContext>(); runtime_context_ = std::make_unique<NativeRuntimeContext>();
RETURN_IF_NOT_OK(runtime_context_->Init()); RETURN_IF_NOT_OK(runtime_context_->Init());
auto consumer = std::make_unique<IteratorConsumer>(); auto consumer = std::make_unique<IteratorConsumer>();
consumer_ = consumer.get(); consumer_ = consumer.get();

View File

@ -19,9 +19,24 @@
namespace mindspore::dataset { namespace mindspore::dataset {
Status PythonRuntimeContext::Terminate() { Status PythonRuntimeContext::Terminate() { return TerminateImpl(); }
Status PythonRuntimeContext::TerminateImpl() {
CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized");
// Release GIL before joining all threads // Release GIL before joining all threads
py::gil_scoped_release gil_release; py::gil_scoped_release gil_release;
return tree_consumer_->Terminate(); return tree_consumer_->Terminate();
} }
PythonRuntimeContext::~PythonRuntimeContext() {
TerminateImpl();
{
py::gil_scoped_acquire gil_acquire;
tree_consumer_.reset();
}
}
PythonIteratorConsumer *PythonRuntimeContext::GetPythonConsumer() {
return dynamic_cast<PythonIteratorConsumer *>(tree_consumer_.get());
}
} // namespace mindspore::dataset } // namespace mindspore::dataset

View File

@ -24,25 +24,24 @@
#include "minddata/dataset/engine/runtime_context.h" #include "minddata/dataset/engine/runtime_context.h"
namespace mindspore::dataset { namespace mindspore::dataset {
class RuntimeContext; class NativeRuntimeContext;
/// Class that represents single runtime instance which can consume data from a data pipeline /// Class that represents Python single runtime instance which can consume data from a data pipeline
class PythonRuntimeContext : public RuntimeContext { class PythonRuntimeContext : public RuntimeContext {
public: public:
/// Method to terminate the runtime, this will not release the resources /// Method to terminate the runtime, this will not release the resources
/// \return Status error code /// \return Status error code
Status Terminate() override; Status Terminate() override;
// Safe destructing the tree that includes python objects /// Safe destructing the tree that includes python objects
~PythonRuntimeContext() { ~PythonRuntimeContext() override;
Terminate();
{
py::gil_scoped_acquire gil_acquire;
tree_consumer_.reset();
}
}
PythonIteratorConsumer *GetPythonConsumer() { return dynamic_cast<PythonIteratorConsumer *>(tree_consumer_.get()); } PythonIteratorConsumer *GetPythonConsumer();
private:
/// Internal function to perform the termination
/// \return Status error code
Status TerminateImpl();
}; };
} // namespace mindspore::dataset } // namespace mindspore::dataset

View File

@ -22,4 +22,17 @@ namespace mindspore::dataset {
void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) { void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) {
tree_consumer_ = std::move(tree_consumer); tree_consumer_ = std::move(tree_consumer);
} }
Status NativeRuntimeContext::Terminate() { return TerminateImpl(); }
Status NativeRuntimeContext::TerminateImpl() {
CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized");
return tree_consumer_->Terminate();
}
NativeRuntimeContext::~NativeRuntimeContext() { TerminateImpl(); }
TreeConsumer *RuntimeContext::GetConsumer() { return tree_consumer_.get(); }
Status RuntimeContext::Init() { return GlobalInit(); }
} // namespace mindspore::dataset } // namespace mindspore::dataset

View File

@ -23,8 +23,7 @@
namespace mindspore::dataset { namespace mindspore::dataset {
class TreeConsumer; class TreeConsumer;
/// Class that represents single runtime instance which can consume data from a data pipeline
/// Class the represents single runtime instance which can consume data from a data pipeline
class RuntimeContext { class RuntimeContext {
public: public:
/// Default constructor /// Default constructor
@ -32,11 +31,7 @@ class RuntimeContext {
/// Initialize the runtime, for now we just call the global init /// Initialize the runtime, for now we just call the global init
/// \return Status error code /// \return Status error code
Status Init() { return GlobalInit(); } Status Init();
/// Method to terminate the runtime, this will not release the resources
/// \return Status error code
virtual Status Terminate() { return Status::OK(); }
/// Set the tree consumer /// Set the tree consumer
/// \param tree_consumer to be assigned /// \param tree_consumer to be assigned
@ -44,13 +39,32 @@ class RuntimeContext {
/// Get the tree consumer /// Get the tree consumer
/// \return Raw pointer to the tree consumer. /// \return Raw pointer to the tree consumer.
TreeConsumer *GetConsumer() { return tree_consumer_.get(); } TreeConsumer *GetConsumer();
~RuntimeContext() { Terminate(); } /// Method to terminate the runtime, this will not release the resources
/// \return Status error code
virtual Status Terminate() = 0;
virtual ~RuntimeContext() = default;
protected: protected:
std::shared_ptr<TreeConsumer> tree_consumer_; std::shared_ptr<TreeConsumer> tree_consumer_;
}; };
/// Class that represents C++ single runtime instance which can consume data from a data pipeline
class NativeRuntimeContext : public RuntimeContext {
public:
/// Method to terminate the runtime, this will not release the resources
/// \return Status error code
Status Terminate() override;
~NativeRuntimeContext() override;
private:
/// Internal function to perform the termination
/// \return Status error code
Status TerminateImpl();
};
} // namespace mindspore::dataset } // namespace mindspore::dataset
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_

View File

@ -33,7 +33,7 @@ class DatasetIterator;
class DatasetOp; class DatasetOp;
class Tensor; class Tensor;
class RuntimeContext; class NativeRuntimeContext;
class IteratorConsumer; class IteratorConsumer;
class Dataset; class Dataset;
@ -113,7 +113,7 @@ class Iterator {
_Iterator end() { return _Iterator(nullptr); } _Iterator end() { return _Iterator(nullptr); }
private: private:
std::unique_ptr<RuntimeContext> runtime_context_; std::unique_ptr<NativeRuntimeContext> runtime_context_;
IteratorConsumer *consumer_; IteratorConsumer *consumer_;
}; };
} // namespace dataset } // namespace dataset