!48637 fixed init kernel_mod bug

Merge pull request !48637 from huoxinyou/0209kernel_mod
This commit is contained in:
i-robot 2023-02-13 06:05:44 +00:00 committed by Gitee
commit 75d84137c3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 20 additions and 6 deletions

View File

@ -113,6 +113,8 @@ class COMMON_EXPORT AnfAlgo {
static size_t GetInputNum(const CNodePtr &cnode);
// get the num of inputs exclude monads for real_kernel (which can be build and run in device)
static size_t GetInputTensorNum(const AnfNodePtr &node);
// get prev node output width output index has tuplegetitem
static bool IsPrevNodeHasTupleGetItem(const AnfNodePtr &anf_node, size_t input_idx, bool skip_nop_node = false);
// get prev node output width output index
static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx, bool skip_nop_node = false);
// get all the untuple real prev_nodes output

View File

@ -81,9 +81,9 @@ abstract::AbstractBasePtr GetChildAbstract(const abstract::AbstractBasePtr &cur_
KernelTensorPtr CreateKernelTensor(const abstract::AbstractBasePtr &cur_abstract, const TypeId &real_type, size_t idx,
const ShapeVector &device_shape_adaptively, const std::string &format_str,
bool is_real_tuple_input = false) {
bool is_real_tuple = false) {
abstract::AbstractBasePtr tag_abstract = nullptr;
if (is_real_tuple_input) {
if (is_real_tuple) {
tag_abstract = cur_abstract;
} else {
tag_abstract = GetChildAbstract(cur_abstract, idx);
@ -163,17 +163,15 @@ inline InOutKernelTensors AbstractInOutFromCNode(const CNodePtr &cnode) {
std::vector<KernelTensorPtr> input_tensors;
auto real_input_types = AnfAlgo::GetAllInputDeviceTypes(cnode);
size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
auto build_info = AnfAlgo::GetSelectKernelBuildInfo(cnode);
auto input_obj_types = build_info->GetAllInputKernelObjectTypes();
for (size_t input_idx = 0; input_idx < input_num; ++input_idx) {
bool is_real_tuple_input = CheckRealTupleFromCNode(input_obj_types, input_idx);
const auto &[prev_node, output_idx] = common::AnfAlgo::GetPrevNodeOutput(cnode, input_idx);
bool prev_node_has_getitem = common::AnfAlgo::IsPrevNodeHasTupleGetItem(cnode, input_idx);
auto prev_abstract = prev_node->abstract();
auto real_input_type = real_input_types[input_idx];
auto device_shape_adaptively = AnfAlgo::GetInputDeviceShapeAdaptively(cnode, input_idx);
auto format_str = AnfAlgo::GetInputFormat(cnode, input_idx);
auto input_tensor = CreateKernelTensor(prev_abstract, real_input_type, output_idx, device_shape_adaptively,
format_str, is_real_tuple_input);
format_str, !prev_node_has_getitem);
input_tensors.push_back(input_tensor);
}
@ -183,6 +181,7 @@ inline InOutKernelTensors AbstractInOutFromCNode(const CNodePtr &cnode) {
auto cur_abstract = cnode->abstract();
MS_EXCEPTION_IF_NULL(cur_abstract);
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
auto build_info = AnfAlgo::GetSelectKernelBuildInfo(cnode);
auto output_obj_types = build_info->GetAllOutputKernelObjectTypes();
for (size_t output_idx = 0; output_idx < output_num; ++output_idx) {
bool is_real_tuple_output = CheckRealTupleFromCNode(output_obj_types, output_idx);

View File

@ -542,6 +542,19 @@ size_t AnfAlgo::GetInputTensorNum(const AnfNodePtr &node) {
return AnfUtils::GetInputTensorNum(node);
}
bool AnfAlgo::IsPrevNodeHasTupleGetItem(const AnfNodePtr &anf_node, size_t input_idx, bool skip_nop_node) {
if (!anf_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode." << trace::DumpSourceLines(anf_node);
}
auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx);
MS_EXCEPTION_IF_NULL(input_node);
auto res = VisitKernelWithReturnType(input_node, 0, skip_nop_node, {prim::kPrimTupleGetItem});
if (CheckPrimitiveType(res.first, prim::kPrimTupleGetItem)) {
return true;
}
return false;
}
KernelWithIndex AnfAlgo::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx, bool skip_nop_node) {
MS_EXCEPTION_IF_NULL(anf_node);
if (!anf_node->isa<CNode>()) {