diff --git a/mindspore/ccsrc/common/graph_kernel/adapter/callback_impl.h b/mindspore/ccsrc/common/graph_kernel/adapter/callback_impl.h index 9dc2f8ce2ab..b5da34e6486 100644 --- a/mindspore/ccsrc/common/graph_kernel/adapter/callback_impl.h +++ b/mindspore/ccsrc/common/graph_kernel/adapter/callback_impl.h @@ -40,6 +40,8 @@ class COMMON_EXPORT CallbackImpl : public Callback { void SetBasicNodeKernelInfo(const AnfNodePtr &node, const std::vector &outputs_info) override; void SetEmptyKernelInfo(const AnfNodePtr &node) override; void ResetKernelInfo(const AnfNodePtr &node) override; + + private: void CollectInputTypesAndFormats(const AnfNodePtr &node, std::vector *input_types, std::vector *input_formats); }; diff --git a/mindspore/ccsrc/common/graph_kernel/reshape_reduce_for_cse.cc b/mindspore/ccsrc/common/graph_kernel/reshape_reduce_for_cse.cc index 9ddcb8f2fe2..b9d4046a0fd 100644 --- a/mindspore/ccsrc/common/graph_kernel/reshape_reduce_for_cse.cc +++ b/mindspore/ccsrc/common/graph_kernel/reshape_reduce_for_cse.cc @@ -57,6 +57,7 @@ void InsertReshape(const FuncGraphPtr &graph, const AnfNodePtr &node, const Type MS_EXCEPTION_IF_NULL(reshape); common::AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(infer_shape), reshape); common::AnfAlgo::SetOutputInferTypeAndShape({infer_type}, {infer_shape}, reshape.get()); + reshape->set_kernel_info(std::make_shared()); auto graph_sel_info = BuildSelectKernelBuildInfo({kOpFormat_DEFAULT}, {device_type}, {kOpFormat_DEFAULT}, {device_type}); AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, reshape.get());