From 46d956d7d8fbacc4c99899c12320dd72923f1398 Mon Sep 17 00:00:00 2001 From: fangzehua Date: Mon, 13 Dec 2021 15:01:47 +0800 Subject: [PATCH] acc infershape in runtime --- .../backend/session/anf_runtime_algorithm.cc | 71 +++++++++++------ .../backend/session/anf_runtime_algorithm.h | 2 + mindspore/ccsrc/common/trans.cc | 29 +++---- .../runtime/device/executor/dynamic_kernel.cc | 79 ++++++++++--------- .../runtime/device/executor/dynamic_kernel.h | 5 +- .../ccsrc/runtime/framework/graph_compiler.cc | 1 + .../ccsrc/runtime/hardware/device_context.h | 14 ++++ .../hardware/gpu/gpu_device_context.cc | 2 +- .../core/abstract/primitive_infer_map.cc | 47 +++++------ mindspore/core/abstract/primitive_infer_map.h | 3 +- mindspore/core/ir/anf.cc | 36 +++++++-- mindspore/core/ir/kernel_info_dev.h | 44 +++++++++++ mindspore/core/utils/anf_utils.cc | 56 +++++++++++-- mindspore/core/utils/info.h | 2 +- 14 files changed, 273 insertions(+), 118 deletions(-) diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index aeef8a8ffb1..fd2b340294a 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -737,12 +737,33 @@ KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_nod if (!anf_node->isa()) { MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode." << trace::DumpSourceLines(anf_node); } - if (CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) { - return VisitKernelWithReturnType(anf_node, 0, skip_nop_node); + auto kernel_info = anf_node->kernel_info(); + if (kernel_info) { + auto runtime_cache = kernel_info->runtime_cache(); + MS_EXCEPTION_IF_NULL(runtime_cache); + if (runtime_cache->is_valid()) { + auto output = runtime_cache->get_prev_node_output(input_idx); + if (output.first != nullptr) { + return output; + } + } } - auto input_node = AnfAlgo::GetInputNode(anf_node->cast(), input_idx); - MS_EXCEPTION_IF_NULL(input_node); - return VisitKernelWithReturnType(input_node, 0, skip_nop_node); + KernelWithIndex res; + if (CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) { + res = VisitKernelWithReturnType(anf_node, 0, skip_nop_node); + } else { + auto input_node = AnfAlgo::GetInputNode(anf_node->cast(), input_idx); + MS_EXCEPTION_IF_NULL(input_node); + res = VisitKernelWithReturnType(input_node, 0, skip_nop_node); + } + if (kernel_info) { + auto runtime_cache = kernel_info->runtime_cache(); + MS_EXCEPTION_IF_NULL(runtime_cache); + if (runtime_cache->is_valid()) { + runtime_cache->set_prev_node_output(input_idx, res); + } + } + return res; } std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) { @@ -2180,9 +2201,9 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node, std::mapinput(i + 1); MS_EXCEPTION_IF_NULL(cnode_input); - MS_EXCEPTION_IF_NULL(real_input); if (depend_tensors != nullptr) { auto iter_tensor = depend_tensors->find(i); if (iter_tensor != depend_tensors->end()) { @@ -2202,28 +2223,32 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node, std::mapShape(); - if (!base_shape->isa()) { - MS_LOG(EXCEPTION) << "Node:" << node->DebugString() - << " input is a tuple_get_item but real input node shape is not a TupleShape. trace: " - << trace::DumpSourceLines(real_input); - } - auto abs = real_input->abstract()->cast(); - MS_EXCEPTION_IF_NULL(abs); - auto tuple_get_item_indexk = AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast()); - auto abs_i = abs->elements()[tuple_get_item_indexk]; - (void)args_spec_list.emplace_back(abs_i); - } else if (cnode_input->isa() && AnfAlgo::GetCNodeName(cnode_input) == prim::kPrimReshape->name()) { - (void)args_spec_list.emplace_back(cnode_input->abstract()); - } else { - (void)args_spec_list.emplace_back(real_input->abstract()); - } + AddArgList(&args_spec_list, cnode_input, real_input, i); } auto eval_result = opt::CppInferShape(primitive, args_spec_list); node->set_abstract(eval_result); } +void AnfRuntimeAlgorithm::AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &cnode_input, + const AnfNodePtr &real_input, size_t index) { + if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimTupleGetItem)) { + auto base_shape = real_input->Shape(); + if (!base_shape->isa()) { + MS_LOG(EXCEPTION) << "Node input is a tuple_get_item but real input node shape is not a TupleShape. trace: " + << trace::DumpSourceLines(real_input); + } + auto abs = real_input->abstract()->cast(); + MS_EXCEPTION_IF_NULL(abs); + auto tuple_get_item_indexk = AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast()); + auto abs_i = abs->elements()[tuple_get_item_indexk]; + (void)args_spec_list->emplace_back(abs_i); + } else if (cnode_input->isa() && AnfAlgo::GetCNodeName(cnode_input) == prim::kPrimReshape->name()) { + (void)args_spec_list->emplace_back(cnode_input->abstract()); + } else { + (void)args_spec_list->emplace_back(real_input->abstract()); + } +} + void AnfRuntimeAlgorithm::InsertMakeTupleForOutput(const NotNull &root_graph) { auto return_node = root_graph->get_return(); MS_EXCEPTION_IF_NULL(return_node); diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index 116600722f1..3eba247d6b7 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -296,6 +296,8 @@ class AnfRuntimeAlgorithm { static std::vector GetOutputMinShape(const AnfNodePtr &anf_node, size_t index); static bool IsNodeDynamicShape(const AnfNodePtr &node); static void InferShape(const CNodePtr &node, std::map *depend_tensors = nullptr); + static void AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &cnode_input, + const AnfNodePtr &real_input, size_t index); static std::vector GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index); static std::vector GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index); // Find real input nodes. diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index 33e64ce4521..4edbc14ef2a 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -1001,20 +1001,21 @@ void StringToAxisVector5D(const std::string &reshape_type_str, std::vector TransShapeToDevice(const std::vector &shape, const std::string &format, const int64_t groups, const std::vector &input_hidden_size) { using DeviceShapeTransfer = std::function(const std::vector &)>; - const std::map device_shape_map{{kOpFormat_NCHW, NchwDeviceShape}, - {kOpFormat_NHWC, NhwcDeviceShape}, - {kOpFormat_HWCN, HwchDeviceShape}, - {kOpFormat_FRAC_Z, FracZDeviceShape}, - {kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, - {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape}, - {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape}, - {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape}, - {kOpFormat_NCDHW, NcdhwDeviceShape}, - {kOpFormat_ChannelLast, ChannelLastDeviceShape}, - {kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape}, - {kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape}, - {kOpFormat_FRAC_NZ, FracNZDeviceShape}, - {kOpFormat_FRACTAL_ZN_LSTM, FracNZLSTMDeviceShape}}; + static const std::map device_shape_map{ + {kOpFormat_NCHW, NchwDeviceShape}, + {kOpFormat_NHWC, NhwcDeviceShape}, + {kOpFormat_HWCN, HwchDeviceShape}, + {kOpFormat_FRAC_Z, FracZDeviceShape}, + {kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, + {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape}, + {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape}, + {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape}, + {kOpFormat_NCDHW, NcdhwDeviceShape}, + {kOpFormat_ChannelLast, ChannelLastDeviceShape}, + {kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape}, + {kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape}, + {kOpFormat_FRAC_NZ, FracNZDeviceShape}, + {kOpFormat_FRACTAL_ZN_LSTM, FracNZLSTMDeviceShape}}; if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { return shape; diff --git a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc index 75292f92a58..837b0090ca0 100644 --- a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc +++ b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc @@ -47,61 +47,66 @@ void DynamicKernel::Initialize() { return; } MS_LOG(INFO) << "Have depends"; - (void)std::transform(ret.begin(), ret.end(), std::back_inserter(depend_list_), + (void)std::transform(ret.begin(), ret.end(), std::inserter(depend_list_, depend_list_.begin()), [](const int64_t &value) { return static_cast(value); }); MS_LOG(INFO) << "Init End"; } int DynamicKernel::GetKernelType() const { return AnfAlgo::GetKernelType(cnode_ptr_.lock()); } -void DynamicKernel::RebuildDependTensor() { - depend_tensor_map_.clear(); - auto cnode = cnode_ptr_.lock(); - MS_EXCEPTION_IF_NULL(cnode); - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - for (auto depend : depend_list_) { - auto pre_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, depend); - bool skip_nop_node = !context->get_param(MS_CTX_ENABLE_MINDRT); - auto output_addr = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode, depend, skip_nop_node); - std::vector shapes = trans::GetRuntimePaddingShape(pre_node_with_index.first, pre_node_with_index.second); - auto host_type = AnfAlgo::GetOutputInferDataType(pre_node_with_index.first, pre_node_with_index.second); - auto out_tensor = std::make_shared(host_type, shapes); - MS_EXCEPTION_IF_NULL(out_tensor); - // The second parameter must be false, otherwise the device address cannot be released and allocated, and the - // address size will be wrong in the dynamic shape scenario. - out_tensor->set_device_address(output_addr, false); - auto ret = depend_tensor_map_.try_emplace(depend, out_tensor); - if (!ret.second) { - MS_LOG(EXCEPTION) << "Insert map failed"; - } - } -} - void DynamicKernel::InferShape() { auto cnode = cnode_ptr_.lock(); MS_EXCEPTION_IF_NULL(cnode); MS_LOG(INFO) << "InferShape start, node:" << cnode->fullname_with_scope(); - InferShapeRecursive(); + depend_tensor_map_.clear(); auto inputs = cnode->inputs(); if (inputs.empty()) { MS_LOG(EXCEPTION) << "Invalid inputs"; } - // rebuild depend tensor map for gpu dynamic memory allocation. - RebuildDependTensor(); - AnfAlgo::InferShape(cnode, &depend_tensor_map_); -} - -void DynamicKernel::InferShapeRecursive() { - auto cnode = cnode_ptr_.lock(); - MS_EXCEPTION_IF_NULL(cnode); + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + AbstractBasePtrList args_spec_list; + auto primitive = GetValueNode(inputs[0]); auto input_size = AnfAlgo::GetInputTensorNum(cnode); for (size_t i = 0; i < input_size; i++) { auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i); - auto input_node = input_node_with_index.first; - MS_EXCEPTION_IF_NULL(input_node); - InferShapeForNopNode(&input_node); + auto real_input = input_node_with_index.first; + MS_EXCEPTION_IF_NULL(real_input); + auto cnode_input = cnode->input(i + 1); + MS_EXCEPTION_IF_NULL(cnode_input); + InferShapeForNopNode(&real_input); + if (depend_list_.find(i) != depend_list_.end()) { + auto pre_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i); + bool skip_nop_node = !context->get_param(MS_CTX_ENABLE_MINDRT); + auto output_addr = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode, i, skip_nop_node); + std::vector shapes = + trans::GetRuntimePaddingShape(pre_node_with_index.first, pre_node_with_index.second); + auto host_type = AnfAlgo::GetOutputInferDataType(pre_node_with_index.first, pre_node_with_index.second); + auto out_tensor = std::make_shared(host_type, shapes); + MS_EXCEPTION_IF_NULL(out_tensor); + // The second parameter must be false, otherwise the device address cannot be released and allocated, and the + // address size will be wrong in the dynamic shape scenario. + out_tensor->set_device_address(output_addr, false); + auto ret = depend_tensor_map_.try_emplace(i, out_tensor); + if (!ret.second) { + MS_LOG(EXCEPTION) << "Insert map failed"; + } + out_tensor->data_sync(); + auto real_abs = real_input->abstract(); + if (real_abs->isa()) { + real_input->abstract()->set_value(out_tensor); + } else if (real_abs->isa()) { + auto tuple_get_item_index = AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast()); + auto abstract_tuple = real_abs->cast(); + MS_EXCEPTION_IF_NULL(abstract_tuple); + auto tuple_elements = abstract_tuple->elements()[tuple_get_item_index]; + tuple_elements->set_value(out_tensor); + } + } + AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input, i); } + auto eval_result = opt::CppInferShape(primitive, args_spec_list); + cnode->set_abstract(eval_result); } void DynamicKernel::InferShapeForNopNode(AnfNodePtr *input_node) { diff --git a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h index 62c86ec1d84..ccc9bc1a9a6 100644 --- a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h +++ b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h @@ -21,6 +21,7 @@ #include #include #include +#include #include "ir/anf.h" #include "ir/tensor.h" #include "abstract/primitive_infer_map.h" @@ -46,8 +47,6 @@ class DynamicKernel { [[nodiscard]] int GetKernelType() const; protected: - void RebuildDependTensor(); - void InferShapeRecursive(); static void InferShapeForNopNode(AnfNodePtr *input_node); void *stream_; @@ -55,7 +54,7 @@ class DynamicKernel { bool is_dynamic_shape_; bool is_input_dynamic_shape_; bool is_output_dynamic_shape_; - std::vector depend_list_; + std::set depend_list_; std::map depend_tensor_map_; }; using DynamicKernelPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/runtime/framework/graph_compiler.cc b/mindspore/ccsrc/runtime/framework/graph_compiler.cc index 60cf0fafe1e..05eb57b5241 100644 --- a/mindspore/ccsrc/runtime/framework/graph_compiler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_compiler.cc @@ -477,6 +477,7 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic (void)mindspore::RDR::RecordGraphExecOrder(SubModuleId::SM_SESSION, exec_order_name, kernels); #endif + device_context->EnableRuntimeCache(graph); session_->DumpGraph(graph); return graph->graph_id(); } diff --git a/mindspore/ccsrc/runtime/hardware/device_context.h b/mindspore/ccsrc/runtime/hardware/device_context.h index 4f3682f6e72..e6608232a37 100644 --- a/mindspore/ccsrc/runtime/hardware/device_context.h +++ b/mindspore/ccsrc/runtime/hardware/device_context.h @@ -156,6 +156,20 @@ class DeviceContext { // Dump all graphs. virtual void DumpAllGraphs(const std::vector &all_graphs) const {} + void EnableRuntimeCache(const KernelGraphPtr &graph) const { + auto node_list = graph->TopoSort(graph->get_return()); + for (auto &node : node_list) { + auto kernel_info = node->kernel_info(); + if (!kernel_info) { + continue; + } + MS_EXCEPTION_IF_NULL(kernel_info); + auto runtime_cache = kernel_info->runtime_cache(); + MS_EXCEPTION_IF_NULL(runtime_cache); + runtime_cache->set_valid(); + } + } + protected: DeviceContextKey device_context_key_; diff --git a/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc b/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc index 4780422ad4d..1fa99383ed3 100644 --- a/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc +++ b/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc @@ -370,7 +370,7 @@ void GPUDeviceContext::UpdateDynamicShape(const CNodePtr &kernel) const { MS_EXCEPTION_IF_NULL(ms_context); bool is_pynative_infer = ms_context->get_param(MS_CTX_ENABLE_PYNATIVE_INFER); bool is_pynative_mode = ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode; - std::vector dynamic_shape_depends = abstract::GetDependsFormMap(kernel); + auto dynamic_shape_depends = abstract::GetDependsFormMap(kernel); if ((is_pynative_infer || is_pynative_mode) && dynamic_shape_depends.empty()) { return; } diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 48668e80496..f96d58ef078 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -19,6 +19,7 @@ #include "abstract/primitive_infer_map.h" #include #include +#include #include "ops/exp.h" #include "ops/log.h" #include "ops/reciprocal.h" @@ -39,9 +40,9 @@ #include "ops/grad/slice_grad.h" namespace mindspore { namespace abstract { -std::vector GetDependsFormMap(const CNodePtr &cnode) { - using ShapeVec = std::vector; - using PrimShapeDependMap = mindspore::HashMap; +std::set GetDependsFormMap(const CNodePtr &cnode) { + using ShapeSet = std::set; + using PrimShapeDependMap = mindspore::HashMap; static const auto &kOneHot = prim::kPrimOneHot->name(); static const auto &kDropoutGenMask = prim::kPrimDropoutGenMask->name(); static const auto &kTranspose = prim::kPrimTranspose->name(); @@ -63,24 +64,24 @@ std::vector GetDependsFormMap(const CNodePtr &cnode) { static const auto &kReshape = prim::kPrimReshape->name(); static const auto &kDynamicReshape = prim::kPrimDynamicReshape->name(); // Common dynamic shape depends. - static const PrimShapeDependMap dynamic_shape_depends{{kUnsortedSegmentSum, ShapeVec{2}}, - {kUnsortedSegmentMin, ShapeVec{2}}, - {kUnsortedSegmentMax, ShapeVec{2}}, - {kGather, ShapeVec{2}}, - {kGatherV2, ShapeVec{2}}, - {kRange, ShapeVec{0, 1, 2}}, - {kConv2DBackpropFilter, ShapeVec{2}}, - {kConv2DBackpropInput, ShapeVec{2}}, - {kOneHot, ShapeVec{1, 3}}, - {kDropoutGenMask, ShapeVec{0}}, - {kStridedSlice, ShapeVec{1, 2, 3}}, - {kStridedSliceGrad, ShapeVec{1, 2, 3, 4}}, - {kTile, ShapeVec{1}}, - {kReshape, ShapeVec{1}}, - {kDynamicReshape, ShapeVec{1}}, - {kSlice, ShapeVec{1, 2}}, - {kSliceGrad, ShapeVec{2, 3}}, - {kDynamicBroadcastTo, ShapeVec{1}}}; + static const PrimShapeDependMap dynamic_shape_depends{{kUnsortedSegmentSum, ShapeSet{2}}, + {kUnsortedSegmentMin, ShapeSet{2}}, + {kUnsortedSegmentMax, ShapeSet{2}}, + {kGather, ShapeSet{2}}, + {kGatherV2, ShapeSet{2}}, + {kRange, ShapeSet{0, 1, 2}}, + {kConv2DBackpropFilter, ShapeSet{2}}, + {kConv2DBackpropInput, ShapeSet{2}}, + {kOneHot, ShapeSet{1, 3}}, + {kDropoutGenMask, ShapeSet{0}}, + {kStridedSlice, ShapeSet{1, 2, 3}}, + {kStridedSliceGrad, ShapeSet{1, 2, 3, 4}}, + {kTile, ShapeSet{1}}, + {kReshape, ShapeSet{1}}, + {kDynamicReshape, ShapeSet{1}}, + {kSlice, ShapeSet{1, 2}}, + {kSliceGrad, ShapeSet{2, 3}}, + {kDynamicBroadcastTo, ShapeSet{1}}}; MS_EXCEPTION_IF_NULL(cnode); if (cnode->inputs().empty()) { @@ -101,9 +102,9 @@ std::vector GetDependsFormMap(const CNodePtr &cnode) { auto iter = dynamic_shape_depends.find(prim_name); if (iter != dynamic_shape_depends.end()) { int64_t cnode_input_size = SizeToLong(cnode->inputs().size()); - std::vector res; + ShapeSet res; auto ori = iter->second; - (void)std::copy_if(ori.begin(), ori.end(), std::back_inserter(res), + (void)std::copy_if(ori.begin(), ori.end(), std::inserter(res, res.begin()), [&](auto idx) { return idx < cnode_input_size - 1; }); return res; } diff --git a/mindspore/core/abstract/primitive_infer_map.h b/mindspore/core/abstract/primitive_infer_map.h index bd92bde1996..136032c262a 100644 --- a/mindspore/core/abstract/primitive_infer_map.h +++ b/mindspore/core/abstract/primitive_infer_map.h @@ -19,6 +19,7 @@ #define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ #include +#include #include #include "utils/hash_map.h" #include "ir/primitive.h" @@ -50,7 +51,7 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap(); StandardPrimitiveImplReg GetPrimitiveInferImpl(const PrimitivePtr &primitive); -std::vector GetDependsFormMap(const CNodePtr &cnode); +std::set GetDependsFormMap(const CNodePtr &cnode); void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg); diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc index 27ebf96c238..83d0f4967f3 100644 --- a/mindspore/core/ir/anf.cc +++ b/mindspore/core/ir/anf.cc @@ -622,14 +622,36 @@ std::string GetOriginNodeTarget(const AnfNodePtr &node) { } std::string GetCNodeTarget(const AnfNodePtr &node) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - std::string default_target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); - auto target = GetOriginNodeTarget(node); - if (target != kTargetUnDefined) { - return target; + auto kernel_info = node->kernel_info(); + if (kernel_info != nullptr) { + auto runtime_cache = kernel_info->runtime_cache(); + MS_EXCEPTION_IF_NULL(runtime_cache); + if (runtime_cache->is_valid()) { + auto tmp_target = runtime_cache->device_target(); + if (!tmp_target.empty()) { + return tmp_target; + } + } } - return default_target; + + std::string target; + auto ori_target = GetOriginNodeTarget(node); + if (ori_target != kTargetUnDefined) { + target = ori_target; + } else { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); + } + + if (kernel_info != nullptr) { + auto runtime_cache = kernel_info->runtime_cache(); + MS_EXCEPTION_IF_NULL(runtime_cache); + if (runtime_cache->is_valid()) { + runtime_cache->set_device_target(target); + } + } + return target; } bool ContainMultiTarget(const std::vector &nodes) { diff --git a/mindspore/core/ir/kernel_info_dev.h b/mindspore/core/ir/kernel_info_dev.h index 9d20a9f67e3..ad98a9cce15 100644 --- a/mindspore/core/ir/kernel_info_dev.h +++ b/mindspore/core/ir/kernel_info_dev.h @@ -18,6 +18,10 @@ #define MINDSPORE_CORE_IR_KERNEL_INFO_DEV_H_ #include +#include +#include +#include +#include "utils/info.h" namespace mindspore { enum Axis : int { @@ -26,11 +30,51 @@ enum Axis : int { H, W, }; + +// Cache some runtime information which not be changed. +class RuntimeCache { + public: + std::pair get_prev_node_output(size_t index) { + auto it = prev_node_output_map_.find(index); + if (it != prev_node_output_map_.end()) { + return it->second; + } else { + return std::pair(); + } + } + + void set_prev_node_output(size_t index, std::pair output) { + auto pr = std::make_pair(index, output); + (void)prev_node_output_map_.insert(pr); + } + + std::string device_target() { return device_target_; } + + void set_device_target(const std::string &target) { device_target_ = target; } + bool is_valid() { return is_valid_; } + void set_valid() { is_valid_ = true; } + void set_output_tensor_num(const ssize_t output_tensor_num) { output_tensor_num_ = output_tensor_num; } + ssize_t output_tensor_num() const { return output_tensor_num_; } + void set_real_kernel(enum CacheBool b) { is_real_kernel_ = b; } + enum CacheBool is_real_kernel() { return is_real_kernel_; } + + private: + bool is_valid_{false}; + std::map> prev_node_output_map_; + std::string device_target_; + ssize_t output_tensor_num_ = -1; + enum CacheBool is_real_kernel_ = CacheBool::UNCACHED; +}; // Interface for device kernel program information. class KernelInfoDevice { public: // If kernel program was built and build info is set. virtual bool has_build_info() const = 0; + + RuntimeCache *runtime_cache() { return &runtime_cache_; } + + private: + RuntimeCache runtime_cache_; }; using KernelInfoDevicePtr = std::shared_ptr; } // namespace mindspore diff --git a/mindspore/core/utils/anf_utils.cc b/mindspore/core/utils/anf_utils.cc index e1757c77a4b..2cebc60ce3a 100644 --- a/mindspore/core/utils/anf_utils.cc +++ b/mindspore/core/utils/anf_utils.cc @@ -108,7 +108,28 @@ bool AnfUtils::IsRealKernel(const AnfNodePtr &node) { if (cnode->size() == 0) { MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString() << trace::DumpSourceLines(node); } - return !IsOneOfPrimitive(cnode->input(kAnfPrimitiveIndex), virtual_prims); + + auto kernel_info = cnode->kernel_info(); + if (kernel_info) { + auto runtime_cache = kernel_info->runtime_cache(); + MS_EXCEPTION_IF_NULL(runtime_cache); + if (runtime_cache->is_real_kernel() != CacheBool::UNCACHED) { + return (runtime_cache->is_real_kernel() == CacheBool::TRUE); + } + } + bool res = !IsOneOfPrimitive(cnode->input(kAnfPrimitiveIndex), virtual_prims); + + if (kernel_info) { + auto runtime_cache = kernel_info->runtime_cache(); + MS_EXCEPTION_IF_NULL(runtime_cache); + if (res) { + runtime_cache->set_real_kernel(CacheBool::TRUE); + } else { + runtime_cache->set_real_kernel(CacheBool::FALSE); + } + } + + return res; } bool AnfUtils::IsRealCNodeKernel(const AnfNodePtr &node) { @@ -183,19 +204,38 @@ size_t AnfUtils::GetInputTensorNum(const AnfNodePtr &node) { size_t AnfUtils::GetOutputTensorNum(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); + auto kernel_info = node->kernel_info(); + if (kernel_info) { + auto runtime_cache = kernel_info->runtime_cache(); + if (runtime_cache->is_valid()) { + ssize_t output_tensor_num = runtime_cache->output_tensor_num(); + if (output_tensor_num >= 0) { + return static_cast(output_tensor_num); + } + } + } + + size_t res; TypePtr type = node->Type(); if (type == nullptr) { - return 0; - } - if (type->isa()) { + res = 0; + } else if (type->isa()) { auto tuple_type = type->cast(); MS_EXCEPTION_IF_NULL(tuple_type); - return tuple_type->size(); + res = tuple_type->size(); + } else if (type->isa()) { + res = 0; + } else { + res = 1; } - if (type->isa()) { - return 0; + + if (kernel_info) { + auto runtime_cache = kernel_info->runtime_cache(); + if (runtime_cache->is_valid()) { + runtime_cache->set_output_tensor_num(static_cast(res)); + } } - return 1; + return res; } std::pair AnfUtils::VisitKernel(const AnfNodePtr &anf_node, size_t index) { diff --git a/mindspore/core/utils/info.h b/mindspore/core/utils/info.h index d994f77784a..99e46d2f7c7 100644 --- a/mindspore/core/utils/info.h +++ b/mindspore/core/utils/info.h @@ -29,7 +29,7 @@ namespace mindspore { enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSourceLineTipInLine = 2 }; - +typedef enum CacheBool { UNCACHED = -1, FALSE, TRUE } CacheBool; // Location class record the location in source code. class Location { public: