remove redundant root()

This commit is contained in:
Xiao Tianci 2020-10-27 10:14:18 +08:00
parent 9505f2adc5
commit 9649efb98a
3 changed files with 15 additions and 6 deletions

View File

@ -80,20 +80,26 @@ Status ToDevice::Init(std::shared_ptr<api::Dataset> d) {
Status ToDevice::Send() {
std::unique_ptr<DataBuffer> db;
RETURN_IF_NOT_OK(tree_adapter_->Launch());
RETURN_IF_NOT_OK(tree_adapter_->root()->GetNextBuffer(&db));
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
RETURN_IF_NOT_OK(root->GetNextBuffer(&db));
return Status::OK();
}
Status ToDevice::Continue() {
// tree_.root() must be DeviceQueueOp
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_adapter_->root().get());
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "ContinueSend only supported by DeviceQueueOp");
op->ContinueSend();
return Status::OK();
}
Status ToDevice::Stop() {
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_adapter_->root().get());
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp");
op->StopSend();
return Status::OK();

View File

@ -87,5 +87,10 @@ Status TreeAdapter::DFSBuildTree(std::shared_ptr<api::Dataset> ir, std::shared_p
return Status::OK();
}
Status TreeAdapter::Launch() const {
CHECK_FAIL_RETURN_UNEXPECTED(tree_ != nullptr, "Tree is a nullptr.");
return tree_->Launch();
}
} // namespace dataset
} // namespace mindspore

View File

@ -57,9 +57,7 @@ class TreeAdapter {
// to be able to launch a thread. BuildAndPrepare needs to be called before this function
TaskGroup *AllTasks() const { return tree_ != nullptr ? tree_->AllTasks() : nullptr; }
std::shared_ptr<DatasetOp> root() { return tree_->root(); }
Status Launch() const { return tree_->Launch(); }
Status Launch() const;
private:
// This RECURSIVE function converts IR nodes into DatasetOp in ExecutionTree. IR could build a vector of ops. In