Decouple GraphKernelCluster from ME backend

Changed the callback function GetProcessorFromContext to GetTargetFromContext,
so that we can use it to filter the clusterable op list, added a GetProcessorByTarget into AkgKernelJsonGenerator.

Moved the function IsKeepBasicNode, GetValidOps, OpListFilter from graph_kernel_helper
to graph_kernel_utils. combined the GetValidOps and OpListFilter.

Decoupled the pass getitem_tuple from "optimizer/common/helper.h", by deleting the checking
of input size. cnode->input(i) also checks the input index.
This commit is contained in:
dayschan 2021-11-23 16:26:37 +08:00
parent 511441a27e
commit 2038295a25
15 changed files with 146 additions and 112 deletions

View File

@ -24,6 +24,7 @@
#include <algorithm>
#include "ir/func_graph.h"
#include "utils/anf_utils.h"
#include "utils/ms_context.h"
#include "backend/kernel_compiler/oplib/oplib.h"
namespace mindspore::graphkernel {
@ -581,6 +582,17 @@ OpInfoPtr AkgKernelJsonGenerator::ExtractOpInfo(const AnfNodePtr &anf_node) cons
}
}
std::string AkgKernelJsonGenerator::GetProcessorByTarget() const {
auto target = cb_->GetTargetFromContext();
if (target == kGPUDevice) {
return "cuda";
}
if (target == kAscendDevice) {
return "aicore";
}
return "cpu";
}
bool AkgKernelJsonGenerator::GenerateSingleKernelJson(const AnfNodePtr &anf_node, nlohmann::json *node_json) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(node_json);
@ -683,7 +695,7 @@ bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::j
(*kernel_json)[kJsonKeyId] = 0; // unused key
(*kernel_json)[kJsonKeyOp] = kernel_name_;
(*kernel_json)[kJsonKeyPlatform] = "AKG";
(*kernel_json)[kJsonKeyProcess] = this->cb_->GetProcessorFromContext();
(*kernel_json)[kJsonKeyProcess] = GetProcessorByTarget();
(*kernel_json)[kJsonKeyComposite] = false;
if (dump_option_.get_compute_capability) {
(*kernel_json)[kJsonKeyComputeCapability] = ComputeCapability::Get();
@ -782,7 +794,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
(*kernel_json)[kJsonKeyId] = 0; // unused key
(*kernel_json)[kJsonKeyOp] = kernel_name_;
(*kernel_json)[kJsonKeyPlatform] = "AKG";
(*kernel_json)[kJsonKeyProcess] = this->cb_->GetProcessorFromContext();
(*kernel_json)[kJsonKeyProcess] = GetProcessorByTarget();
(*kernel_json)[kJsonKeyComposite] = true;
(*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString();
if (dump_option_.get_compute_capability) {

View File

@ -141,6 +141,7 @@ class AkgKernelJsonGenerator {
const std::map<AnfNodePtr, nlohmann::json> &node_json_map, nlohmann::json *kernel_json);
bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *node_json) const;
size_t GetTensorSize(const nlohmann::json &node_json) const;
std::string GetProcessorByTarget() const;
DumpOption dump_option_;
std::string kernel_name_;

View File

@ -20,6 +20,7 @@
#include <vector>
#include <utility>
#include <memory>
#include "utils/ms_context.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/common_utils.h"
@ -77,7 +78,11 @@ std::string CallbackImpl::GetOutputFormat(const AnfNodePtr &node, size_t i) {
std::string CallbackImpl::GetProcessor(const AnfNodePtr &node) { return kernel::GetProcessorStr(node); }
std::string CallbackImpl::GetProcessorFromContext() { return kernel::GetStrProcessorFromContext(); }
std::string CallbackImpl::GetTargetFromContext() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
return context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
}
void CallbackImpl::SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) {
std::vector<std::string> graph_input_format;

View File

@ -33,7 +33,7 @@ class CallbackImpl : public Callback {
std::string GetInputFormat(const AnfNodePtr &node, size_t i) override;
std::string GetOutputFormat(const AnfNodePtr &node, size_t i) override;
std::string GetProcessor(const AnfNodePtr &node) override;
std::string GetProcessorFromContext() override;
std::string GetTargetFromContext() override;
void SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) override;
};
} // namespace mindspore::graphkernel

View File

@ -116,9 +116,9 @@ class Callback {
virtual std::string GetProcessor(const AnfNodePtr &node) = 0;
/**
* @brief Get the backend processor from context.
* @brief Get the backend target from context.
*/
virtual std::string GetProcessorFromContext() = 0;
virtual std::string GetTargetFromContext() = 0;
/**
* @brief Set KernelInfo for a GraphKernel node, the info is extract from its inputs/outputs.

View File

@ -14,9 +14,13 @@
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h"
#include <algorithm>
#include <memory>
#include <sstream>
#include "base/core_ops.h"
#include "utils/anf_utils.h"
#include "utils/utils.h"
#include "backend/optimizer/graph_kernel/core/graph_kernel_callback.h"
namespace mindspore::graphkernel {
std::string GkUtils::ExtractGraphKernelName(const AnfNodePtrList &nodes, const std::string &prefix,
@ -53,4 +57,64 @@ AnfNodePtrList GkUtils::SpreadTuples(const AnfNodePtrList &nodes, size_t begin_i
}
return result;
}
std::vector<PrimitivePtr> GkUtils::GetValidOps(const std::vector<OpWithLevel> &ops_with_level, unsigned int level,
const std::vector<std::string> &enable_ops_only,
const std::vector<std::string> &enable_ops,
const std::vector<std::string> &disable_ops) {
std::vector<PrimitivePtr> ops;
auto new_prim = [](const std::string &name) { return std::make_shared<Primitive>(name); };
if (!enable_ops_only.empty()) {
(void)std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::back_inserter(ops), new_prim);
return ops;
}
auto target = Callback::Instance()->GetTargetFromContext();
for (const auto &[op_target, op_level, op] : ops_with_level) {
if (op_target == kAllTarget || op_target == target) {
if (level >= op_level) {
(void)ops.emplace_back(op);
}
}
}
if (!enable_ops.empty()) {
(void)std::transform(enable_ops.begin(), enable_ops.end(), std::back_inserter(ops), new_prim);
}
if (!disable_ops.empty()) {
auto iter = std::remove_if(ops.begin(), ops.end(), [&disable_ops](const PrimitivePtr &p) {
return std::find(disable_ops.begin(), disable_ops.end(), p->name()) != disable_ops.end();
});
(void)ops.erase(iter, ops.end());
}
return ops;
}
bool GkUtils::IsKeepBasicNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto get_bool_attr = [](const PrimitivePtr &primitive, const std::string &attr_name) -> bool {
if (!primitive->HasAttr(attr_name)) {
return false;
}
return GetValue<bool>(primitive->GetAttr(attr_name));
};
auto prim = GetCNodePrimitive(node);
if (prim == nullptr) return false;
// Dynamic shape is unsupported yet.
if (get_bool_attr(prim, kAttrInputIsDynamicShape) || get_bool_attr(prim, kAttrOutputIsDynamicShape) ||
get_bool_attr(prim, kAttrIsDynamicShape)) {
return true;
}
if (get_bool_attr(prim, "skip")) {
return true;
}
// If node contain attribute in contagious_attrs, it have to keep basic no matter what the value is.
const std::vector<std::string> contagious_attrs = {"inplace_group", "inplace_algo", "inplace_output_index",
"aggregate", "aggregate_input_indexx"};
if (std::any_of(contagious_attrs.cbegin(), contagious_attrs.cend(),
[&prim](const std::string &attr_name) -> bool { return prim->HasAttr(attr_name); })) {
return true;
}
return false;
}
} // namespace mindspore::graphkernel

View File

@ -18,10 +18,17 @@
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_
#include <string>
#include <tuple>
#include <vector>
#include "ir/anf.h"
#include "ir/func_graph.h"
namespace mindspore::graphkernel {
constexpr auto kGraphKernelDumpPath = "graph_kernel_dump";
constexpr auto kAllTarget = "ALL";
using OpWithLevel = std::tuple<std::string, unsigned int, PrimitivePtr>;
class GkUtils {
public:
/**
@ -47,6 +54,25 @@ class GkUtils {
* @return std::vector<AnfNodePtr>
*/
static AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index = 0);
/**
* @brief Filter operators by target, op level, and enable/disable flags.
* @param[in] ops_with_level the default operator list
* @param[in] level enabled op level
* @param[in] enable_ops_only the "enable_xxx_ops_only" flag
* @param[in] enable_ops the "enable_xxx_ops" flag
* @param[in] disable_ops the "disable_xxx_ops" flag
* @return Available primitive list
*/
static std::vector<PrimitivePtr> GetValidOps(const std::vector<OpWithLevel> &ops_with_level, unsigned int level,
const std::vector<std::string> &enable_ops_only,
const std::vector<std::string> &enable_ops,
const std::vector<std::string> &disable_ops);
/**
* @brief Check whether graphkernel supports the node
*/
static bool IsKeepBasicNode(const AnfNodePtr &node);
};
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_

View File

@ -16,28 +16,27 @@
#include "backend/optimizer/graph_kernel/graph_kernel_cluster.h"
#include <algorithm>
#include <map>
#include <unordered_map>
#include <set>
#include <vector>
#include <tuple>
#include <memory>
#include <utility>
#include <fstream>
#include <string>
#include "base/core_ops.h"
#include "ir/graph_utils.h"
#include "utils/anf_utils.h"
#include "utils/ms_context.h"
#include "utils/file_utils.h"
#include "utils/context/graph_kernel_flags.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/pass/getitem_tuple.h"
#include "backend/optimizer/graph_kernel/core/graph_kernel_callback.h"
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h"
#include "backend/optimizer/graph_kernel/core/graph_builder.h"
namespace mindspore::graphkernel {
using opt::GetitemTuple;
std::vector<PrimitivePtr> GraphKernelCluster::GetClusterableOpList() {
std::vector<std::tuple<std::string, unsigned int, PrimitivePtr>> clusterable_ops_with_level = {
std::vector<OpWithLevel> clusterable_ops_with_level = {
// all target
{kAllTarget, OpLevel_0, prim::kPrimAbs},
{kAllTarget, OpLevel_0, prim::kPrimAdd},
@ -100,19 +99,17 @@ std::vector<PrimitivePtr> GraphKernelCluster::GetClusterableOpList() {
{kGPUDevice, OpLevel_0, prim::kPrimSign},
{kGPUDevice, OpLevel_0, prim::kPrimSin},
{kGPUDevice, OpLevel_0, prim::kPrimStridedSlice},
{kGPUDevice, OpLevel_0, prim::kPrimUserDefined},
};
const auto &flags = GraphKernelFlags::GetInstance();
std::vector<PrimitivePtr> clusterable_ops = GetValidOps(clusterable_ops_with_level, flags.fusion_ops_level);
OpListFilter(&clusterable_ops, flags.enable_cluster_ops_only, flags.enable_cluster_ops, flags.disable_cluster_ops);
return clusterable_ops;
return GkUtils::GetValidOps(clusterable_ops_with_level, flags.fusion_ops_level, flags.enable_cluster_ops_only,
flags.enable_cluster_ops, flags.disable_cluster_ops);
}
bool GraphKernelCluster::IsClusterableOp(const AnfNodePtr &node) {
if (AnfAlgo::IsGraphKernel(node)) {
if (AnfUtils::IsGraphKernel(node)) {
return true;
}
if (IsKeepBasicNode(node)) {
if (GkUtils::IsKeepBasicNode(node)) {
return false;
}
bool node_in_oplist = std::any_of(op_list_.begin(), op_list_.end(),
@ -120,12 +117,15 @@ bool GraphKernelCluster::IsClusterableOp(const AnfNodePtr &node) {
if (!node_in_oplist) {
return false;
}
#if ENABLE_D
// For AICPU operators, only the Reshape can be clustered.
if (AnfAlgo::GetProcessor(node) != kernel::Processor::AICORE && !IsPrimitiveCNode(node, prim::kPrimReshape)) {
return false;
auto cb = Callback::Instance();
MS_EXCEPTION_IF_NULL(cb);
if (cb->GetTargetFromContext() == kAscendDevice) {
if (cb->GetProcessor(node) != "aicore" && !IsPrimitiveCNode(node, prim::kPrimReshape)) {
return false;
}
}
#endif
return true;
}
@ -393,7 +393,7 @@ bool GraphKernelCluster::Process(const FuncGraphPtr &func_graph) {
// Do not cluster a single GraphKernel again.
// Do not cluster a single Assign.
const auto &node = nodes_[clusters[i][0]];
if (AnfAlgo::IsGraphKernel(node) || IsPrimitiveCNode(node, prim::kPrimAssign) || !IsClusterableOp(node)) {
if (AnfUtils::IsGraphKernel(node) || IsPrimitiveCNode(node, prim::kPrimAssign) || !IsClusterableOp(node)) {
continue;
}
}
@ -408,8 +408,8 @@ void GraphKernelCluster::CreateFuncGraph(const FuncGraphPtr &func_graph, const s
(void)std::transform(nodes_id.begin(), nodes_id.end(), std::back_inserter(old_nodes),
[this](size_t id) { return this->nodes_[id]; });
auto new_node = ReplaceNodesWithGraphKernelNode(old_nodes, func_graph, "fusion");
std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>();
(void)eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(new_node));
std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<GetitemTuple>();
(void)eliminate_getitem_pass->Run(GetCNodeFuncGraph(new_node));
if (GraphKernelFlags::GetInstance().dump_as_text) {
DumpClusterInfo(old_nodes, new_node);
}

View File

@ -17,19 +17,13 @@
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CLUSTER_H_
#include <vector>
#include <string>
#include <unordered_map>
#include <set>
#include <memory>
#include "ir/anf.h"
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace prim {
inline const PrimitivePtr kPrimUserDefined = std::make_shared<Primitive>("UserDefined");
}
namespace graphkernel {
class Graph;
using GraphPtr = std::shared_ptr<Graph>;

View File

@ -29,6 +29,7 @@
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/kernel_build_info.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h"
#include "backend/optimizer/graph_kernel/split_umonad.h"
#include "backend/optimizer/graph_kernel/substitute_dropout.h"
#include "backend/session/anf_runtime_algorithm.h"
@ -47,7 +48,7 @@ constexpr size_t kLambWeightInputIdx = 4;
constexpr size_t kRandomInputIdx = 1;
std::vector<PrimitivePtr> GetExpandOps() {
std::vector<std::tuple<std::string, unsigned int, PrimitivePtr>> expand_ops_with_level = {
std::vector<OpWithLevel> expand_ops_with_level = {
{kAllTarget, OpLevel_0, prim::kPrimAddN},
{kAllTarget, OpLevel_0, prim::kPrimAssignAdd},
{kAllTarget, OpLevel_0, prim::kPrimErfc},
@ -95,9 +96,8 @@ std::vector<PrimitivePtr> GetExpandOps() {
{kGPUDevice, OpLevel_0, prim::kPrimStandardNormal},
};
const auto &flags = GraphKernelFlags::GetInstance();
std::vector<PrimitivePtr> expand_ops = GetValidOps(expand_ops_with_level, flags.fusion_ops_level);
OpListFilter(&expand_ops, flags.enable_expand_ops_only, flags.enable_expand_ops, flags.disable_expand_ops);
return expand_ops;
return GkUtils::GetValidOps(expand_ops_with_level, flags.fusion_ops_level, flags.enable_expand_ops_only,
flags.enable_expand_ops, flags.disable_expand_ops);
}
} // namespace
@ -217,8 +217,8 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(mng);
for (const auto &n : todos) {
auto node = n->cast<CNodePtr>();
if (node == nullptr || AnfAlgo::IsGraphKernel(node) || IsKeepBasicNode(node) || !AnfUtils::IsRealKernel(node) ||
!CanExpand(node)) {
if (node == nullptr || AnfAlgo::IsGraphKernel(node) || GkUtils::IsKeepBasicNode(node) ||
!AnfUtils::IsRealKernel(node) || !CanExpand(node)) {
continue;
}

View File

@ -444,51 +444,6 @@ void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfN
AnfAlgo::SetNodeAttr(key, value, node);
}
bool IsKeepBasicNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// Dynamic shape is unsupported yet.
if (AnfAlgo::HasDynamicShapeFlag(AnfAlgo::GetCNodePrimitive(cnode))) {
return true;
}
static const std::vector<std::string> contagious_attrs = {"inplace_group", "inplace_algo", "inplace_output_index",
"aggregate", "aggregate_input_indexx"};
// If node contain attribute in contagious_attrs, it have to keep basic no matter what the value is.
if (std::any_of(contagious_attrs.cbegin(), contagious_attrs.cend(),
[&cnode](const std::string &attr_name) -> bool { return AnfAlgo::HasNodeAttr(attr_name, cnode); })) {
return true;
}
if (AnfAlgo::GetBooleanAttr(cnode, "skip")) {
return true;
}
return false;
}
void OpListFilter(std::vector<PrimitivePtr> *ops, const std::vector<std::string> &enable_ops_only,
const std::vector<std::string> &enable_ops, const std::vector<std::string> &disable_ops) {
auto new_prim = [](const std::string &name) { return std::make_shared<Primitive>(name); };
if (!enable_ops_only.empty()) {
ops->clear();
(void)std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::back_inserter(*ops), new_prim);
} else {
if (!enable_ops.empty()) {
(void)std::transform(enable_ops.begin(), enable_ops.end(), std::back_inserter(*ops), new_prim);
}
if (!disable_ops.empty()) {
auto iter = std::remove_if(ops->begin(), ops->end(), [&disable_ops](const PrimitivePtr &p) {
return std::find(disable_ops.begin(), disable_ops.end(), p->name()) != disable_ops.end();
});
(void)ops->erase(iter, ops->end());
}
}
}
inner::LiteGraphPtr AnfGraph2LiteGraph(const FuncGraphPtr &func_graph) {
inner::LiteGraph::GraphBuilder gb(GetValue<std::string>(func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)));
std::map<AnfNodePtr, inner::NodePtr> node_map;
@ -632,22 +587,6 @@ void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList
*inputs = std::move(new_inputs);
}
std::vector<PrimitivePtr> GetValidOps(
const std::vector<std::tuple<std::string, unsigned int, PrimitivePtr>> &ops_with_level, unsigned int level) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
std::vector<PrimitivePtr> valid_ops;
for (const auto &[op_target, op_level, op] : ops_with_level) {
if (op_target == kAllTarget || op_target == target) {
if (level >= op_level) {
(void)valid_ops.emplace_back(op);
}
}
}
return valid_ops;
}
FuncGraphManagerPtr GetFuncGraphManager(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
FuncGraphManagerPtr manager = func_graph->manager();

View File

@ -44,9 +44,7 @@ constexpr auto kGetGraphKernelOpExpander = "get_op_expander";
constexpr auto kJsonKeyMultiGraph = "multi_graph";
constexpr auto kJsonKeyGraphDesc = "graph_desc";
constexpr auto kJsonKeyGraphMode = "graph_mode";
constexpr auto kAllTarget = "ALL";
constexpr auto kGraphKernelDumpPath = "graph_kernel_dump";
struct DataInfo {
std::string format{kOpFormat_DEFAULT};
ShapeVector shape{1};
@ -79,9 +77,6 @@ std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node);
CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info,
bool use_fake_abstract = false);
void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
bool IsKeepBasicNode(const AnfNodePtr &node);
void OpListFilter(std::vector<PrimitivePtr> *ops, const std::vector<std::string> &enable_ops_only,
const std::vector<std::string> &enable_ops, const std::vector<std::string> &disable_ops);
template <typename T>
ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) {
// Create tensor value.
@ -127,9 +122,6 @@ FuncGraphPtr LiteGraph2AnfGraph(const inner::LiteGraphPtr &lite_graph, AnfNodePt
// remove parameter which is not used
void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs);
std::vector<PrimitivePtr> GetValidOps(
const std::vector<std::tuple<std::string, unsigned int, PrimitivePtr>> &ops_with_level, unsigned int level);
// return a func_graph's manager
FuncGraphManagerPtr GetFuncGraphManager(const FuncGraphPtr &func_graph);

View File

@ -119,5 +119,7 @@ std::string CallbackImpl::GetOutputFormat(const AnfNodePtr &node, size_t i) { re
std::string CallbackImpl::GetProcessor(const AnfNodePtr &node) { return "cpu"; }
std::string CallbackImpl::GetProcessorFromContext() { return "cpu"; }
std::string CallbackImpl::GetTargetFromContext() { return "CPU"; }
void CallbackImpl::SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) {}
} // namespace mindspore::graphkernel

View File

@ -35,7 +35,8 @@ class CallbackImpl : public Callback {
std::string GetInputFormat(const AnfNodePtr &node, size_t i) override;
std::string GetOutputFormat(const AnfNodePtr &node, size_t i) override;
std::string GetProcessor(const AnfNodePtr &node) override;
std::string GetProcessorFromContext() override;
std::string GetTargetFromContext() override;
void SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) override;
};
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_LITE_ADAPTER_CALLBACK_IMPL_H_

View File

@ -18,7 +18,6 @@
#include <memory>
#include "base/core_ops.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
@ -45,7 +44,6 @@ const AnfNodePtr GetitemTuple::Process(const FuncGraphPtr &, const AnfNodePtr &n
MS_EXCEPTION_IF_NULL(node);
auto tuple_getitem = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getitem);
CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputTensorNum);
AnfNodePtr make_tuple_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(make_tuple_anf);
AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem);