!8235 Some enhancements to the RuntimeContext and TreeConsumer

Merge pull request !8235 from h.farahat/consumer_changes
This commit is contained in:
mindspore-ci-bot 2020-11-05 06:29:31 +08:00 committed by Gitee
commit 27218ad9a3
7 changed files with 75 additions and 34 deletions

View File

@ -111,7 +111,7 @@ bool Dataset::DeviceQueue(bool send_epoch_end) {
Status rc;
// Build and launch tree
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
@ -147,7 +147,7 @@ bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string data
Status rc;
// Build and launch tree
auto ds = shared_from_this();
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "CreateSaver failed." << rc;
@ -193,7 +193,7 @@ Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
int64_t Dataset::GetDatasetSize() {
int64_t dataset_size;
Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
@ -213,7 +213,7 @@ int64_t Dataset::GetDatasetSize() {
std::vector<DataType> Dataset::GetOutputTypes() {
std::vector<DataType> types;
Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed.";
@ -240,7 +240,7 @@ std::vector<DataType> Dataset::GetOutputTypes() {
std::vector<TensorShape> Dataset::GetOutputShapes() {
std::vector<TensorShape> shapes;
Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed.";
@ -268,7 +268,7 @@ int64_t Dataset::GetNumClasses() {
int64_t num_classes;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed.";
@ -562,7 +562,7 @@ int64_t Dataset::GetBatchSize() {
int64_t batch_size;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed.";
@ -583,7 +583,7 @@ int64_t Dataset::GetRepeatCount() {
int64_t repeat_count;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed.";
@ -613,7 +613,7 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage,
model_type, params);
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
Status rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init runtime context. Error status: " << rc;
@ -645,7 +645,7 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
auto ds =
std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first);
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
Status rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "BuildVocab: Failed to init runtime context. Error status: " << rc;

View File

@ -48,7 +48,7 @@ void Iterator::Stop() { runtime_context_->Terminate(); }
// Function to build and launch the execution tree.
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
runtime_context_ = std::make_unique<RuntimeContext>();
runtime_context_ = std::make_unique<NativeRuntimeContext>();
RETURN_IF_NOT_OK(runtime_context_->Init());
auto consumer = std::make_unique<IteratorConsumer>();
consumer_ = consumer.get();

View File

@ -19,9 +19,24 @@
namespace mindspore::dataset {
Status PythonRuntimeContext::Terminate() {
Status PythonRuntimeContext::Terminate() { return TerminateImpl(); }
Status PythonRuntimeContext::TerminateImpl() {
CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized");
// Release GIL before joining all threads
py::gil_scoped_release gil_release;
return tree_consumer_->Terminate();
}
PythonRuntimeContext::~PythonRuntimeContext() {
TerminateImpl();
{
py::gil_scoped_acquire gil_acquire;
tree_consumer_.reset();
}
}
PythonIteratorConsumer *PythonRuntimeContext::GetPythonConsumer() {
return dynamic_cast<PythonIteratorConsumer *>(tree_consumer_.get());
}
} // namespace mindspore::dataset

View File

@ -24,25 +24,24 @@
#include "minddata/dataset/engine/runtime_context.h"
namespace mindspore::dataset {
class RuntimeContext;
class NativeRuntimeContext;
/// Class that represents single runtime instance which can consume data from a data pipeline
/// Class that represents Python single runtime instance which can consume data from a data pipeline
class PythonRuntimeContext : public RuntimeContext {
public:
/// Method to terminate the runtime, this will not release the resources
/// \return Status error code
Status Terminate() override;
// Safe destructing the tree that includes python objects
~PythonRuntimeContext() {
Terminate();
{
py::gil_scoped_acquire gil_acquire;
tree_consumer_.reset();
}
}
/// Safe destructing the tree that includes python objects
~PythonRuntimeContext() override;
PythonIteratorConsumer *GetPythonConsumer() { return dynamic_cast<PythonIteratorConsumer *>(tree_consumer_.get()); }
PythonIteratorConsumer *GetPythonConsumer();
private:
/// Internal function to perform the termination
/// \return Status error code
Status TerminateImpl();
};
} // namespace mindspore::dataset

View File

@ -22,4 +22,17 @@ namespace mindspore::dataset {
void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) {
tree_consumer_ = std::move(tree_consumer);
}
Status NativeRuntimeContext::Terminate() { return TerminateImpl(); }
Status NativeRuntimeContext::TerminateImpl() {
CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized");
return tree_consumer_->Terminate();
}
NativeRuntimeContext::~NativeRuntimeContext() { TerminateImpl(); }
TreeConsumer *RuntimeContext::GetConsumer() { return tree_consumer_.get(); }
Status RuntimeContext::Init() { return GlobalInit(); }
} // namespace mindspore::dataset

View File

@ -23,8 +23,7 @@
namespace mindspore::dataset {
class TreeConsumer;
/// Class the represents single runtime instance which can consume data from a data pipeline
/// Class that represents single runtime instance which can consume data from a data pipeline
class RuntimeContext {
public:
/// Default constructor
@ -32,11 +31,7 @@ class RuntimeContext {
/// Initialize the runtime, for now we just call the global init
/// \return Status error code
Status Init() { return GlobalInit(); }
/// Method to terminate the runtime, this will not release the resources
/// \return Status error code
virtual Status Terminate() { return Status::OK(); }
Status Init();
/// Set the tree consumer
/// \param tree_consumer to be assigned
@ -44,13 +39,32 @@ class RuntimeContext {
/// Get the tree consumer
/// \return Raw pointer to the tree consumer.
TreeConsumer *GetConsumer() { return tree_consumer_.get(); }
TreeConsumer *GetConsumer();
~RuntimeContext() { Terminate(); }
/// Method to terminate the runtime, this will not release the resources
/// \return Status error code
virtual Status Terminate() = 0;
virtual ~RuntimeContext() = default;
protected:
std::shared_ptr<TreeConsumer> tree_consumer_;
};
/// Class that represents C++ single runtime instance which can consume data from a data pipeline
class NativeRuntimeContext : public RuntimeContext {
public:
/// Method to terminate the runtime, this will not release the resources
/// \return Status error code
Status Terminate() override;
~NativeRuntimeContext() override;
private:
/// Internal function to perform the termination
/// \return Status error code
Status TerminateImpl();
};
} // namespace mindspore::dataset
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_

View File

@ -33,7 +33,7 @@ class DatasetIterator;
class DatasetOp;
class Tensor;
class RuntimeContext;
class NativeRuntimeContext;
class IteratorConsumer;
class Dataset;
@ -113,7 +113,7 @@ class Iterator {
_Iterator end() { return _Iterator(nullptr); }
private:
std::unique_ptr<RuntimeContext> runtime_context_;
std::unique_ptr<NativeRuntimeContext> runtime_context_;
IteratorConsumer *consumer_;
};
} // namespace dataset