From a3e5f2ab8b3eaf39ed644d3cb51ca20ce756bd55 Mon Sep 17 00:00:00 2001 From: Zichun Ye Date: Fri, 19 Aug 2022 10:16:20 +0800 Subject: [PATCH] update dynamic shape related dtype infer logic fix code check problem udpate dict check logic revert cpu/gpu related file --- .../optimizer/common_backend_optimization.cc | 2 + .../dynamic_shape_dtype_record.cc | 100 ++++++++++++++++++ .../dynamic_shape_dtype_record.h | 49 +++++++++ .../dynamic_shape/dynamic_shape_helper.cc | 50 ++++++++- 4 files changed, 199 insertions(+), 2 deletions(-) create mode 100644 mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.cc create mode 100644 mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.h diff --git a/mindspore/ccsrc/backend/common/optimizer/common_backend_optimization.cc b/mindspore/ccsrc/backend/common/optimizer/common_backend_optimization.cc index f32c4c94f91..0721c0627d6 100644 --- a/mindspore/ccsrc/backend/common/optimizer/common_backend_optimization.cc +++ b/mindspore/ccsrc/backend/common/optimizer/common_backend_optimization.cc @@ -40,6 +40,7 @@ #include "backend/common/optimizer/dynamic_shape/link_custom_op.h" #include "backend/common/pass/convert_unused_tuple_para_to_make_tuple.h" #include "backend/common/pass/convert_dynamic_broadcast_to.h" +#include "backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.h" #include "utils/ms_context.h" #include "include/common/debug/anf_ir_dump.h" @@ -192,6 +193,7 @@ void DynamicShapeConvertPass(const std::shared_ptr &kernel auto dynamic_shape_convert_pm = std::make_shared("dynamic_shape_convert_pm"); dynamic_shape_convert_pm->AddPass(std::make_shared()); dynamic_shape_convert_pm->AddPass(std::make_shared()); + dynamic_shape_convert_pm->AddPass(std::make_shared()); optimizer->AddPassManager(dynamic_shape_convert_pm); (void)optimizer->Optimize(kernel_graph); #ifdef ENABLE_DUMP_IR diff --git a/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.cc b/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.cc new file mode 100644 index 00000000000..9dd7ea49360 --- /dev/null +++ b/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.h" + +#include "backend/common/session/anf_runtime_algorithm.h" +#include "include/common/utils/anfalgo.h" +#include "backend/common/optimizer/helper.h" +#include "utils/anf_utils.h" + +namespace mindspore { +namespace opt::dynamic_shape { +DynamicShapeDtypeManager &DynamicShapeDtypeManager::GetInstance() { + static DynamicShapeDtypeManager instance{}; + return instance; +} + +void DynamicShapeDtypeManager::Register(const AnfNodePtr &node, const TypePtrList &device_abs) { + if (device_type_recorder_.find(node) == device_type_recorder_.end()) { + (void)device_type_recorder_.emplace(node, device_abs); + } +} + +bool DynamicShapeDtypeManager::CheckDeviceType(const AnfNodePtr &node) const { + return (device_type_recorder_.find(node) != device_type_recorder_.end()); +} + +TypePtrList DynamicShapeDtypeManager::GetDeviceType(const AnfNodePtr &node) { + auto iter = device_type_recorder_.find(node); + if (iter != device_type_recorder_.end()) { + return iter->second; + } + return {}; +} + +bool DynamicShapeDtypeRecord::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto nodes = TopoSort(func_graph->get_return()); + // auto manager = DynamicShapeDtypeManager::GetInstance(); + for (const auto &node : nodes) { + CNodePtr cnode = node->cast(); + if (cnode == nullptr || !AnfUtils::IsRealKernel(cnode)) { + continue; + } + + auto kernel_info = node->kernel_info(); + if (kernel_info == nullptr || !kernel_info->has_build_info()) { + continue; + } + + auto out_num = common::AnfAlgo::GetOutputTensorNum(node); + auto node_abs = node->abstract(); + if (node_abs->isa()) { + if (out_num != 1) { + continue; + } + auto infer_type = common::AnfAlgo::GetOutputInferDataType(node, 0); + auto device_type = AnfAlgo::GetOutputDeviceDataType(node, 0); + if (infer_type != device_type) { + TypePtrList new_abstract = {TypeIdToType(device_type)}; + DynamicShapeDtypeManager::GetInstance().Register(node, new_abstract); + } + } else if (node_abs->isa()) { + auto abstract_tuple = node_abs->cast(); + MS_EXCEPTION_IF_NULL(abstract_tuple); + TypePtrList abstract_list; + bool find_diff_element = false; + for (size_t output_index = 0; output_index < out_num; ++output_index) { + auto cur_element = abstract_tuple->elements()[output_index]; + MS_EXCEPTION_IF_NULL(cur_element); + auto infer_type = common::AnfAlgo::GetOutputInferDataType(node, output_index); + auto device_type = AnfAlgo::GetOutputDeviceDataType(node, output_index); + if (infer_type != device_type) { + find_diff_element = true; + } + abstract_list.push_back(TypeIdToType(device_type)); + } + if (find_diff_element) { + DynamicShapeDtypeManager::GetInstance().Register(node, abstract_list); + } + } + } + + return true; +} +} // namespace opt::dynamic_shape +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.h b/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.h new file mode 100644 index 00000000000..2e4258ca5c2 --- /dev/null +++ b/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.h @@ -0,0 +1,49 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_DYNAMIC_SHAPE_DTYPE_RECORD_H_ +#define MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_DYNAMIC_SHAPE_DTYPE_RECORD_H_ + +#include +#include +#include "ir/anf.h" +#include "backend/common/optimizer/optimizer.h" + +namespace mindspore::opt::dynamic_shape { +class DynamicShapeDtypeManager { + public: + static DynamicShapeDtypeManager &GetInstance(); + void Register(const AnfNodePtr &node, const std::vector &device_abs); + bool CheckDeviceType(const AnfNodePtr &node) const; + TypePtrList GetDeviceType(const AnfNodePtr &node); + + private: + DynamicShapeDtypeManager() = default; + ~DynamicShapeDtypeManager() = default; + DISABLE_COPY_AND_ASSIGN(DynamicShapeDtypeManager); + + std::map device_type_recorder_; +}; + +// If the data type of abstract is not same with the one of device, it will replace with device data type. +class DynamicShapeDtypeRecord : public Pass { + public: + DynamicShapeDtypeRecord() : Pass("dynamic_shape_dtype_record") {} + ~DynamicShapeDtypeRecord() override = default; + bool Run(const FuncGraphPtr &func_graph) override; +}; +} // namespace mindspore::opt::dynamic_shape +#endif // MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_DYNAMIC_SHAPE_DTYPE_RECORD_H_ diff --git a/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_helper.cc b/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_helper.cc index 9312e726ee2..373086253e2 100644 --- a/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_helper.cc +++ b/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_helper.cc @@ -24,6 +24,7 @@ #include #include #include "backend/common/session/anf_runtime_algorithm.h" +#include "backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.h" #include "runtime/device/ms_device_shape_transfer.h" #include "include/common/utils/anfalgo.h" #include "include/common/utils/utils.h" @@ -154,6 +155,34 @@ void InferShape(const CNodePtr &cnode, std::map *de auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i, false); auto real_input = input_node_with_index.first; auto real_input_index = input_node_with_index.second; + + bool abstract_in_cache = DynamicShapeDtypeManager::GetInstance().CheckDeviceType(real_input); + AbstractBasePtr cached_abstract; + AbstractBasePtr real_input_abs = real_input->abstract(); + + if (abstract_in_cache) { + auto cached_type_list = DynamicShapeDtypeManager::GetInstance().GetDeviceType(real_input); + if (real_input_abs->isa()) { + auto shape_ptr = real_input_abs->BuildShape(); + cached_abstract = std::make_shared(cached_type_list[0], shape_ptr); + } else if (real_input_abs->isa()) { + auto abstract_tuple = real_input_abs->cast(); + MS_EXCEPTION_IF_NULL(abstract_tuple); + AbstractBasePtrList abstract_list; + + for (size_t output_index = 0; output_index < cached_type_list.size(); ++output_index) { + auto cur_element = abstract_tuple->elements()[output_index]; + MS_EXCEPTION_IF_NULL(cur_element); + auto shape_ptr = cur_element->BuildShape(); + auto new_abstract = std::make_shared(cached_type_list[output_index], shape_ptr); + abstract_list.push_back(new_abstract); + } + cached_abstract = std::make_shared(abstract_list); + } else { + MS_LOG(EXCEPTION) << "Output of " << real_input->fullname_with_scope() + << " is neither a Tensor nor a Tuple of Tensor, but " << real_input_abs->ToString(); + } + } MS_EXCEPTION_IF_NULL(real_input); if (skip_nop_node) { InferShapeForNopNode(real_input); @@ -167,7 +196,12 @@ void InferShape(const CNodePtr &cnode, std::map *de // cppcheck-suppress unreadVariable auto lock = AnfUtils::GetAbstractLock(real_input.get()); - auto real_abs = real_input->abstract(); + AbstractBasePtr real_abs; + if (abstract_in_cache) { + real_abs = cached_abstract; + } else { + real_abs = real_input->abstract(); + } if (real_abs->isa()) { real_abs->set_value(out_tensor); } else if (real_abs->isa()) { @@ -178,7 +212,19 @@ void InferShape(const CNodePtr &cnode, std::map *de tuple_elements->set_value(out_tensor); } } - common::AnfAlgo::AddArgList(&args_spec_list, real_input, real_input_index); + if (abstract_in_cache) { + if (cached_abstract->isa()) { + auto abs_tuple = cached_abstract->Clone()->cast(); + MS_EXCEPTION_IF_NULL(abs_tuple); + MS_EXCEPTION_IF_CHECK_FAIL((real_input_index < abs_tuple->elements().size()), "Index is out of range."); + auto abs_index = abs_tuple->elements()[real_input_index]; + (void)args_spec_list.emplace_back(abs_index); + } else { + (void)args_spec_list.emplace_back(cached_abstract->Clone()); + } + } else { + common::AnfAlgo::AddArgList(&args_spec_list, real_input, real_input_index); + } } // Pynative mode is rely on the origin abstract of cnode, so cannot modify the abstract inplace, clone from old