!40456 fix type infer in dynamic shape case

Merge pull request !40456 from zichun_ye/dyn_dtype
This commit is contained in:
i-robot 2022-08-24 07:03:12 +00:00 committed by Gitee
commit d317fdcf78
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 199 additions and 2 deletions

View File

@ -40,6 +40,7 @@
#include "backend/common/optimizer/dynamic_shape/link_custom_op.h"
#include "backend/common/pass/convert_unused_tuple_para_to_make_tuple.h"
#include "backend/common/pass/convert_dynamic_broadcast_to.h"
#include "backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.h"
#include "utils/ms_context.h"
#include "include/common/debug/anf_ir_dump.h"
@ -192,6 +193,7 @@ void DynamicShapeConvertPass(const std::shared_ptr<session::KernelGraph> &kernel
auto dynamic_shape_convert_pm = std::make_shared<opt::PassManager>("dynamic_shape_convert_pm");
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::ConvertCustomOp>());
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::LinkCustomOp>());
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::DynamicShapeDtypeRecord>());
optimizer->AddPassManager(dynamic_shape_convert_pm);
(void)optimizer->Optimize(kernel_graph);
#ifdef ENABLE_DUMP_IR

View File

@ -0,0 +1,100 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "backend/common/optimizer/helper.h"
#include "utils/anf_utils.h"
namespace mindspore {
namespace opt::dynamic_shape {
DynamicShapeDtypeManager &DynamicShapeDtypeManager::GetInstance() {
static DynamicShapeDtypeManager instance{};
return instance;
}
void DynamicShapeDtypeManager::Register(const AnfNodePtr &node, const TypePtrList &device_abs) {
if (device_type_recorder_.find(node) == device_type_recorder_.end()) {
(void)device_type_recorder_.emplace(node, device_abs);
}
}
bool DynamicShapeDtypeManager::CheckDeviceType(const AnfNodePtr &node) const {
return (device_type_recorder_.find(node) != device_type_recorder_.end());
}
TypePtrList DynamicShapeDtypeManager::GetDeviceType(const AnfNodePtr &node) {
auto iter = device_type_recorder_.find(node);
if (iter != device_type_recorder_.end()) {
return iter->second;
}
return {};
}
bool DynamicShapeDtypeRecord::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto nodes = TopoSort(func_graph->get_return());
// auto manager = DynamicShapeDtypeManager::GetInstance();
for (const auto &node : nodes) {
CNodePtr cnode = node->cast<CNodePtr>();
if (cnode == nullptr || !AnfUtils::IsRealKernel(cnode)) {
continue;
}
auto kernel_info = node->kernel_info();
if (kernel_info == nullptr || !kernel_info->has_build_info()) {
continue;
}
auto out_num = common::AnfAlgo::GetOutputTensorNum(node);
auto node_abs = node->abstract();
if (node_abs->isa<abstract::AbstractTensor>()) {
if (out_num != 1) {
continue;
}
auto infer_type = common::AnfAlgo::GetOutputInferDataType(node, 0);
auto device_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
if (infer_type != device_type) {
TypePtrList new_abstract = {TypeIdToType(device_type)};
DynamicShapeDtypeManager::GetInstance().Register(node, new_abstract);
}
} else if (node_abs->isa<abstract::AbstractTuple>()) {
auto abstract_tuple = node_abs->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(abstract_tuple);
TypePtrList abstract_list;
bool find_diff_element = false;
for (size_t output_index = 0; output_index < out_num; ++output_index) {
auto cur_element = abstract_tuple->elements()[output_index];
MS_EXCEPTION_IF_NULL(cur_element);
auto infer_type = common::AnfAlgo::GetOutputInferDataType(node, output_index);
auto device_type = AnfAlgo::GetOutputDeviceDataType(node, output_index);
if (infer_type != device_type) {
find_diff_element = true;
}
abstract_list.push_back(TypeIdToType(device_type));
}
if (find_diff_element) {
DynamicShapeDtypeManager::GetInstance().Register(node, abstract_list);
}
}
}
return true;
}
} // namespace opt::dynamic_shape
} // namespace mindspore

View File

@ -0,0 +1,49 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_DYNAMIC_SHAPE_DTYPE_RECORD_H_
#define MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_DYNAMIC_SHAPE_DTYPE_RECORD_H_
#include <map>
#include <vector>
#include "ir/anf.h"
#include "backend/common/optimizer/optimizer.h"
namespace mindspore::opt::dynamic_shape {
class DynamicShapeDtypeManager {
public:
static DynamicShapeDtypeManager &GetInstance();
void Register(const AnfNodePtr &node, const std::vector<TypePtr> &device_abs);
bool CheckDeviceType(const AnfNodePtr &node) const;
TypePtrList GetDeviceType(const AnfNodePtr &node);
private:
DynamicShapeDtypeManager() = default;
~DynamicShapeDtypeManager() = default;
DISABLE_COPY_AND_ASSIGN(DynamicShapeDtypeManager);
std::map<AnfNodePtr, TypePtrList> device_type_recorder_;
};
// If the data type of abstract is not same with the one of device, it will replace with device data type.
class DynamicShapeDtypeRecord : public Pass {
public:
DynamicShapeDtypeRecord() : Pass("dynamic_shape_dtype_record") {}
~DynamicShapeDtypeRecord() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
} // namespace mindspore::opt::dynamic_shape
#endif // MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_DYNAMIC_SHAPE_DTYPE_RECORD_H_

View File

@ -24,6 +24,7 @@
#include <map>
#include <utility>
#include "backend/common/session/anf_runtime_algorithm.h"
#include "backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.h"
#include "runtime/device/ms_device_shape_transfer.h"
#include "include/common/utils/anfalgo.h"
#include "include/common/utils/utils.h"
@ -154,6 +155,34 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i, false);
auto real_input = input_node_with_index.first;
auto real_input_index = input_node_with_index.second;
bool abstract_in_cache = DynamicShapeDtypeManager::GetInstance().CheckDeviceType(real_input);
AbstractBasePtr cached_abstract;
AbstractBasePtr real_input_abs = real_input->abstract();
if (abstract_in_cache) {
auto cached_type_list = DynamicShapeDtypeManager::GetInstance().GetDeviceType(real_input);
if (real_input_abs->isa<abstract::AbstractTensor>()) {
auto shape_ptr = real_input_abs->BuildShape();
cached_abstract = std::make_shared<abstract::AbstractTensor>(cached_type_list[0], shape_ptr);
} else if (real_input_abs->isa<abstract::AbstractTuple>()) {
auto abstract_tuple = real_input_abs->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(abstract_tuple);
AbstractBasePtrList abstract_list;
for (size_t output_index = 0; output_index < cached_type_list.size(); ++output_index) {
auto cur_element = abstract_tuple->elements()[output_index];
MS_EXCEPTION_IF_NULL(cur_element);
auto shape_ptr = cur_element->BuildShape();
auto new_abstract = std::make_shared<abstract::AbstractTensor>(cached_type_list[output_index], shape_ptr);
abstract_list.push_back(new_abstract);
}
cached_abstract = std::make_shared<abstract::AbstractTuple>(abstract_list);
} else {
MS_LOG(EXCEPTION) << "Output of " << real_input->fullname_with_scope()
<< " is neither a Tensor nor a Tuple of Tensor, but " << real_input_abs->ToString();
}
}
MS_EXCEPTION_IF_NULL(real_input);
if (skip_nop_node) {
InferShapeForNopNode(real_input);
@ -167,7 +196,12 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
// cppcheck-suppress unreadVariable
auto lock = AnfUtils::GetAbstractLock(real_input.get());
auto real_abs = real_input->abstract();
AbstractBasePtr real_abs;
if (abstract_in_cache) {
real_abs = cached_abstract;
} else {
real_abs = real_input->abstract();
}
if (real_abs->isa<abstract::AbstractTensor>()) {
real_abs->set_value(out_tensor);
} else if (real_abs->isa<abstract::AbstractTuple>()) {
@ -178,7 +212,19 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
tuple_elements->set_value(out_tensor);
}
}
common::AnfAlgo::AddArgList(&args_spec_list, real_input, real_input_index);
if (abstract_in_cache) {
if (cached_abstract->isa<abstract::AbstractTuple>()) {
auto abs_tuple = cached_abstract->Clone()->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(abs_tuple);
MS_EXCEPTION_IF_CHECK_FAIL((real_input_index < abs_tuple->elements().size()), "Index is out of range.");
auto abs_index = abs_tuple->elements()[real_input_index];
(void)args_spec_list.emplace_back(abs_index);
} else {
(void)args_spec_list.emplace_back(cached_abstract->Clone());
}
} else {
common::AnfAlgo::AddArgList(&args_spec_list, real_input, real_input_index);
}
}
// Pynative mode is rely on the origin abstract of cnode, so cannot modify the abstract inplace, clone from old