forked from mindspore-Ecosystem/mindspore
!8235 Some enhancements to the RuntimeContext and TreeConsumer
Merge pull request !8235 from h.farahat/consumer_changes
This commit is contained in:
commit
27218ad9a3
|
@ -111,7 +111,7 @@ bool Dataset::DeviceQueue(bool send_epoch_end) {
|
|||
Status rc;
|
||||
|
||||
// Build and launch tree
|
||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
|
||||
|
@ -147,7 +147,7 @@ bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string data
|
|||
Status rc;
|
||||
// Build and launch tree
|
||||
auto ds = shared_from_this();
|
||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "CreateSaver failed." << rc;
|
||||
|
@ -193,7 +193,7 @@ Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
|
|||
int64_t Dataset::GetDatasetSize() {
|
||||
int64_t dataset_size;
|
||||
Status rc;
|
||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
|
||||
|
@ -213,7 +213,7 @@ int64_t Dataset::GetDatasetSize() {
|
|||
std::vector<DataType> Dataset::GetOutputTypes() {
|
||||
std::vector<DataType> types;
|
||||
Status rc;
|
||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed.";
|
||||
|
@ -240,7 +240,7 @@ std::vector<DataType> Dataset::GetOutputTypes() {
|
|||
std::vector<TensorShape> Dataset::GetOutputShapes() {
|
||||
std::vector<TensorShape> shapes;
|
||||
Status rc;
|
||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed.";
|
||||
|
@ -268,7 +268,7 @@ int64_t Dataset::GetNumClasses() {
|
|||
int64_t num_classes;
|
||||
auto ds = shared_from_this();
|
||||
Status rc;
|
||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed.";
|
||||
|
@ -562,7 +562,7 @@ int64_t Dataset::GetBatchSize() {
|
|||
int64_t batch_size;
|
||||
auto ds = shared_from_this();
|
||||
Status rc;
|
||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed.";
|
||||
|
@ -583,7 +583,7 @@ int64_t Dataset::GetRepeatCount() {
|
|||
int64_t repeat_count;
|
||||
auto ds = shared_from_this();
|
||||
Status rc;
|
||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed.";
|
||||
|
@ -613,7 +613,7 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
|
|||
auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage,
|
||||
model_type, params);
|
||||
|
||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
Status rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init runtime context. Error status: " << rc;
|
||||
|
@ -645,7 +645,7 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
|
|||
auto ds =
|
||||
std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first);
|
||||
|
||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
Status rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "BuildVocab: Failed to init runtime context. Error status: " << rc;
|
||||
|
|
|
@ -48,7 +48,7 @@ void Iterator::Stop() { runtime_context_->Terminate(); }
|
|||
|
||||
// Function to build and launch the execution tree.
|
||||
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
|
||||
runtime_context_ = std::make_unique<RuntimeContext>();
|
||||
runtime_context_ = std::make_unique<NativeRuntimeContext>();
|
||||
RETURN_IF_NOT_OK(runtime_context_->Init());
|
||||
auto consumer = std::make_unique<IteratorConsumer>();
|
||||
consumer_ = consumer.get();
|
||||
|
|
|
@ -19,9 +19,24 @@
|
|||
|
||||
namespace mindspore::dataset {
|
||||
|
||||
Status PythonRuntimeContext::Terminate() {
|
||||
Status PythonRuntimeContext::Terminate() { return TerminateImpl(); }
|
||||
|
||||
Status PythonRuntimeContext::TerminateImpl() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized");
|
||||
// Release GIL before joining all threads
|
||||
py::gil_scoped_release gil_release;
|
||||
return tree_consumer_->Terminate();
|
||||
}
|
||||
|
||||
PythonRuntimeContext::~PythonRuntimeContext() {
|
||||
TerminateImpl();
|
||||
{
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
tree_consumer_.reset();
|
||||
}
|
||||
}
|
||||
|
||||
PythonIteratorConsumer *PythonRuntimeContext::GetPythonConsumer() {
|
||||
return dynamic_cast<PythonIteratorConsumer *>(tree_consumer_.get());
|
||||
}
|
||||
} // namespace mindspore::dataset
|
||||
|
|
|
@ -24,25 +24,24 @@
|
|||
#include "minddata/dataset/engine/runtime_context.h"
|
||||
|
||||
namespace mindspore::dataset {
|
||||
class RuntimeContext;
|
||||
class NativeRuntimeContext;
|
||||
|
||||
/// Class that represents single runtime instance which can consume data from a data pipeline
|
||||
/// Class that represents Python single runtime instance which can consume data from a data pipeline
|
||||
class PythonRuntimeContext : public RuntimeContext {
|
||||
public:
|
||||
/// Method to terminate the runtime, this will not release the resources
|
||||
/// \return Status error code
|
||||
Status Terminate() override;
|
||||
|
||||
// Safe destructing the tree that includes python objects
|
||||
~PythonRuntimeContext() {
|
||||
Terminate();
|
||||
{
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
tree_consumer_.reset();
|
||||
}
|
||||
}
|
||||
/// Safe destructing the tree that includes python objects
|
||||
~PythonRuntimeContext() override;
|
||||
|
||||
PythonIteratorConsumer *GetPythonConsumer() { return dynamic_cast<PythonIteratorConsumer *>(tree_consumer_.get()); }
|
||||
PythonIteratorConsumer *GetPythonConsumer();
|
||||
|
||||
private:
|
||||
/// Internal function to perform the termination
|
||||
/// \return Status error code
|
||||
Status TerminateImpl();
|
||||
};
|
||||
|
||||
} // namespace mindspore::dataset
|
||||
|
|
|
@ -22,4 +22,17 @@ namespace mindspore::dataset {
|
|||
void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) {
|
||||
tree_consumer_ = std::move(tree_consumer);
|
||||
}
|
||||
Status NativeRuntimeContext::Terminate() { return TerminateImpl(); }
|
||||
|
||||
Status NativeRuntimeContext::TerminateImpl() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized");
|
||||
return tree_consumer_->Terminate();
|
||||
}
|
||||
|
||||
NativeRuntimeContext::~NativeRuntimeContext() { TerminateImpl(); }
|
||||
|
||||
TreeConsumer *RuntimeContext::GetConsumer() { return tree_consumer_.get(); }
|
||||
|
||||
Status RuntimeContext::Init() { return GlobalInit(); }
|
||||
|
||||
} // namespace mindspore::dataset
|
||||
|
|
|
@ -23,8 +23,7 @@
|
|||
|
||||
namespace mindspore::dataset {
|
||||
class TreeConsumer;
|
||||
|
||||
/// Class the represents single runtime instance which can consume data from a data pipeline
|
||||
/// Class that represents single runtime instance which can consume data from a data pipeline
|
||||
class RuntimeContext {
|
||||
public:
|
||||
/// Default constructor
|
||||
|
@ -32,11 +31,7 @@ class RuntimeContext {
|
|||
|
||||
/// Initialize the runtime, for now we just call the global init
|
||||
/// \return Status error code
|
||||
Status Init() { return GlobalInit(); }
|
||||
|
||||
/// Method to terminate the runtime, this will not release the resources
|
||||
/// \return Status error code
|
||||
virtual Status Terminate() { return Status::OK(); }
|
||||
Status Init();
|
||||
|
||||
/// Set the tree consumer
|
||||
/// \param tree_consumer to be assigned
|
||||
|
@ -44,13 +39,32 @@ class RuntimeContext {
|
|||
|
||||
/// Get the tree consumer
|
||||
/// \return Raw pointer to the tree consumer.
|
||||
TreeConsumer *GetConsumer() { return tree_consumer_.get(); }
|
||||
TreeConsumer *GetConsumer();
|
||||
|
||||
~RuntimeContext() { Terminate(); }
|
||||
/// Method to terminate the runtime, this will not release the resources
|
||||
/// \return Status error code
|
||||
virtual Status Terminate() = 0;
|
||||
|
||||
virtual ~RuntimeContext() = default;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<TreeConsumer> tree_consumer_;
|
||||
};
|
||||
|
||||
/// Class that represents C++ single runtime instance which can consume data from a data pipeline
|
||||
class NativeRuntimeContext : public RuntimeContext {
|
||||
public:
|
||||
/// Method to terminate the runtime, this will not release the resources
|
||||
/// \return Status error code
|
||||
Status Terminate() override;
|
||||
|
||||
~NativeRuntimeContext() override;
|
||||
|
||||
private:
|
||||
/// Internal function to perform the termination
|
||||
/// \return Status error code
|
||||
Status TerminateImpl();
|
||||
};
|
||||
|
||||
} // namespace mindspore::dataset
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_
|
||||
|
|
|
@ -33,7 +33,7 @@ class DatasetIterator;
|
|||
class DatasetOp;
|
||||
class Tensor;
|
||||
|
||||
class RuntimeContext;
|
||||
class NativeRuntimeContext;
|
||||
class IteratorConsumer;
|
||||
|
||||
class Dataset;
|
||||
|
@ -113,7 +113,7 @@ class Iterator {
|
|||
_Iterator end() { return _Iterator(nullptr); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<RuntimeContext> runtime_context_;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context_;
|
||||
IteratorConsumer *consumer_;
|
||||
};
|
||||
} // namespace dataset
|
||||
|
|
Loading…
Reference in New Issue