forked from mindspore-Ecosystem/mindspore
!48637 fixed init kernel_mod bug
Merge pull request !48637 from huoxinyou/0209kernel_mod
This commit is contained in:
commit
75d84137c3
|
@ -113,6 +113,8 @@ class COMMON_EXPORT AnfAlgo {
|
|||
static size_t GetInputNum(const CNodePtr &cnode);
|
||||
// get the num of inputs exclude monads for real_kernel (which can be build and run in device)
|
||||
static size_t GetInputTensorNum(const AnfNodePtr &node);
|
||||
// get prev node output width output index has tuplegetitem
|
||||
static bool IsPrevNodeHasTupleGetItem(const AnfNodePtr &anf_node, size_t input_idx, bool skip_nop_node = false);
|
||||
// get prev node output width output index
|
||||
static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx, bool skip_nop_node = false);
|
||||
// get all the untuple real prev_nodes output
|
||||
|
|
|
@ -81,9 +81,9 @@ abstract::AbstractBasePtr GetChildAbstract(const abstract::AbstractBasePtr &cur_
|
|||
|
||||
KernelTensorPtr CreateKernelTensor(const abstract::AbstractBasePtr &cur_abstract, const TypeId &real_type, size_t idx,
|
||||
const ShapeVector &device_shape_adaptively, const std::string &format_str,
|
||||
bool is_real_tuple_input = false) {
|
||||
bool is_real_tuple = false) {
|
||||
abstract::AbstractBasePtr tag_abstract = nullptr;
|
||||
if (is_real_tuple_input) {
|
||||
if (is_real_tuple) {
|
||||
tag_abstract = cur_abstract;
|
||||
} else {
|
||||
tag_abstract = GetChildAbstract(cur_abstract, idx);
|
||||
|
@ -163,17 +163,15 @@ inline InOutKernelTensors AbstractInOutFromCNode(const CNodePtr &cnode) {
|
|||
std::vector<KernelTensorPtr> input_tensors;
|
||||
auto real_input_types = AnfAlgo::GetAllInputDeviceTypes(cnode);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
auto build_info = AnfAlgo::GetSelectKernelBuildInfo(cnode);
|
||||
auto input_obj_types = build_info->GetAllInputKernelObjectTypes();
|
||||
for (size_t input_idx = 0; input_idx < input_num; ++input_idx) {
|
||||
bool is_real_tuple_input = CheckRealTupleFromCNode(input_obj_types, input_idx);
|
||||
const auto &[prev_node, output_idx] = common::AnfAlgo::GetPrevNodeOutput(cnode, input_idx);
|
||||
bool prev_node_has_getitem = common::AnfAlgo::IsPrevNodeHasTupleGetItem(cnode, input_idx);
|
||||
auto prev_abstract = prev_node->abstract();
|
||||
auto real_input_type = real_input_types[input_idx];
|
||||
auto device_shape_adaptively = AnfAlgo::GetInputDeviceShapeAdaptively(cnode, input_idx);
|
||||
auto format_str = AnfAlgo::GetInputFormat(cnode, input_idx);
|
||||
auto input_tensor = CreateKernelTensor(prev_abstract, real_input_type, output_idx, device_shape_adaptively,
|
||||
format_str, is_real_tuple_input);
|
||||
format_str, !prev_node_has_getitem);
|
||||
input_tensors.push_back(input_tensor);
|
||||
}
|
||||
|
||||
|
@ -183,6 +181,7 @@ inline InOutKernelTensors AbstractInOutFromCNode(const CNodePtr &cnode) {
|
|||
auto cur_abstract = cnode->abstract();
|
||||
MS_EXCEPTION_IF_NULL(cur_abstract);
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
auto build_info = AnfAlgo::GetSelectKernelBuildInfo(cnode);
|
||||
auto output_obj_types = build_info->GetAllOutputKernelObjectTypes();
|
||||
for (size_t output_idx = 0; output_idx < output_num; ++output_idx) {
|
||||
bool is_real_tuple_output = CheckRealTupleFromCNode(output_obj_types, output_idx);
|
||||
|
|
|
@ -542,6 +542,19 @@ size_t AnfAlgo::GetInputTensorNum(const AnfNodePtr &node) {
|
|||
return AnfUtils::GetInputTensorNum(node);
|
||||
}
|
||||
|
||||
bool AnfAlgo::IsPrevNodeHasTupleGetItem(const AnfNodePtr &anf_node, size_t input_idx, bool skip_nop_node) {
|
||||
if (!anf_node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode." << trace::DumpSourceLines(anf_node);
|
||||
}
|
||||
auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
auto res = VisitKernelWithReturnType(input_node, 0, skip_nop_node, {prim::kPrimTupleGetItem});
|
||||
if (CheckPrimitiveType(res.first, prim::kPrimTupleGetItem)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
KernelWithIndex AnfAlgo::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx, bool skip_nop_node) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
if (!anf_node->isa<CNode>()) {
|
||||
|
|
Loading…
Reference in New Issue