forked from mindspore-Ecosystem/mindspore
!40456 fix type infer in dynamic shape case
Merge pull request !40456 from zichun_ye/dyn_dtype
This commit is contained in:
commit
d317fdcf78
|
@ -40,6 +40,7 @@
|
|||
#include "backend/common/optimizer/dynamic_shape/link_custom_op.h"
|
||||
#include "backend/common/pass/convert_unused_tuple_para_to_make_tuple.h"
|
||||
#include "backend/common/pass/convert_dynamic_broadcast_to.h"
|
||||
#include "backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
|
||||
|
@ -192,6 +193,7 @@ void DynamicShapeConvertPass(const std::shared_ptr<session::KernelGraph> &kernel
|
|||
auto dynamic_shape_convert_pm = std::make_shared<opt::PassManager>("dynamic_shape_convert_pm");
|
||||
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::ConvertCustomOp>());
|
||||
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::LinkCustomOp>());
|
||||
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::DynamicShapeDtypeRecord>());
|
||||
optimizer->AddPassManager(dynamic_shape_convert_pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.h"
|
||||
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "utils/anf_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt::dynamic_shape {
|
||||
DynamicShapeDtypeManager &DynamicShapeDtypeManager::GetInstance() {
|
||||
static DynamicShapeDtypeManager instance{};
|
||||
return instance;
|
||||
}
|
||||
|
||||
void DynamicShapeDtypeManager::Register(const AnfNodePtr &node, const TypePtrList &device_abs) {
|
||||
if (device_type_recorder_.find(node) == device_type_recorder_.end()) {
|
||||
(void)device_type_recorder_.emplace(node, device_abs);
|
||||
}
|
||||
}
|
||||
|
||||
bool DynamicShapeDtypeManager::CheckDeviceType(const AnfNodePtr &node) const {
|
||||
return (device_type_recorder_.find(node) != device_type_recorder_.end());
|
||||
}
|
||||
|
||||
TypePtrList DynamicShapeDtypeManager::GetDeviceType(const AnfNodePtr &node) {
|
||||
auto iter = device_type_recorder_.find(node);
|
||||
if (iter != device_type_recorder_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
bool DynamicShapeDtypeRecord::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto nodes = TopoSort(func_graph->get_return());
|
||||
// auto manager = DynamicShapeDtypeManager::GetInstance();
|
||||
for (const auto &node : nodes) {
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr || !AnfUtils::IsRealKernel(cnode)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto kernel_info = node->kernel_info();
|
||||
if (kernel_info == nullptr || !kernel_info->has_build_info()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto out_num = common::AnfAlgo::GetOutputTensorNum(node);
|
||||
auto node_abs = node->abstract();
|
||||
if (node_abs->isa<abstract::AbstractTensor>()) {
|
||||
if (out_num != 1) {
|
||||
continue;
|
||||
}
|
||||
auto infer_type = common::AnfAlgo::GetOutputInferDataType(node, 0);
|
||||
auto device_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
|
||||
if (infer_type != device_type) {
|
||||
TypePtrList new_abstract = {TypeIdToType(device_type)};
|
||||
DynamicShapeDtypeManager::GetInstance().Register(node, new_abstract);
|
||||
}
|
||||
} else if (node_abs->isa<abstract::AbstractTuple>()) {
|
||||
auto abstract_tuple = node_abs->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
||||
TypePtrList abstract_list;
|
||||
bool find_diff_element = false;
|
||||
for (size_t output_index = 0; output_index < out_num; ++output_index) {
|
||||
auto cur_element = abstract_tuple->elements()[output_index];
|
||||
MS_EXCEPTION_IF_NULL(cur_element);
|
||||
auto infer_type = common::AnfAlgo::GetOutputInferDataType(node, output_index);
|
||||
auto device_type = AnfAlgo::GetOutputDeviceDataType(node, output_index);
|
||||
if (infer_type != device_type) {
|
||||
find_diff_element = true;
|
||||
}
|
||||
abstract_list.push_back(TypeIdToType(device_type));
|
||||
}
|
||||
if (find_diff_element) {
|
||||
DynamicShapeDtypeManager::GetInstance().Register(node, abstract_list);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace opt::dynamic_shape
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_DYNAMIC_SHAPE_DTYPE_RECORD_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_DYNAMIC_SHAPE_DTYPE_RECORD_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "ir/anf.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore::opt::dynamic_shape {
|
||||
class DynamicShapeDtypeManager {
|
||||
public:
|
||||
static DynamicShapeDtypeManager &GetInstance();
|
||||
void Register(const AnfNodePtr &node, const std::vector<TypePtr> &device_abs);
|
||||
bool CheckDeviceType(const AnfNodePtr &node) const;
|
||||
TypePtrList GetDeviceType(const AnfNodePtr &node);
|
||||
|
||||
private:
|
||||
DynamicShapeDtypeManager() = default;
|
||||
~DynamicShapeDtypeManager() = default;
|
||||
DISABLE_COPY_AND_ASSIGN(DynamicShapeDtypeManager);
|
||||
|
||||
std::map<AnfNodePtr, TypePtrList> device_type_recorder_;
|
||||
};
|
||||
|
||||
// If the data type of abstract is not same with the one of device, it will replace with device data type.
|
||||
class DynamicShapeDtypeRecord : public Pass {
|
||||
public:
|
||||
DynamicShapeDtypeRecord() : Pass("dynamic_shape_dtype_record") {}
|
||||
~DynamicShapeDtypeRecord() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
};
|
||||
} // namespace mindspore::opt::dynamic_shape
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_DYNAMIC_SHAPE_DTYPE_RECORD_H_
|
|
@ -24,6 +24,7 @@
|
|||
#include <map>
|
||||
#include <utility>
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.h"
|
||||
#include "runtime/device/ms_device_shape_transfer.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
@ -154,6 +155,34 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
|
|||
auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i, false);
|
||||
auto real_input = input_node_with_index.first;
|
||||
auto real_input_index = input_node_with_index.second;
|
||||
|
||||
bool abstract_in_cache = DynamicShapeDtypeManager::GetInstance().CheckDeviceType(real_input);
|
||||
AbstractBasePtr cached_abstract;
|
||||
AbstractBasePtr real_input_abs = real_input->abstract();
|
||||
|
||||
if (abstract_in_cache) {
|
||||
auto cached_type_list = DynamicShapeDtypeManager::GetInstance().GetDeviceType(real_input);
|
||||
if (real_input_abs->isa<abstract::AbstractTensor>()) {
|
||||
auto shape_ptr = real_input_abs->BuildShape();
|
||||
cached_abstract = std::make_shared<abstract::AbstractTensor>(cached_type_list[0], shape_ptr);
|
||||
} else if (real_input_abs->isa<abstract::AbstractTuple>()) {
|
||||
auto abstract_tuple = real_input_abs->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
||||
AbstractBasePtrList abstract_list;
|
||||
|
||||
for (size_t output_index = 0; output_index < cached_type_list.size(); ++output_index) {
|
||||
auto cur_element = abstract_tuple->elements()[output_index];
|
||||
MS_EXCEPTION_IF_NULL(cur_element);
|
||||
auto shape_ptr = cur_element->BuildShape();
|
||||
auto new_abstract = std::make_shared<abstract::AbstractTensor>(cached_type_list[output_index], shape_ptr);
|
||||
abstract_list.push_back(new_abstract);
|
||||
}
|
||||
cached_abstract = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Output of " << real_input->fullname_with_scope()
|
||||
<< " is neither a Tensor nor a Tuple of Tensor, but " << real_input_abs->ToString();
|
||||
}
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
if (skip_nop_node) {
|
||||
InferShapeForNopNode(real_input);
|
||||
|
@ -167,7 +196,12 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
|
|||
|
||||
// cppcheck-suppress unreadVariable
|
||||
auto lock = AnfUtils::GetAbstractLock(real_input.get());
|
||||
auto real_abs = real_input->abstract();
|
||||
AbstractBasePtr real_abs;
|
||||
if (abstract_in_cache) {
|
||||
real_abs = cached_abstract;
|
||||
} else {
|
||||
real_abs = real_input->abstract();
|
||||
}
|
||||
if (real_abs->isa<abstract::AbstractTensor>()) {
|
||||
real_abs->set_value(out_tensor);
|
||||
} else if (real_abs->isa<abstract::AbstractTuple>()) {
|
||||
|
@ -178,7 +212,19 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
|
|||
tuple_elements->set_value(out_tensor);
|
||||
}
|
||||
}
|
||||
common::AnfAlgo::AddArgList(&args_spec_list, real_input, real_input_index);
|
||||
if (abstract_in_cache) {
|
||||
if (cached_abstract->isa<abstract::AbstractTuple>()) {
|
||||
auto abs_tuple = cached_abstract->Clone()->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(abs_tuple);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL((real_input_index < abs_tuple->elements().size()), "Index is out of range.");
|
||||
auto abs_index = abs_tuple->elements()[real_input_index];
|
||||
(void)args_spec_list.emplace_back(abs_index);
|
||||
} else {
|
||||
(void)args_spec_list.emplace_back(cached_abstract->Clone());
|
||||
}
|
||||
} else {
|
||||
common::AnfAlgo::AddArgList(&args_spec_list, real_input, real_input_index);
|
||||
}
|
||||
}
|
||||
|
||||
// Pynative mode is rely on the origin abstract of cnode, so cannot modify the abstract inplace, clone from old
|
||||
|
|
Loading…
Reference in New Issue