!21235 Add pass TransformOpOptimizer
Merge pull request !21235 from DeshiChen/0719_elim_transform
This commit is contained in:
commit
033b4706f2
|
@ -624,13 +624,13 @@ void ReorganizeEmptyGraph(const graphkernel::LiteGraphPtr &litegraph) {
|
|||
auto &outputs = litegraph->GetOutputs();
|
||||
for (size_t i = 0; i < outputs.size(); i++) {
|
||||
if (outputs[i]->NodeType() == graphkernel::NType::Value) {
|
||||
graphkernel::PrimOpPtr op_ptr = std::make_shared<graphkernel::BroadcastToOp>("BroadcastTo", "");
|
||||
graphkernel::LiteGraph::GraphBuilder gb;
|
||||
std::vector<int64_t> new_shape = {1};
|
||||
op_ptr->Infer({outputs[i]}, {{"shape", MakeValue(new_shape)}});
|
||||
auto op_ptr = gb.Emit("BroadcastTo", {outputs[i]}, {{"shape", MakeValue(new_shape)}});
|
||||
litegraph->output()->SetInput(i, op_ptr);
|
||||
} else if (outputs[i]->NodeType() == graphkernel::NType::Parameter) {
|
||||
graphkernel::PrimOpPtr op_ptr = std::make_shared<graphkernel::ReshapeOp>("Reshape", "");
|
||||
op_ptr->Infer({outputs[i]}, {{"shape", MakeValue(outputs[i]->shape)}});
|
||||
graphkernel::LiteGraph::GraphBuilder gb;
|
||||
auto op_ptr = gb.Emit("Reshape", {outputs[i]}, {{"shape", MakeValue(outputs[i]->shape)}});
|
||||
litegraph->output()->SetInput(i, op_ptr);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,6 +48,7 @@
|
|||
#include "backend/optimizer/graph_kernel/uss_atomic_add.h"
|
||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_pass_manager.h"
|
||||
#include "backend/optimizer/graph_kernel/transform_op_optimizer.h"
|
||||
#include "backend/optimizer/graph_kernel/rewrite_output_shape.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -123,8 +124,12 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt1() const {
|
|||
// Common subexpression elimination
|
||||
pm->AddPass(std::make_shared<GraphKernelCSE>(), OptLevel_2);
|
||||
|
||||
// Elimate Redundant Complex op
|
||||
// Eliminate Redundant Complex op
|
||||
pm->AddPass(std::make_shared<EliminateRedundantComplex>(), OptLevel_2, false);
|
||||
|
||||
// Eliminate unnecessary transform ops
|
||||
auto level = GetPassLevelByFlag(context::GraphKernelFlags::GetInstance().enable_trans_op_optimize);
|
||||
pm->AddPass(std::make_shared<TransformOpOptimizer>(), level, is_gpu);
|
||||
return pm;
|
||||
}
|
||||
|
||||
|
|
|
@ -102,12 +102,16 @@ NodePtr LiteGraph::GraphBuilder::Emit(const std::string &op, const NodePtrList &
|
|||
std::string node_name) {
|
||||
if (node_name.empty()) node_name = NewName();
|
||||
PrimOpPtr op_ptr = CreateOp(op, node_name);
|
||||
op_ptr->Infer(inputs, attrs);
|
||||
auto baseinfo = op_ptr->Infer(inputs, attrs);
|
||||
op_ptr->SetInputs(inputs);
|
||||
op_ptr->SetAttrs(attrs);
|
||||
op_ptr->SetBaseInfo(baseinfo);
|
||||
return graph_->Add(op_ptr);
|
||||
}
|
||||
|
||||
NodePtr LiteGraph::GraphBuilder::Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs,
|
||||
const DAttrs &attrs, std::string node_name) {
|
||||
if (node_name.empty()) node_name = NewName();
|
||||
PrimOpPtr op_ptr = CreateOp(op, node_name);
|
||||
op_ptr->SetInputs(inputs);
|
||||
op_ptr->SetAttrs(attrs);
|
||||
|
|
|
@ -76,9 +76,6 @@ void Node::SetInputs(const NodePtrList &inputs) {
|
|||
|
||||
void Node::ReplaceWith(const NodePtr &other_node) {
|
||||
if (this->users_.empty()) return;
|
||||
if (this->NodeType() != NType::Primitive) {
|
||||
MS_LOG(EXCEPTION) << "Only Primitive node can be replaced, but the node type is " << NodeType();
|
||||
}
|
||||
// copy the users before traversal
|
||||
auto users = this->users_;
|
||||
for (auto &user : users) {
|
||||
|
|
|
@ -59,7 +59,7 @@ struct NodeBase {
|
|||
class Node;
|
||||
using NodePtr = std::shared_ptr<Node>;
|
||||
using NodePtrList = std::vector<NodePtr>;
|
||||
class Node : public NodeBase {
|
||||
class Node : public NodeBase, public std::enable_shared_from_this<Node> {
|
||||
public:
|
||||
Node(const NodeBase &baseinfo, const std::string &name) : NodeBase(baseinfo), name_(name) {}
|
||||
virtual ~Node() {
|
||||
|
@ -90,16 +90,13 @@ class Node : public NodeBase {
|
|||
void SetAttr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; }
|
||||
|
||||
template <typename T>
|
||||
T *As() {
|
||||
return static_cast<T *>(this);
|
||||
}
|
||||
template <typename T>
|
||||
const T *As() const {
|
||||
return static_cast<const T *>(this);
|
||||
std::shared_ptr<T> As() {
|
||||
return std::static_pointer_cast<T>(shared_from_this());
|
||||
}
|
||||
|
||||
const std::string &name() const { return name_; }
|
||||
const DAttrs &attrs() const { return attrs_; }
|
||||
const NodePtr &input(size_t i) const { return inputs_[i]; }
|
||||
const NodePtrList &inputs() const { return inputs_; }
|
||||
const std::unordered_map<Node *, std::set<size_t>> &users() const { return users_; }
|
||||
|
||||
|
@ -145,7 +142,7 @@ class ParamNode : public Node {
|
|||
|
||||
class OutputNode : public Node {
|
||||
public:
|
||||
OutputNode() : Node({{1}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, "") {}
|
||||
OutputNode() : Node({{1}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, "Output") {}
|
||||
void Dump(std::ostringstream &os) const override { ; }
|
||||
NType NodeType() override { return NType::Output; }
|
||||
};
|
||||
|
|
|
@ -81,13 +81,14 @@ void PrimOp::CheckFormat(const NodePtrList &inputs, const DAttrs &attrs) {
|
|||
}
|
||||
}
|
||||
}
|
||||
void PrimOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
|
||||
NodeBase PrimOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
Check(inputs, attrs);
|
||||
this->shape = InferShape(inputs, attrs);
|
||||
this->type = InferType(inputs, attrs);
|
||||
this->format = InferFormat(inputs, attrs);
|
||||
this->attrs_ = attrs;
|
||||
SetInputs(inputs);
|
||||
NodeBase nodebase;
|
||||
nodebase.shape = InferShape(inputs, attrs);
|
||||
nodebase.type = InferType(inputs, attrs);
|
||||
nodebase.format = InferFormat(inputs, attrs);
|
||||
return nodebase;
|
||||
}
|
||||
|
||||
void PrimOp::Dump(std::ostringstream &os) const {
|
||||
|
@ -279,8 +280,8 @@ DFormat ElemwiseOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs)
|
|||
return it == inputs.end() ? kOpFormat_DEFAULT : (*it)->format;
|
||||
}
|
||||
|
||||
void ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
PrimOp::Infer(inputs, attrs);
|
||||
NodeBase ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
auto nodebase = PrimOp::Infer(inputs, attrs);
|
||||
auto IsBroadcast = [this](const NodePtrList &inputs) -> bool {
|
||||
for (auto &ref : inputs) {
|
||||
if (ref->shape.size() != this->shape.size()) return true;
|
||||
|
@ -291,6 +292,7 @@ void ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
|
|||
return false;
|
||||
};
|
||||
compute_type_ = IsBroadcast(inputs) ? BROADCAST : ELEMWISE;
|
||||
return nodebase;
|
||||
}
|
||||
|
||||
TypeId CastOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
|
|
|
@ -50,16 +50,8 @@ class PrimOp : public Node {
|
|||
PrimOp(const std::string &op, const std::string &node_name, ComputeType compute)
|
||||
: Node({{}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, node_name), op_(op), compute_type_(compute) {}
|
||||
|
||||
virtual void Check(const NodePtrList &inputs, const DAttrs &attrs);
|
||||
virtual void CheckShape(const NodePtrList &inputs, const DAttrs &attrs) {}
|
||||
virtual void CheckType(const NodePtrList &inputs, const DAttrs &attrs);
|
||||
virtual void CheckFormat(const NodePtrList &inputs, const DAttrs &attrs);
|
||||
|
||||
virtual void Infer(const NodePtrList &inputs, const DAttrs &attrs);
|
||||
virtual NodeBase Infer(const NodePtrList &inputs, const DAttrs &attrs);
|
||||
virtual NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs, const std::string &op);
|
||||
virtual DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->shape; }
|
||||
virtual TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->type; }
|
||||
virtual DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->format; }
|
||||
|
||||
void Dump(std::ostringstream &os) const override;
|
||||
NType NodeType() override { return NType::Primitive; }
|
||||
|
@ -68,6 +60,15 @@ class PrimOp : public Node {
|
|||
ComputeType compute_type() const { return compute_type_; }
|
||||
|
||||
protected:
|
||||
virtual void Check(const NodePtrList &inputs, const DAttrs &attrs);
|
||||
virtual void CheckShape(const NodePtrList &inputs, const DAttrs &attrs) {}
|
||||
virtual void CheckType(const NodePtrList &inputs, const DAttrs &attrs);
|
||||
virtual void CheckFormat(const NodePtrList &inputs, const DAttrs &attrs);
|
||||
|
||||
virtual DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->shape; }
|
||||
virtual TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->type; }
|
||||
virtual DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->format; }
|
||||
|
||||
std::string op_;
|
||||
ComputeType compute_type_;
|
||||
};
|
||||
|
@ -76,8 +77,9 @@ using PrimOpPtr = std::shared_ptr<PrimOp>;
|
|||
class ElemwiseOp : public PrimOp {
|
||||
public:
|
||||
ElemwiseOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, ELEMWISE) {}
|
||||
NodeBase Infer(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
|
||||
void Infer(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
protected:
|
||||
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
};
|
||||
|
@ -86,6 +88,7 @@ class CastOp : public ElemwiseOp {
|
|||
public:
|
||||
CastOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Cast", node_name) {}
|
||||
|
||||
protected:
|
||||
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
};
|
||||
|
||||
|
@ -93,6 +96,7 @@ class InplaceAssignOp : public ElemwiseOp {
|
|||
public:
|
||||
InplaceAssignOp(const std::string &op, const std::string &node_name) : ElemwiseOp("InplaceAssign", node_name) {}
|
||||
|
||||
protected:
|
||||
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->shape; }
|
||||
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->type; }
|
||||
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->format; }
|
||||
|
@ -102,6 +106,7 @@ class SelectOp : public ElemwiseOp {
|
|||
public:
|
||||
SelectOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Select", node_name) {}
|
||||
|
||||
protected:
|
||||
void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[1]->type; }
|
||||
};
|
||||
|
@ -110,6 +115,7 @@ class CompareOp : public ElemwiseOp {
|
|||
public:
|
||||
CompareOp(const std::string &op, const std::string &node_name) : ElemwiseOp(op, node_name) {}
|
||||
|
||||
protected:
|
||||
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeBool; }
|
||||
};
|
||||
|
||||
|
@ -142,6 +148,7 @@ class ReshapeOp : public PrimOp {
|
|||
public:
|
||||
ReshapeOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, RESHAPE) {}
|
||||
|
||||
protected:
|
||||
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override {
|
||||
return attrs.find("format") == attrs.end() ? kOpFormat_DEFAULT
|
||||
|
@ -153,6 +160,7 @@ class BroadcastToOp : public PrimOp {
|
|||
public:
|
||||
BroadcastToOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, BROADCAST) {}
|
||||
|
||||
protected:
|
||||
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
};
|
||||
|
||||
|
@ -160,8 +168,8 @@ class ReduceOp : public PrimOp {
|
|||
public:
|
||||
ReduceOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, REDUCE) {}
|
||||
|
||||
protected:
|
||||
void Check(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
|
||||
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return kOpFormat_DEFAULT; };
|
||||
};
|
||||
|
@ -175,6 +183,7 @@ class Conv2dOp : public OpaqueOp {
|
|||
public:
|
||||
Conv2dOp(const std::string &op, const std::string &node_name) : OpaqueOp("Conv2D", node_name) {}
|
||||
|
||||
protected:
|
||||
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
};
|
||||
|
@ -183,6 +192,7 @@ class TransposeOp : public OpaqueOp {
|
|||
public:
|
||||
TransposeOp(const std::string &op, const std::string &node_name) : OpaqueOp("Transpose", node_name) {}
|
||||
|
||||
protected:
|
||||
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
};
|
||||
|
@ -191,6 +201,7 @@ class MatMulOp : public OpaqueOp {
|
|||
public:
|
||||
MatMulOp(const std::string &op, const std::string &node_name) : OpaqueOp("MatMul", node_name) {}
|
||||
|
||||
protected:
|
||||
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
};
|
||||
|
@ -199,6 +210,7 @@ class PadAkgOp : public OpaqueOp {
|
|||
public:
|
||||
PadAkgOp(const std::string &op, const std::string &node_name) : OpaqueOp("PadAkg", node_name) {}
|
||||
|
||||
protected:
|
||||
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
};
|
||||
|
||||
|
@ -206,6 +218,7 @@ class UnPadAkgOp : public OpaqueOp {
|
|||
public:
|
||||
UnPadAkgOp(const std::string &op, const std::string &node_name) : OpaqueOp("UnPadAkg", node_name) {}
|
||||
|
||||
protected:
|
||||
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
};
|
||||
|
||||
|
@ -213,6 +226,7 @@ class CImagOp : public ElemwiseOp {
|
|||
public:
|
||||
CImagOp(const std::string &op, const std::string &node_name) : ElemwiseOp("CImag", node_name) {}
|
||||
|
||||
protected:
|
||||
void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override {
|
||||
if (inputs[0]->type != TypeId::kNumberTypeComplex64) {
|
||||
throw GKException("CImag's input[0] should be complex64");
|
||||
|
@ -226,6 +240,7 @@ class CRealOp : public ElemwiseOp {
|
|||
public:
|
||||
CRealOp(const std::string &op, const std::string &node_name) : ElemwiseOp("CReal", node_name) {}
|
||||
|
||||
protected:
|
||||
void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override {
|
||||
if (inputs[0]->type != TypeId::kNumberTypeComplex64) {
|
||||
throw GKException("CReal's input[0] should be complex64");
|
||||
|
@ -239,8 +254,8 @@ class ComplexOp : public ElemwiseOp {
|
|||
public:
|
||||
ComplexOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Complex", node_name) {}
|
||||
|
||||
protected:
|
||||
void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
|
||||
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeComplex64; }
|
||||
};
|
||||
} // namespace graphkernel
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,175 +13,141 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/graph_kernel/transform_op_optimizer.h"
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include "base/core_ops.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "debug/common.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/optimizer/graph_kernel/model/lite_graph.h"
|
||||
#include "backend/optimizer/graph_kernel/model/op_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
enum Format { kFormatUnknown, kFormatA, kFormatB };
|
||||
enum FormatType { kFormatUnknown, kFormatA, kFormatB };
|
||||
enum TransOpType { kTransAB, kTransBA };
|
||||
|
||||
struct Edge {
|
||||
size_t from;
|
||||
size_t to;
|
||||
size_t val;
|
||||
size_t next;
|
||||
Edge(size_t From, size_t To, size_t Val, size_t Next) {
|
||||
from = From;
|
||||
to = To;
|
||||
val = Val;
|
||||
next = Next;
|
||||
}
|
||||
size_t capacity;
|
||||
};
|
||||
|
||||
struct Node {
|
||||
int head;
|
||||
int cur;
|
||||
int depth;
|
||||
size_t pre;
|
||||
Format format;
|
||||
Node() {
|
||||
head = -1;
|
||||
cur = -1;
|
||||
depth = -1;
|
||||
pre = 0;
|
||||
format = kFormatB;
|
||||
}
|
||||
struct Vertex {
|
||||
FormatType format{kFormatB};
|
||||
size_t depth{0};
|
||||
std::vector<size_t> out_edges;
|
||||
};
|
||||
|
||||
constexpr size_t INF = static_cast<size_t>(1) << 30;
|
||||
|
||||
class MinCut {
|
||||
public:
|
||||
// Connect the source_node to the node_with_a_certain_formatA
|
||||
// or Connect the node_with_a_certain_formatB to the sink_node
|
||||
void Add_1(size_t from, size_t to) {
|
||||
edges_.emplace_back(from, to, INF, nodes_[from].head);
|
||||
nodes_[from].head = edges_count_++;
|
||||
edges_.emplace_back(to, from, 0, nodes_[to].head);
|
||||
nodes_[to].head = edges_count_++;
|
||||
private:
|
||||
// Add the bidirectional edges for the vertex `from` and `to`.
|
||||
// the two edge ids are adjacent in vector, x and x+1 (x are 0,2,4,...)
|
||||
// we can use (i xor 1) to get the inverse edge for any edge i.
|
||||
// e.g. edge_0 and edge_1 are a couple, 0^1=1, 1^1=0.
|
||||
void AddEdge(size_t from, size_t to, size_t capacity, size_t inv_capacity) {
|
||||
edges_.emplace_back(Edge{to, capacity});
|
||||
nodes_[from].out_edges.emplace_back(edges_.size() - 1);
|
||||
// inverse edge
|
||||
edges_.emplace_back(Edge{from, inv_capacity});
|
||||
nodes_[to].out_edges.emplace_back(edges_.size() - 1);
|
||||
}
|
||||
|
||||
// Split one origin_node into two new_nodes and connect them
|
||||
void Add_2(size_t nodes_id) {
|
||||
edges_.emplace_back(nodes_id, nodes_id + origin_nodes_num_, 1, nodes_[nodes_id].head);
|
||||
nodes_[nodes_id].head = edges_count_++;
|
||||
edges_.emplace_back(nodes_id + origin_nodes_num_, nodes_id, 1, nodes_[nodes_id + origin_nodes_num_].head);
|
||||
nodes_[nodes_id + origin_nodes_num_].head = edges_count_++;
|
||||
}
|
||||
|
||||
// After splitting the origin_nodes, construct the new_edges based on the original_edges
|
||||
void Add_3(size_t from, size_t to) {
|
||||
edges_.emplace_back(from + origin_nodes_num_, to, 1, nodes_[from + origin_nodes_num_].head);
|
||||
nodes_[from + origin_nodes_num_].head = edges_count_++;
|
||||
edges_.emplace_back(to, from + origin_nodes_num_, 1, nodes_[to].head);
|
||||
nodes_[to].head = edges_count_++;
|
||||
}
|
||||
|
||||
void BFS() {
|
||||
bool BfsSetDepth() {
|
||||
std::queue<size_t> bfs_queue;
|
||||
nodes_[sink_id_].depth = 0;
|
||||
bfs_queue.push(sink_id_);
|
||||
for (auto &node : nodes_) {
|
||||
node.depth = 0;
|
||||
}
|
||||
nodes_[source_id_].depth = 1;
|
||||
bfs_queue.push(source_id_);
|
||||
while (!bfs_queue.empty()) {
|
||||
size_t temp_node = bfs_queue.front();
|
||||
auto edge_from = bfs_queue.front();
|
||||
bfs_queue.pop();
|
||||
depth_num_[nodes_[temp_node].depth]++;
|
||||
for (size_t i = nodes_[temp_node].head; ~i; i = edges_[i].next) {
|
||||
if (edges_[i ^ 1].val && nodes_[edges_[i].to].depth == -1) {
|
||||
nodes_[edges_[i].to].depth = nodes_[temp_node].depth + 1;
|
||||
bfs_queue.push(edges_[i].to);
|
||||
for (auto e_id : nodes_[edge_from].out_edges) {
|
||||
auto edge_to = edges_[e_id].to;
|
||||
if (edges_[e_id].capacity > 0 && nodes_[edge_to].depth == 0) {
|
||||
nodes_[edge_to].depth = nodes_[edge_from].depth + 1;
|
||||
bfs_queue.push(edge_to);
|
||||
}
|
||||
}
|
||||
}
|
||||
return nodes_[sink_id_].depth > 0;
|
||||
}
|
||||
|
||||
void EdgeValueUpdate() {
|
||||
size_t k = sink_id_, flow = INF;
|
||||
while (k != source_id_) {
|
||||
if (edges_[nodes_[k].pre].val < flow) {
|
||||
flow = edges_[nodes_[k].pre].val;
|
||||
size_t DfsMaxFlow(size_t node, size_t flow) {
|
||||
if (node == sink_id_) return flow;
|
||||
size_t max_flow = 0;
|
||||
for (size_t e_id : nodes_[node].out_edges) {
|
||||
if ((edges_[e_id].capacity > 0) && (nodes_[node].depth + 1 == nodes_[edges_[e_id].to].depth)) {
|
||||
auto tmp_flow = DfsMaxFlow(edges_[e_id].to, std::min(flow, edges_[e_id].capacity));
|
||||
if (tmp_flow > 0) {
|
||||
max_flow += tmp_flow;
|
||||
flow -= tmp_flow;
|
||||
edges_[e_id].capacity -= tmp_flow;
|
||||
edges_[e_id ^ 1].capacity += tmp_flow;
|
||||
}
|
||||
}
|
||||
k = edges_[nodes_[k].pre].from;
|
||||
}
|
||||
k = sink_id_;
|
||||
while (k != source_id_) {
|
||||
edges_[nodes_[k].pre].val -= flow;
|
||||
edges_[nodes_[k].pre ^ 1].val += flow;
|
||||
k = edges_[nodes_[k].pre].from;
|
||||
}
|
||||
return max_flow;
|
||||
}
|
||||
|
||||
void ISAP() {
|
||||
size_t node_id = source_id_;
|
||||
int maxdep = 2 * origin_nodes_num_ + 2;
|
||||
BFS();
|
||||
for (size_t i = source_id_; i <= sink_id_; ++i) {
|
||||
nodes_[i].cur = nodes_[i].head;
|
||||
}
|
||||
while (nodes_[source_id_].depth <= maxdep) {
|
||||
if (node_id == sink_id_) {
|
||||
EdgeValueUpdate();
|
||||
node_id = source_id_;
|
||||
}
|
||||
bool can_arrive = false;
|
||||
for (size_t i = nodes_[node_id].cur; ~i; i = edges_[i].next) {
|
||||
if (edges_[i].val && nodes_[edges_[i].to].depth + 1 == nodes_[node_id].depth) {
|
||||
can_arrive = true;
|
||||
nodes_[edges_[i].to].pre = i;
|
||||
nodes_[node_id].cur = i;
|
||||
node_id = edges_[i].to;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!can_arrive) {
|
||||
int mindep = 2 * origin_nodes_num_ + 2;
|
||||
for (size_t i = nodes_[node_id].head; ~i; i = edges_[i].next) {
|
||||
if (nodes_[edges_[i].to].depth < mindep && edges_[i].val) {
|
||||
mindep = nodes_[edges_[i].to].depth;
|
||||
}
|
||||
}
|
||||
--depth_num_[nodes_[node_id].depth];
|
||||
if (!depth_num_[nodes_[node_id].depth]) {
|
||||
break;
|
||||
}
|
||||
nodes_[node_id].depth = mindep + 1;
|
||||
depth_num_[nodes_[node_id].depth]++;
|
||||
nodes_[node_id].cur = nodes_[node_id].head;
|
||||
if (node_id != source_id_) {
|
||||
node_id = edges_[nodes_[node_id].pre].from;
|
||||
}
|
||||
}
|
||||
void Dinic() {
|
||||
while (BfsSetDepth()) {
|
||||
(void)DfsMaxFlow(source_id_, INF);
|
||||
}
|
||||
}
|
||||
|
||||
void SetFormat(size_t node_id) {
|
||||
nodes_[node_id].format = kFormatA;
|
||||
for (size_t i = nodes_[node_id].head; ~i; i = edges_[i].next) {
|
||||
if (edges_[i].val && nodes_[edges_[i].to].format != kFormatA) {
|
||||
for (size_t i : nodes_[node_id].out_edges) {
|
||||
if (edges_[i].capacity > 0 && nodes_[edges_[i].to].format != kFormatA) {
|
||||
SetFormat(edges_[i].to);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
MinCut(std::vector<std::pair<size_t, Format>> original_nodes, std::vector<std::pair<size_t, size_t>> original_edges)
|
||||
: origin_nodes_num_(original_nodes.size()),
|
||||
sink_id_(2 * origin_nodes_num_ + 1),
|
||||
depth_num_(std::vector<size_t>(2 * origin_nodes_num_ + 2, 0)),
|
||||
nodes_(std::vector<Node>(2 * origin_nodes_num_ + 2, Node())),
|
||||
original_edges_(std::move(original_edges)) {
|
||||
void BuildGraph(const std::vector<std::pair<size_t, FormatType>> &original_nodes) {
|
||||
for (size_t i = 0; i < origin_nodes_num_; ++i) {
|
||||
// link the source node to the nodes with FormatA,
|
||||
// link the nodes with FormatB to the sink node.
|
||||
if (original_nodes[i].second == kFormatA) {
|
||||
Add_1(source_id_, original_nodes[i].first);
|
||||
AddEdge(source_id_, original_nodes[i].first, INF, 0);
|
||||
} else if (original_nodes[i].second == kFormatB) {
|
||||
Add_1(original_nodes[i].first, sink_id_);
|
||||
AddEdge(original_nodes[i].first, sink_id_, INF, 0);
|
||||
}
|
||||
Add_2(original_nodes[i].first);
|
||||
// each nodes was split into two part, input part and output part.
|
||||
// the input part's id is the original node's id, the output part's id is input id + origin_nodes_num_.
|
||||
AddEdge(original_nodes[i].first, original_nodes[i].first + origin_nodes_num_, 1, 1);
|
||||
}
|
||||
for (auto i : original_edges_) {
|
||||
Add_3(i.first, i.second);
|
||||
for (auto e : original_edges_) {
|
||||
auto from = e.first, to = e.second;
|
||||
AddEdge(from + origin_nodes_num_, to, 1, 1);
|
||||
}
|
||||
ISAP();
|
||||
}
|
||||
|
||||
public:
|
||||
MinCut(const std::vector<std::pair<size_t, FormatType>> &original_nodes,
|
||||
const std::vector<std::pair<size_t, size_t>> &original_edges)
|
||||
: origin_nodes_num_(original_nodes.size()),
|
||||
sink_id_(2 * original_nodes.size() + 1), // source_id_ is 0
|
||||
nodes_(2 * original_nodes.size() + 2), // double nodes, and source_node/sink_node
|
||||
original_edges_(original_edges) {
|
||||
BuildGraph(original_nodes);
|
||||
}
|
||||
|
||||
void Run() {
|
||||
Dinic();
|
||||
SetFormat(source_id_);
|
||||
}
|
||||
|
||||
|
@ -213,11 +179,285 @@ class MinCut {
|
|||
size_t origin_nodes_num_;
|
||||
size_t source_id_{0};
|
||||
size_t sink_id_;
|
||||
int edges_count_{0};
|
||||
std::vector<size_t> depth_num_;
|
||||
std::vector<Node> nodes_;
|
||||
std::vector<Vertex> nodes_;
|
||||
std::vector<Edge> edges_;
|
||||
std::vector<std::pair<size_t, size_t>> original_edges_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
using graphkernel::LiteGraph;
|
||||
using graphkernel::LiteGraphPtr;
|
||||
using graphkernel::Node;
|
||||
using graphkernel::NodePtr;
|
||||
using graphkernel::NodePtrList;
|
||||
using graphkernel::NType;
|
||||
using graphkernel::PrimOp;
|
||||
using graphkernel::PrimOpPtr;
|
||||
|
||||
class TransformOp {
|
||||
public:
|
||||
explicit TransformOp(const NodePtr &node)
|
||||
: op_(node->As<PrimOp>()->op()), format_a_(node->input(0)->format), format_b_(node->format) {}
|
||||
bool IsTransformOp(const NodePtr &node) {
|
||||
if (node->NodeType() != NType::Primitive || node->As<PrimOp>()->op() != op_) {
|
||||
return false;
|
||||
}
|
||||
if (node->input(0)->format == format_a_ && node->format == format_b_) {
|
||||
return true;
|
||||
} else if (node->input(0)->format == format_b_ && node->format == format_a_) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
FormatType GetFormatType(const std::string &fmt) {
|
||||
return fmt == format_a_ ? FormatType::kFormatA : FormatType::kFormatB;
|
||||
}
|
||||
|
||||
NodePtr GenTransformOp(TransOpType trans_type) {
|
||||
// Only support Transpose now
|
||||
static std::map<std::pair<std::string, std::string>, std::vector<int64_t>> perm_map = {
|
||||
{{kOpFormat_DEFAULT, kOpFormat_NHWC}, {0, 2, 3, 1}},
|
||||
{{kOpFormat_NCHW, kOpFormat_NHWC}, {0, 2, 3, 1}},
|
||||
{{kOpFormat_NHWC, kOpFormat_NCHW}, {0, 3, 1, 2}},
|
||||
{{kOpFormat_NHWC, kOpFormat_DEFAULT}, {0, 3, 1, 2}},
|
||||
};
|
||||
std::vector<int64_t> perm;
|
||||
if (trans_type == TransOpType::kTransAB) {
|
||||
perm = perm_map[{format_a_, format_b_}];
|
||||
} else {
|
||||
perm = perm_map[{format_b_, format_a_}];
|
||||
}
|
||||
if (perm.empty()) {
|
||||
std::ostringstream oss;
|
||||
oss << "unsupported format: " << format_a_ << " to " << format_b_;
|
||||
throw graphkernel::GKException(oss.str());
|
||||
}
|
||||
auto op = graphkernel::OpRegistry::Instance().NewOp("Transpose", "new_trans");
|
||||
op->SetAttr("perm", MakeValue(perm));
|
||||
return op;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string op_;
|
||||
std::string format_a_;
|
||||
std::string format_b_;
|
||||
};
|
||||
|
||||
bool IsFlexibleOp(const NodePtr &node) {
|
||||
static std::set<std::string> format_flexible_ops = {
|
||||
"Abs", "Add", "Sub", "Mul", "Round", "Cast", "Neg", "Exp", "Log",
|
||||
"Pow", "Minimum", "Maximum", "Rsqrt", "Sqrt", "Reciprocal", "Tanh", "Sin", "Cos",
|
||||
"Asin", "ACos", "RealDiv", "Equal", "Greater", "GreaterEqual", "Less", "LessEqual", "Sign"};
|
||||
if (node->NodeType() != NType::Primitive) {
|
||||
return false;
|
||||
}
|
||||
if (format_flexible_ops.count(node->As<PrimOp>()->op()) == 0) {
|
||||
return false;
|
||||
}
|
||||
// check the input and output formats are all the same, except ConstValue.
|
||||
for (auto &inp : node->inputs()) {
|
||||
if (inp->NodeType() != NType::Value && inp->format != node->format) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
class Mutator {
|
||||
public:
|
||||
explicit Mutator(const NodePtr &node) : op_checker_(node), basenode_(node), ori_node_(1) {}
|
||||
bool Run() {
|
||||
VisitNode(basenode_);
|
||||
if (flexible_ops_.empty()) return false;
|
||||
// remove transform ops in litegraph
|
||||
RemoveTransOp();
|
||||
GenFormatGraph();
|
||||
RebuildLiteGraph();
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
// visit nodes bidirectionally
|
||||
void VisitNode(const NodePtr &node) {
|
||||
if (visited_.count(node) > 0) return;
|
||||
visited_.insert(node);
|
||||
if (op_checker_.IsTransformOp(node)) {
|
||||
trans_ops_.insert(node);
|
||||
} else if (!IsFlexibleOp(node)) {
|
||||
if (node->NodeType() != NType::Output) {
|
||||
fmt_type[{node, -1}] = op_checker_.GetFormatType(node->format);
|
||||
}
|
||||
if (node->NodeType() != NType::Parameter) {
|
||||
for (size_t i = 0; i < node->inputs().size(); i++) {
|
||||
if (node->input(i)->NodeType() == NType::Value) {
|
||||
continue;
|
||||
}
|
||||
fmt_type[{node, i}] = op_checker_.GetFormatType(node->input(i)->format);
|
||||
}
|
||||
}
|
||||
return;
|
||||
} else {
|
||||
flexible_ops_.insert(node);
|
||||
fmt_type[{node, -1}] = FormatType::kFormatUnknown;
|
||||
}
|
||||
|
||||
for (auto &input : node->inputs()) {
|
||||
if (input->NodeType() != NType::Value) {
|
||||
VisitNode(input);
|
||||
}
|
||||
}
|
||||
for (auto &user : node->users()) {
|
||||
VisitNode(user.first->shared_from_this());
|
||||
}
|
||||
}
|
||||
|
||||
void RemoveTransOp() {
|
||||
for (auto &node : trans_ops_) {
|
||||
visited_.erase(node);
|
||||
node->ReplaceWith(node->input(0));
|
||||
// clear inputs, so that the node will not be the basenode again.
|
||||
node->SetInputs({});
|
||||
}
|
||||
trans_ops_.clear();
|
||||
}
|
||||
|
||||
void GenFormatGraph() {
|
||||
for (auto &node : visited_) {
|
||||
if (node->NodeType() == NType::Parameter) continue;
|
||||
bool is_flexible = (flexible_ops_.find(node) != flexible_ops_.end());
|
||||
size_t cur_id = 0;
|
||||
if (is_flexible) {
|
||||
cur_id = GetId({node, -1});
|
||||
}
|
||||
for (size_t i = 0; i < node->inputs().size(); i++) {
|
||||
if (visited_.count(node->input(i)) == 0) continue;
|
||||
if (!is_flexible) {
|
||||
cur_id = GetId({node, SizeToInt(i)});
|
||||
}
|
||||
auto input_id = GetId({node->input(i), -1});
|
||||
graph_edges_.emplace_back(input_id, cur_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RebuildLiteGraph() {
|
||||
MinCut min_cut(graph_vertex_, graph_edges_);
|
||||
min_cut.Run();
|
||||
for (auto [node_id, trans_type] : min_cut.GetOneNodeOps()) {
|
||||
if (ori_node_[node_id].second != -1) {
|
||||
MS_LOG(EXCEPTION) << "OneNodeOp should be the output edge. node_id:" << node_id
|
||||
<< " index:" << ori_node_[node_id].second;
|
||||
}
|
||||
auto trans_op = op_checker_.GenTransformOp(trans_type);
|
||||
ori_node_[node_id].first->ReplaceWith(trans_op);
|
||||
trans_op->SetInputs({ori_node_[node_id].first});
|
||||
}
|
||||
|
||||
std::map<size_t, NodePtr> trans_op_cache;
|
||||
for (auto [edge, trans_type] : min_cut.GetTwoNodeOps()) {
|
||||
auto node_id_from = edge.first;
|
||||
auto node_id_to = edge.second;
|
||||
if (ori_node_[node_id_from].second != -1) {
|
||||
MS_LOG(EXCEPTION) << "node_from should be the output edge. node_id:" << node_id_from
|
||||
<< " index:" << ori_node_[node_id_from].second;
|
||||
}
|
||||
auto node_from = ori_node_[node_id_from].first;
|
||||
auto node_to = ori_node_[node_id_to].first;
|
||||
auto &trans_op = trans_op_cache[node_id_from];
|
||||
if (trans_op == nullptr) {
|
||||
trans_op = op_checker_.GenTransformOp(trans_type);
|
||||
trans_op->SetInputs({node_from});
|
||||
}
|
||||
if (ori_node_[node_id_to].second >= 0) {
|
||||
node_to->SetInput(ori_node_[node_id_to].second, trans_op);
|
||||
} else {
|
||||
for (size_t i = 0; i < node_to->inputs().size(); i++) {
|
||||
if (node_to->input(i) == node_from) {
|
||||
node_to->SetInput(i, trans_op);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t GetId(const std::pair<NodePtr, int> &node) {
|
||||
// the nodes are indexed from 1 in the MinCut model.
|
||||
auto &id = node_id_[node];
|
||||
if (id == 0) {
|
||||
id = node_id_.size();
|
||||
ori_node_.push_back(node);
|
||||
// set format_type for new id.
|
||||
graph_vertex_.emplace_back(id, fmt_type[node]);
|
||||
}
|
||||
return id;
|
||||
}
|
||||
|
||||
TransformOp op_checker_;
|
||||
NodePtr basenode_;
|
||||
std::set<NodePtr> flexible_ops_;
|
||||
std::set<NodePtr> trans_ops_;
|
||||
std::set<NodePtr> visited_;
|
||||
|
||||
std::map<std::pair<NodePtr, int>, FormatType> fmt_type;
|
||||
std::map<std::pair<NodePtr, int>, size_t> node_id_;
|
||||
std::vector<std::pair<NodePtr, int>> ori_node_;
|
||||
std::vector<std::pair<size_t, FormatType>> graph_vertex_;
|
||||
std::vector<std::pair<size_t, size_t>> graph_edges_;
|
||||
};
|
||||
|
||||
bool TransformOpOptimizer::Process(const LiteGraphPtr &litegraph, const std::string &trans_op_name) {
|
||||
ori_trans_op_num_ = 0;
|
||||
auto &ops = litegraph->ops();
|
||||
std::set<NodePtr> visited;
|
||||
bool changed = true;
|
||||
auto check_is_trans_op = [&trans_op_name](const NodePtr &node) { return node->As<PrimOp>()->op() == trans_op_name; };
|
||||
auto ori_trans_op_num = std::count_if(ops.begin(), ops.end(), check_is_trans_op);
|
||||
for (auto &op : ops) {
|
||||
if (check_is_trans_op(op) && !op->inputs().empty() && op->input(0)->format != op->format) {
|
||||
auto mutator = Mutator(op);
|
||||
changed = mutator.Run() || changed;
|
||||
}
|
||||
}
|
||||
if (!changed) return false;
|
||||
auto &new_ops = litegraph->GetOrderedNodes();
|
||||
auto new_trans_op_num = std::count_if(new_ops.begin(), new_ops.end(), check_is_trans_op);
|
||||
if (new_trans_op_num >= ori_trans_op_num) {
|
||||
return false;
|
||||
}
|
||||
for (auto &op : new_ops) {
|
||||
op->SetBaseInfo(op->As<PrimOp>()->Infer(op->inputs(), op->attrs()));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TransformOpOptimizer::Run(const FuncGraphPtr &kernel_graph) {
|
||||
auto mng = kernel_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
auto todos = TopoSort(kernel_graph->get_return());
|
||||
bool changed = false;
|
||||
for (auto node : todos) {
|
||||
if (!AnfAlgo::IsGraphKernel(node)) continue;
|
||||
try {
|
||||
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
auto litegraph = AnfGraph2LiteGraph(sub_func_graph);
|
||||
if (Process(litegraph)) {
|
||||
changed = true;
|
||||
AnfNodePtrList outputs;
|
||||
auto new_funcgraph = LiteGraph2AnfGraph(litegraph, &outputs);
|
||||
new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs, outputs);
|
||||
SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
|
||||
mng->Replace(node, new_node);
|
||||
mng->AddFuncGraph(new_funcgraph);
|
||||
}
|
||||
} catch (const graphkernel::GKException &e) {
|
||||
MS_LOG(WARNING) << e.what() << ", so we undo airthmetic simplify for this graph";
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_TRANSFORM_OP_OPTIMIZER_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_TRANSFORM_OP_OPTIMIZER_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "backend/optimizer/graph_kernel/model/lite_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
/**
|
||||
* @brief Eliminate the unnecessary transformation ops when the other operators
|
||||
* are format flexible.
|
||||
* @example
|
||||
* %1 = Transpose(p0) // NCHW to NHWC
|
||||
* %2 = Transpose(p1) // NCHW to NHWC
|
||||
* %3 = Add(%1, %2)
|
||||
* return %3
|
||||
* -->
|
||||
* %1 = Add(p0, p1)
|
||||
* %2 = Transpose(%1) // NCHW to NHWC
|
||||
* return %2
|
||||
* @example
|
||||
* %1 = Transpose(p0) // NCHW to NHWC
|
||||
* %2 = Transpose(p1) // NCHW to NHWC
|
||||
* %3 = Add(%1, %2)
|
||||
* %4 = Transpose(%3) // NHWC to NCHW
|
||||
* return %4
|
||||
* -->
|
||||
* %1 = Add(p0, p1)
|
||||
* return %1
|
||||
*/
|
||||
class TransformOpOptimizer : public Pass {
|
||||
public:
|
||||
TransformOpOptimizer() : Pass("transform_op_optimizer") {}
|
||||
~TransformOpOptimizer() = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
bool Process(const graphkernel::LiteGraphPtr &litegraph, const std::string &trans_op_name = "Transpose");
|
||||
bool IsFlexibleOp(const graphkernel::NodePtr &node);
|
||||
size_t ori_trans_op_num_{0};
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_TRANSFORM_OP_OPTIMIZER_H_
|
|
@ -165,7 +165,7 @@ void GraphKernelFlags::Refresh() {
|
|||
}
|
||||
}
|
||||
// Dump flags so that people can check the setting.
|
||||
MS_LOG(INFO) << "GraphKernelFlags info: " << DumpAllFlags();
|
||||
MS_LOG(INFO) << "graph_kernel_flags = \"" << flags_cache_ << "\", all flags: " << DumpAllFlags();
|
||||
}
|
||||
|
||||
void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_map) {
|
||||
|
@ -188,10 +188,11 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
|
|||
reg.AddFlag("enable_recompute_fusion", &enable_recompute_fusion, opt_level >= OptLevel_2);
|
||||
reg.AddFlag("enable_parallel_fusion", &enable_parallel_fusion, opt_level == OptLevel_3);
|
||||
reg.AddFlag("enable_low_precision", &enable_low_precision);
|
||||
reg.AddFlag("fusion_ops_level", &fusion_ops_level, is_gpu ? OpLevel_MAX : OpLevel_0);
|
||||
reg.AddFlag("enable_trans_op_optimize", &enable_trans_op_optimize);
|
||||
|
||||
// Integer flags
|
||||
reg.AddFlag("online_tuning", &online_tuning);
|
||||
reg.AddFlag("fusion_ops_level", &fusion_ops_level, is_gpu ? OpLevel_MAX : OpLevel_0);
|
||||
|
||||
// String flags
|
||||
reg.AddFlag("repository_path", &repository_path);
|
||||
|
@ -217,6 +218,7 @@ std::string GraphKernelFlags::DumpAllFlags() const {
|
|||
json["enable_recompute_fusion"] = enable_recompute_fusion;
|
||||
json["enable_parallel_fusion"] = enable_parallel_fusion;
|
||||
json["enable_low_precision"] = enable_low_precision;
|
||||
json["enable_trans_op_optimize"] = enable_trans_op_optimize;
|
||||
|
||||
json["opt_level"] = opt_level;
|
||||
json["fusion_ops_level"] = fusion_ops_level;
|
||||
|
|
|
@ -95,6 +95,13 @@ class GraphKernelFlags {
|
|||
*/
|
||||
unsigned int fusion_ops_level{OpLevel_0};
|
||||
|
||||
/**
|
||||
* Enable optimization for transform operators (Transpose/TransData)
|
||||
*
|
||||
* Experimental feature, enabled by default when opt_level=3.
|
||||
*/
|
||||
bool enable_trans_op_optimize{false};
|
||||
|
||||
/**
|
||||
* Optimization level, value from 0 to 3.
|
||||
* 0: Disable GraphKernel
|
||||
|
|
|
@ -63,7 +63,7 @@ def set_graph_kernel_context():
|
|||
if context.get_context("device_target") == "GPU":
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
context.set_context(graph_kernel_flags="--enable_parallel_fusion "
|
||||
"--disable_expand_ops=BatchNorm,BatchNormGrad "
|
||||
"--enable_trans_op_optimize "
|
||||
"--disable_cluster_ops=ReduceMax,Reshape "
|
||||
"--enable_expand_ops=Conv2D")
|
||||
|
||||
|
|
Loading…
Reference in New Issue