format some executor func name

This commit is contained in:
kswang 2021-03-30 21:51:36 +08:00
parent eec1b4441d
commit 54ef8520ab
9 changed files with 120 additions and 175 deletions

View File

@ -1517,11 +1517,6 @@ void AscendSession::UpdateRefOutputMap(NotNull<KernelGraphPtr> graph,
}
}
GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const vector<tensor::TensorPtr> &inputs) {
RunInfer(func_graph, inputs);
return CompileGraphImpl(func_graph);
}
void AscendSession::SyncStream() {
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);

View File

@ -49,7 +49,6 @@ class AscendSession : public SessionBasic {
void UnifyMindIR(const KernelGraphPtr &graph) override;
GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override;
GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) override;
bool IsSupportSummary() override;
void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
void BuildGraphImpl(GraphId) override;

View File

@ -32,7 +32,7 @@ namespace {
void UpdateOutputTensors(const VectorRef *outputs,
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
MS_EXCEPTION_IF_NULL(outputs);
for (auto item : *outputs) {
for (auto &item : *outputs) {
if (utils::isa<VectorRefPtr>(item)) {
auto vector_ref = utils::cast<VectorRef>(item);
UpdateOutputTensors(&vector_ref, tensor_to_node);
@ -45,7 +45,6 @@ void UpdateOutputTensors(const VectorRef *outputs,
auto &output_index = iter->second.second;
auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
tensor->set_device_address(address);
if (AnfAlgo::IsDynamicShape(node)) {
auto updated_shape = AnfAlgo::GetOutputInferShape(node, output_index);
ShapeVector int_shape;
@ -62,12 +61,12 @@ void UpdateOutputTensors(const VectorRef *outputs,
}
}
void NotifyOutputTensors(const VectorRef *outputs) {
void SetOutputTensorsWaitStatus(const VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
for (auto item : *outputs) {
for (auto &item : *outputs) {
if (utils::isa<VectorRefPtr>(item)) {
auto vector_ref = utils::cast<VectorRef>(item);
NotifyOutputTensors(&vector_ref);
SetOutputTensorsWaitStatus(&vector_ref);
} else if (utils::isa<tensor::TensorPtr>(item)) {
auto tensor = utils::cast<tensor::TensorPtr>(item);
MS_EXCEPTION_IF_NULL(tensor);
@ -78,7 +77,7 @@ void NotifyOutputTensors(const VectorRef *outputs) {
bool TensorInVector(const VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
for (auto item : *outputs) {
for (auto &item : *outputs) {
if (utils::isa<VectorRefPtr>(item)) {
auto vector_ref = utils::cast<VectorRef>(item);
if (TensorInVector(&vector_ref)) {
@ -90,6 +89,50 @@ bool TensorInVector(const VectorRef *outputs) {
}
return false;
}
bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
MS_EXCEPTION_IF_NULL(task);
for (auto &input : task->input_need_wait_tensors_) {
MS_EXCEPTION_IF_NULL(input);
if (input->NeedWait()) {
return false;
}
}
auto session = task->session_;
MS_EXCEPTION_IF_NULL(session);
auto graph = session->GetGraph(task->graph_id_);
if (graph != nullptr) {
return graph->IsPreGraphFinished();
}
return true;
}
void WaitLockedInputs(const SessionPtr &session, const std::shared_ptr<RunGraphTask> &task) {
bool need_lock = false;
for (auto &tensor : task->input_tensors_) {
if (tensor->NeedWait()) {
if (tensor->IsGraphOutput()) {
task->input_need_wait_tensors_.emplace_back(tensor);
} else {
need_lock = true;
}
}
}
if (need_lock) {
mindspore::ScopedLongRunning long_running;
for (auto &tensor : task->input_tensors_) {
if (tensor->NeedWait() && !tensor->IsGraphOutput()) {
MsException::Instance().CheckException();
tensor->Wait();
}
}
MsException::Instance().CheckException();
}
// need lock input parameters for optimizer
for (auto &tensor : task->input_need_lock_tensors_) {
tensor->SetNeedWait(true);
}
}
} // namespace
void CompileNodesTask::Run() {
@ -129,7 +172,7 @@ void RunGraphTask::Run() {
for (auto &tensor : input_need_lock_tensors_) {
tensor->SetNeedWait(false);
}
NotifyOutputTensors(&outputs_);
SetOutputTensorsWaitStatus(&outputs_);
ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished);
}
@ -198,7 +241,7 @@ void Executor::WorkerLoop() {
}
}
std::vector<std::shared_ptr<RunGraphTask>> Executor::GetNewReadyTasks() {
std::vector<std::shared_ptr<RunGraphTask>> Executor::GetReadyTasksFromPendingList() {
std::vector<std::shared_ptr<RunGraphTask>> new_ready_tasks;
std::lock_guard<std::mutex> lock(pending_task_mutex_);
for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) {
@ -249,7 +292,7 @@ void Executor::OnException() {
}
void Executor::OnRunGraphFinished() {
auto new_ready_tasks = GetNewReadyTasks();
auto new_ready_tasks = GetReadyTasksFromPendingList();
std::lock_guard<std::mutex> lock(task_mutex_);
for (auto &task : new_ready_tasks) {
ready_tasks_.push(task);
@ -260,23 +303,6 @@ void Executor::OnRunGraphFinished() {
reenter_cond_var_.notify_all();
}
bool Executor::IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
MS_EXCEPTION_IF_NULL(task);
for (auto &input : task->input_need_wait_tensors_) {
MS_EXCEPTION_IF_NULL(input);
if (input->NeedWait()) {
return false;
}
}
auto session = task->session_;
MS_EXCEPTION_IF_NULL(session);
auto graph = session->GetGraph(task->graph_id_);
if (graph != nullptr) {
return graph->IsPreGraphFinished();
}
return true;
}
void Executor::ClearDoneTasks() {
std::lock_guard<std::mutex> lock(done_task_mutex_);
done_tasks_.clear();
@ -341,33 +367,6 @@ void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id,
RunTask(task, true, true);
}
void Executor::WaitLockedInputs(const SessionPtr &session, const std::shared_ptr<RunGraphTask> &task) {
bool need_lock = false;
for (auto &tensor : task->input_tensors_) {
if (tensor->NeedWait()) {
if (tensor->IsGraphOutput()) {
task->input_need_wait_tensors_.emplace_back(tensor);
} else {
need_lock = true;
}
}
}
if (need_lock) {
mindspore::ScopedLongRunning long_running;
for (auto &tensor : task->input_tensors_) {
if (tensor->NeedWait() && !tensor->IsGraphOutput()) {
MsException::Instance().CheckException();
tensor->Wait();
}
}
MsException::Instance().CheckException();
}
// need lock input parameters for optimizer
for (auto &tensor : task->input_need_lock_tensors_) {
tensor->SetNeedWait(true);
}
}
void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(session);

View File

@ -172,9 +172,7 @@ class Executor {
private:
void RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run = false);
std::vector<std::shared_ptr<RunGraphTask>> GetNewReadyTasks();
bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task);
void WaitLockedInputs(const SessionPtr &session, const std::shared_ptr<RunGraphTask> &task);
std::vector<std::shared_ptr<RunGraphTask>> GetReadyTasksFromPendingList();
void OnWorkerExit();
void OnClear();
void OnRunGraphFinished();

View File

@ -118,68 +118,16 @@ ParamInfoPtr GetParamDefaultValue(const AnfNodePtr &node) {
return parameter->param_info();
}
tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_output_pair,
const KernelGraphPtr &graph) {
auto &node = node_output_pair.first;
auto &output_index = node_output_pair.second;
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
if (type_id == kTypeUnknown) {
type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
}
tensor::TensorPtr tensor = nullptr;
std::vector<int64_t> temp_shape;
if (graph->IsUniqueTargetInternalOutput(node, output_index)) {
temp_shape.emplace_back(1);
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
tensor->set_sync_status(kNoNeedSync);
tensor->SetNeedWait(true);
tensor->SetIsGraphOutput();
return tensor;
}
tensor = graph->GetInternalOutputTensor(node, output_index);
if (tensor == nullptr) {
auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
bool is_internal_output = graph->IsInternalOutput(node, output_index);
if (is_internal_output) {
graph->AddInternalOutputTensor(node, output_index, tensor);
}
}
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
// if in pynative mode,data only copied to host when user want to print data
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
} else {
tensor->set_sync_status(kNeedSyncDeviceToHost);
}
tensor->SetNeedWait(true);
tensor->SetIsGraphOutput();
return tensor;
}
static bool IsPynativeMode() {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
return ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
}
BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
BaseRef GetNodeOutputTensorFromInputs(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
auto &node = node_output_pair.first;
auto &output_index = node_output_pair.second;
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(tensor_to_node);
MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << node_output_pair.second << "]";
if (HasAbstractMonad(node)) {
return std::make_shared<tensor::Tensor>(int64_t(0), kBool);
}
@ -189,7 +137,8 @@ BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair,
MS_EXCEPTION_IF_NULL(value_node);
return value_node->value();
}
bool output_addr_exist = AnfAlgo::OutputAddrExist(node, output_index);
MS_EXCEPTION_IF_NULL(graph);
bool output_addr_exist = AnfAlgo::OutputAddrExist(node, node_output_pair.second);
if (!output_addr_exist || (CheckIfNeedCreateOutputTensor(node) && !IsPynativeMode())) {
if (node->isa<Parameter>()) {
for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) {
@ -205,7 +154,56 @@ BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair,
}
}
}
auto tensor = CreateCNodeOutputTensor(node_output_pair, graph);
return nullptr;
}
BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
auto &node = node_output_pair.first;
auto &output_index = node_output_pair.second;
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
auto tensor_from_input = GetNodeOutputTensorFromInputs(node_output_pair, graph, input_tensors);
if (tensor_from_input != nullptr) {
return tensor_from_input;
}
TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
if (type_id == kTypeUnknown) {
type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
}
tensor::TensorPtr tensor = nullptr;
std::vector<int64_t> temp_shape;
if (graph->IsUniqueTargetInternalOutput(node, output_index)) {
temp_shape.emplace_back(1);
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
tensor->set_sync_status(kNoNeedSync);
} else {
tensor = graph->GetInternalOutputTensor(node, output_index);
if (tensor == nullptr) {
auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
bool is_internal_output = graph->IsInternalOutput(node, output_index);
if (is_internal_output) {
graph->AddInternalOutputTensor(node, output_index, tensor);
}
}
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
// if in pynative mode,data only copied to host when user want to print data
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
} else {
tensor->set_sync_status(kNeedSyncDeviceToHost);
}
}
tensor->SetNeedWait(true);
tensor->SetIsGraphOutput();
(*tensor_to_node)[tensor] = node_output_pair;
return tensor;
}
@ -1778,43 +1776,6 @@ void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
summary_callback_ = callback;
}
void SessionBasic::RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) {
auto node_list = TopoSort(func_graph->get_return());
size_t tensor_index = 0;
for (const auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
AbstractBasePtrList input_abstracts;
size_t input_num = AnfAlgo::GetInputTensorNum(node);
for (size_t index = 0; index < input_num; ++index) {
auto input_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), index);
MS_EXCEPTION_IF_NULL(input_node);
auto abstract = input_node->abstract();
MS_EXCEPTION_IF_NULL(abstract);
input_abstracts.emplace_back(abstract);
}
auto prim = AnfAlgo::GetCNodePrimitive(node);
if (prim->isa<ops::PrimitiveC>()) {
auto prim_c = prim->cast<std::shared_ptr<ops::PrimitiveC>>();
MS_EXCEPTION_IF_NULL(prim_c);
auto abstract = prim_c->Infer(input_abstracts);
node->set_abstract(abstract);
}
} else if (node->isa<Parameter>()) {
if (tensor_index > inputs.size()) {
MS_EXCEPTION(IndexError) << "Index " << tensor_index << "is out of " << inputs.size() << "tensor's size";
}
node->set_abstract(inputs[tensor_index++]->ToAbstract());
} else {
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
value_node->set_abstract(value->ToAbstract());
}
}
}
void SessionBasic::SetSummaryNodes(KernelGraph *graph) {
MS_LOG(DEBUG) << "Update summary Start";
MS_EXCEPTION_IF_NULL(graph);

View File

@ -166,9 +166,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
virtual void UnifyMindIR(const KernelGraphPtr &graph) = 0;
virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0;
virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; }
virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) {
MS_EXCEPTION(NotExistsError) << "Call an empty function";
}
virtual void BuildGraphImpl(GraphId) {}
virtual void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) = 0;
@ -182,8 +179,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
virtual void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs,
const std::map<KernelWithIndex, size_t> &cnode_refcount) {}
void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs);
virtual void SetSummaryNodes(KernelGraph *graph);
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,

View File

@ -131,7 +131,7 @@ void AscendKernelRuntime::SetContext() {
}
}
void AscendKernelRuntime::InnerSetContext() {
void AscendKernelRuntime::SetCurrentContext() {
if (rt_context_ == nullptr) {
return;
}
@ -142,7 +142,7 @@ void AscendKernelRuntime::InnerSetContext() {
}
void AscendKernelRuntime::ClearGraphModelMap() {
InnerSetContext();
SetCurrentContext();
for (auto &iter : graph_data_dumper_) {
MS_LOG(INFO) << "[DataDump] Unload data dumper:" << iter.first;
auto &data_dumper = iter.second;
@ -168,7 +168,7 @@ void AscendKernelRuntime::ClearGraphModelMap() {
void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &,
const std::unordered_set<ValueNodePtr> &,
const std::vector<CNodePtr> &) {
InnerSetContext();
SetCurrentContext();
MS_LOG(DEBUG) << "Clear graph:" << graph_id << " data dumper";
if (auto dumper_iter = graph_data_dumper_.find(graph_id); dumper_iter != graph_data_dumper_.end()) {
MS_LOG(DEBUG) << "Unload dump info " << graph_id;
@ -247,7 +247,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
if (!initialized_) {
return;
}
InnerSetContext();
SetCurrentContext();
ReportProfilingData();
// release ge runtime
ClearGraphModelMap();
@ -284,7 +284,7 @@ void AscendKernelRuntime::PreInit() {
bool AscendKernelRuntime::Init() {
if (initialized_) {
InnerSetContext();
SetCurrentContext();
return true;
}
OpTilingCalculater::GetInstance().Init();
@ -437,7 +437,7 @@ bool AscendKernelRuntime::GenDynamicKernel(const session::KernelGraph *graph) {
bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
InnerSetContext();
SetCurrentContext();
if (graph->is_dynamic_shape()) {
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE && (ConfigManager::GetInstance().iter_num() > 1)) {
MS_LOG(EXCEPTION) << "Dynamic shape is not supported with sink mode.";
@ -498,7 +498,7 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
InnerSetContext();
SetCurrentContext();
if (graph->is_dynamic_shape()) {
MS_LOG(INFO) << "Dynamic Shape Graph Skip Load Task Step";
return true;
@ -716,7 +716,7 @@ bool AscendKernelRuntime::RunDynamicKernelAsync(const session::KernelGraph *grap
bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
current_graph_ = graph;
InnerSetContext();
SetCurrentContext();
MS_EXCEPTION_IF_NULL(graph);
if (graph->is_dynamic_shape()) {
MS_LOG(INFO) << "Dynamic Shape Graph Run Task Async";
@ -761,7 +761,7 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
}
bool AscendKernelRuntime::SyncStream() {
InnerSetContext();
SetCurrentContext();
if (stream_ == nullptr) {
MS_LOG(ERROR) << "SyncStream failed. stream_ is nullptr";
return false;
@ -779,7 +779,7 @@ bool AscendKernelRuntime::SyncStream() {
}
bool AscendKernelRuntime::MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) {
InnerSetContext();
SetCurrentContext();
if (stream_ == nullptr) {
MS_LOG(ERROR) << "MemcpyAsync failed. stream_ is nullptr";
return false;
@ -803,7 +803,7 @@ void AscendKernelRuntime::CreateContext() {
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]";
}
}
InnerSetContext();
SetCurrentContext();
}
bool AscendKernelRuntime::InitDevice() {
@ -850,7 +850,7 @@ bool AscendKernelRuntime::InitDevice() {
}
bool AscendKernelRuntime::ResetDevice(uint32_t device_id) {
InnerSetContext();
SetCurrentContext();
if (stream_ != nullptr) {
auto ret = rtStreamDestroy(stream_);
if (ret != RT_ERROR_NONE) {

View File

@ -76,7 +76,7 @@ class AscendKernelRuntime : public KernelRuntime {
static bool NeedDestroyHccl();
static bool DestroyHccl();
static bool DestroySingleOpHccl();
void InnerSetContext();
void SetCurrentContext();
void ClearGraphModelMap();
void ReleaseDeviceRes() override;

View File

@ -121,10 +121,8 @@ class KernelRuntime {
void AssignStaticMemory(session::KernelGraph *graph);
void AssignDynamicMemory(session::KernelGraph *graph);
void ReuseAssignDynamicMemory(session::KernelGraph *graph);
void AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index);
void AssignWorkSpaceMem(MemType type, const AnfNodePtr &node);
void AssignReuseWorkSpaceMem(const AnfNodePtr &node);
void UpdateRefNodeOutputMem(const session::KernelGraph *graph);