forked from mindspore-Ecosystem/mindspore
!8235 Some enhancements to the RuntimeContext and TreeConsumer
Merge pull request !8235 from h.farahat/consumer_changes
This commit is contained in:
commit
27218ad9a3
|
@ -111,7 +111,7 @@ bool Dataset::DeviceQueue(bool send_epoch_end) {
|
||||||
Status rc;
|
Status rc;
|
||||||
|
|
||||||
// Build and launch tree
|
// Build and launch tree
|
||||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||||
rc = runtime_context->Init();
|
rc = runtime_context->Init();
|
||||||
if (rc.IsError()) {
|
if (rc.IsError()) {
|
||||||
MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
|
MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
|
||||||
|
@ -147,7 +147,7 @@ bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string data
|
||||||
Status rc;
|
Status rc;
|
||||||
// Build and launch tree
|
// Build and launch tree
|
||||||
auto ds = shared_from_this();
|
auto ds = shared_from_this();
|
||||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||||
rc = runtime_context->Init();
|
rc = runtime_context->Init();
|
||||||
if (rc.IsError()) {
|
if (rc.IsError()) {
|
||||||
MS_LOG(ERROR) << "CreateSaver failed." << rc;
|
MS_LOG(ERROR) << "CreateSaver failed." << rc;
|
||||||
|
@ -193,7 +193,7 @@ Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
|
||||||
int64_t Dataset::GetDatasetSize() {
|
int64_t Dataset::GetDatasetSize() {
|
||||||
int64_t dataset_size;
|
int64_t dataset_size;
|
||||||
Status rc;
|
Status rc;
|
||||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||||
rc = runtime_context->Init();
|
rc = runtime_context->Init();
|
||||||
if (rc.IsError()) {
|
if (rc.IsError()) {
|
||||||
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
|
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
|
||||||
|
@ -213,7 +213,7 @@ int64_t Dataset::GetDatasetSize() {
|
||||||
std::vector<DataType> Dataset::GetOutputTypes() {
|
std::vector<DataType> Dataset::GetOutputTypes() {
|
||||||
std::vector<DataType> types;
|
std::vector<DataType> types;
|
||||||
Status rc;
|
Status rc;
|
||||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||||
rc = runtime_context->Init();
|
rc = runtime_context->Init();
|
||||||
if (rc.IsError()) {
|
if (rc.IsError()) {
|
||||||
MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed.";
|
MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed.";
|
||||||
|
@ -240,7 +240,7 @@ std::vector<DataType> Dataset::GetOutputTypes() {
|
||||||
std::vector<TensorShape> Dataset::GetOutputShapes() {
|
std::vector<TensorShape> Dataset::GetOutputShapes() {
|
||||||
std::vector<TensorShape> shapes;
|
std::vector<TensorShape> shapes;
|
||||||
Status rc;
|
Status rc;
|
||||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||||
rc = runtime_context->Init();
|
rc = runtime_context->Init();
|
||||||
if (rc.IsError()) {
|
if (rc.IsError()) {
|
||||||
MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed.";
|
MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed.";
|
||||||
|
@ -268,7 +268,7 @@ int64_t Dataset::GetNumClasses() {
|
||||||
int64_t num_classes;
|
int64_t num_classes;
|
||||||
auto ds = shared_from_this();
|
auto ds = shared_from_this();
|
||||||
Status rc;
|
Status rc;
|
||||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||||
rc = runtime_context->Init();
|
rc = runtime_context->Init();
|
||||||
if (rc.IsError()) {
|
if (rc.IsError()) {
|
||||||
MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed.";
|
MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed.";
|
||||||
|
@ -562,7 +562,7 @@ int64_t Dataset::GetBatchSize() {
|
||||||
int64_t batch_size;
|
int64_t batch_size;
|
||||||
auto ds = shared_from_this();
|
auto ds = shared_from_this();
|
||||||
Status rc;
|
Status rc;
|
||||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||||
rc = runtime_context->Init();
|
rc = runtime_context->Init();
|
||||||
if (rc.IsError()) {
|
if (rc.IsError()) {
|
||||||
MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed.";
|
MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed.";
|
||||||
|
@ -583,7 +583,7 @@ int64_t Dataset::GetRepeatCount() {
|
||||||
int64_t repeat_count;
|
int64_t repeat_count;
|
||||||
auto ds = shared_from_this();
|
auto ds = shared_from_this();
|
||||||
Status rc;
|
Status rc;
|
||||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||||
rc = runtime_context->Init();
|
rc = runtime_context->Init();
|
||||||
if (rc.IsError()) {
|
if (rc.IsError()) {
|
||||||
MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed.";
|
MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed.";
|
||||||
|
@ -613,7 +613,7 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
|
||||||
auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage,
|
auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage,
|
||||||
model_type, params);
|
model_type, params);
|
||||||
|
|
||||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||||
Status rc = runtime_context->Init();
|
Status rc = runtime_context->Init();
|
||||||
if (rc.IsError()) {
|
if (rc.IsError()) {
|
||||||
MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init runtime context. Error status: " << rc;
|
MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init runtime context. Error status: " << rc;
|
||||||
|
@ -645,7 +645,7 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
|
||||||
auto ds =
|
auto ds =
|
||||||
std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first);
|
std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first);
|
||||||
|
|
||||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||||
Status rc = runtime_context->Init();
|
Status rc = runtime_context->Init();
|
||||||
if (rc.IsError()) {
|
if (rc.IsError()) {
|
||||||
MS_LOG(ERROR) << "BuildVocab: Failed to init runtime context. Error status: " << rc;
|
MS_LOG(ERROR) << "BuildVocab: Failed to init runtime context. Error status: " << rc;
|
||||||
|
|
|
@ -48,7 +48,7 @@ void Iterator::Stop() { runtime_context_->Terminate(); }
|
||||||
|
|
||||||
// Function to build and launch the execution tree.
|
// Function to build and launch the execution tree.
|
||||||
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
|
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
|
||||||
runtime_context_ = std::make_unique<RuntimeContext>();
|
runtime_context_ = std::make_unique<NativeRuntimeContext>();
|
||||||
RETURN_IF_NOT_OK(runtime_context_->Init());
|
RETURN_IF_NOT_OK(runtime_context_->Init());
|
||||||
auto consumer = std::make_unique<IteratorConsumer>();
|
auto consumer = std::make_unique<IteratorConsumer>();
|
||||||
consumer_ = consumer.get();
|
consumer_ = consumer.get();
|
||||||
|
|
|
@ -19,9 +19,24 @@
|
||||||
|
|
||||||
namespace mindspore::dataset {
|
namespace mindspore::dataset {
|
||||||
|
|
||||||
Status PythonRuntimeContext::Terminate() {
|
Status PythonRuntimeContext::Terminate() { return TerminateImpl(); }
|
||||||
|
|
||||||
|
Status PythonRuntimeContext::TerminateImpl() {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized");
|
||||||
// Release GIL before joining all threads
|
// Release GIL before joining all threads
|
||||||
py::gil_scoped_release gil_release;
|
py::gil_scoped_release gil_release;
|
||||||
return tree_consumer_->Terminate();
|
return tree_consumer_->Terminate();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PythonRuntimeContext::~PythonRuntimeContext() {
|
||||||
|
TerminateImpl();
|
||||||
|
{
|
||||||
|
py::gil_scoped_acquire gil_acquire;
|
||||||
|
tree_consumer_.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PythonIteratorConsumer *PythonRuntimeContext::GetPythonConsumer() {
|
||||||
|
return dynamic_cast<PythonIteratorConsumer *>(tree_consumer_.get());
|
||||||
|
}
|
||||||
} // namespace mindspore::dataset
|
} // namespace mindspore::dataset
|
||||||
|
|
|
@ -24,25 +24,24 @@
|
||||||
#include "minddata/dataset/engine/runtime_context.h"
|
#include "minddata/dataset/engine/runtime_context.h"
|
||||||
|
|
||||||
namespace mindspore::dataset {
|
namespace mindspore::dataset {
|
||||||
class RuntimeContext;
|
class NativeRuntimeContext;
|
||||||
|
|
||||||
/// Class that represents single runtime instance which can consume data from a data pipeline
|
/// Class that represents Python single runtime instance which can consume data from a data pipeline
|
||||||
class PythonRuntimeContext : public RuntimeContext {
|
class PythonRuntimeContext : public RuntimeContext {
|
||||||
public:
|
public:
|
||||||
/// Method to terminate the runtime, this will not release the resources
|
/// Method to terminate the runtime, this will not release the resources
|
||||||
/// \return Status error code
|
/// \return Status error code
|
||||||
Status Terminate() override;
|
Status Terminate() override;
|
||||||
|
|
||||||
// Safe destructing the tree that includes python objects
|
/// Safe destructing the tree that includes python objects
|
||||||
~PythonRuntimeContext() {
|
~PythonRuntimeContext() override;
|
||||||
Terminate();
|
|
||||||
{
|
|
||||||
py::gil_scoped_acquire gil_acquire;
|
|
||||||
tree_consumer_.reset();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
PythonIteratorConsumer *GetPythonConsumer() { return dynamic_cast<PythonIteratorConsumer *>(tree_consumer_.get()); }
|
PythonIteratorConsumer *GetPythonConsumer();
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// Internal function to perform the termination
|
||||||
|
/// \return Status error code
|
||||||
|
Status TerminateImpl();
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mindspore::dataset
|
} // namespace mindspore::dataset
|
||||||
|
|
|
@ -22,4 +22,17 @@ namespace mindspore::dataset {
|
||||||
void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) {
|
void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) {
|
||||||
tree_consumer_ = std::move(tree_consumer);
|
tree_consumer_ = std::move(tree_consumer);
|
||||||
}
|
}
|
||||||
|
Status NativeRuntimeContext::Terminate() { return TerminateImpl(); }
|
||||||
|
|
||||||
|
Status NativeRuntimeContext::TerminateImpl() {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized");
|
||||||
|
return tree_consumer_->Terminate();
|
||||||
|
}
|
||||||
|
|
||||||
|
NativeRuntimeContext::~NativeRuntimeContext() { TerminateImpl(); }
|
||||||
|
|
||||||
|
TreeConsumer *RuntimeContext::GetConsumer() { return tree_consumer_.get(); }
|
||||||
|
|
||||||
|
Status RuntimeContext::Init() { return GlobalInit(); }
|
||||||
|
|
||||||
} // namespace mindspore::dataset
|
} // namespace mindspore::dataset
|
||||||
|
|
|
@ -23,8 +23,7 @@
|
||||||
|
|
||||||
namespace mindspore::dataset {
|
namespace mindspore::dataset {
|
||||||
class TreeConsumer;
|
class TreeConsumer;
|
||||||
|
/// Class that represents single runtime instance which can consume data from a data pipeline
|
||||||
/// Class the represents single runtime instance which can consume data from a data pipeline
|
|
||||||
class RuntimeContext {
|
class RuntimeContext {
|
||||||
public:
|
public:
|
||||||
/// Default constructor
|
/// Default constructor
|
||||||
|
@ -32,11 +31,7 @@ class RuntimeContext {
|
||||||
|
|
||||||
/// Initialize the runtime, for now we just call the global init
|
/// Initialize the runtime, for now we just call the global init
|
||||||
/// \return Status error code
|
/// \return Status error code
|
||||||
Status Init() { return GlobalInit(); }
|
Status Init();
|
||||||
|
|
||||||
/// Method to terminate the runtime, this will not release the resources
|
|
||||||
/// \return Status error code
|
|
||||||
virtual Status Terminate() { return Status::OK(); }
|
|
||||||
|
|
||||||
/// Set the tree consumer
|
/// Set the tree consumer
|
||||||
/// \param tree_consumer to be assigned
|
/// \param tree_consumer to be assigned
|
||||||
|
@ -44,13 +39,32 @@ class RuntimeContext {
|
||||||
|
|
||||||
/// Get the tree consumer
|
/// Get the tree consumer
|
||||||
/// \return Raw pointer to the tree consumer.
|
/// \return Raw pointer to the tree consumer.
|
||||||
TreeConsumer *GetConsumer() { return tree_consumer_.get(); }
|
TreeConsumer *GetConsumer();
|
||||||
|
|
||||||
~RuntimeContext() { Terminate(); }
|
/// Method to terminate the runtime, this will not release the resources
|
||||||
|
/// \return Status error code
|
||||||
|
virtual Status Terminate() = 0;
|
||||||
|
|
||||||
|
virtual ~RuntimeContext() = default;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::shared_ptr<TreeConsumer> tree_consumer_;
|
std::shared_ptr<TreeConsumer> tree_consumer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Class that represents C++ single runtime instance which can consume data from a data pipeline
|
||||||
|
class NativeRuntimeContext : public RuntimeContext {
|
||||||
|
public:
|
||||||
|
/// Method to terminate the runtime, this will not release the resources
|
||||||
|
/// \return Status error code
|
||||||
|
Status Terminate() override;
|
||||||
|
|
||||||
|
~NativeRuntimeContext() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// Internal function to perform the termination
|
||||||
|
/// \return Status error code
|
||||||
|
Status TerminateImpl();
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace mindspore::dataset
|
} // namespace mindspore::dataset
|
||||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_
|
||||||
|
|
|
@ -33,7 +33,7 @@ class DatasetIterator;
|
||||||
class DatasetOp;
|
class DatasetOp;
|
||||||
class Tensor;
|
class Tensor;
|
||||||
|
|
||||||
class RuntimeContext;
|
class NativeRuntimeContext;
|
||||||
class IteratorConsumer;
|
class IteratorConsumer;
|
||||||
|
|
||||||
class Dataset;
|
class Dataset;
|
||||||
|
@ -113,7 +113,7 @@ class Iterator {
|
||||||
_Iterator end() { return _Iterator(nullptr); }
|
_Iterator end() { return _Iterator(nullptr); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<RuntimeContext> runtime_context_;
|
std::unique_ptr<NativeRuntimeContext> runtime_context_;
|
||||||
IteratorConsumer *consumer_;
|
IteratorConsumer *consumer_;
|
||||||
};
|
};
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
|
|
Loading…
Reference in New Issue