From 2c8d65049af4a25592d30895a58891b500427dbd Mon Sep 17 00:00:00 2001 From: nomindcarry Date: Wed, 16 Nov 2022 17:15:35 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A7=A3=E8=80=A6SetArgs=E5=92=8CInferOp?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit flag fix --- .../dynamic_shape/dynamic_shape_helper.cc | 112 ++++++++++++++++++ .../dynamic_shape/dynamic_shape_helper.h | 2 + .../ccsrc/runtime/pynative/run_op_helper.cc | 3 +- 3 files changed, 116 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_helper.cc b/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_helper.cc index 3ffe9a406d6..6462f1e72c5 100644 --- a/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_helper.cc +++ b/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_helper.cc @@ -244,6 +244,65 @@ void InferShape(const CNodePtr &cnode, std::map *de cnode->set_abstract(new_abs); } +void InferShapeDynamic(const CNodePtr &cnode, std::map *depend_tensor_map, void *args) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(depend_tensor_map); + MS_LOG(DEBUG) << "InferShape start, node:" << cnode->fullname_with_scope(); + std::set depend_list = abstract::GetValueDependArgIndices(cnode); + auto ret = InferShapeForDefiniteOutputNode(cnode); + if (ret) { + return; + } + + depend_tensor_map->clear(); + auto &inputs = cnode->inputs(); + if (inputs.empty()) { + MS_LOG(EXCEPTION) << "Invalid inputs."; + } + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + AbstractBasePtrList args_spec_list; + auto primitive = GetValueNode(inputs[0]); + auto input_size = common::AnfAlgo::GetInputTensorNum(cnode); + for (size_t i = 0; i < input_size; i++) { + auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i, false); + auto real_input = input_node_with_index.first; + auto real_input_index = input_node_with_index.second; + + AbstractBasePtr cached_abstract; + AbstractBasePtr real_input_abs = real_input->abstract(); + + MS_EXCEPTION_IF_NULL(real_input); + if (depend_list.find(i) != depend_list.end()) { + auto out_tensor = GetDependValueTensor(cnode, i, input_node_with_index, false, args, false); + + // cppcheck-suppress unreadVariable + auto lock = AnfUtils::GetAbstractLock(real_input.get()); + AbstractBasePtr real_abs = real_input->abstract(); + if (real_abs->isa()) { + real_abs->set_value(out_tensor); + } else if (real_abs->isa()) { + auto abstract_tuple = real_abs->cast(); + MS_EXCEPTION_IF_NULL(abstract_tuple); + MS_EXCEPTION_IF_CHECK_FAIL((real_input_index < abstract_tuple->elements().size()), "Index is out of range."); + auto tuple_elements = abstract_tuple->elements()[real_input_index]; + tuple_elements->set_value(out_tensor); + } + } + common::AnfAlgo::AddArgList(&args_spec_list, real_input, real_input_index); + } + + // Pynative mode is rely on the origin abstract of cnode, so cannot modify the abstract inplace, clone from old + // abstract instead. + auto old_abs = cnode->abstract(); + MS_EXCEPTION_IF_NULL(old_abs); + auto new_abs = old_abs->Clone(); + opt::CppInferShape(primitive, args_spec_list, new_abs); + MS_LOG(DEBUG) << "The abstract of " << cnode->fullname_with_scope() << " changes from " << old_abs << " to " + << new_abs; + cnode->set_abstract(new_abs); +} + inline bool IsDeprecatedCpuOrGpuKernelMod(kernel::KernelModType kernel_mod_type) { return kernel_mod_type == kernel::KernelModType::DeprecatedNativeGpuKernelMod || kernel_mod_type == kernel::KernelModType::DeprecatedNativeCpuKernelMod; @@ -324,6 +383,59 @@ void InferOp(const CNodePtr &cnode, void *args) { } } +void InferOpDynamic(const CNodePtr &cnode, void *args) { + MS_EXCEPTION_IF_NULL(cnode); + auto kernel_mod = AnfAlgo::GetKernelMod(cnode); + MS_EXCEPTION_IF_NULL(kernel_mod); + kernel::KernelArgs kernel_args; + if (AnfAlgo::IsDynamicShapeSkipExecute(cnode)) { + std::vector dtypes{common::AnfAlgo::GetOutputInferDataType(cnode, 0)}; + common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetInputDeviceShape(cnode, 0)}, cnode.get()); + } else { + InferShapeDynamic(cnode, &kernel_args.depend_tensor_map, args); + } +} + +void SetOpArgs(const CNodePtr &cnode, void *args) { + MS_EXCEPTION_IF_NULL(cnode); + auto kernel_mod = AnfAlgo::GetKernelMod(cnode); + MS_EXCEPTION_IF_NULL(kernel_mod); + kernel::KernelArgs kernel_args; + + std::set depend_list = abstract::GetValueDependArgIndices(cnode); + auto *depend_tensor_map = &kernel_args.depend_tensor_map; + + auto &inputs = cnode->inputs(); + if (inputs.empty()) { + MS_LOG(EXCEPTION) << "Invalid inputs."; + } + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + auto input_size = common::AnfAlgo::GetInputTensorNum(cnode); + bool skip_nop_node = !context->get_param(MS_CTX_ENABLE_MINDRT); + for (size_t i = 0; i < input_size; i++) { + auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i, false); + auto real_input = input_node_with_index.first; + bool abstract_in_cache = DynamicShapeDtypeManager::GetInstance().CheckDeviceType(real_input); + if (depend_list.find(i) != depend_list.end()) { + auto out_tensor = GetDependValueTensor(cnode, i, input_node_with_index, skip_nop_node, args, abstract_in_cache); + auto ret2 = depend_tensor_map->try_emplace(i, out_tensor); + if (!ret2.second) { + MS_LOG(EXCEPTION) << "Insert map failed."; + } + } + } + + if (auto kernel_mod_type = kernel_mod->GetKernelModType(); IsCpuGpuKernelMod(kernel_mod_type)) { + auto update = kernel::AbstractArgsFromCNode(cnode, IsDeprecatedCpuOrGpuKernelMod(kernel_mod_type)); + update.depend_tensor_map = std::move(kernel_args.depend_tensor_map); + kernel::SetInputsByDependMap(update.depend_tensor_map, &update.inputs, IsCpuKernelMod(kernel_mod_type)); + kernel::SetArgsToCNode(cnode, update); + } else { + kernel::SetArgsToCNode(cnode, kernel_args); + } +} + CustomActorNodeManager &CustomActorNodeManager::Instance() { static CustomActorNodeManager instance{}; return instance; diff --git a/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h b/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h index 03716036ad8..63a22d32ab2 100644 --- a/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h +++ b/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h @@ -24,6 +24,8 @@ namespace mindspore::opt::dynamic_shape { bool IsRealCNode(const BaseRef &n); BACKEND_EXPORT void InferOp(const CNodePtr &node, void *args = nullptr); +BACKEND_EXPORT void InferOpDynamic(const CNodePtr &node, void *args = nullptr); +BACKEND_EXPORT void SetOpArgs(const CNodePtr &node, void *args = nullptr); AnfNodePtr GenInferNode(const AnfNodePtr &node); AnfNodePtr GenInitNode(const AnfNodePtr &node); diff --git a/mindspore/ccsrc/runtime/pynative/run_op_helper.cc b/mindspore/ccsrc/runtime/pynative/run_op_helper.cc index 30780b16a06..5eb27c0636c 100644 --- a/mindspore/ccsrc/runtime/pynative/run_op_helper.cc +++ b/mindspore/ccsrc/runtime/pynative/run_op_helper.cc @@ -486,7 +486,8 @@ void LaunchKernelsDynamic(const KernelGraphPtr &graph, const device::DeviceConte } auto inputs = CreateKernelInputAddress(runtime_info); - InferNodeRealShape(node); + opt::dynamic_shape::InferOpDynamic(node); + opt::dynamic_shape::SetOpArgs(node); runtime::DeviceAddressUtils::CreateKernelOutputDeviceAddress(device_context, graph, is_gradient_out); runtime::DeviceAddressUtils::UpdateDeviceAddressForInplaceNode(graph);