forked from mindspore-Ecosystem/mindspore
!26680 Decouple GraphKernelCluster from ME backend
Merge pull request !26680 from DeshiChen/1122_cluster
This commit is contained in:
commit
3d0b785241
|
@ -24,6 +24,7 @@
|
|||
#include <algorithm>
|
||||
#include "ir/func_graph.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
|
@ -581,6 +582,17 @@ OpInfoPtr AkgKernelJsonGenerator::ExtractOpInfo(const AnfNodePtr &anf_node) cons
|
|||
}
|
||||
}
|
||||
|
||||
std::string AkgKernelJsonGenerator::GetProcessorByTarget() const {
|
||||
auto target = cb_->GetTargetFromContext();
|
||||
if (target == kGPUDevice) {
|
||||
return "cuda";
|
||||
}
|
||||
if (target == kAscendDevice) {
|
||||
return "aicore";
|
||||
}
|
||||
return "cpu";
|
||||
}
|
||||
|
||||
bool AkgKernelJsonGenerator::GenerateSingleKernelJson(const AnfNodePtr &anf_node, nlohmann::json *node_json) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(node_json);
|
||||
|
@ -683,7 +695,7 @@ bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::j
|
|||
(*kernel_json)[kJsonKeyId] = 0; // unused key
|
||||
(*kernel_json)[kJsonKeyOp] = kernel_name_;
|
||||
(*kernel_json)[kJsonKeyPlatform] = "AKG";
|
||||
(*kernel_json)[kJsonKeyProcess] = this->cb_->GetProcessorFromContext();
|
||||
(*kernel_json)[kJsonKeyProcess] = GetProcessorByTarget();
|
||||
(*kernel_json)[kJsonKeyComposite] = false;
|
||||
if (dump_option_.get_compute_capability) {
|
||||
(*kernel_json)[kJsonKeyComputeCapability] = ComputeCapability::Get();
|
||||
|
@ -782,7 +794,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
|
|||
(*kernel_json)[kJsonKeyId] = 0; // unused key
|
||||
(*kernel_json)[kJsonKeyOp] = kernel_name_;
|
||||
(*kernel_json)[kJsonKeyPlatform] = "AKG";
|
||||
(*kernel_json)[kJsonKeyProcess] = this->cb_->GetProcessorFromContext();
|
||||
(*kernel_json)[kJsonKeyProcess] = GetProcessorByTarget();
|
||||
(*kernel_json)[kJsonKeyComposite] = true;
|
||||
(*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString();
|
||||
if (dump_option_.get_compute_capability) {
|
||||
|
|
|
@ -141,6 +141,7 @@ class AkgKernelJsonGenerator {
|
|||
const std::map<AnfNodePtr, nlohmann::json> &node_json_map, nlohmann::json *kernel_json);
|
||||
bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *node_json) const;
|
||||
size_t GetTensorSize(const nlohmann::json &node_json) const;
|
||||
std::string GetProcessorByTarget() const;
|
||||
|
||||
DumpOption dump_option_;
|
||||
std::string kernel_name_;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <vector>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
|
||||
|
@ -77,7 +78,11 @@ std::string CallbackImpl::GetOutputFormat(const AnfNodePtr &node, size_t i) {
|
|||
|
||||
std::string CallbackImpl::GetProcessor(const AnfNodePtr &node) { return kernel::GetProcessorStr(node); }
|
||||
|
||||
std::string CallbackImpl::GetProcessorFromContext() { return kernel::GetStrProcessorFromContext(); }
|
||||
std::string CallbackImpl::GetTargetFromContext() {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
return context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
}
|
||||
|
||||
void CallbackImpl::SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) {
|
||||
std::vector<std::string> graph_input_format;
|
||||
|
|
|
@ -33,7 +33,7 @@ class CallbackImpl : public Callback {
|
|||
std::string GetInputFormat(const AnfNodePtr &node, size_t i) override;
|
||||
std::string GetOutputFormat(const AnfNodePtr &node, size_t i) override;
|
||||
std::string GetProcessor(const AnfNodePtr &node) override;
|
||||
std::string GetProcessorFromContext() override;
|
||||
std::string GetTargetFromContext() override;
|
||||
void SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) override;
|
||||
};
|
||||
} // namespace mindspore::graphkernel
|
||||
|
|
|
@ -116,9 +116,9 @@ class Callback {
|
|||
virtual std::string GetProcessor(const AnfNodePtr &node) = 0;
|
||||
|
||||
/**
|
||||
* @brief Get the backend processor from context.
|
||||
* @brief Get the backend target from context.
|
||||
*/
|
||||
virtual std::string GetProcessorFromContext() = 0;
|
||||
virtual std::string GetTargetFromContext() = 0;
|
||||
|
||||
/**
|
||||
* @brief Set KernelInfo for a GraphKernel node, the info is extract from its inputs/outputs.
|
||||
|
|
|
@ -14,9 +14,13 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/graph_kernel/core/graph_kernel_callback.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
std::string GkUtils::ExtractGraphKernelName(const AnfNodePtrList &nodes, const std::string &prefix,
|
||||
|
@ -53,4 +57,64 @@ AnfNodePtrList GkUtils::SpreadTuples(const AnfNodePtrList &nodes, size_t begin_i
|
|||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<PrimitivePtr> GkUtils::GetValidOps(const std::vector<OpWithLevel> &ops_with_level, unsigned int level,
|
||||
const std::vector<std::string> &enable_ops_only,
|
||||
const std::vector<std::string> &enable_ops,
|
||||
const std::vector<std::string> &disable_ops) {
|
||||
std::vector<PrimitivePtr> ops;
|
||||
auto new_prim = [](const std::string &name) { return std::make_shared<Primitive>(name); };
|
||||
if (!enable_ops_only.empty()) {
|
||||
(void)std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::back_inserter(ops), new_prim);
|
||||
return ops;
|
||||
}
|
||||
auto target = Callback::Instance()->GetTargetFromContext();
|
||||
for (const auto &[op_target, op_level, op] : ops_with_level) {
|
||||
if (op_target == kAllTarget || op_target == target) {
|
||||
if (level >= op_level) {
|
||||
(void)ops.emplace_back(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!enable_ops.empty()) {
|
||||
(void)std::transform(enable_ops.begin(), enable_ops.end(), std::back_inserter(ops), new_prim);
|
||||
}
|
||||
if (!disable_ops.empty()) {
|
||||
auto iter = std::remove_if(ops.begin(), ops.end(), [&disable_ops](const PrimitivePtr &p) {
|
||||
return std::find(disable_ops.begin(), disable_ops.end(), p->name()) != disable_ops.end();
|
||||
});
|
||||
(void)ops.erase(iter, ops.end());
|
||||
}
|
||||
return ops;
|
||||
}
|
||||
|
||||
bool GkUtils::IsKeepBasicNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto get_bool_attr = [](const PrimitivePtr &primitive, const std::string &attr_name) -> bool {
|
||||
if (!primitive->HasAttr(attr_name)) {
|
||||
return false;
|
||||
}
|
||||
return GetValue<bool>(primitive->GetAttr(attr_name));
|
||||
};
|
||||
auto prim = GetCNodePrimitive(node);
|
||||
if (prim == nullptr) return false;
|
||||
|
||||
// Dynamic shape is unsupported yet.
|
||||
if (get_bool_attr(prim, kAttrInputIsDynamicShape) || get_bool_attr(prim, kAttrOutputIsDynamicShape) ||
|
||||
get_bool_attr(prim, kAttrIsDynamicShape)) {
|
||||
return true;
|
||||
}
|
||||
if (get_bool_attr(prim, "skip")) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// If node contain attribute in contagious_attrs, it have to keep basic no matter what the value is.
|
||||
const std::vector<std::string> contagious_attrs = {"inplace_group", "inplace_algo", "inplace_output_index",
|
||||
"aggregate", "aggregate_input_indexx"};
|
||||
if (std::any_of(contagious_attrs.cbegin(), contagious_attrs.cend(),
|
||||
[&prim](const std::string &attr_name) -> bool { return prim->HasAttr(attr_name); })) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
||||
|
|
|
@ -18,10 +18,17 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
constexpr auto kGraphKernelDumpPath = "graph_kernel_dump";
|
||||
constexpr auto kAllTarget = "ALL";
|
||||
|
||||
using OpWithLevel = std::tuple<std::string, unsigned int, PrimitivePtr>;
|
||||
|
||||
class GkUtils {
|
||||
public:
|
||||
/**
|
||||
|
@ -47,6 +54,25 @@ class GkUtils {
|
|||
* @return std::vector<AnfNodePtr>
|
||||
*/
|
||||
static AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index = 0);
|
||||
|
||||
/**
|
||||
* @brief Filter operators by target, op level, and enable/disable flags.
|
||||
* @param[in] ops_with_level the default operator list
|
||||
* @param[in] level enabled op level
|
||||
* @param[in] enable_ops_only the "enable_xxx_ops_only" flag
|
||||
* @param[in] enable_ops the "enable_xxx_ops" flag
|
||||
* @param[in] disable_ops the "disable_xxx_ops" flag
|
||||
* @return Available primitive list
|
||||
*/
|
||||
static std::vector<PrimitivePtr> GetValidOps(const std::vector<OpWithLevel> &ops_with_level, unsigned int level,
|
||||
const std::vector<std::string> &enable_ops_only,
|
||||
const std::vector<std::string> &enable_ops,
|
||||
const std::vector<std::string> &disable_ops);
|
||||
|
||||
/**
|
||||
* @brief Check whether graphkernel supports the node
|
||||
*/
|
||||
static bool IsKeepBasicNode(const AnfNodePtr &node);
|
||||
};
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_
|
||||
|
|
|
@ -16,28 +16,27 @@
|
|||
#include "backend/optimizer/graph_kernel/graph_kernel_cluster.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
||||
#include "base/core_ops.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/file_utils.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||
#include "backend/optimizer/graph_kernel/core/graph_kernel_callback.h"
|
||||
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/core/graph_builder.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
using opt::GetitemTuple;
|
||||
|
||||
std::vector<PrimitivePtr> GraphKernelCluster::GetClusterableOpList() {
|
||||
std::vector<std::tuple<std::string, unsigned int, PrimitivePtr>> clusterable_ops_with_level = {
|
||||
std::vector<OpWithLevel> clusterable_ops_with_level = {
|
||||
// all target
|
||||
{kAllTarget, OpLevel_0, prim::kPrimAbs},
|
||||
{kAllTarget, OpLevel_0, prim::kPrimAdd},
|
||||
|
@ -100,19 +99,17 @@ std::vector<PrimitivePtr> GraphKernelCluster::GetClusterableOpList() {
|
|||
{kGPUDevice, OpLevel_0, prim::kPrimSign},
|
||||
{kGPUDevice, OpLevel_0, prim::kPrimSin},
|
||||
{kGPUDevice, OpLevel_0, prim::kPrimStridedSlice},
|
||||
{kGPUDevice, OpLevel_0, prim::kPrimUserDefined},
|
||||
};
|
||||
const auto &flags = GraphKernelFlags::GetInstance();
|
||||
std::vector<PrimitivePtr> clusterable_ops = GetValidOps(clusterable_ops_with_level, flags.fusion_ops_level);
|
||||
OpListFilter(&clusterable_ops, flags.enable_cluster_ops_only, flags.enable_cluster_ops, flags.disable_cluster_ops);
|
||||
return clusterable_ops;
|
||||
return GkUtils::GetValidOps(clusterable_ops_with_level, flags.fusion_ops_level, flags.enable_cluster_ops_only,
|
||||
flags.enable_cluster_ops, flags.disable_cluster_ops);
|
||||
}
|
||||
|
||||
bool GraphKernelCluster::IsClusterableOp(const AnfNodePtr &node) {
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
if (AnfUtils::IsGraphKernel(node)) {
|
||||
return true;
|
||||
}
|
||||
if (IsKeepBasicNode(node)) {
|
||||
if (GkUtils::IsKeepBasicNode(node)) {
|
||||
return false;
|
||||
}
|
||||
bool node_in_oplist = std::any_of(op_list_.begin(), op_list_.end(),
|
||||
|
@ -120,12 +117,15 @@ bool GraphKernelCluster::IsClusterableOp(const AnfNodePtr &node) {
|
|||
if (!node_in_oplist) {
|
||||
return false;
|
||||
}
|
||||
#if ENABLE_D
|
||||
|
||||
// For AICPU operators, only the Reshape can be clustered.
|
||||
if (AnfAlgo::GetProcessor(node) != kernel::Processor::AICORE && !IsPrimitiveCNode(node, prim::kPrimReshape)) {
|
||||
return false;
|
||||
auto cb = Callback::Instance();
|
||||
MS_EXCEPTION_IF_NULL(cb);
|
||||
if (cb->GetTargetFromContext() == kAscendDevice) {
|
||||
if (cb->GetProcessor(node) != "aicore" && !IsPrimitiveCNode(node, prim::kPrimReshape)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -393,7 +393,7 @@ bool GraphKernelCluster::Process(const FuncGraphPtr &func_graph) {
|
|||
// Do not cluster a single GraphKernel again.
|
||||
// Do not cluster a single Assign.
|
||||
const auto &node = nodes_[clusters[i][0]];
|
||||
if (AnfAlgo::IsGraphKernel(node) || IsPrimitiveCNode(node, prim::kPrimAssign) || !IsClusterableOp(node)) {
|
||||
if (AnfUtils::IsGraphKernel(node) || IsPrimitiveCNode(node, prim::kPrimAssign) || !IsClusterableOp(node)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
@ -408,8 +408,8 @@ void GraphKernelCluster::CreateFuncGraph(const FuncGraphPtr &func_graph, const s
|
|||
(void)std::transform(nodes_id.begin(), nodes_id.end(), std::back_inserter(old_nodes),
|
||||
[this](size_t id) { return this->nodes_[id]; });
|
||||
auto new_node = ReplaceNodesWithGraphKernelNode(old_nodes, func_graph, "fusion");
|
||||
std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>();
|
||||
(void)eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(new_node));
|
||||
std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<GetitemTuple>();
|
||||
(void)eliminate_getitem_pass->Run(GetCNodeFuncGraph(new_node));
|
||||
if (GraphKernelFlags::GetInstance().dump_as_text) {
|
||||
DumpClusterInfo(old_nodes, new_node);
|
||||
}
|
||||
|
|
|
@ -17,19 +17,13 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CLUSTER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace prim {
|
||||
inline const PrimitivePtr kPrimUserDefined = std::make_shared<Primitive>("UserDefined");
|
||||
}
|
||||
|
||||
namespace graphkernel {
|
||||
class Graph;
|
||||
using GraphPtr = std::shared_ptr<Graph>;
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/split_umonad.h"
|
||||
#include "backend/optimizer/graph_kernel/substitute_dropout.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
@ -47,7 +48,7 @@ constexpr size_t kLambWeightInputIdx = 4;
|
|||
constexpr size_t kRandomInputIdx = 1;
|
||||
|
||||
std::vector<PrimitivePtr> GetExpandOps() {
|
||||
std::vector<std::tuple<std::string, unsigned int, PrimitivePtr>> expand_ops_with_level = {
|
||||
std::vector<OpWithLevel> expand_ops_with_level = {
|
||||
{kAllTarget, OpLevel_0, prim::kPrimAddN},
|
||||
{kAllTarget, OpLevel_0, prim::kPrimAssignAdd},
|
||||
{kAllTarget, OpLevel_0, prim::kPrimErfc},
|
||||
|
@ -95,9 +96,8 @@ std::vector<PrimitivePtr> GetExpandOps() {
|
|||
{kGPUDevice, OpLevel_0, prim::kPrimStandardNormal},
|
||||
};
|
||||
const auto &flags = GraphKernelFlags::GetInstance();
|
||||
std::vector<PrimitivePtr> expand_ops = GetValidOps(expand_ops_with_level, flags.fusion_ops_level);
|
||||
OpListFilter(&expand_ops, flags.enable_expand_ops_only, flags.enable_expand_ops, flags.disable_expand_ops);
|
||||
return expand_ops;
|
||||
return GkUtils::GetValidOps(expand_ops_with_level, flags.fusion_ops_level, flags.enable_expand_ops_only,
|
||||
flags.enable_expand_ops, flags.disable_expand_ops);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -217,8 +217,8 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
|
|||
MS_EXCEPTION_IF_NULL(mng);
|
||||
for (const auto &n : todos) {
|
||||
auto node = n->cast<CNodePtr>();
|
||||
if (node == nullptr || AnfAlgo::IsGraphKernel(node) || IsKeepBasicNode(node) || !AnfUtils::IsRealKernel(node) ||
|
||||
!CanExpand(node)) {
|
||||
if (node == nullptr || AnfAlgo::IsGraphKernel(node) || GkUtils::IsKeepBasicNode(node) ||
|
||||
!AnfUtils::IsRealKernel(node) || !CanExpand(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
@ -444,51 +444,6 @@ void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfN
|
|||
AnfAlgo::SetNodeAttr(key, value, node);
|
||||
}
|
||||
|
||||
bool IsKeepBasicNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
// Dynamic shape is unsupported yet.
|
||||
if (AnfAlgo::HasDynamicShapeFlag(AnfAlgo::GetCNodePrimitive(cnode))) {
|
||||
return true;
|
||||
}
|
||||
|
||||
static const std::vector<std::string> contagious_attrs = {"inplace_group", "inplace_algo", "inplace_output_index",
|
||||
"aggregate", "aggregate_input_indexx"};
|
||||
// If node contain attribute in contagious_attrs, it have to keep basic no matter what the value is.
|
||||
if (std::any_of(contagious_attrs.cbegin(), contagious_attrs.cend(),
|
||||
[&cnode](const std::string &attr_name) -> bool { return AnfAlgo::HasNodeAttr(attr_name, cnode); })) {
|
||||
return true;
|
||||
}
|
||||
if (AnfAlgo::GetBooleanAttr(cnode, "skip")) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void OpListFilter(std::vector<PrimitivePtr> *ops, const std::vector<std::string> &enable_ops_only,
|
||||
const std::vector<std::string> &enable_ops, const std::vector<std::string> &disable_ops) {
|
||||
auto new_prim = [](const std::string &name) { return std::make_shared<Primitive>(name); };
|
||||
if (!enable_ops_only.empty()) {
|
||||
ops->clear();
|
||||
(void)std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::back_inserter(*ops), new_prim);
|
||||
} else {
|
||||
if (!enable_ops.empty()) {
|
||||
(void)std::transform(enable_ops.begin(), enable_ops.end(), std::back_inserter(*ops), new_prim);
|
||||
}
|
||||
if (!disable_ops.empty()) {
|
||||
auto iter = std::remove_if(ops->begin(), ops->end(), [&disable_ops](const PrimitivePtr &p) {
|
||||
return std::find(disable_ops.begin(), disable_ops.end(), p->name()) != disable_ops.end();
|
||||
});
|
||||
(void)ops->erase(iter, ops->end());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inner::LiteGraphPtr AnfGraph2LiteGraph(const FuncGraphPtr &func_graph) {
|
||||
inner::LiteGraph::GraphBuilder gb(GetValue<std::string>(func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)));
|
||||
std::map<AnfNodePtr, inner::NodePtr> node_map;
|
||||
|
@ -632,22 +587,6 @@ void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList
|
|||
*inputs = std::move(new_inputs);
|
||||
}
|
||||
|
||||
std::vector<PrimitivePtr> GetValidOps(
|
||||
const std::vector<std::tuple<std::string, unsigned int, PrimitivePtr>> &ops_with_level, unsigned int level) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::string target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
std::vector<PrimitivePtr> valid_ops;
|
||||
for (const auto &[op_target, op_level, op] : ops_with_level) {
|
||||
if (op_target == kAllTarget || op_target == target) {
|
||||
if (level >= op_level) {
|
||||
(void)valid_ops.emplace_back(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
return valid_ops;
|
||||
}
|
||||
|
||||
FuncGraphManagerPtr GetFuncGraphManager(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
|
|
|
@ -44,9 +44,7 @@ constexpr auto kGetGraphKernelOpExpander = "get_op_expander";
|
|||
constexpr auto kJsonKeyMultiGraph = "multi_graph";
|
||||
constexpr auto kJsonKeyGraphDesc = "graph_desc";
|
||||
constexpr auto kJsonKeyGraphMode = "graph_mode";
|
||||
constexpr auto kAllTarget = "ALL";
|
||||
|
||||
constexpr auto kGraphKernelDumpPath = "graph_kernel_dump";
|
||||
struct DataInfo {
|
||||
std::string format{kOpFormat_DEFAULT};
|
||||
ShapeVector shape{1};
|
||||
|
@ -79,9 +77,6 @@ std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node);
|
|||
CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info,
|
||||
bool use_fake_abstract = false);
|
||||
void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
|
||||
bool IsKeepBasicNode(const AnfNodePtr &node);
|
||||
void OpListFilter(std::vector<PrimitivePtr> *ops, const std::vector<std::string> &enable_ops_only,
|
||||
const std::vector<std::string> &enable_ops, const std::vector<std::string> &disable_ops);
|
||||
template <typename T>
|
||||
ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) {
|
||||
// Create tensor value.
|
||||
|
@ -127,9 +122,6 @@ FuncGraphPtr LiteGraph2AnfGraph(const inner::LiteGraphPtr &lite_graph, AnfNodePt
|
|||
// remove parameter which is not used
|
||||
void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs);
|
||||
|
||||
std::vector<PrimitivePtr> GetValidOps(
|
||||
const std::vector<std::tuple<std::string, unsigned int, PrimitivePtr>> &ops_with_level, unsigned int level);
|
||||
|
||||
// return a func_graph's manager
|
||||
FuncGraphManagerPtr GetFuncGraphManager(const FuncGraphPtr &func_graph);
|
||||
|
||||
|
|
|
@ -119,5 +119,7 @@ std::string CallbackImpl::GetOutputFormat(const AnfNodePtr &node, size_t i) { re
|
|||
|
||||
std::string CallbackImpl::GetProcessor(const AnfNodePtr &node) { return "cpu"; }
|
||||
|
||||
std::string CallbackImpl::GetProcessorFromContext() { return "cpu"; }
|
||||
std::string CallbackImpl::GetTargetFromContext() { return "CPU"; }
|
||||
|
||||
void CallbackImpl::SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) {}
|
||||
} // namespace mindspore::graphkernel
|
||||
|
|
|
@ -35,7 +35,8 @@ class CallbackImpl : public Callback {
|
|||
std::string GetInputFormat(const AnfNodePtr &node, size_t i) override;
|
||||
std::string GetOutputFormat(const AnfNodePtr &node, size_t i) override;
|
||||
std::string GetProcessor(const AnfNodePtr &node) override;
|
||||
std::string GetProcessorFromContext() override;
|
||||
std::string GetTargetFromContext() override;
|
||||
void SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) override;
|
||||
};
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_LITE_ADAPTER_CALLBACK_IMPL_H_
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
#include <memory>
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -45,7 +44,6 @@ const AnfNodePtr GetitemTuple::Process(const FuncGraphPtr &, const AnfNodePtr &n
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto tuple_getitem = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputTensorNum);
|
||||
AnfNodePtr make_tuple_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple_anf);
|
||||
AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem);
|
||||
|
|
Loading…
Reference in New Issue