!46931 remove cache for dtype in abstract

Merge pull request !46931 from chengbin/infer_shape_only
This commit is contained in:
i-robot 2022-12-28 07:06:56 +00:00 committed by Gitee
commit 6ef1661f88
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 13 additions and 258 deletions

View File

@ -38,7 +38,6 @@
#include "backend/common/optimizer/dynamic_shape/link_custom_op.h" #include "backend/common/optimizer/dynamic_shape/link_custom_op.h"
#include "backend/common/pass/convert_unused_tuple_para_to_make_tuple.h" #include "backend/common/pass/convert_unused_tuple_para_to_make_tuple.h"
#include "backend/common/pass/convert_dynamic_broadcast_to.h" #include "backend/common/pass/convert_dynamic_broadcast_to.h"
#include "backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "include/common/debug/anf_ir_dump.h" #include "include/common/debug/anf_ir_dump.h"
@ -207,7 +206,6 @@ void DynamicShapeConvertPass(const std::shared_ptr<session::KernelGraph> &kernel
auto dynamic_shape_convert_pm = std::make_shared<opt::PassManager>("dynamic_shape_convert_pm"); auto dynamic_shape_convert_pm = std::make_shared<opt::PassManager>("dynamic_shape_convert_pm");
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::ConvertCustomOp>()); dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::ConvertCustomOp>());
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::LinkCustomOp>()); dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::LinkCustomOp>());
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::DynamicShapeDtypeRecord>());
optimizer->AddPassManager(dynamic_shape_convert_pm); optimizer->AddPassManager(dynamic_shape_convert_pm);
(void)optimizer->Optimize(kernel_graph); (void)optimizer->Optimize(kernel_graph);
#ifdef ENABLE_DUMP_IR #ifdef ENABLE_DUMP_IR

View File

@ -1,98 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "backend/common/optimizer/helper.h"
#include "utils/anf_utils.h"
namespace mindspore {
namespace opt::dynamic_shape {
DynamicShapeDtypeManager &DynamicShapeDtypeManager::GetInstance() {
static DynamicShapeDtypeManager instance{};
return instance;
}
void DynamicShapeDtypeManager::Register(const AnfNodePtr &node, const TypePtrList &device_abs) {
if (device_type_recorder_.find(node) == device_type_recorder_.end()) {
(void)device_type_recorder_.emplace(node, device_abs);
}
}
bool DynamicShapeDtypeManager::CheckDeviceType(const AnfNodePtr &node) const {
return (device_type_recorder_.find(node) != device_type_recorder_.end());
}
TypePtrList DynamicShapeDtypeManager::GetDeviceType(const AnfNodePtr &node) {
auto iter = device_type_recorder_.find(node);
if (iter != device_type_recorder_.end()) {
return iter->second;
}
return {};
}
bool DynamicShapeDtypeRecord::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto nodes = TopoSort(func_graph->get_return());
for (const auto &node : nodes) {
CNodePtr cnode = node->cast<CNodePtr>();
if (cnode == nullptr || !AnfUtils::IsRealKernel(cnode)) {
continue;
}
auto kernel_info = node->kernel_info();
if (kernel_info == nullptr || !kernel_info->has_build_info()) {
continue;
}
auto out_num = AnfAlgo::GetOutputTensorNum(node);
auto node_abs = node->abstract();
if (node_abs->isa<abstract::AbstractTensor>()) {
if (out_num != 1) {
continue;
}
auto infer_type = common::AnfAlgo::GetOutputInferDataType(node, 0);
auto device_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
if (infer_type != device_type) {
TypePtrList new_abstract = {TypeIdToType(device_type)};
DynamicShapeDtypeManager::GetInstance().Register(node, new_abstract);
}
} else if (node_abs->isa<abstract::AbstractTuple>()) {
auto abstract_tuple = node_abs->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(abstract_tuple);
TypePtrList abstract_list;
bool find_diff_element = false;
for (size_t output_index = 0; output_index < out_num; ++output_index) {
auto cur_element = abstract_tuple->elements()[output_index];
MS_EXCEPTION_IF_NULL(cur_element);
auto infer_type = common::AnfAlgo::GetOutputInferDataType(node, output_index);
auto device_type = AnfAlgo::GetOutputDeviceDataType(node, output_index);
if (infer_type != device_type) {
find_diff_element = true;
}
abstract_list.push_back(TypeIdToType(device_type));
}
if (find_diff_element) {
DynamicShapeDtypeManager::GetInstance().Register(node, abstract_list);
}
}
}
return true;
}
} // namespace opt::dynamic_shape
} // namespace mindspore

View File

@ -1,49 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_DYNAMIC_SHAPE_DTYPE_RECORD_H_
#define MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_DYNAMIC_SHAPE_DTYPE_RECORD_H_
#include <map>
#include <vector>
#include "ir/anf.h"
#include "backend/common/optimizer/optimizer.h"
namespace mindspore::opt::dynamic_shape {
class DynamicShapeDtypeManager {
public:
static DynamicShapeDtypeManager &GetInstance();
void Register(const AnfNodePtr &node, const std::vector<TypePtr> &device_abs);
bool CheckDeviceType(const AnfNodePtr &node) const;
TypePtrList GetDeviceType(const AnfNodePtr &node);
private:
DynamicShapeDtypeManager() = default;
~DynamicShapeDtypeManager() = default;
DISABLE_COPY_AND_ASSIGN(DynamicShapeDtypeManager);
std::map<AnfNodePtr, TypePtrList> device_type_recorder_;
};
// If the data type of abstract is not same with the one of device, it will replace with device data type.
class DynamicShapeDtypeRecord : public Pass {
public:
DynamicShapeDtypeRecord() : Pass("dynamic_shape_dtype_record") {}
~DynamicShapeDtypeRecord() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
} // namespace mindspore::opt::dynamic_shape
#endif // MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_DYNAMIC_SHAPE_DTYPE_RECORD_H_

View File

@ -24,7 +24,6 @@
#include <map> #include <map>
#include <utility> #include <utility>
#include "backend/common/session/anf_runtime_algorithm.h" #include "backend/common/session/anf_runtime_algorithm.h"
#include "backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.h"
#include "runtime/device/ms_device_shape_transfer.h" #include "runtime/device/ms_device_shape_transfer.h"
#include "include/common/utils/anfalgo.h" #include "include/common/utils/anfalgo.h"
#include "include/common/utils/utils.h" #include "include/common/utils/utils.h"
@ -98,18 +97,17 @@ bool InferShapeForDefiniteOutputNode(const CNodePtr &cnode) {
tensor::TensorPtr GetDependValueTensor(const AnfNodePtr &node, size_t i, tensor::TensorPtr GetDependValueTensor(const AnfNodePtr &node, size_t i,
const std::pair<AnfNodePtr, size_t> &input_node_with_index, bool skip_nop_node, const std::pair<AnfNodePtr, size_t> &input_node_with_index, bool skip_nop_node,
void *args, bool abstract_in_cache) { void *args) {
auto real_input = input_node_with_index.first; auto real_input = input_node_with_index.first;
MS_EXCEPTION_IF_NULL(real_input); MS_EXCEPTION_IF_NULL(real_input);
auto real_input_index = input_node_with_index.second; auto real_input_index = input_node_with_index.second;
auto shapes = trans::GetRuntimePaddingShape(real_input, real_input_index); auto shapes = trans::GetRuntimePaddingShape(real_input, real_input_index);
TypeId host_type; TypeId host_type;
if (abstract_in_cache) { if (real_input->isa<ValueNode>()) {
// for cnode in the cache, we use device type as there is a mismatch // the type of ValueNode in KernelInfo is kTypeUnknown
host_type = AnfAlgo::GetOutputDeviceDataType(real_input, real_input_index);
} else {
// for cnode not in the cache, valuenodes and other nodes, we use inferred type
host_type = common::AnfAlgo::GetOutputInferDataType(real_input, real_input_index); host_type = common::AnfAlgo::GetOutputInferDataType(real_input, real_input_index);
} else {
host_type = AnfAlgo::GetOutputDeviceDataType(real_input, real_input_index);
} }
auto out_tensor = std::make_shared<tensor::Tensor>(host_type, shapes); auto out_tensor = std::make_shared<tensor::Tensor>(host_type, shapes);
@ -167,40 +165,14 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i, false); auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i, false);
auto real_input = input_node_with_index.first; auto real_input = input_node_with_index.first;
auto real_input_index = input_node_with_index.second; auto real_input_index = input_node_with_index.second;
bool abstract_in_cache = DynamicShapeDtypeManager::GetInstance().CheckDeviceType(real_input);
AbstractBasePtr cached_abstract;
AbstractBasePtr real_input_abs = real_input->abstract(); AbstractBasePtr real_input_abs = real_input->abstract();
if (abstract_in_cache) {
auto cached_type_list = DynamicShapeDtypeManager::GetInstance().GetDeviceType(real_input);
if (real_input_abs->isa<abstract::AbstractTensor>()) {
auto shape_ptr = real_input_abs->BuildShape();
cached_abstract = std::make_shared<abstract::AbstractTensor>(cached_type_list[0], shape_ptr);
} else if (real_input_abs->isa<abstract::AbstractTuple>()) {
auto abstract_tuple = real_input_abs->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(abstract_tuple);
AbstractBasePtrList abstract_list;
for (size_t output_index = 0; output_index < cached_type_list.size(); ++output_index) {
auto cur_element = abstract_tuple->elements()[output_index];
MS_EXCEPTION_IF_NULL(cur_element);
auto shape_ptr = cur_element->BuildShape();
auto new_abstract = std::make_shared<abstract::AbstractTensor>(cached_type_list[output_index], shape_ptr);
abstract_list.push_back(new_abstract);
}
cached_abstract = std::make_shared<abstract::AbstractTuple>(abstract_list);
} else {
MS_LOG(EXCEPTION) << "Output of " << real_input->fullname_with_scope()
<< " is neither a Tensor nor a Tuple of Tensor, but " << real_input_abs->ToString();
}
}
MS_EXCEPTION_IF_NULL(real_input); MS_EXCEPTION_IF_NULL(real_input);
if (skip_nop_node) { if (skip_nop_node) {
InferShapeForNopNode(real_input); InferShapeForNopNode(real_input);
} }
if (depend_list.find(i) != depend_list.end()) { if (depend_list.find(i) != depend_list.end()) {
auto out_tensor = GetDependValueTensor(cnode, i, input_node_with_index, skip_nop_node, args, abstract_in_cache); auto out_tensor = GetDependValueTensor(cnode, i, input_node_with_index, skip_nop_node, args);
auto ret2 = depend_tensor_map->try_emplace(i, out_tensor); auto ret2 = depend_tensor_map->try_emplace(i, out_tensor);
if (!ret2.second) { if (!ret2.second) {
MS_LOG(EXCEPTION) << "Insert map failed."; MS_LOG(EXCEPTION) << "Insert map failed.";
@ -208,12 +180,7 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
// cppcheck-suppress unreadVariable // cppcheck-suppress unreadVariable
auto lock = AnfUtils::GetAbstractLock(real_input.get()); auto lock = AnfUtils::GetAbstractLock(real_input.get());
AbstractBasePtr real_abs; auto real_abs = real_input->abstract();
if (abstract_in_cache) {
real_abs = cached_abstract;
} else {
real_abs = real_input->abstract();
}
if (real_abs->isa<abstract::AbstractTensor>()) { if (real_abs->isa<abstract::AbstractTensor>()) {
real_abs->set_value(out_tensor); real_abs->set_value(out_tensor);
} else if (real_abs->isa<abstract::AbstractTuple>()) { } else if (real_abs->isa<abstract::AbstractTuple>()) {
@ -224,19 +191,8 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
tuple_elements->set_value(out_tensor); tuple_elements->set_value(out_tensor);
} }
} }
if (abstract_in_cache) {
if (cached_abstract->isa<abstract::AbstractTuple>()) { common::AnfAlgo::AddArgList(&args_spec_list, real_input, real_input_index);
auto abs_tuple = cached_abstract->Clone()->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(abs_tuple);
MS_EXCEPTION_IF_CHECK_FAIL((real_input_index < abs_tuple->elements().size()), "Index is out of range.");
auto abs_index = abs_tuple->elements()[real_input_index];
(void)args_spec_list.emplace_back(abs_index);
} else {
(void)args_spec_list.emplace_back(cached_abstract->Clone());
}
} else {
common::AnfAlgo::AddArgList(&args_spec_list, real_input, real_input_index);
}
} }
// Pynative mode is rely on the origin abstract of cnode, so cannot modify the abstract inplace, clone from old // Pynative mode is rely on the origin abstract of cnode, so cannot modify the abstract inplace, clone from old

View File

@ -796,58 +796,6 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, Primitive
return CreateCNodeWithGraph(input_nodes, graph); return CreateCNodeWithGraph(input_nodes, graph);
} }
// rectify absttract if the input has been converted to the attr
AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive,
const AbstractBasePtrList &input_abstract) {
MS_EXCEPTION_IF_NULL(primitive);
opt::ConstInputToAttrInfoRegister reg;
if (!opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(primitive->name(), &reg)) {
return input_abstract;
}
if (common::AnfAlgo::HasDynamicShapeFlag(primitive)) {
return input_abstract;
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto device = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (device == kGPUDevice) {
if (IsOneOfDynamicShapeConstInputToAttrGPU(primitive->name())) {
return input_abstract;
}
}
auto convert_input_list = reg.GetConstInputAttrInfo();
auto input_names = primitive->GetAttr(kAttrInputNames);
if (input_names == nullptr) {
return input_abstract;
}
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
AbstractBasePtrList rectify_abs_list;
size_t ori_index = 0;
rectify_abs_list.resize(input_names_vec.size());
for (size_t index = 0; index < rectify_abs_list.size(); ++index) {
// if convert input list find the index it means the input has been converted to the attr
if (convert_input_list.find(index) != convert_input_list.end()) {
AbstractBasePtr rectify_abs = nullptr;
auto input_name = input_names_vec[index];
auto attr = primitive->GetAttr(input_name);
if (attr != nullptr) {
rectify_abs = attr->ToAbstract();
} else {
MS_LOG(DEBUG) << "the node prim name :" << primitive->name() << "input index :" << index
<< " input name :" << input_name << "has not been converted to the attr";
rectify_abs = input_abstract[ori_index++];
}
rectify_abs_list[index] = rectify_abs;
continue;
}
if (ori_index > input_abstract.size()) {
MS_LOG(EXCEPTION) << "Index " << ori_index << " is out of range in input abstract size " << input_abstract.size();
}
rectify_abs_list[index] = input_abstract[ori_index++];
}
return rectify_abs_list;
}
AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &prim, AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &prim,
const AbstractBasePtrList &input_abstract) { const AbstractBasePtrList &input_abstract) {
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
@ -889,8 +837,7 @@ AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &prim,
} }
AbstractBasePtrList RectifyAbstract(const PrimitivePtr &primitive, const AbstractBasePtrList &input_abstract) { AbstractBasePtrList RectifyAbstract(const PrimitivePtr &primitive, const AbstractBasePtrList &input_abstract) {
auto rectify_abs_list = RectifyAbstractFromRegAttr(primitive, input_abstract); return RectifyAbstractFromDynamicInput(primitive, input_abstract);
return RectifyAbstractFromDynamicInput(primitive, rectify_abs_list);
} }
} // namespace } // namespace

View File

@ -45,9 +45,10 @@ inline const PrimitivePtr kPrimAttrConvertTest = std::make_shared<Primitive>(kAt
inline const PrimitivePtr kPrimDynamicInputTest = std::make_shared<Primitive>("dynamic_input_test"); inline const PrimitivePtr kPrimDynamicInputTest = std::make_shared<Primitive>("dynamic_input_test");
AbstractBasePtr InferImplAttrTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplAttrTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
EXPECT_EQ(args_spec_list.size(), 3); // CppInferShapeAndType does not convert attr to input
EXPECT_EQ(args_spec_list.size(), 2);
EXPECT_NE(args_spec_list[1], nullptr); EXPECT_NE(args_spec_list[1], nullptr);
EXPECT_EQ(args_spec_list[1]->isa<abstract::AbstractTuple>(), true); EXPECT_EQ(args_spec_list[1]->isa<abstract::AbstractTensor>(), true);
return args_spec_list[0]; return args_spec_list[0];
} }
REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr, kPrimAttrConvertTest, InferImplAttrTest, nullptr, true); REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr, kPrimAttrConvertTest, InferImplAttrTest, nullptr, true);