!21235 Add pass TransformOpOptimizer

Merge pull request !21235 from DeshiChen/0719_elim_transform
This commit is contained in:
i-robot 2021-08-31 08:36:25 +00:00 committed by Gitee
commit 033b4706f2
12 changed files with 500 additions and 169 deletions

View File

@ -624,13 +624,13 @@ void ReorganizeEmptyGraph(const graphkernel::LiteGraphPtr &litegraph) {
auto &outputs = litegraph->GetOutputs();
for (size_t i = 0; i < outputs.size(); i++) {
if (outputs[i]->NodeType() == graphkernel::NType::Value) {
graphkernel::PrimOpPtr op_ptr = std::make_shared<graphkernel::BroadcastToOp>("BroadcastTo", "");
graphkernel::LiteGraph::GraphBuilder gb;
std::vector<int64_t> new_shape = {1};
op_ptr->Infer({outputs[i]}, {{"shape", MakeValue(new_shape)}});
auto op_ptr = gb.Emit("BroadcastTo", {outputs[i]}, {{"shape", MakeValue(new_shape)}});
litegraph->output()->SetInput(i, op_ptr);
} else if (outputs[i]->NodeType() == graphkernel::NType::Parameter) {
graphkernel::PrimOpPtr op_ptr = std::make_shared<graphkernel::ReshapeOp>("Reshape", "");
op_ptr->Infer({outputs[i]}, {{"shape", MakeValue(outputs[i]->shape)}});
graphkernel::LiteGraph::GraphBuilder gb;
auto op_ptr = gb.Emit("Reshape", {outputs[i]}, {{"shape", MakeValue(outputs[i]->shape)}});
litegraph->output()->SetInput(i, op_ptr);
}
}

View File

@ -48,6 +48,7 @@
#include "backend/optimizer/graph_kernel/uss_atomic_add.h"
#include "backend/optimizer/pass/getitem_tuple.h"
#include "backend/optimizer/graph_kernel/graph_kernel_pass_manager.h"
#include "backend/optimizer/graph_kernel/transform_op_optimizer.h"
#include "backend/optimizer/graph_kernel/rewrite_output_shape.h"
namespace mindspore {
@ -123,8 +124,12 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt1() const {
// Common subexpression elimination
pm->AddPass(std::make_shared<GraphKernelCSE>(), OptLevel_2);
// Elimate Redundant Complex op
// Eliminate Redundant Complex op
pm->AddPass(std::make_shared<EliminateRedundantComplex>(), OptLevel_2, false);
// Eliminate unnecessary transform ops
auto level = GetPassLevelByFlag(context::GraphKernelFlags::GetInstance().enable_trans_op_optimize);
pm->AddPass(std::make_shared<TransformOpOptimizer>(), level, is_gpu);
return pm;
}

View File

@ -102,12 +102,16 @@ NodePtr LiteGraph::GraphBuilder::Emit(const std::string &op, const NodePtrList &
std::string node_name) {
if (node_name.empty()) node_name = NewName();
PrimOpPtr op_ptr = CreateOp(op, node_name);
op_ptr->Infer(inputs, attrs);
auto baseinfo = op_ptr->Infer(inputs, attrs);
op_ptr->SetInputs(inputs);
op_ptr->SetAttrs(attrs);
op_ptr->SetBaseInfo(baseinfo);
return graph_->Add(op_ptr);
}
NodePtr LiteGraph::GraphBuilder::Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs,
const DAttrs &attrs, std::string node_name) {
if (node_name.empty()) node_name = NewName();
PrimOpPtr op_ptr = CreateOp(op, node_name);
op_ptr->SetInputs(inputs);
op_ptr->SetAttrs(attrs);

View File

@ -76,9 +76,6 @@ void Node::SetInputs(const NodePtrList &inputs) {
void Node::ReplaceWith(const NodePtr &other_node) {
if (this->users_.empty()) return;
if (this->NodeType() != NType::Primitive) {
MS_LOG(EXCEPTION) << "Only Primitive node can be replaced, but the node type is " << NodeType();
}
// copy the users before traversal
auto users = this->users_;
for (auto &user : users) {

View File

@ -59,7 +59,7 @@ struct NodeBase {
class Node;
using NodePtr = std::shared_ptr<Node>;
using NodePtrList = std::vector<NodePtr>;
class Node : public NodeBase {
class Node : public NodeBase, public std::enable_shared_from_this<Node> {
public:
Node(const NodeBase &baseinfo, const std::string &name) : NodeBase(baseinfo), name_(name) {}
virtual ~Node() {
@ -90,16 +90,13 @@ class Node : public NodeBase {
void SetAttr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; }
template <typename T>
T *As() {
return static_cast<T *>(this);
}
template <typename T>
const T *As() const {
return static_cast<const T *>(this);
std::shared_ptr<T> As() {
return std::static_pointer_cast<T>(shared_from_this());
}
const std::string &name() const { return name_; }
const DAttrs &attrs() const { return attrs_; }
const NodePtr &input(size_t i) const { return inputs_[i]; }
const NodePtrList &inputs() const { return inputs_; }
const std::unordered_map<Node *, std::set<size_t>> &users() const { return users_; }
@ -145,7 +142,7 @@ class ParamNode : public Node {
class OutputNode : public Node {
public:
OutputNode() : Node({{1}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, "") {}
OutputNode() : Node({{1}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, "Output") {}
void Dump(std::ostringstream &os) const override { ; }
NType NodeType() override { return NType::Output; }
};

View File

@ -81,13 +81,14 @@ void PrimOp::CheckFormat(const NodePtrList &inputs, const DAttrs &attrs) {
}
}
}
void PrimOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
NodeBase PrimOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
Check(inputs, attrs);
this->shape = InferShape(inputs, attrs);
this->type = InferType(inputs, attrs);
this->format = InferFormat(inputs, attrs);
this->attrs_ = attrs;
SetInputs(inputs);
NodeBase nodebase;
nodebase.shape = InferShape(inputs, attrs);
nodebase.type = InferType(inputs, attrs);
nodebase.format = InferFormat(inputs, attrs);
return nodebase;
}
void PrimOp::Dump(std::ostringstream &os) const {
@ -279,8 +280,8 @@ DFormat ElemwiseOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs)
return it == inputs.end() ? kOpFormat_DEFAULT : (*it)->format;
}
void ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
PrimOp::Infer(inputs, attrs);
NodeBase ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
auto nodebase = PrimOp::Infer(inputs, attrs);
auto IsBroadcast = [this](const NodePtrList &inputs) -> bool {
for (auto &ref : inputs) {
if (ref->shape.size() != this->shape.size()) return true;
@ -291,6 +292,7 @@ void ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
return false;
};
compute_type_ = IsBroadcast(inputs) ? BROADCAST : ELEMWISE;
return nodebase;
}
TypeId CastOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {

View File

@ -50,16 +50,8 @@ class PrimOp : public Node {
PrimOp(const std::string &op, const std::string &node_name, ComputeType compute)
: Node({{}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, node_name), op_(op), compute_type_(compute) {}
virtual void Check(const NodePtrList &inputs, const DAttrs &attrs);
virtual void CheckShape(const NodePtrList &inputs, const DAttrs &attrs) {}
virtual void CheckType(const NodePtrList &inputs, const DAttrs &attrs);
virtual void CheckFormat(const NodePtrList &inputs, const DAttrs &attrs);
virtual void Infer(const NodePtrList &inputs, const DAttrs &attrs);
virtual NodeBase Infer(const NodePtrList &inputs, const DAttrs &attrs);
virtual NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs, const std::string &op);
virtual DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->shape; }
virtual TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->type; }
virtual DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->format; }
void Dump(std::ostringstream &os) const override;
NType NodeType() override { return NType::Primitive; }
@ -68,6 +60,15 @@ class PrimOp : public Node {
ComputeType compute_type() const { return compute_type_; }
protected:
virtual void Check(const NodePtrList &inputs, const DAttrs &attrs);
virtual void CheckShape(const NodePtrList &inputs, const DAttrs &attrs) {}
virtual void CheckType(const NodePtrList &inputs, const DAttrs &attrs);
virtual void CheckFormat(const NodePtrList &inputs, const DAttrs &attrs);
virtual DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->shape; }
virtual TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->type; }
virtual DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->format; }
std::string op_;
ComputeType compute_type_;
};
@ -76,8 +77,9 @@ using PrimOpPtr = std::shared_ptr<PrimOp>;
class ElemwiseOp : public PrimOp {
public:
ElemwiseOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, ELEMWISE) {}
NodeBase Infer(const NodePtrList &inputs, const DAttrs &attrs) override;
void Infer(const NodePtrList &inputs, const DAttrs &attrs) override;
protected:
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override;
};
@ -86,6 +88,7 @@ class CastOp : public ElemwiseOp {
public:
CastOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Cast", node_name) {}
protected:
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override;
};
@ -93,6 +96,7 @@ class InplaceAssignOp : public ElemwiseOp {
public:
InplaceAssignOp(const std::string &op, const std::string &node_name) : ElemwiseOp("InplaceAssign", node_name) {}
protected:
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->shape; }
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->type; }
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->format; }
@ -102,6 +106,7 @@ class SelectOp : public ElemwiseOp {
public:
SelectOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Select", node_name) {}
protected:
void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override;
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[1]->type; }
};
@ -110,6 +115,7 @@ class CompareOp : public ElemwiseOp {
public:
CompareOp(const std::string &op, const std::string &node_name) : ElemwiseOp(op, node_name) {}
protected:
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeBool; }
};
@ -142,6 +148,7 @@ class ReshapeOp : public PrimOp {
public:
ReshapeOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, RESHAPE) {}
protected:
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override {
return attrs.find("format") == attrs.end() ? kOpFormat_DEFAULT
@ -153,6 +160,7 @@ class BroadcastToOp : public PrimOp {
public:
BroadcastToOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, BROADCAST) {}
protected:
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
};
@ -160,8 +168,8 @@ class ReduceOp : public PrimOp {
public:
ReduceOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, REDUCE) {}
protected:
void Check(const NodePtrList &inputs, const DAttrs &attrs) override;
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return kOpFormat_DEFAULT; };
};
@ -175,6 +183,7 @@ class Conv2dOp : public OpaqueOp {
public:
Conv2dOp(const std::string &op, const std::string &node_name) : OpaqueOp("Conv2D", node_name) {}
protected:
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override;
};
@ -183,6 +192,7 @@ class TransposeOp : public OpaqueOp {
public:
TransposeOp(const std::string &op, const std::string &node_name) : OpaqueOp("Transpose", node_name) {}
protected:
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override;
};
@ -191,6 +201,7 @@ class MatMulOp : public OpaqueOp {
public:
MatMulOp(const std::string &op, const std::string &node_name) : OpaqueOp("MatMul", node_name) {}
protected:
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override;
};
@ -199,6 +210,7 @@ class PadAkgOp : public OpaqueOp {
public:
PadAkgOp(const std::string &op, const std::string &node_name) : OpaqueOp("PadAkg", node_name) {}
protected:
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
};
@ -206,6 +218,7 @@ class UnPadAkgOp : public OpaqueOp {
public:
UnPadAkgOp(const std::string &op, const std::string &node_name) : OpaqueOp("UnPadAkg", node_name) {}
protected:
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
};
@ -213,6 +226,7 @@ class CImagOp : public ElemwiseOp {
public:
CImagOp(const std::string &op, const std::string &node_name) : ElemwiseOp("CImag", node_name) {}
protected:
void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override {
if (inputs[0]->type != TypeId::kNumberTypeComplex64) {
throw GKException("CImag's input[0] should be complex64");
@ -226,6 +240,7 @@ class CRealOp : public ElemwiseOp {
public:
CRealOp(const std::string &op, const std::string &node_name) : ElemwiseOp("CReal", node_name) {}
protected:
void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override {
if (inputs[0]->type != TypeId::kNumberTypeComplex64) {
throw GKException("CReal's input[0] should be complex64");
@ -239,8 +254,8 @@ class ComplexOp : public ElemwiseOp {
public:
ComplexOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Complex", node_name) {}
protected:
void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override;
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeComplex64; }
};
} // namespace graphkernel

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,175 +13,141 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/transform_op_optimizer.h"
#include <algorithm>
#include <iostream>
#include <vector>
#include <queue>
#include <memory>
#include <set>
#include <map>
#include <utility>
#include <string>
#include "base/core_ops.h"
#include "ir/graph_utils.h"
#include "debug/common.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/graph_kernel/model/lite_graph.h"
#include "backend/optimizer/graph_kernel/model/op_register.h"
namespace mindspore {
namespace opt {
namespace {
enum Format { kFormatUnknown, kFormatA, kFormatB };
enum FormatType { kFormatUnknown, kFormatA, kFormatB };
enum TransOpType { kTransAB, kTransBA };
struct Edge {
size_t from;
size_t to;
size_t val;
size_t next;
Edge(size_t From, size_t To, size_t Val, size_t Next) {
from = From;
to = To;
val = Val;
next = Next;
}
size_t capacity;
};
struct Node {
int head;
int cur;
int depth;
size_t pre;
Format format;
Node() {
head = -1;
cur = -1;
depth = -1;
pre = 0;
format = kFormatB;
}
struct Vertex {
FormatType format{kFormatB};
size_t depth{0};
std::vector<size_t> out_edges;
};
constexpr size_t INF = static_cast<size_t>(1) << 30;
class MinCut {
public:
// Connect the source_node to the node_with_a_certain_formatA
// or Connect the node_with_a_certain_formatB to the sink_node
void Add_1(size_t from, size_t to) {
edges_.emplace_back(from, to, INF, nodes_[from].head);
nodes_[from].head = edges_count_++;
edges_.emplace_back(to, from, 0, nodes_[to].head);
nodes_[to].head = edges_count_++;
private:
// Add the bidirectional edges for the vertex `from` and `to`.
// the two edge ids are adjacent in vector, x and x+1 (x are 0,2,4,...)
// we can use (i xor 1) to get the inverse edge for any edge i.
// e.g. edge_0 and edge_1 are a couple, 0^1=1, 1^1=0.
void AddEdge(size_t from, size_t to, size_t capacity, size_t inv_capacity) {
edges_.emplace_back(Edge{to, capacity});
nodes_[from].out_edges.emplace_back(edges_.size() - 1);
// inverse edge
edges_.emplace_back(Edge{from, inv_capacity});
nodes_[to].out_edges.emplace_back(edges_.size() - 1);
}
// Split one origin_node into two new_nodes and connect them
void Add_2(size_t nodes_id) {
edges_.emplace_back(nodes_id, nodes_id + origin_nodes_num_, 1, nodes_[nodes_id].head);
nodes_[nodes_id].head = edges_count_++;
edges_.emplace_back(nodes_id + origin_nodes_num_, nodes_id, 1, nodes_[nodes_id + origin_nodes_num_].head);
nodes_[nodes_id + origin_nodes_num_].head = edges_count_++;
}
// After splitting the origin_nodes, construct the new_edges based on the original_edges
void Add_3(size_t from, size_t to) {
edges_.emplace_back(from + origin_nodes_num_, to, 1, nodes_[from + origin_nodes_num_].head);
nodes_[from + origin_nodes_num_].head = edges_count_++;
edges_.emplace_back(to, from + origin_nodes_num_, 1, nodes_[to].head);
nodes_[to].head = edges_count_++;
}
void BFS() {
bool BfsSetDepth() {
std::queue<size_t> bfs_queue;
nodes_[sink_id_].depth = 0;
bfs_queue.push(sink_id_);
for (auto &node : nodes_) {
node.depth = 0;
}
nodes_[source_id_].depth = 1;
bfs_queue.push(source_id_);
while (!bfs_queue.empty()) {
size_t temp_node = bfs_queue.front();
auto edge_from = bfs_queue.front();
bfs_queue.pop();
depth_num_[nodes_[temp_node].depth]++;
for (size_t i = nodes_[temp_node].head; ~i; i = edges_[i].next) {
if (edges_[i ^ 1].val && nodes_[edges_[i].to].depth == -1) {
nodes_[edges_[i].to].depth = nodes_[temp_node].depth + 1;
bfs_queue.push(edges_[i].to);
for (auto e_id : nodes_[edge_from].out_edges) {
auto edge_to = edges_[e_id].to;
if (edges_[e_id].capacity > 0 && nodes_[edge_to].depth == 0) {
nodes_[edge_to].depth = nodes_[edge_from].depth + 1;
bfs_queue.push(edge_to);
}
}
}
return nodes_[sink_id_].depth > 0;
}
void EdgeValueUpdate() {
size_t k = sink_id_, flow = INF;
while (k != source_id_) {
if (edges_[nodes_[k].pre].val < flow) {
flow = edges_[nodes_[k].pre].val;
size_t DfsMaxFlow(size_t node, size_t flow) {
if (node == sink_id_) return flow;
size_t max_flow = 0;
for (size_t e_id : nodes_[node].out_edges) {
if ((edges_[e_id].capacity > 0) && (nodes_[node].depth + 1 == nodes_[edges_[e_id].to].depth)) {
auto tmp_flow = DfsMaxFlow(edges_[e_id].to, std::min(flow, edges_[e_id].capacity));
if (tmp_flow > 0) {
max_flow += tmp_flow;
flow -= tmp_flow;
edges_[e_id].capacity -= tmp_flow;
edges_[e_id ^ 1].capacity += tmp_flow;
}
}
k = edges_[nodes_[k].pre].from;
}
k = sink_id_;
while (k != source_id_) {
edges_[nodes_[k].pre].val -= flow;
edges_[nodes_[k].pre ^ 1].val += flow;
k = edges_[nodes_[k].pre].from;
}
return max_flow;
}
void ISAP() {
size_t node_id = source_id_;
int maxdep = 2 * origin_nodes_num_ + 2;
BFS();
for (size_t i = source_id_; i <= sink_id_; ++i) {
nodes_[i].cur = nodes_[i].head;
}
while (nodes_[source_id_].depth <= maxdep) {
if (node_id == sink_id_) {
EdgeValueUpdate();
node_id = source_id_;
}
bool can_arrive = false;
for (size_t i = nodes_[node_id].cur; ~i; i = edges_[i].next) {
if (edges_[i].val && nodes_[edges_[i].to].depth + 1 == nodes_[node_id].depth) {
can_arrive = true;
nodes_[edges_[i].to].pre = i;
nodes_[node_id].cur = i;
node_id = edges_[i].to;
break;
}
}
if (!can_arrive) {
int mindep = 2 * origin_nodes_num_ + 2;
for (size_t i = nodes_[node_id].head; ~i; i = edges_[i].next) {
if (nodes_[edges_[i].to].depth < mindep && edges_[i].val) {
mindep = nodes_[edges_[i].to].depth;
}
}
--depth_num_[nodes_[node_id].depth];
if (!depth_num_[nodes_[node_id].depth]) {
break;
}
nodes_[node_id].depth = mindep + 1;
depth_num_[nodes_[node_id].depth]++;
nodes_[node_id].cur = nodes_[node_id].head;
if (node_id != source_id_) {
node_id = edges_[nodes_[node_id].pre].from;
}
}
void Dinic() {
while (BfsSetDepth()) {
(void)DfsMaxFlow(source_id_, INF);
}
}
void SetFormat(size_t node_id) {
nodes_[node_id].format = kFormatA;
for (size_t i = nodes_[node_id].head; ~i; i = edges_[i].next) {
if (edges_[i].val && nodes_[edges_[i].to].format != kFormatA) {
for (size_t i : nodes_[node_id].out_edges) {
if (edges_[i].capacity > 0 && nodes_[edges_[i].to].format != kFormatA) {
SetFormat(edges_[i].to);
}
}
}
MinCut(std::vector<std::pair<size_t, Format>> original_nodes, std::vector<std::pair<size_t, size_t>> original_edges)
: origin_nodes_num_(original_nodes.size()),
sink_id_(2 * origin_nodes_num_ + 1),
depth_num_(std::vector<size_t>(2 * origin_nodes_num_ + 2, 0)),
nodes_(std::vector<Node>(2 * origin_nodes_num_ + 2, Node())),
original_edges_(std::move(original_edges)) {
void BuildGraph(const std::vector<std::pair<size_t, FormatType>> &original_nodes) {
for (size_t i = 0; i < origin_nodes_num_; ++i) {
// link the source node to the nodes with FormatA,
// link the nodes with FormatB to the sink node.
if (original_nodes[i].second == kFormatA) {
Add_1(source_id_, original_nodes[i].first);
AddEdge(source_id_, original_nodes[i].first, INF, 0);
} else if (original_nodes[i].second == kFormatB) {
Add_1(original_nodes[i].first, sink_id_);
AddEdge(original_nodes[i].first, sink_id_, INF, 0);
}
Add_2(original_nodes[i].first);
// each nodes was split into two part, input part and output part.
// the input part's id is the original node's id, the output part's id is input id + origin_nodes_num_.
AddEdge(original_nodes[i].first, original_nodes[i].first + origin_nodes_num_, 1, 1);
}
for (auto i : original_edges_) {
Add_3(i.first, i.second);
for (auto e : original_edges_) {
auto from = e.first, to = e.second;
AddEdge(from + origin_nodes_num_, to, 1, 1);
}
ISAP();
}
public:
MinCut(const std::vector<std::pair<size_t, FormatType>> &original_nodes,
const std::vector<std::pair<size_t, size_t>> &original_edges)
: origin_nodes_num_(original_nodes.size()),
sink_id_(2 * original_nodes.size() + 1), // source_id_ is 0
nodes_(2 * original_nodes.size() + 2), // double nodes, and source_node/sink_node
original_edges_(original_edges) {
BuildGraph(original_nodes);
}
void Run() {
Dinic();
SetFormat(source_id_);
}
@ -213,11 +179,285 @@ class MinCut {
size_t origin_nodes_num_;
size_t source_id_{0};
size_t sink_id_;
int edges_count_{0};
std::vector<size_t> depth_num_;
std::vector<Node> nodes_;
std::vector<Vertex> nodes_;
std::vector<Edge> edges_;
std::vector<std::pair<size_t, size_t>> original_edges_;
};
} // namespace
using graphkernel::LiteGraph;
using graphkernel::LiteGraphPtr;
using graphkernel::Node;
using graphkernel::NodePtr;
using graphkernel::NodePtrList;
using graphkernel::NType;
using graphkernel::PrimOp;
using graphkernel::PrimOpPtr;
class TransformOp {
public:
explicit TransformOp(const NodePtr &node)
: op_(node->As<PrimOp>()->op()), format_a_(node->input(0)->format), format_b_(node->format) {}
bool IsTransformOp(const NodePtr &node) {
if (node->NodeType() != NType::Primitive || node->As<PrimOp>()->op() != op_) {
return false;
}
if (node->input(0)->format == format_a_ && node->format == format_b_) {
return true;
} else if (node->input(0)->format == format_b_ && node->format == format_a_) {
return true;
}
return false;
}
FormatType GetFormatType(const std::string &fmt) {
return fmt == format_a_ ? FormatType::kFormatA : FormatType::kFormatB;
}
NodePtr GenTransformOp(TransOpType trans_type) {
// Only support Transpose now
static std::map<std::pair<std::string, std::string>, std::vector<int64_t>> perm_map = {
{{kOpFormat_DEFAULT, kOpFormat_NHWC}, {0, 2, 3, 1}},
{{kOpFormat_NCHW, kOpFormat_NHWC}, {0, 2, 3, 1}},
{{kOpFormat_NHWC, kOpFormat_NCHW}, {0, 3, 1, 2}},
{{kOpFormat_NHWC, kOpFormat_DEFAULT}, {0, 3, 1, 2}},
};
std::vector<int64_t> perm;
if (trans_type == TransOpType::kTransAB) {
perm = perm_map[{format_a_, format_b_}];
} else {
perm = perm_map[{format_b_, format_a_}];
}
if (perm.empty()) {
std::ostringstream oss;
oss << "unsupported format: " << format_a_ << " to " << format_b_;
throw graphkernel::GKException(oss.str());
}
auto op = graphkernel::OpRegistry::Instance().NewOp("Transpose", "new_trans");
op->SetAttr("perm", MakeValue(perm));
return op;
}
private:
std::string op_;
std::string format_a_;
std::string format_b_;
};
bool IsFlexibleOp(const NodePtr &node) {
static std::set<std::string> format_flexible_ops = {
"Abs", "Add", "Sub", "Mul", "Round", "Cast", "Neg", "Exp", "Log",
"Pow", "Minimum", "Maximum", "Rsqrt", "Sqrt", "Reciprocal", "Tanh", "Sin", "Cos",
"Asin", "ACos", "RealDiv", "Equal", "Greater", "GreaterEqual", "Less", "LessEqual", "Sign"};
if (node->NodeType() != NType::Primitive) {
return false;
}
if (format_flexible_ops.count(node->As<PrimOp>()->op()) == 0) {
return false;
}
// check the input and output formats are all the same, except ConstValue.
for (auto &inp : node->inputs()) {
if (inp->NodeType() != NType::Value && inp->format != node->format) {
return false;
}
}
return true;
}
class Mutator {
public:
explicit Mutator(const NodePtr &node) : op_checker_(node), basenode_(node), ori_node_(1) {}
bool Run() {
VisitNode(basenode_);
if (flexible_ops_.empty()) return false;
// remove transform ops in litegraph
RemoveTransOp();
GenFormatGraph();
RebuildLiteGraph();
return true;
}
private:
// visit nodes bidirectionally
void VisitNode(const NodePtr &node) {
if (visited_.count(node) > 0) return;
visited_.insert(node);
if (op_checker_.IsTransformOp(node)) {
trans_ops_.insert(node);
} else if (!IsFlexibleOp(node)) {
if (node->NodeType() != NType::Output) {
fmt_type[{node, -1}] = op_checker_.GetFormatType(node->format);
}
if (node->NodeType() != NType::Parameter) {
for (size_t i = 0; i < node->inputs().size(); i++) {
if (node->input(i)->NodeType() == NType::Value) {
continue;
}
fmt_type[{node, i}] = op_checker_.GetFormatType(node->input(i)->format);
}
}
return;
} else {
flexible_ops_.insert(node);
fmt_type[{node, -1}] = FormatType::kFormatUnknown;
}
for (auto &input : node->inputs()) {
if (input->NodeType() != NType::Value) {
VisitNode(input);
}
}
for (auto &user : node->users()) {
VisitNode(user.first->shared_from_this());
}
}
void RemoveTransOp() {
for (auto &node : trans_ops_) {
visited_.erase(node);
node->ReplaceWith(node->input(0));
// clear inputs, so that the node will not be the basenode again.
node->SetInputs({});
}
trans_ops_.clear();
}
void GenFormatGraph() {
for (auto &node : visited_) {
if (node->NodeType() == NType::Parameter) continue;
bool is_flexible = (flexible_ops_.find(node) != flexible_ops_.end());
size_t cur_id = 0;
if (is_flexible) {
cur_id = GetId({node, -1});
}
for (size_t i = 0; i < node->inputs().size(); i++) {
if (visited_.count(node->input(i)) == 0) continue;
if (!is_flexible) {
cur_id = GetId({node, SizeToInt(i)});
}
auto input_id = GetId({node->input(i), -1});
graph_edges_.emplace_back(input_id, cur_id);
}
}
}
void RebuildLiteGraph() {
MinCut min_cut(graph_vertex_, graph_edges_);
min_cut.Run();
for (auto [node_id, trans_type] : min_cut.GetOneNodeOps()) {
if (ori_node_[node_id].second != -1) {
MS_LOG(EXCEPTION) << "OneNodeOp should be the output edge. node_id:" << node_id
<< " index:" << ori_node_[node_id].second;
}
auto trans_op = op_checker_.GenTransformOp(trans_type);
ori_node_[node_id].first->ReplaceWith(trans_op);
trans_op->SetInputs({ori_node_[node_id].first});
}
std::map<size_t, NodePtr> trans_op_cache;
for (auto [edge, trans_type] : min_cut.GetTwoNodeOps()) {
auto node_id_from = edge.first;
auto node_id_to = edge.second;
if (ori_node_[node_id_from].second != -1) {
MS_LOG(EXCEPTION) << "node_from should be the output edge. node_id:" << node_id_from
<< " index:" << ori_node_[node_id_from].second;
}
auto node_from = ori_node_[node_id_from].first;
auto node_to = ori_node_[node_id_to].first;
auto &trans_op = trans_op_cache[node_id_from];
if (trans_op == nullptr) {
trans_op = op_checker_.GenTransformOp(trans_type);
trans_op->SetInputs({node_from});
}
if (ori_node_[node_id_to].second >= 0) {
node_to->SetInput(ori_node_[node_id_to].second, trans_op);
} else {
for (size_t i = 0; i < node_to->inputs().size(); i++) {
if (node_to->input(i) == node_from) {
node_to->SetInput(i, trans_op);
}
}
}
}
}
size_t GetId(const std::pair<NodePtr, int> &node) {
// the nodes are indexed from 1 in the MinCut model.
auto &id = node_id_[node];
if (id == 0) {
id = node_id_.size();
ori_node_.push_back(node);
// set format_type for new id.
graph_vertex_.emplace_back(id, fmt_type[node]);
}
return id;
}
TransformOp op_checker_;
NodePtr basenode_;
std::set<NodePtr> flexible_ops_;
std::set<NodePtr> trans_ops_;
std::set<NodePtr> visited_;
std::map<std::pair<NodePtr, int>, FormatType> fmt_type;
std::map<std::pair<NodePtr, int>, size_t> node_id_;
std::vector<std::pair<NodePtr, int>> ori_node_;
std::vector<std::pair<size_t, FormatType>> graph_vertex_;
std::vector<std::pair<size_t, size_t>> graph_edges_;
};
bool TransformOpOptimizer::Process(const LiteGraphPtr &litegraph, const std::string &trans_op_name) {
ori_trans_op_num_ = 0;
auto &ops = litegraph->ops();
std::set<NodePtr> visited;
bool changed = true;
auto check_is_trans_op = [&trans_op_name](const NodePtr &node) { return node->As<PrimOp>()->op() == trans_op_name; };
auto ori_trans_op_num = std::count_if(ops.begin(), ops.end(), check_is_trans_op);
for (auto &op : ops) {
if (check_is_trans_op(op) && !op->inputs().empty() && op->input(0)->format != op->format) {
auto mutator = Mutator(op);
changed = mutator.Run() || changed;
}
}
if (!changed) return false;
auto &new_ops = litegraph->GetOrderedNodes();
auto new_trans_op_num = std::count_if(new_ops.begin(), new_ops.end(), check_is_trans_op);
if (new_trans_op_num >= ori_trans_op_num) {
return false;
}
for (auto &op : new_ops) {
op->SetBaseInfo(op->As<PrimOp>()->Infer(op->inputs(), op->attrs()));
}
return true;
}
bool TransformOpOptimizer::Run(const FuncGraphPtr &kernel_graph) {
auto mng = kernel_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto todos = TopoSort(kernel_graph->get_return());
bool changed = false;
for (auto node : todos) {
if (!AnfAlgo::IsGraphKernel(node)) continue;
try {
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
auto litegraph = AnfGraph2LiteGraph(sub_func_graph);
if (Process(litegraph)) {
changed = true;
AnfNodePtrList outputs;
auto new_funcgraph = LiteGraph2AnfGraph(litegraph, &outputs);
new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
auto cnode = node->cast<CNodePtr>();
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs, outputs);
SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
mng->Replace(node, new_node);
mng->AddFuncGraph(new_funcgraph);
}
} catch (const graphkernel::GKException &e) {
MS_LOG(WARNING) << e.what() << ", so we undo airthmetic simplify for this graph";
}
}
return changed;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,62 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_TRANSFORM_OP_OPTIMIZER_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_TRANSFORM_OP_OPTIMIZER_H_
#include <string>
#include <vector>
#include "backend/optimizer/common/pass.h"
#include "ir/func_graph.h"
#include "backend/optimizer/graph_kernel/model/lite_graph.h"
namespace mindspore {
namespace opt {
/**
* @brief Eliminate the unnecessary transformation ops when the other operators
* are format flexible.
* @example
* %1 = Transpose(p0) // NCHW to NHWC
* %2 = Transpose(p1) // NCHW to NHWC
* %3 = Add(%1, %2)
* return %3
* -->
* %1 = Add(p0, p1)
* %2 = Transpose(%1) // NCHW to NHWC
* return %2
* @example
* %1 = Transpose(p0) // NCHW to NHWC
* %2 = Transpose(p1) // NCHW to NHWC
* %3 = Add(%1, %2)
* %4 = Transpose(%3) // NHWC to NCHW
* return %4
* -->
* %1 = Add(p0, p1)
* return %1
*/
class TransformOpOptimizer : public Pass {
public:
TransformOpOptimizer() : Pass("transform_op_optimizer") {}
~TransformOpOptimizer() = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
bool Process(const graphkernel::LiteGraphPtr &litegraph, const std::string &trans_op_name = "Transpose");
bool IsFlexibleOp(const graphkernel::NodePtr &node);
size_t ori_trans_op_num_{0};
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_TRANSFORM_OP_OPTIMIZER_H_

View File

@ -165,7 +165,7 @@ void GraphKernelFlags::Refresh() {
}
}
// Dump flags so that people can check the setting.
MS_LOG(INFO) << "GraphKernelFlags info: " << DumpAllFlags();
MS_LOG(INFO) << "graph_kernel_flags = \"" << flags_cache_ << "\", all flags: " << DumpAllFlags();
}
void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_map) {
@ -188,10 +188,11 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
reg.AddFlag("enable_recompute_fusion", &enable_recompute_fusion, opt_level >= OptLevel_2);
reg.AddFlag("enable_parallel_fusion", &enable_parallel_fusion, opt_level == OptLevel_3);
reg.AddFlag("enable_low_precision", &enable_low_precision);
reg.AddFlag("fusion_ops_level", &fusion_ops_level, is_gpu ? OpLevel_MAX : OpLevel_0);
reg.AddFlag("enable_trans_op_optimize", &enable_trans_op_optimize);
// Integer flags
reg.AddFlag("online_tuning", &online_tuning);
reg.AddFlag("fusion_ops_level", &fusion_ops_level, is_gpu ? OpLevel_MAX : OpLevel_0);
// String flags
reg.AddFlag("repository_path", &repository_path);
@ -217,6 +218,7 @@ std::string GraphKernelFlags::DumpAllFlags() const {
json["enable_recompute_fusion"] = enable_recompute_fusion;
json["enable_parallel_fusion"] = enable_parallel_fusion;
json["enable_low_precision"] = enable_low_precision;
json["enable_trans_op_optimize"] = enable_trans_op_optimize;
json["opt_level"] = opt_level;
json["fusion_ops_level"] = fusion_ops_level;

View File

@ -95,6 +95,13 @@ class GraphKernelFlags {
*/
unsigned int fusion_ops_level{OpLevel_0};
/**
* Enable optimization for transform operators (Transpose/TransData)
*
* Experimental feature, enabled by default when opt_level=3.
*/
bool enable_trans_op_optimize{false};
/**
* Optimization level, value from 0 to 3.
* 0: Disable GraphKernel

View File

@ -63,7 +63,7 @@ def set_graph_kernel_context():
if context.get_context("device_target") == "GPU":
context.set_context(enable_graph_kernel=True)
context.set_context(graph_kernel_flags="--enable_parallel_fusion "
"--disable_expand_ops=BatchNorm,BatchNormGrad "
"--enable_trans_op_optimize "
"--disable_cluster_ops=ReduceMax,Reshape "
"--enable_expand_ops=Conv2D")