!28528 Add lazy task callback to Tensor

Merge pull request !28528 from caifubi/master-pynative-tensor-sync-lazy-build
This commit is contained in:
i-robot 2022-01-28 07:08:16 +00:00 committed by Gitee
commit 4aa3173d9a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 40 additions and 22 deletions

View File

@ -230,18 +230,6 @@ TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index)
return tensor;
}
void UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, VectorRef *const outputs) {
MS_EXCEPTION_IF_NULL(outputs);
for (auto &item_with_index : output_nodes) {
MS_EXCEPTION_IF_NULL(item_with_index.first);
// if is graph return nothing ,the function should return a null anylist
if (AnfAlgo::GetOutputTensorNum(item_with_index.first) == 0) {
continue;
}
outputs->emplace_back(CreateOutputTensor(item_with_index.first, item_with_index.second));
}
}
void ClearGraphDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context, bool is_gradient_out) {
MS_EXCEPTION_IF_NULL(graph);
for (const auto &node : graph->execution_order()) {
@ -1390,9 +1378,8 @@ void MindRTBackend::RunOpInternal(bool single_op_cache_hit, GraphCompilerInfo *g
op_lazy_builder.PushOpBuildTask(std::make_shared<runtime::OpBuildTask>(run_op_context));
}
op_lazy_builder.PushOpRunTask(std::make_shared<runtime::OpRunTask>(run_op_context));
if (!op_lazy_builder.registered()) {
op_lazy_builder.Register([this]() { LazyExecuteTaskCallback(); });
}
// Callbacks need to be re-registered in heterogeneous scenarios.
op_lazy_builder.Register([this]() { LazyExecuteTaskCallback(); });
if (op_lazy_builder.QueueFull()) {
op_lazy_builder.ExecuteRemainingTasks();
}
@ -1450,5 +1437,19 @@ void MindRTBackend::CompileSingleOpGraph(const KernelGraphPtr &graph, const Devi
// So `Schedule` need to execute after `CreateKernelWorkspaceDeviceAddress`.
runtime::GraphScheduler::GetInstance().Schedule(actor_set);
}
void MindRTBackend::UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, VectorRef *const outputs) {
MS_EXCEPTION_IF_NULL(outputs);
for (auto &item_with_index : output_nodes) {
MS_EXCEPTION_IF_NULL(item_with_index.first);
if (AnfAlgo::GetOutputTensorNum(item_with_index.first) == 0) {
continue;
}
auto output_tensor = CreateOutputTensor(item_with_index.first, item_with_index.second);
MS_EXCEPTION_IF_NULL(output_tensor);
output_tensor->set_lazy_callback([]() { runtime::OpLazyBuilder::GetInstance().ExecuteRemainingTasks(); });
outputs->emplace_back(output_tensor);
}
}
} // namespace compile
} // namespace mindspore

View File

@ -179,6 +179,8 @@ class MindRTBackend : public Backend {
void RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs,
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs);
void UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, VectorRef *const outputs);
// When compiling FuncGraph, it is divided according to the control nodes, and obtain the control nodes and several
// node segments. Node segments will be compiled into kernelGraphs which are expressed as GraphId and bound to
// the corresponding device_context.

View File

@ -20,7 +20,6 @@
#include <functional>
#include <utility>
#include <algorithm>
#include "abstract/utils.h"
#include "abstract/abstract_value.h"
#include "base/complex_storage.h"
@ -497,7 +496,8 @@ Tensor::Tensor(const Tensor &tensor)
cache_tensor_ptr_(tensor.cache_tensor_ptr_),
hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
padding_type_(tensor.padding_type()),
device_event_(tensor.device_event_) {}
device_event_(tensor.device_event_),
lazy_callback_(tensor.lazy_callback_) {}
Tensor::Tensor(const Tensor &tensor, TypeId data_type)
: MetaTensor(data_type, tensor.shape_),
@ -513,7 +513,8 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
cache_tensor_ptr_(tensor.cache_tensor_ptr_),
hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
padding_type_(tensor.padding_type()),
device_event_(tensor.device_event_) {}
device_event_(tensor.device_event_),
lazy_callback_(tensor.lazy_callback_) {}
Tensor::Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data)
: MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {}
@ -571,9 +572,17 @@ bool Tensor::ValueEqual(const Tensor &tensor) const {
return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_)));
}
void Tensor::ExecuteLazyTask() const {
if (lazy_callback_ != nullptr) {
lazy_callback_();
}
}
// assign value to this tensor
Tensor &Tensor::AssignValue(const Tensor &tensor) {
if (this != &tensor) {
lazy_callback_ = tensor.lazy_callback_;
ExecuteLazyTask();
MetaTensor::operator=(tensor);
device_sync_ = tensor.device_sync_;
need_release_device_mem_ = tensor.need_release_device_mem_;
@ -649,6 +658,8 @@ std::string Tensor::ToStringRepr() const {
}
void Tensor::data_sync(bool need_wait) const {
ExecuteLazyTask();
if (need_wait) {
Wait();
}

View File

@ -528,7 +528,14 @@ class MS_CORE_API Tensor final : public MetaTensor {
/// \brief Set whether this Tensor is updated by the device.
void SetIsUpdateByDevice() { updated_by_device_ = true; }
/// \brief Set lazy callback function to this Tensor
///
/// \param[in] lazy_callback The callback from backend when lazy build is enabled
void set_lazy_callback(const std::function<void(void)> &lazy_callback) { lazy_callback_ = lazy_callback; }
private:
void ExecuteLazyTask() const;
bool init_flag_{false};
TensorDataPtr data_{nullptr};
std::string id_{""};
@ -546,6 +553,7 @@ class MS_CORE_API Tensor final : public MetaTensor {
std::string padding_type_{""};
TypePtr cast_dtype_{nullptr};
std::shared_ptr<DeviceEvent> device_event_{nullptr};
std::function<void(void)> lazy_callback_{nullptr};
};
using TensorPtr = std::shared_ptr<Tensor>;
using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;

View File

@ -23,7 +23,6 @@ from ._register_for_tensor import tensor_operator_registry
from .._c_expression import Tensor as Tensor_
from .._c_expression import CSRTensor as CSRTensor_
from .._c_expression import COOTensor as COOTensor_
from .._c_expression import PynativeExecutor_
from .._checkparam import Validator as validator
from .._checkparam import Rel
@ -170,7 +169,6 @@ class Tensor(Tensor_):
return new_obj
def __repr__(self):
PynativeExecutor_.get_instance().execute_lazy_task()
if self.init_finished:
Tensor_.data_sync(self, False)
return Tensor_.__repr__(self)
@ -392,7 +390,6 @@ class Tensor(Tensor_):
return Tensor(Tensor_.from_numpy(array))
def assign_value(self, value):
PynativeExecutor_.get_instance().execute_lazy_task()
self.assign_value_cpp(value)
return self
@ -486,7 +483,6 @@ class Tensor(Tensor_):
[11. 2.]
"""
self._init_check()
PynativeExecutor_.get_instance().execute_lazy_task()
return Tensor_.asnumpy(self)
def flush_from_cache(self):