PyNative RunOp without Actor

This commit is contained in:
caifubi 2022-03-10 11:15:55 +08:00
parent 98888f6cd8
commit d26d7b21ff
19 changed files with 755 additions and 119 deletions

View File

@ -250,7 +250,7 @@ set(SUB_COMP
runtime/device
runtime/graph_scheduler
runtime/hardware
runtime/op_builder
runtime/pynative
plugin/device/ascend/hal/device
plugin/device/ascend/hal/hardware
plugin/device/ascend/hal/hccl_adapter

View File

@ -1155,27 +1155,6 @@ void AnfRuntimeAlgorithm::CacheAddrForAtomicClean(const AnfNodePtr &node, kernel
kernel_mod->set_inputs_addr(kernel_inputs);
}
std::string OpRuntimeInfo::output_format(size_t index) const {
if (index >= output_format_.size()) {
MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_format:" << output_format_.size();
}
return output_format_[index];
}
TypeId OpRuntimeInfo::output_type(size_t index) const {
if (index >= output_type_.size()) {
MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_type:" << output_type_.size();
}
return output_type_[index];
}
size_t OpRuntimeInfo::output_tensor_size(size_t index) const {
if (index >= output_tensor_size_.size()) {
MS_LOG(EXCEPTION) << "Invalid index::" << index << " total output_tensor_size:" << output_tensor_size_.size();
}
return output_tensor_size_[index];
}
void AnfRuntimeAlgorithm::UpdateGraphValidRefPair(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
const auto &origin_ref_map = graph->GetRefMap();

View File

@ -45,28 +45,6 @@ using DeviceAddressPtr = device::DeviceAddressPtr;
using Address = kernel::Address;
using AddressPtr = kernel::AddressPtr;
class OpRuntimeInfo {
public:
OpRuntimeInfo(std::vector<std::string> output_format, std::vector<TypeId> output_type,
std::vector<size_t> output_tensor_size)
: output_format_(std::move(output_format)),
output_type_(std::move(output_type)),
output_tensor_size_(std::move(output_tensor_size)) {}
~OpRuntimeInfo() = default;
// Key for user data.
constexpr static char key[] = "OpRuntimeInfo";
std::string output_format(size_t index) const;
TypeId output_type(size_t index) const;
size_t output_tensor_size(size_t index) const;
private:
std::vector<std::string> output_format_;
std::vector<TypeId> output_type_;
std::vector<size_t> output_tensor_size_;
};
class AnfRuntimeAlgorithm {
public:
static AnfNodePtr MakeMonadValueNode(const KernelGraphPtr &kg);

View File

@ -44,6 +44,7 @@
#include "include/common/utils/context/graph_kernel_flags.h"
#include "backend/common/optimizer/helper.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "runtime/pynative/op_runtime_info.h"
#include "include/common/utils/config_manager.h"
#ifndef ENABLE_SECURITY
#include "debug/data_dump/dump_json_parser.h"
@ -899,54 +900,10 @@ KernelGraphPtr AscendSession::PreBuildOp(const OpRunInfo &op_run_info,
opt::RunOpAscendBackendIRFusionOptimization(graph);
SelectKernel(graph);
RunOpHardwareOptimize(graph);
CacheCNodeOutputInfo(*graph);
runtime::OpRuntimeInfo::CacheGraphOpRuntimeInfo(graph);
return graph;
}
void AscendSession::CacheCNodeOutputInfo(const KernelGraph &graph) const {
auto &nodes = graph.execution_order();
for (auto const &node : nodes) {
std::vector<std::string> formats;
std::vector<TypeId> types;
std::vector<size_t> tensor_sizes;
auto output_num = common::AnfAlgo::GetOutputTensorNum(node);
for (size_t i = 0; i < output_num; ++i) {
std::string output_format = AnfAlgo::GetOutputFormat(node, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(node, i);
formats.emplace_back(output_format);
types.emplace_back(output_type);
tensor_sizes.emplace_back(tensor_size);
}
MS_EXCEPTION_IF_NULL(node);
node->set_user_data<OpRuntimeInfo>(std::make_shared<OpRuntimeInfo>(formats, types, tensor_sizes));
}
auto &inputs = graph.inputs();
for (const auto &input : inputs) {
MS_EXCEPTION_IF_NULL(input);
if (!input->isa<Parameter>()) {
continue;
}
std::vector<std::string> formats;
std::vector<TypeId> types;
std::vector<size_t> tensor_sizes;
auto output_size = common::AnfAlgo::GetOutputTensorNum(input);
for (size_t index = 0; index < output_size; index++) {
auto format = AnfAlgo::GetOutputFormat(input, index);
auto type_id = AnfAlgo::GetOutputDeviceDataType(input, index);
if (type_id == kTypeUnknown) {
type_id = common::AnfAlgo::GetOutputInferDataType(input, index);
}
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(input, index);
formats.emplace_back(format);
types.emplace_back(type_id);
tensor_sizes.emplace_back(tensor_size);
}
input->set_user_data<OpRuntimeInfo>(std::make_shared<OpRuntimeInfo>(formats, types, tensor_sizes));
}
}
void AscendSession::GetOpInputStubTensors(const CNodePtr &cnode, const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs,
const std::map<KernelWithIndex, OutputTensorInfo> &node_output_info,

View File

@ -140,7 +140,6 @@ class AscendSession : public SessionBasic {
#endif
void AssignStaticMemory(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
void UpdateRefOutputMap(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
void CacheCNodeOutputInfo(const KernelGraph &graph) const;
KernelGraphPtr PreBuildOp(const OpRunInfo &op_run_info, const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask);
void GetOpInputStubTensors(const CNodePtr &cnode, const std::map<AnfNodePtr, size_t> &parameter_index,

View File

@ -55,7 +55,7 @@
#endif
#include "backend/common/session/session_factory.h"
#include "backend/common/session/pynative_task_manager.h"
#include "runtime/op_builder/op_lazy_builder.h"
#include "runtime/pynative/op_lazy_builder.h"
#ifdef ENABLE_DEBUGGER
#include "debug/tensor_load.h"
#include "debug/debugger/proto_exporter.h"

View File

@ -22,7 +22,7 @@
#include "include/common/utils/parallel_context.h"
#include "backend/graph_compiler/transform.h"
#include "backend/common/session/session_factory.h"
#include "runtime/op_builder/op_lazy_builder.h"
#include "runtime/pynative/op_lazy_builder.h"
#include "backend/common/optimizer/helper.h"
#include "pipeline/pynative/pynative_execute.h"
#include "pipeline/jit/action.h"
@ -36,6 +36,7 @@
#include "utils/ms_utils.h"
#include "runtime/hardware/device_context_manager.h"
#include "runtime/graph_scheduler/graph_compiler.h"
#include "runtime/pynative/run_op_helper.h"
#include "include/common/utils/scoped_long_running.h"
#ifdef ENABLE_D
#include "include/common/utils/callbacks_ge.h"
@ -135,6 +136,22 @@ VectorRef MsBackend::MsSimuRunGraph(const GraphId &g) {
}
namespace {
std::vector<tensor::TensorPtr> GetTensorWithoutValueMask(const OpRunInfo &op_run_info) {
std::vector<tensor::TensorPtr> tensors_without_value_node;
const auto &input_tensors = op_run_info.input_tensors;
const auto &tensors_mask = op_run_info.tensor_mask;
if (input_tensors.size() != tensors_mask.size()) {
MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size "
<< tensors_mask.size();
}
for (size_t index = 0; index < tensors_mask.size(); ++index) {
if (tensors_mask.at(index) != kValueNodeTensorMask) {
(void)tensors_without_value_node.emplace_back(input_tensors.at(index));
}
}
return tensors_without_value_node;
}
void PushInputTensor(const BaseRef &arg, std::vector<tensor::TensorPtr> *inputs) {
MS_EXCEPTION_IF_NULL(inputs);
if (utils::isa<tensor::TensorPtr>(arg)) {
@ -232,10 +249,24 @@ TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index)
return tensor;
}
device::DeviceAddressPtr CloneEmptyDeviceAddress(const device::DeviceAddressPtr &old_device_address,
const DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(old_device_address);
MS_EXCEPTION_IF_NULL(device_context);
auto new_device_address =
device_context->CreateDeviceAddress(nullptr, old_device_address->GetSize(), old_device_address->format(),
old_device_address->type_id(), old_device_address->host_shape());
MS_EXCEPTION_IF_NULL(new_device_address);
new_device_address->set_original_ref_count(old_device_address->original_ref_count());
new_device_address->ResetRefCount();
return new_device_address;
}
void ClearGraphDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context, bool is_gradient_out) {
MS_EXCEPTION_IF_NULL(graph);
for (const auto &node : graph->execution_order()) {
auto output_address_num = AnfAlgo::GetOutputAddressNum(node);
// Clear old output device address of kernel
for (size_t i = 0; i < output_address_num; ++i) {
if (!AnfAlgo::OutputAddrExist(node, i, false)) {
continue;
@ -245,26 +276,40 @@ void ClearGraphDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *d
continue;
}
MS_EXCEPTION_IF_NULL(device_context);
auto new_device_address =
device_context->CreateDeviceAddress(nullptr, device_address->GetSize(), device_address->format(),
device_address->type_id(), device_address->host_shape());
MS_EXCEPTION_IF_NULL(new_device_address);
new_device_address->set_original_ref_count(device_address->original_ref_count());
new_device_address->ResetRefCount();
auto new_device_address = CloneEmptyDeviceAddress(device_address, device_context);
if (is_gradient_out) {
new_device_address->set_from_persistent_mem(true);
}
AnfAlgo::SetOutputAddr(new_device_address, i, node.get());
}
// Clear old workspace device address of kernel
auto kernel_mod = AnfAlgo::GetKernelMod(node);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
for (size_t i = 0; i < workspace_lists.size(); ++i) {
if (!AnfAlgo::WorkspaceAddrExist(node, i)) {
continue;
}
const auto &device_address = AnfAlgo::GetMutableWorkspaceAddr(node, i);
auto new_device_address = CloneEmptyDeviceAddress(device_address, device_context);
AnfAlgo::SetWorkspaceAddr(new_device_address, i, node.get());
}
}
}
void UpdateInputDeviceAddress(const KernelGraphPtr &graph) {
void ClearInputDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(device_context);
for (const auto &node : graph->input_nodes()) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<Parameter>() && (!common::AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
AnfAlgo::SetOutputAddr(nullptr, 0, node.get());
if (node->isa<Parameter>()) {
auto device_address = AnfAlgo::GetMutableOutputAddr(node, 0, false);
if (device_address == nullptr) {
continue;
}
auto new_device_address = CloneEmptyDeviceAddress(device_address, device_context);
AnfAlgo::SetOutputAddr(new_device_address, 0, node.get());
}
}
}
@ -1329,11 +1374,11 @@ void MindRTBackend::LazyExecuteTaskCallback() {
auto &op_run_task = op_run_tasks.front();
const auto &context = op_run_task->context();
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, context->is_pynative_infer());
RunSingleOpGraph(context->graph(), context->op_run_info(), context->graph_compiler_info());
auto tensor_without_value_mask = GetTensorWithoutValueMask(context->op_run_info());
runtime::RunSingleOpGraph(context->graph(), tensor_without_value_mask, context->device_context(),
context->op_run_info().is_dynamic_shape);
ClearGraphDeviceAddress(context->graph(), context->device_context(), context->op_run_info().is_gradient_out);
UpdateInputDeviceAddress(context->graph());
ClearInputDeviceAddress(context->graph(), context->device_context());
op_lazy_builder.PopOpRunTask();
}
@ -1377,6 +1422,9 @@ void MindRTBackend::RunOpInternal(bool single_op_cache_hit, GraphCompilerInfo *g
bool lazy_build_disabled = NeedDisableLazyBuild(graph_compiler_info->need_erase_,
(single_op_cache_hit && op_lazy_builder.QueueEmpty()), *op_run_info);
auto tensor_without_value_mask = GetTensorWithoutValueMask(*op_run_info);
if (lazy_build_disabled) {
if (!op_lazy_builder.QueueEmpty()) {
op_lazy_builder.ExecuteRemainingTasks();
@ -1384,10 +1432,12 @@ void MindRTBackend::RunOpInternal(bool single_op_cache_hit, GraphCompilerInfo *g
if (!single_op_cache_hit) {
CompileSingleOpGraph(graph, device_context, graph_compiler_info);
}
RunSingleOpGraph(graph, *op_run_info, graph_compiler_info);
runtime::UpdateDeviceAddress(graph, tensor_without_value_mask, device_context);
runtime::RunSingleOpGraph(graph, tensor_without_value_mask, device_context, op_run_info->is_dynamic_shape);
UpdateOutput(output_nodes, outputs);
ClearGraphDeviceAddress(graph, device_context, op_run_info->is_gradient_out);
UpdateInputDeviceAddress(graph);
ClearInputDeviceAddress(graph, device_context);
if (op_run_info->is_dynamic_shape) {
UpdateOutputAbstract(graph, op_run_info);
}
@ -1395,6 +1445,7 @@ void MindRTBackend::RunOpInternal(bool single_op_cache_hit, GraphCompilerInfo *g
EraseSingleOpCache(graph_compiler_info->name_, graph);
}
} else {
runtime::UpdateDeviceAddress(graph, tensor_without_value_mask, device_context);
UpdateOutput(output_nodes, outputs);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);

View File

@ -33,7 +33,7 @@
#include "backend/common/session/session_basic.h"
#include "runtime/hardware/device_context.h"
#include "runtime/graph_scheduler/graph_scheduler.h"
#include "runtime/op_builder/op_lazy_builder.h"
#include "runtime/pynative/op_lazy_builder.h"
namespace mindspore {
namespace compile {

View File

@ -24,6 +24,7 @@
#include "include/common/utils/anfalgo.h"
#include "backend/common/session/kernel_graph.h"
#include "runtime/device/ms_device_shape_transfer.h"
#include "runtime/pynative/op_runtime_info.h"
#include "debug/data_dump/dump_json_parser.h"
#include "frontend/operator/ops.h"
#include "ir/value.h"
@ -219,7 +220,7 @@ void KernelRuntime::RunOpMallocPre(const session::KernelGraph &graph,
auto output_num = common::AnfAlgo::GetOutputTensorNum(node);
for (size_t i = 0; i < output_num; ++i) {
MS_EXCEPTION_IF_NULL(node);
auto runtime_info = node->user_data<session::OpRuntimeInfo>();
auto runtime_info = node->user_data<runtime::OpRuntimeInfo>();
MS_EXCEPTION_IF_NULL(runtime_info);
auto const &output_format = runtime_info->output_format(i);
auto output_type = runtime_info->output_type(i);
@ -252,7 +253,7 @@ void KernelRuntime::RunOpMallocPre(const session::KernelGraph &graph,
AnfAlgo::SetOutputAddr(output_address, index, item.get());
continue;
}
auto op_runtime_info = item->user_data<session::OpRuntimeInfo>();
auto op_runtime_info = item->user_data<runtime::OpRuntimeInfo>();
MS_EXCEPTION_IF_NULL(op_runtime_info);
TypeId output_type_id = op_runtime_info->output_type(index);
auto output_tensor_size = op_runtime_info->output_tensor_size(index);

View File

@ -399,9 +399,10 @@ void KernelActor::FetchOutputDeviceTensor(OpContext<DeviceTensor> *const context
auto output_address = output_addresses[i].get();
MS_EXCEPTION_IF_NULL(output_address);
if (output_size_list[i] != output_address->GetSize()) {
// The size of output address may be changed in dynamic shape scenario.
// If the format of the DeviceAddress is different, then the size is originally different.
// Such as NCHW(1,1,1,3) and NC1HWC0(1,1,1,1,16). So we don't need to update the size.
// 1. The size of output address may be changed in dynamic shape scenario.
// 2. If the format of the DeviceAddress is different, then the size is originally different.
// Such as NCHW(1,1,1,3) and NC1HWC0(1,1,1,1,16). So we don't need to update the size.
// 3. For example, we need to call cudnnGetRNNTrainingReserveSize to get real output size in LstmGpuKernelMod!
if (AnfAlgo::GetOutputFormat(kernel_, i) == output_address->format()) {
output_address->SetSize(output_size_list[i]);
}

View File

@ -20,9 +20,10 @@
#include <utility>
#include <algorithm>
#include "runtime/graph_scheduler/graph_scheduler.h"
#include "runtime/op_builder/op_lazy_builder.h"
#include "runtime/pynative/op_lazy_builder.h"
#include "runtime/device/device_address.h"
#include "runtime/device/ms_device_shape_transfer.h"
#include "runtime/pynative/op_runtime_info.h"
#include "include/common/utils/convert_utils.h"
#include "include/common/utils/context/graph_kernel_flags.h"
#include "utils/ms_context.h"
@ -556,7 +557,6 @@ GraphId GraphCompiler::CompileGraph(const session::OpRunInfo &op_run_info, bool
UpdateRefCountForGraphOutput(outputs_with_index);
AnfAlgo::UpdateGraphValidRefPair(graph);
return graph->graph_id();
}
@ -574,6 +574,8 @@ void GraphCompiler::BuildSingleOpGraphs(const std::vector<KernelGraphPtr> &graph
for (const auto &graph : graphs) {
device_context->PreprocessBeforeRunSingleOpGraph(graph);
CreateKernelWorkspaceDeviceAddress(device_context, graph);
// Need to execute after PreprocessBeforeRunSingleOpGraph
runtime::OpRuntimeInfo::CacheGraphOpRuntimeInfo(graph);
}
}

View File

@ -1,5 +1,5 @@
file(GLOB_RECURSE BUILDER_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"op_lazy_builder.cc")
"*.cc")
set_property(SOURCE ${BUILDER_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
add_library(_mindspore_runtime_op_builder_obj OBJECT ${BUILDER_SRC_LIST})
add_library(_mindspore_runtime_pynative_obj OBJECT ${BUILDER_SRC_LIST})

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "runtime/op_builder/op_lazy_builder.h"
#include "runtime/pynative/op_lazy_builder.h"
namespace mindspore::runtime {
void OpLazyBuilder::Register(const std::function<void()> &callback) {

View File

@ -0,0 +1,147 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/pynative/op_runtime_info.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
namespace mindspore::runtime {
namespace {
void CacheForExecutionOrder(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
const auto &nodes = graph->execution_order();
for (auto const &node : nodes) {
std::vector<std::string> formats;
std::vector<TypeId> types;
std::vector<size_t> tensor_sizes;
auto output_num = common::AnfAlgo::GetOutputTensorNum(node);
for (size_t i = 0; i < output_num; ++i) {
std::string output_format = AnfAlgo::GetOutputFormat(node, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(node, i);
formats.emplace_back(output_format);
types.emplace_back(output_type);
tensor_sizes.emplace_back(tensor_size);
}
// For input
std::vector<std::pair<device::KernelInfo *, size_t>> input_kernel_infos;
auto input_size = common::AnfAlgo::GetInputTensorNum(node);
for (size_t i = 0; i < input_size; ++i) {
session::KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i, true);
MS_EXCEPTION_IF_NULL(kernel_with_index.first);
input_kernel_infos.emplace_back(dynamic_cast<device::KernelInfo *>(kernel_with_index.first->kernel_info()),
kernel_with_index.second);
}
// For workspace and output
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
node->set_user_data<runtime::OpRuntimeInfo>(
std::make_shared<runtime::OpRuntimeInfo>(formats, types, tensor_sizes, kernel_info, input_kernel_infos));
}
}
void CacheForGraphInputs(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
const auto &inputs = graph->inputs();
for (const auto &input : inputs) {
MS_EXCEPTION_IF_NULL(input);
if (!input->isa<Parameter>()) {
continue;
}
std::vector<std::string> formats;
std::vector<TypeId> types;
std::vector<size_t> tensor_sizes;
auto output_size = common::AnfAlgo::GetOutputTensorNum(input);
for (size_t index = 0; index < output_size; index++) {
auto format = AnfAlgo::GetOutputFormat(input, index);
auto type_id = AnfAlgo::GetOutputDeviceDataType(input, index);
if (type_id == kTypeUnknown) {
type_id = common::AnfAlgo::GetOutputInferDataType(input, index);
}
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(input, index);
formats.emplace_back(format);
types.emplace_back(type_id);
tensor_sizes.emplace_back(tensor_size);
}
input->set_user_data<runtime::OpRuntimeInfo>(std::make_shared<runtime::OpRuntimeInfo>(
formats, types, tensor_sizes, nullptr, std::vector<std::pair<device::KernelInfo *, size_t>>()));
}
}
} // namespace
std::string OpRuntimeInfo::output_format(size_t index) const {
if (index >= output_format_.size()) {
MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_format:" << output_format_.size();
}
return output_format_[index];
}
TypeId OpRuntimeInfo::output_type(size_t index) const {
if (index >= output_type_.size()) {
MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_type:" << output_type_.size();
}
return output_type_[index];
}
size_t OpRuntimeInfo::output_tensor_size(size_t index) const {
if (index >= output_tensor_size_.size()) {
MS_LOG(EXCEPTION) << "Invalid index::" << index << " total output_tensor_size:" << output_tensor_size_.size();
}
return output_tensor_size_[index];
}
device::DeviceAddressPtr OpRuntimeInfo::GetOutputDeviceAddress(size_t index) const {
MS_EXCEPTION_IF_NULL(kernel_info_);
return kernel_info_->GetMutableOutputAddr(index);
}
device::DeviceAddressPtr OpRuntimeInfo::GetWorkspaceDeviceAddress(size_t index) const {
MS_EXCEPTION_IF_NULL(kernel_info_);
return kernel_info_->GetMutableWorkspaceAddr(index);
}
device::DeviceAddressPtr OpRuntimeInfo::GetInputDeviceAddress(size_t index) const {
if (index >= input_kernel_infos_.size()) {
MS_LOG(ERROR) << "Output range! index:" << index << " input size:" << input_kernel_infos_.size();
return nullptr;
}
auto kernel_info_pair = input_kernel_infos_[index];
MS_EXCEPTION_IF_NULL(kernel_info_pair.first);
return kernel_info_pair.first->GetMutableOutputAddr(kernel_info_pair.second);
}
size_t OpRuntimeInfo::GetInputSize() const { return input_kernel_infos_.size(); }
size_t OpRuntimeInfo::GetOutputSize() const {
MS_EXCEPTION_IF_NULL(kernel_info_);
return kernel_info_->output_address_list().size();
}
size_t OpRuntimeInfo::GetWorkspaceSize() const {
MS_EXCEPTION_IF_NULL(kernel_info_);
return kernel_info_->workspace_address_list().size();
}
void OpRuntimeInfo::CacheGraphOpRuntimeInfo(const KernelGraphPtr &graph) {
CacheForExecutionOrder(graph);
CacheForGraphInputs(graph);
}
} // namespace mindspore::runtime

View File

@ -0,0 +1,63 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_RUN_OP_OP_RUNTIME_INFO_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_RUN_OP_OP_RUNTIME_INFO_H_
#include <memory>
#include <utility>
#include <vector>
#include <string>
#include "runtime/device/device_address.h"
#include "runtime/device/kernel_info.h"
#include "backend/common/session/kernel_graph.h"
namespace mindspore::runtime {
class OpRuntimeInfo {
public:
OpRuntimeInfo(std::vector<std::string> output_format, std::vector<TypeId> output_type,
std::vector<size_t> output_tensor_size, device::KernelInfo *kernel_info,
std::vector<std::pair<device::KernelInfo *, size_t>> input_kernel_infos)
: output_format_(std::move(output_format)),
output_type_(std::move(output_type)),
output_tensor_size_(std::move(output_tensor_size)),
kernel_info_(kernel_info),
input_kernel_infos_(std::move(input_kernel_infos)) {}
~OpRuntimeInfo() = default;
// Key for user data.
constexpr static char key[] = "OpRuntimeInfo";
std::string output_format(size_t index) const;
TypeId output_type(size_t index) const;
size_t output_tensor_size(size_t index) const;
device::DeviceAddressPtr GetOutputDeviceAddress(size_t index) const;
device::DeviceAddressPtr GetWorkspaceDeviceAddress(size_t index) const;
device::DeviceAddressPtr GetInputDeviceAddress(size_t index) const;
size_t GetInputSize() const;
size_t GetOutputSize() const;
size_t GetWorkspaceSize() const;
static void CacheGraphOpRuntimeInfo(const KernelGraphPtr &graph);
private:
std::vector<std::string> output_format_;
std::vector<TypeId> output_type_;
std::vector<size_t> output_tensor_size_;
device::KernelInfo *kernel_info_;
std::vector<std::pair<device::KernelInfo *, size_t>> input_kernel_infos_;
};
} // namespace mindspore::runtime
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_RUN_OP_OP_RUNTIME_INFO_H_

View File

@ -0,0 +1,415 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/pynative/run_op_helper.h"
#include <string>
#include <vector>
#include <memory>
#include "utils/log_adapter.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/convert_utils.h"
#include "runtime/device/ms_device_shape_transfer.h"
#include "runtime/pynative/op_runtime_info.h"
namespace mindspore::runtime {
namespace {
// 1. Device type is different in heterogeneous scenes.
// 2. The device address format is different.
void UpdateInputTensorFromDevice(const std::vector<AnfNodePtr> &input_nodes,
const std::vector<tensor::TensorPtr> &input_tensors,
const device::DeviceContext *device_context) {
MS_LOG(DEBUG) << "Start";
auto input_size = input_nodes.size();
for (size_t i = 0; i < input_size; ++i) {
auto &tensor = input_tensors[i];
auto &input_node = input_nodes[i];
MS_EXCEPTION_IF_NULL(tensor);
auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
auto node_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
// node_address can't be null
MS_EXCEPTION_IF_NULL(node_address);
if (tensor_address != nullptr) {
if (tensor_address->DeviceType() != device_context->GetDeviceAddressType() ||
tensor_address->format() != node_address->format()) {
// Need wait for OpExecutor task finish
tensor->data_sync();
// If tensor address is null, we will set Parameter address to the Tensor.
tensor->set_device_address(nullptr);
}
}
}
MS_LOG(DEBUG) << "End";
}
void UpdateInputNodeDeviceAddress(const std::vector<AnfNodePtr> &input_nodes,
const std::vector<tensor::TensorPtr> &input_tensors,
const device::DeviceContext *device_context) {
MS_LOG(DEBUG) << "Start";
auto input_size = input_nodes.size();
auto tensor_size = input_tensors.size();
if (input_size != tensor_size) {
MS_LOG(EXCEPTION) << "input node size:" << input_size << " not equal to tensors size:" << tensor_size;
}
for (size_t i = 0; i < input_size; ++i) {
auto &input_node = input_nodes[i];
auto &input_tensor = input_tensors[i];
MS_EXCEPTION_IF_NULL(input_tensor);
auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
auto node_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
MS_EXCEPTION_IF_NULL(node_address);
if (tensor_address == nullptr) {
input_tensor->set_device_address(node_address);
input_tensor->set_sync_status(kNeedSyncHostToDeviceImmediately);
}
// The DeviceType and format of DeviceAddress is always the same after UpdateInputTensor
if (tensor_address != nullptr && tensor_address != node_address) {
AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get());
}
}
MS_LOG(DEBUG) << "End";
}
void UpdateRefNodeOutputDeviceAddress(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto ref_node_map = graph->GetRefMap();
for (const auto &iter : ref_node_map) {
auto &output_pair = iter.first;
auto &input_pair = iter.second;
auto &ref_node = output_pair.first;
auto output_index = output_pair.second;
auto &input_node = input_pair.first;
auto input_node_output_index = input_pair.second;
auto input_addr = AnfAlgo::GetMutableOutputAddr(input_node, input_node_output_index, false);
auto ref_node_output_addr = AnfAlgo::GetMutableOutputAddr(ref_node, output_index, false);
if (input_addr != ref_node_output_addr) {
AnfAlgo::SetOutputAddr(input_addr, output_index, ref_node.get());
}
}
}
void CopyTensorDataToDevice(const tensor::TensorPtr &tensor, const AnfNodePtr &node,
const device::DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(tensor);
MS_EXCEPTION_IF_NULL(device_context);
auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
MS_EXCEPTION_IF_NULL(device_address);
if ((device_address->GetPtr() == nullptr) &&
(!device_context->AllocateMemory(device_address.get(), device_address->GetSize()))) {
MS_LOG(EXCEPTION) << "Allocate memory failed";
}
// Copy data from host tensor to device.
auto tensor_size = LongToSize(tensor->data().nbytes());
auto tensor_type = tensor->data_type();
MS_LOG(DEBUG) << "Copy to device, node:" << node->DebugString();
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0), tensor_size, tensor_type,
tensor->data_c(), tensor->device_info().host_format_)) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed";
}
}
void CopyValueNodeTensorToDevice(const ValueNodePtr &node, const device::DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(device_context);
auto &node_value = node->value();
MS_EXCEPTION_IF_NULL(node_value);
std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(node_value, &tensors);
for (size_t i = 0; i < tensors.size(); i++) {
const auto &tensor = tensors[i];
MS_EXCEPTION_IF_NULL(tensor);
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, i, false);
MS_EXCEPTION_IF_NULL(device_tensor);
if (device_tensor->GetPtr() != nullptr) {
return;
}
tensor->set_device_address(device_tensor);
CopyTensorDataToDevice(tensor, node, device_context);
}
}
void CopyValueNodeStringToDevice(const ValueNodePtr &node, const device::DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(device_context);
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, 0, false);
MS_EXCEPTION_IF_NULL(device_tensor);
// If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
if (device_tensor->GetPtr() != nullptr) {
return;
}
if (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
MS_LOG(EXCEPTION) << "Allocate memory failed";
}
auto &node_value = node->value();
MS_EXCEPTION_IF_NULL(node_value);
// Copy data to device.
auto value = GetValue<std::string>(node_value);
size_t tensor_size = value.size();
ShapeVector shape = {1, SizeToLong(tensor_size)};
if (!device_tensor->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed";
}
}
void CopyValueNodeDataToDevice(const KernelGraphPtr &graph, const device::DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(DEBUG) << "Start";
const auto &value_nodes = graph->graph_value_nodes();
for (const auto &value_node : value_nodes) {
MS_EXCEPTION_IF_NULL(value_node);
auto &node_value = value_node->value();
MS_EXCEPTION_IF_NULL(node_value);
if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
CopyValueNodeTensorToDevice(value_node, device_context);
} else if (node_value->isa<StringImm>()) {
CopyValueNodeStringToDevice(value_node, device_context);
} else {
MS_LOG(WARNING) << "Unknown value node type:" << value_node->DebugString();
}
}
MS_LOG(DEBUG) << "End";
}
void CopyParameterDataToDevice(const std::vector<AnfNodePtr> &input_nodes,
const std::vector<tensor::TensorPtr> &input_tensors,
const device::DeviceContext *device_context) {
MS_LOG(DEBUG) << "Start";
auto input_size = input_nodes.size();
for (size_t i = 0; i < input_size; ++i) {
MS_EXCEPTION_IF_NULL(input_tensors[i]);
if (input_tensors[i]->NeedSyncHostToDeviceImmediately()) {
CopyTensorDataToDevice(input_tensors[i], input_nodes[i], device_context);
input_tensors[i]->set_sync_status(kNoNeedSync);
}
}
MS_LOG(DEBUG) << "End";
}
void UpdateOutputAddrSize(const AnfNodePtr &node, const std::shared_ptr<OpRuntimeInfo> &runtime_info) {
MS_EXCEPTION_IF_NULL(runtime_info);
auto output_size = runtime_info->GetOutputSize();
for (size_t i = 0; i < output_size; ++i) {
auto output_address = runtime_info->GetOutputDeviceAddress(i);
MS_EXCEPTION_IF_NULL(output_address);
auto output_addr_size = AnfAlgo::GetOutputTensorMemSize(node, i);
if (output_addr_size != output_address->GetSize()) {
output_address->SetSize(output_addr_size);
}
}
}
bool MallocForKernelInput(const std::shared_ptr<OpRuntimeInfo> &runtime_info,
const device::DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(runtime_info);
MS_EXCEPTION_IF_NULL(device_context);
auto input_size = runtime_info->GetInputSize();
for (size_t i = 0; i < input_size; ++i) {
auto input_address = runtime_info->GetInputDeviceAddress(i);
MS_EXCEPTION_IF_NULL(input_address);
if (input_address->GetPtr() == nullptr &&
!device_context->AllocateMemory(input_address.get(), input_address->GetSize())) {
return false;
}
}
return true;
}
bool MallocForKernelOutput(const std::shared_ptr<OpRuntimeInfo> &runtime_info, const AnfNodePtr &node,
const device::DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(runtime_info);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(device_context);
auto kernel_mod = AnfAlgo::GetKernelMod(node);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto output_size = runtime_info->GetOutputSize();
auto kernel_out_size_list = kernel_mod->GetOutputSizeList();
if (kernel_out_size_list.size() != output_size) {
MS_LOG(ERROR) << "Node " << node->fullname_with_scope() << " output num is:" << output_size
<< " but kernel_mod output num:" << kernel_out_size_list.size();
return false;
}
for (size_t i = 0; i < output_size; ++i) {
auto device_address = runtime_info->GetOutputDeviceAddress(i);
MS_EXCEPTION_IF_NULL(device_address);
// For example, we need to call cudnnGetRNNTrainingReserveSize to get real output size in LstmGpuKernelMod!
if (kernel_out_size_list[i] != device_address->GetSize()) {
// If the format of the DeviceAddress is different, then the size is originally different.
// Such as NCHW(1,1,1,3) and NC1HWC0(1,1,1,1,16). So we don't need to update the size.
if (AnfAlgo::GetOutputFormat(node, i) == device_address->format()) {
if (device_address->GetPtr() != nullptr) {
MS_LOG(ERROR) << "kernel mod output " << i << " size:" << kernel_out_size_list[i]
<< " not equal to device_address size:" << device_address->GetSize()
<< ", but the device address is already have ptr";
return false;
}
device_address->SetSize(kernel_out_size_list[i]);
}
}
if (device_address->GetPtr() == nullptr &&
!device_context->AllocateMemory(device_address.get(), device_address->GetSize())) {
MS_LOG(ERROR) << "Allocate output memory failed, node:" << node->fullname_with_scope();
return false;
}
}
return true;
}
bool MallocForKernelWorkspace(const std::shared_ptr<OpRuntimeInfo> &runtime_info,
const device::DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(runtime_info);
MS_EXCEPTION_IF_NULL(device_context);
auto workspace_size = runtime_info->GetWorkspaceSize();
for (size_t i = 0; i < workspace_size; ++i) {
auto device_address = runtime_info->GetWorkspaceDeviceAddress(i);
MS_EXCEPTION_IF_NULL(device_address);
if (device_address->GetPtr() == nullptr &&
!device_context->AllocateMemory(device_address.get(), device_address->GetSize())) {
MS_LOG(ERROR) << "Allocate workspace memory failed";
return false;
}
}
return true;
}
kernel::AddressPtrList CreateKernelInputAddress(const std::shared_ptr<OpRuntimeInfo> &runtime_info) {
MS_EXCEPTION_IF_NULL(runtime_info);
auto input_size = runtime_info->GetInputSize();
kernel::AddressPtrList inputs;
for (size_t i = 0; i < input_size; ++i) {
auto device_address = runtime_info->GetInputDeviceAddress(i);
MS_EXCEPTION_IF_NULL(device_address);
inputs.emplace_back(std::make_shared<kernel::Address>(device_address->GetMutablePtr(), device_address->GetSize()));
MS_LOG(DEBUG) << "input[" << i << "]:" << inputs.back()->addr << " size:" << inputs.back()->size;
}
return inputs;
}
kernel::AddressPtrList CreateKernelWorkspaceAddress(const std::shared_ptr<OpRuntimeInfo> &runtime_info) {
MS_EXCEPTION_IF_NULL(runtime_info);
auto workspace_size = runtime_info->GetWorkspaceSize();
kernel::AddressPtrList workspaces;
for (size_t i = 0; i < workspace_size; ++i) {
auto device_address = runtime_info->GetWorkspaceDeviceAddress(i);
MS_EXCEPTION_IF_NULL(device_address);
workspaces.emplace_back(
std::make_shared<kernel::Address>(device_address->GetMutablePtr(), device_address->GetSize()));
MS_LOG(DEBUG) << "workspace[" << i << "]:" << workspaces.back()->addr << " size:" << workspaces.back()->size;
}
return workspaces;
}
kernel::AddressPtrList CreateKernelOutputAddress(const std::shared_ptr<OpRuntimeInfo> &runtime_info) {
auto output_size = runtime_info->GetOutputSize();
kernel::AddressPtrList outputs;
for (size_t i = 0; i < output_size; ++i) {
auto device_address = runtime_info->GetOutputDeviceAddress(i);
outputs.emplace_back(std::make_shared<kernel::Address>(device_address->GetMutablePtr(), device_address->GetSize()));
MS_LOG(DEBUG) << "output[" << i << "]:" << outputs.back()->addr << " size:" << outputs.back()->size;
}
return outputs;
}
void MallocForKernel(const KernelGraphPtr &graph, const device::DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(device_context);
CopyValueNodeDataToDevice(graph, device_context);
const auto &execution_order = graph->execution_order();
for (const auto &node : execution_order) {
MS_EXCEPTION_IF_NULL(node);
const auto &runtime_info = node->user_data<runtime::OpRuntimeInfo>();
MS_EXCEPTION_IF_NULL(runtime_info);
if (!MallocForKernelInput(runtime_info, device_context)) {
MS_LOG(EXCEPTION) << "Malloc for kernel input failed, node:" << node->fullname_with_scope();
}
if (!MallocForKernelWorkspace(runtime_info, device_context)) {
MS_LOG(EXCEPTION) << "Malloc for kernel workspace failed, node:" << node->fullname_with_scope();
}
if (!MallocForKernelOutput(runtime_info, node, device_context)) {
MS_LOG(EXCEPTION) << "Malloc for kernel output failed, node:" << node->fullname_with_scope();
}
}
}
// Host to Device or Device to Host
void CopyDataToDevice(const KernelGraphPtr &graph, const std::vector<tensor::TensorPtr> &input_tensors,
const device::DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(graph);
CopyParameterDataToDevice(graph->input_nodes(), input_tensors, device_context);
}
// kernel_mode launch
void LaunchKernels(const KernelGraphPtr &graph, const device::DeviceContext *device_context, bool is_dynamic_shape) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(device_context);
MS_LOG(DEBUG) << "Start";
// Get device address from OpRuntimeInfo
const auto &execution_order = graph->execution_order();
for (auto const &node : execution_order) {
MS_EXCEPTION_IF_NULL(node);
auto runtime_info = node->user_data<runtime::OpRuntimeInfo>();
MS_EXCEPTION_IF_NULL(runtime_info);
auto inputs = CreateKernelInputAddress(runtime_info);
if (is_dynamic_shape) {
device_context->UpdateDynamicShape(node);
}
auto workspaces = CreateKernelWorkspaceAddress(runtime_info);
auto outputs = CreateKernelOutputAddress(runtime_info);
device_context->LaunchKernel(node, inputs, workspaces, outputs, is_dynamic_shape);
if (is_dynamic_shape) {
UpdateOutputAddrSize(node, runtime_info);
}
}
MS_LOG(DEBUG) << "End";
}
} // namespace
// Determine the address of the graph and do not change the address in subsequent executions
void UpdateDeviceAddress(const KernelGraphPtr &graph, const std::vector<tensor::TensorPtr> &tensors_without_value_mask,
const device::DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(DEBUG) << "Start";
const auto &input_nodes = graph->input_nodes();
UpdateInputTensorFromDevice(input_nodes, tensors_without_value_mask, device_context);
UpdateInputNodeDeviceAddress(input_nodes, tensors_without_value_mask, device_context);
UpdateRefNodeOutputDeviceAddress(graph);
MS_LOG(DEBUG) << "End";
}
void RunSingleOpGraph(const KernelGraphPtr &graph, const std::vector<tensor::TensorPtr> &input_tensors,
const device::DeviceContext *device_context, bool is_dynamic_shape) {
MallocForKernel(graph, device_context);
CopyDataToDevice(graph, input_tensors, device_context);
LaunchKernels(graph, device_context, is_dynamic_shape);
}
} // namespace mindspore::runtime

View File

@ -0,0 +1,32 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_RUN_OP_RUN_OP_HELPER_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_RUN_OP_RUN_OP_HELPER_H_
#include <vector>
#include "backend/common/session/kernel_graph.h"
#include "runtime/hardware/device_context.h"
namespace mindspore::runtime {
// Update Tensor or input node DeviceAddress before PyNative async running.
void UpdateDeviceAddress(const KernelGraphPtr &graph, const std::vector<tensor::TensorPtr> &tensors_without_value_mask,
const device::DeviceContext *device_context);
void RunSingleOpGraph(const KernelGraphPtr &graph, const std::vector<tensor::TensorPtr> &input_tensors,
const device::DeviceContext *device_context, bool is_dynamic_shape);
} // namespace mindspore::runtime
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_RUN_OP_RUN_OP_HELPER_H_

View File

@ -38,7 +38,13 @@
// Other namespace should be a sub namespace of mindspore namespace in the ME project.
namespace mindspore {
// brief mindspore::tensor namespace
enum TensorSyncStatus { kNoNeedSync, kNeedSyncHostToDevice, kNeedSyncDeviceToHost, kNeedSyncDeviceToHostImmediately };
enum TensorSyncStatus {
kNoNeedSync,
kNeedSyncHostToDevice,
kNeedSyncHostToDeviceImmediately,
kNeedSyncDeviceToHost,
kNeedSyncDeviceToHostImmediately
};
// A sub namespace in ME to support tensor related definition.
namespace tensor {
// Tensor data interface.
@ -512,6 +518,11 @@ class MS_CORE_API Tensor final : public MetaTensor {
/// \return Ture if sync_status_ is kNeedSyncHostToDevice.
bool NeedSyncHostToDevice() const { return sync_status_ == kNeedSyncHostToDevice; }
/// \brief Check the value of sync_status_.
///
/// \return Ture if sync_status_ is kNeedSyncHostToDeviceImmediately.
bool NeedSyncHostToDeviceImmediately() const { return sync_status_ == kNeedSyncHostToDeviceImmediately; }
/// \brief Check if this Tensor is the output of graph.
///
/// \return Whether this Tensor is the output of graph