From e6b1855174fe00df9b64c23987c276692576834f Mon Sep 17 00:00:00 2001 From: dayschan Date: Tue, 17 Aug 2021 19:09:43 +0800 Subject: [PATCH] Completed pass transform_op_optimizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit modifications for pass transform_op_optimizer: 1. Changed the maxflow-mincut algorithm to the Dinic's Algorithm, since bug exists in the original ISAP codes. if the algorithm is slow, we can apply some optimization for it. (e.g. current-arc optimization) 2. Added the pass TransformOpOptimizer in OptLevel_3. this pass collects nodes around the specific transform operator (only Transpose now), and use the mincut algorithm to get a plan, then re-link the original graph and re-inference the shape and format of graph. modifications for litegraph: 1. the class Node inherits from std::enable_shared_from_this,so we can get a shared_ptr by a pure pointer. 2. modified the Infer interface. it don't change the node, only inference the infos and return them. --- .../graph_kernel/arithmetic_simplify.cc | 8 +- .../graph_kernel/graph_kernel_optimization.cc | 7 +- .../graph_kernel/model/lite_graph.cc | 6 +- .../optimizer/graph_kernel/model/node.cc | 3 - .../optimizer/graph_kernel/model/node.h | 13 +- .../optimizer/graph_kernel/model/op_node.cc | 18 +- .../optimizer/graph_kernel/model/op_node.h | 39 +- .../graph_kernel/transform_op_optimizer.cc | 498 +++++++++++++----- .../graph_kernel/transform_op_optimizer.h | 62 +++ .../ccsrc/utils/context/graph_kernel_flags.cc | 6 +- .../ccsrc/utils/context/graph_kernel_flags.h | 7 + .../official/cv/yolov3_darknet53/train.py | 2 +- 12 files changed, 500 insertions(+), 169 deletions(-) create mode 100644 mindspore/ccsrc/backend/optimizer/graph_kernel/transform_op_optimizer.h diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc index 4ea6b813056..dd747dbb430 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc @@ -624,13 +624,13 @@ void ReorganizeEmptyGraph(const graphkernel::LiteGraphPtr &litegraph) { auto &outputs = litegraph->GetOutputs(); for (size_t i = 0; i < outputs.size(); i++) { if (outputs[i]->NodeType() == graphkernel::NType::Value) { - graphkernel::PrimOpPtr op_ptr = std::make_shared("BroadcastTo", ""); + graphkernel::LiteGraph::GraphBuilder gb; std::vector new_shape = {1}; - op_ptr->Infer({outputs[i]}, {{"shape", MakeValue(new_shape)}}); + auto op_ptr = gb.Emit("BroadcastTo", {outputs[i]}, {{"shape", MakeValue(new_shape)}}); litegraph->output()->SetInput(i, op_ptr); } else if (outputs[i]->NodeType() == graphkernel::NType::Parameter) { - graphkernel::PrimOpPtr op_ptr = std::make_shared("Reshape", ""); - op_ptr->Infer({outputs[i]}, {{"shape", MakeValue(outputs[i]->shape)}}); + graphkernel::LiteGraph::GraphBuilder gb; + auto op_ptr = gb.Emit("Reshape", {outputs[i]}, {{"shape", MakeValue(outputs[i]->shape)}}); litegraph->output()->SetInput(i, op_ptr); } } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc index af6d6246dfe..a768c532cdf 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc @@ -48,6 +48,7 @@ #include "backend/optimizer/graph_kernel/uss_atomic_add.h" #include "backend/optimizer/pass/getitem_tuple.h" #include "backend/optimizer/graph_kernel/graph_kernel_pass_manager.h" +#include "backend/optimizer/graph_kernel/transform_op_optimizer.h" #include "backend/optimizer/graph_kernel/rewrite_output_shape.h" namespace mindspore { @@ -123,8 +124,12 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt1() const { // Common subexpression elimination pm->AddPass(std::make_shared(), OptLevel_2); - // Elimate Redundant Complex op + // Eliminate Redundant Complex op pm->AddPass(std::make_shared(), OptLevel_2, false); + + // Eliminate unnecessary transform ops + auto level = GetPassLevelByFlag(context::GraphKernelFlags::GetInstance().enable_trans_op_optimize); + pm->AddPass(std::make_shared(), level, is_gpu); return pm; } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/lite_graph.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/lite_graph.cc index d113064a337..f11e333ac20 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/lite_graph.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/lite_graph.cc @@ -102,12 +102,16 @@ NodePtr LiteGraph::GraphBuilder::Emit(const std::string &op, const NodePtrList & std::string node_name) { if (node_name.empty()) node_name = NewName(); PrimOpPtr op_ptr = CreateOp(op, node_name); - op_ptr->Infer(inputs, attrs); + auto baseinfo = op_ptr->Infer(inputs, attrs); + op_ptr->SetInputs(inputs); + op_ptr->SetAttrs(attrs); + op_ptr->SetBaseInfo(baseinfo); return graph_->Add(op_ptr); } NodePtr LiteGraph::GraphBuilder::Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs, const DAttrs &attrs, std::string node_name) { + if (node_name.empty()) node_name = NewName(); PrimOpPtr op_ptr = CreateOp(op, node_name); op_ptr->SetInputs(inputs); op_ptr->SetAttrs(attrs); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/node.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/node.cc index 7595c974391..d49699bde2f 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/node.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/node.cc @@ -76,9 +76,6 @@ void Node::SetInputs(const NodePtrList &inputs) { void Node::ReplaceWith(const NodePtr &other_node) { if (this->users_.empty()) return; - if (this->NodeType() != NType::Primitive) { - MS_LOG(EXCEPTION) << "Only Primitive node can be replaced, but the node type is " << NodeType(); - } // copy the users before traversal auto users = this->users_; for (auto &user : users) { diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/node.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/node.h index 7c34218f14e..f34149f530f 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/node.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/node.h @@ -59,7 +59,7 @@ struct NodeBase { class Node; using NodePtr = std::shared_ptr; using NodePtrList = std::vector; -class Node : public NodeBase { +class Node : public NodeBase, public std::enable_shared_from_this { public: Node(const NodeBase &baseinfo, const std::string &name) : NodeBase(baseinfo), name_(name) {} virtual ~Node() { @@ -90,16 +90,13 @@ class Node : public NodeBase { void SetAttr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; } template - T *As() { - return static_cast(this); - } - template - const T *As() const { - return static_cast(this); + std::shared_ptr As() { + return std::static_pointer_cast(shared_from_this()); } const std::string &name() const { return name_; } const DAttrs &attrs() const { return attrs_; } + const NodePtr &input(size_t i) const { return inputs_[i]; } const NodePtrList &inputs() const { return inputs_; } const std::unordered_map> &users() const { return users_; } @@ -145,7 +142,7 @@ class ParamNode : public Node { class OutputNode : public Node { public: - OutputNode() : Node({{1}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, "") {} + OutputNode() : Node({{1}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, "Output") {} void Dump(std::ostringstream &os) const override { ; } NType NodeType() override { return NType::Output; } }; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.cc index 1ec5b2f2b3c..16e8a0b231c 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.cc @@ -81,13 +81,14 @@ void PrimOp::CheckFormat(const NodePtrList &inputs, const DAttrs &attrs) { } } } -void PrimOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) { + +NodeBase PrimOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) { Check(inputs, attrs); - this->shape = InferShape(inputs, attrs); - this->type = InferType(inputs, attrs); - this->format = InferFormat(inputs, attrs); - this->attrs_ = attrs; - SetInputs(inputs); + NodeBase nodebase; + nodebase.shape = InferShape(inputs, attrs); + nodebase.type = InferType(inputs, attrs); + nodebase.format = InferFormat(inputs, attrs); + return nodebase; } void PrimOp::Dump(std::ostringstream &os) const { @@ -279,8 +280,8 @@ DFormat ElemwiseOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) return it == inputs.end() ? kOpFormat_DEFAULT : (*it)->format; } -void ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) { - PrimOp::Infer(inputs, attrs); +NodeBase ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) { + auto nodebase = PrimOp::Infer(inputs, attrs); auto IsBroadcast = [this](const NodePtrList &inputs) -> bool { for (auto &ref : inputs) { if (ref->shape.size() != this->shape.size()) return true; @@ -291,6 +292,7 @@ void ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) { return false; }; compute_type_ = IsBroadcast(inputs) ? BROADCAST : ELEMWISE; + return nodebase; } TypeId CastOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) { diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.h index fd59c677ce8..b563bf3cc52 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.h @@ -50,16 +50,8 @@ class PrimOp : public Node { PrimOp(const std::string &op, const std::string &node_name, ComputeType compute) : Node({{}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, node_name), op_(op), compute_type_(compute) {} - virtual void Check(const NodePtrList &inputs, const DAttrs &attrs); - virtual void CheckShape(const NodePtrList &inputs, const DAttrs &attrs) {} - virtual void CheckType(const NodePtrList &inputs, const DAttrs &attrs); - virtual void CheckFormat(const NodePtrList &inputs, const DAttrs &attrs); - - virtual void Infer(const NodePtrList &inputs, const DAttrs &attrs); + virtual NodeBase Infer(const NodePtrList &inputs, const DAttrs &attrs); virtual NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs, const std::string &op); - virtual DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->shape; } - virtual TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->type; } - virtual DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->format; } void Dump(std::ostringstream &os) const override; NType NodeType() override { return NType::Primitive; } @@ -68,6 +60,15 @@ class PrimOp : public Node { ComputeType compute_type() const { return compute_type_; } protected: + virtual void Check(const NodePtrList &inputs, const DAttrs &attrs); + virtual void CheckShape(const NodePtrList &inputs, const DAttrs &attrs) {} + virtual void CheckType(const NodePtrList &inputs, const DAttrs &attrs); + virtual void CheckFormat(const NodePtrList &inputs, const DAttrs &attrs); + + virtual DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->shape; } + virtual TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->type; } + virtual DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->format; } + std::string op_; ComputeType compute_type_; }; @@ -76,8 +77,9 @@ using PrimOpPtr = std::shared_ptr; class ElemwiseOp : public PrimOp { public: ElemwiseOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, ELEMWISE) {} + NodeBase Infer(const NodePtrList &inputs, const DAttrs &attrs) override; - void Infer(const NodePtrList &inputs, const DAttrs &attrs) override; + protected: DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override; }; @@ -86,6 +88,7 @@ class CastOp : public ElemwiseOp { public: CastOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Cast", node_name) {} + protected: TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override; }; @@ -93,6 +96,7 @@ class InplaceAssignOp : public ElemwiseOp { public: InplaceAssignOp(const std::string &op, const std::string &node_name) : ElemwiseOp("InplaceAssign", node_name) {} + protected: DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->shape; } TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->type; } DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->format; } @@ -102,6 +106,7 @@ class SelectOp : public ElemwiseOp { public: SelectOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Select", node_name) {} + protected: void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override; TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[1]->type; } }; @@ -110,6 +115,7 @@ class CompareOp : public ElemwiseOp { public: CompareOp(const std::string &op, const std::string &node_name) : ElemwiseOp(op, node_name) {} + protected: TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeBool; } }; @@ -142,6 +148,7 @@ class ReshapeOp : public PrimOp { public: ReshapeOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, RESHAPE) {} + protected: DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return attrs.find("format") == attrs.end() ? kOpFormat_DEFAULT @@ -153,6 +160,7 @@ class BroadcastToOp : public PrimOp { public: BroadcastToOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, BROADCAST) {} + protected: DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; }; @@ -160,8 +168,8 @@ class ReduceOp : public PrimOp { public: ReduceOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, REDUCE) {} + protected: void Check(const NodePtrList &inputs, const DAttrs &attrs) override; - DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return kOpFormat_DEFAULT; }; }; @@ -175,6 +183,7 @@ class Conv2dOp : public OpaqueOp { public: Conv2dOp(const std::string &op, const std::string &node_name) : OpaqueOp("Conv2D", node_name) {} + protected: DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override; }; @@ -183,6 +192,7 @@ class TransposeOp : public OpaqueOp { public: TransposeOp(const std::string &op, const std::string &node_name) : OpaqueOp("Transpose", node_name) {} + protected: DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override; }; @@ -191,6 +201,7 @@ class MatMulOp : public OpaqueOp { public: MatMulOp(const std::string &op, const std::string &node_name) : OpaqueOp("MatMul", node_name) {} + protected: DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override; }; @@ -199,6 +210,7 @@ class PadAkgOp : public OpaqueOp { public: PadAkgOp(const std::string &op, const std::string &node_name) : OpaqueOp("PadAkg", node_name) {} + protected: DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; }; @@ -206,6 +218,7 @@ class UnPadAkgOp : public OpaqueOp { public: UnPadAkgOp(const std::string &op, const std::string &node_name) : OpaqueOp("UnPadAkg", node_name) {} + protected: DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; }; @@ -213,6 +226,7 @@ class CImagOp : public ElemwiseOp { public: CImagOp(const std::string &op, const std::string &node_name) : ElemwiseOp("CImag", node_name) {} + protected: void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override { if (inputs[0]->type != TypeId::kNumberTypeComplex64) { throw GKException("CImag's input[0] should be complex64"); @@ -226,6 +240,7 @@ class CRealOp : public ElemwiseOp { public: CRealOp(const std::string &op, const std::string &node_name) : ElemwiseOp("CReal", node_name) {} + protected: void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override { if (inputs[0]->type != TypeId::kNumberTypeComplex64) { throw GKException("CReal's input[0] should be complex64"); @@ -239,8 +254,8 @@ class ComplexOp : public ElemwiseOp { public: ComplexOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Complex", node_name) {} + protected: void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override; - TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeComplex64; } }; } // namespace graphkernel diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/transform_op_optimizer.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/transform_op_optimizer.cc index 26ab54e46b3..ce8f2308aea 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/transform_op_optimizer.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/transform_op_optimizer.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,175 +13,141 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#include "backend/optimizer/graph_kernel/transform_op_optimizer.h" +#include #include #include #include +#include +#include +#include +#include +#include +#include "base/core_ops.h" +#include "ir/graph_utils.h" +#include "debug/common.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" +#include "backend/optimizer/graph_kernel/model/lite_graph.h" +#include "backend/optimizer/graph_kernel/model/op_register.h" +namespace mindspore { +namespace opt { namespace { -enum Format { kFormatUnknown, kFormatA, kFormatB }; +enum FormatType { kFormatUnknown, kFormatA, kFormatB }; enum TransOpType { kTransAB, kTransBA }; - struct Edge { - size_t from; size_t to; - size_t val; - size_t next; - Edge(size_t From, size_t To, size_t Val, size_t Next) { - from = From; - to = To; - val = Val; - next = Next; - } + size_t capacity; }; -struct Node { - int head; - int cur; - int depth; - size_t pre; - Format format; - Node() { - head = -1; - cur = -1; - depth = -1; - pre = 0; - format = kFormatB; - } +struct Vertex { + FormatType format{kFormatB}; + size_t depth{0}; + std::vector out_edges; }; constexpr size_t INF = static_cast(1) << 30; class MinCut { - public: - // Connect the source_node to the node_with_a_certain_formatA - // or Connect the node_with_a_certain_formatB to the sink_node - void Add_1(size_t from, size_t to) { - edges_.emplace_back(from, to, INF, nodes_[from].head); - nodes_[from].head = edges_count_++; - edges_.emplace_back(to, from, 0, nodes_[to].head); - nodes_[to].head = edges_count_++; + private: + // Add the bidirectional edges for the vertex `from` and `to`. + // the two edge ids are adjacent in vector, x and x+1 (x are 0,2,4,...) + // we can use (i xor 1) to get the inverse edge for any edge i. + // e.g. edge_0 and edge_1 are a couple, 0^1=1, 1^1=0. + void AddEdge(size_t from, size_t to, size_t capacity, size_t inv_capacity) { + edges_.emplace_back(Edge{to, capacity}); + nodes_[from].out_edges.emplace_back(edges_.size() - 1); + // inverse edge + edges_.emplace_back(Edge{from, inv_capacity}); + nodes_[to].out_edges.emplace_back(edges_.size() - 1); } - // Split one origin_node into two new_nodes and connect them - void Add_2(size_t nodes_id) { - edges_.emplace_back(nodes_id, nodes_id + origin_nodes_num_, 1, nodes_[nodes_id].head); - nodes_[nodes_id].head = edges_count_++; - edges_.emplace_back(nodes_id + origin_nodes_num_, nodes_id, 1, nodes_[nodes_id + origin_nodes_num_].head); - nodes_[nodes_id + origin_nodes_num_].head = edges_count_++; - } - - // After splitting the origin_nodes, construct the new_edges based on the original_edges - void Add_3(size_t from, size_t to) { - edges_.emplace_back(from + origin_nodes_num_, to, 1, nodes_[from + origin_nodes_num_].head); - nodes_[from + origin_nodes_num_].head = edges_count_++; - edges_.emplace_back(to, from + origin_nodes_num_, 1, nodes_[to].head); - nodes_[to].head = edges_count_++; - } - - void BFS() { + bool BfsSetDepth() { std::queue bfs_queue; - nodes_[sink_id_].depth = 0; - bfs_queue.push(sink_id_); + for (auto &node : nodes_) { + node.depth = 0; + } + nodes_[source_id_].depth = 1; + bfs_queue.push(source_id_); while (!bfs_queue.empty()) { - size_t temp_node = bfs_queue.front(); + auto edge_from = bfs_queue.front(); bfs_queue.pop(); - depth_num_[nodes_[temp_node].depth]++; - for (size_t i = nodes_[temp_node].head; ~i; i = edges_[i].next) { - if (edges_[i ^ 1].val && nodes_[edges_[i].to].depth == -1) { - nodes_[edges_[i].to].depth = nodes_[temp_node].depth + 1; - bfs_queue.push(edges_[i].to); + for (auto e_id : nodes_[edge_from].out_edges) { + auto edge_to = edges_[e_id].to; + if (edges_[e_id].capacity > 0 && nodes_[edge_to].depth == 0) { + nodes_[edge_to].depth = nodes_[edge_from].depth + 1; + bfs_queue.push(edge_to); } } } + return nodes_[sink_id_].depth > 0; } - void EdgeValueUpdate() { - size_t k = sink_id_, flow = INF; - while (k != source_id_) { - if (edges_[nodes_[k].pre].val < flow) { - flow = edges_[nodes_[k].pre].val; + size_t DfsMaxFlow(size_t node, size_t flow) { + if (node == sink_id_) return flow; + size_t max_flow = 0; + for (size_t e_id : nodes_[node].out_edges) { + if ((edges_[e_id].capacity > 0) && (nodes_[node].depth + 1 == nodes_[edges_[e_id].to].depth)) { + auto tmp_flow = DfsMaxFlow(edges_[e_id].to, std::min(flow, edges_[e_id].capacity)); + if (tmp_flow > 0) { + max_flow += tmp_flow; + flow -= tmp_flow; + edges_[e_id].capacity -= tmp_flow; + edges_[e_id ^ 1].capacity += tmp_flow; + } } - k = edges_[nodes_[k].pre].from; - } - k = sink_id_; - while (k != source_id_) { - edges_[nodes_[k].pre].val -= flow; - edges_[nodes_[k].pre ^ 1].val += flow; - k = edges_[nodes_[k].pre].from; } + return max_flow; } - void ISAP() { - size_t node_id = source_id_; - int maxdep = 2 * origin_nodes_num_ + 2; - BFS(); - for (size_t i = source_id_; i <= sink_id_; ++i) { - nodes_[i].cur = nodes_[i].head; - } - while (nodes_[source_id_].depth <= maxdep) { - if (node_id == sink_id_) { - EdgeValueUpdate(); - node_id = source_id_; - } - bool can_arrive = false; - for (size_t i = nodes_[node_id].cur; ~i; i = edges_[i].next) { - if (edges_[i].val && nodes_[edges_[i].to].depth + 1 == nodes_[node_id].depth) { - can_arrive = true; - nodes_[edges_[i].to].pre = i; - nodes_[node_id].cur = i; - node_id = edges_[i].to; - break; - } - } - if (!can_arrive) { - int mindep = 2 * origin_nodes_num_ + 2; - for (size_t i = nodes_[node_id].head; ~i; i = edges_[i].next) { - if (nodes_[edges_[i].to].depth < mindep && edges_[i].val) { - mindep = nodes_[edges_[i].to].depth; - } - } - --depth_num_[nodes_[node_id].depth]; - if (!depth_num_[nodes_[node_id].depth]) { - break; - } - nodes_[node_id].depth = mindep + 1; - depth_num_[nodes_[node_id].depth]++; - nodes_[node_id].cur = nodes_[node_id].head; - if (node_id != source_id_) { - node_id = edges_[nodes_[node_id].pre].from; - } - } + void Dinic() { + while (BfsSetDepth()) { + (void)DfsMaxFlow(source_id_, INF); } } void SetFormat(size_t node_id) { nodes_[node_id].format = kFormatA; - for (size_t i = nodes_[node_id].head; ~i; i = edges_[i].next) { - if (edges_[i].val && nodes_[edges_[i].to].format != kFormatA) { + for (size_t i : nodes_[node_id].out_edges) { + if (edges_[i].capacity > 0 && nodes_[edges_[i].to].format != kFormatA) { SetFormat(edges_[i].to); } } } - MinCut(std::vector> original_nodes, std::vector> original_edges) - : origin_nodes_num_(original_nodes.size()), - sink_id_(2 * origin_nodes_num_ + 1), - depth_num_(std::vector(2 * origin_nodes_num_ + 2, 0)), - nodes_(std::vector(2 * origin_nodes_num_ + 2, Node())), - original_edges_(std::move(original_edges)) { + void BuildGraph(const std::vector> &original_nodes) { for (size_t i = 0; i < origin_nodes_num_; ++i) { + // link the source node to the nodes with FormatA, + // link the nodes with FormatB to the sink node. if (original_nodes[i].second == kFormatA) { - Add_1(source_id_, original_nodes[i].first); + AddEdge(source_id_, original_nodes[i].first, INF, 0); } else if (original_nodes[i].second == kFormatB) { - Add_1(original_nodes[i].first, sink_id_); + AddEdge(original_nodes[i].first, sink_id_, INF, 0); } - Add_2(original_nodes[i].first); + // each nodes was split into two part, input part and output part. + // the input part's id is the original node's id, the output part's id is input id + origin_nodes_num_. + AddEdge(original_nodes[i].first, original_nodes[i].first + origin_nodes_num_, 1, 1); } - for (auto i : original_edges_) { - Add_3(i.first, i.second); + for (auto e : original_edges_) { + auto from = e.first, to = e.second; + AddEdge(from + origin_nodes_num_, to, 1, 1); } - ISAP(); + } + + public: + MinCut(const std::vector> &original_nodes, + const std::vector> &original_edges) + : origin_nodes_num_(original_nodes.size()), + sink_id_(2 * original_nodes.size() + 1), // source_id_ is 0 + nodes_(2 * original_nodes.size() + 2), // double nodes, and source_node/sink_node + original_edges_(original_edges) { + BuildGraph(original_nodes); + } + + void Run() { + Dinic(); SetFormat(source_id_); } @@ -213,11 +179,285 @@ class MinCut { size_t origin_nodes_num_; size_t source_id_{0}; size_t sink_id_; - int edges_count_{0}; - std::vector depth_num_; - std::vector nodes_; + std::vector nodes_; std::vector edges_; std::vector> original_edges_; }; - } // namespace + +using graphkernel::LiteGraph; +using graphkernel::LiteGraphPtr; +using graphkernel::Node; +using graphkernel::NodePtr; +using graphkernel::NodePtrList; +using graphkernel::NType; +using graphkernel::PrimOp; +using graphkernel::PrimOpPtr; + +class TransformOp { + public: + explicit TransformOp(const NodePtr &node) + : op_(node->As()->op()), format_a_(node->input(0)->format), format_b_(node->format) {} + bool IsTransformOp(const NodePtr &node) { + if (node->NodeType() != NType::Primitive || node->As()->op() != op_) { + return false; + } + if (node->input(0)->format == format_a_ && node->format == format_b_) { + return true; + } else if (node->input(0)->format == format_b_ && node->format == format_a_) { + return true; + } + return false; + } + + FormatType GetFormatType(const std::string &fmt) { + return fmt == format_a_ ? FormatType::kFormatA : FormatType::kFormatB; + } + + NodePtr GenTransformOp(TransOpType trans_type) { + // Only support Transpose now + static std::map, std::vector> perm_map = { + {{kOpFormat_DEFAULT, kOpFormat_NHWC}, {0, 2, 3, 1}}, + {{kOpFormat_NCHW, kOpFormat_NHWC}, {0, 2, 3, 1}}, + {{kOpFormat_NHWC, kOpFormat_NCHW}, {0, 3, 1, 2}}, + {{kOpFormat_NHWC, kOpFormat_DEFAULT}, {0, 3, 1, 2}}, + }; + std::vector perm; + if (trans_type == TransOpType::kTransAB) { + perm = perm_map[{format_a_, format_b_}]; + } else { + perm = perm_map[{format_b_, format_a_}]; + } + if (perm.empty()) { + std::ostringstream oss; + oss << "unsupported format: " << format_a_ << " to " << format_b_; + throw graphkernel::GKException(oss.str()); + } + auto op = graphkernel::OpRegistry::Instance().NewOp("Transpose", "new_trans"); + op->SetAttr("perm", MakeValue(perm)); + return op; + } + + private: + std::string op_; + std::string format_a_; + std::string format_b_; +}; + +bool IsFlexibleOp(const NodePtr &node) { + static std::set format_flexible_ops = { + "Abs", "Add", "Sub", "Mul", "Round", "Cast", "Neg", "Exp", "Log", + "Pow", "Minimum", "Maximum", "Rsqrt", "Sqrt", "Reciprocal", "Tanh", "Sin", "Cos", + "Asin", "ACos", "RealDiv", "Equal", "Greater", "GreaterEqual", "Less", "LessEqual", "Sign"}; + if (node->NodeType() != NType::Primitive) { + return false; + } + if (format_flexible_ops.count(node->As()->op()) == 0) { + return false; + } + // check the input and output formats are all the same, except ConstValue. + for (auto &inp : node->inputs()) { + if (inp->NodeType() != NType::Value && inp->format != node->format) { + return false; + } + } + return true; +} + +class Mutator { + public: + explicit Mutator(const NodePtr &node) : op_checker_(node), basenode_(node), ori_node_(1) {} + bool Run() { + VisitNode(basenode_); + if (flexible_ops_.empty()) return false; + // remove transform ops in litegraph + RemoveTransOp(); + GenFormatGraph(); + RebuildLiteGraph(); + return true; + } + + private: + // visit nodes bidirectionally + void VisitNode(const NodePtr &node) { + if (visited_.count(node) > 0) return; + visited_.insert(node); + if (op_checker_.IsTransformOp(node)) { + trans_ops_.insert(node); + } else if (!IsFlexibleOp(node)) { + if (node->NodeType() != NType::Output) { + fmt_type[{node, -1}] = op_checker_.GetFormatType(node->format); + } + if (node->NodeType() != NType::Parameter) { + for (size_t i = 0; i < node->inputs().size(); i++) { + if (node->input(i)->NodeType() == NType::Value) { + continue; + } + fmt_type[{node, i}] = op_checker_.GetFormatType(node->input(i)->format); + } + } + return; + } else { + flexible_ops_.insert(node); + fmt_type[{node, -1}] = FormatType::kFormatUnknown; + } + + for (auto &input : node->inputs()) { + if (input->NodeType() != NType::Value) { + VisitNode(input); + } + } + for (auto &user : node->users()) { + VisitNode(user.first->shared_from_this()); + } + } + + void RemoveTransOp() { + for (auto &node : trans_ops_) { + visited_.erase(node); + node->ReplaceWith(node->input(0)); + // clear inputs, so that the node will not be the basenode again. + node->SetInputs({}); + } + trans_ops_.clear(); + } + + void GenFormatGraph() { + for (auto &node : visited_) { + if (node->NodeType() == NType::Parameter) continue; + bool is_flexible = (flexible_ops_.find(node) != flexible_ops_.end()); + size_t cur_id = 0; + if (is_flexible) { + cur_id = GetId({node, -1}); + } + for (size_t i = 0; i < node->inputs().size(); i++) { + if (visited_.count(node->input(i)) == 0) continue; + if (!is_flexible) { + cur_id = GetId({node, SizeToInt(i)}); + } + auto input_id = GetId({node->input(i), -1}); + graph_edges_.emplace_back(input_id, cur_id); + } + } + } + + void RebuildLiteGraph() { + MinCut min_cut(graph_vertex_, graph_edges_); + min_cut.Run(); + for (auto [node_id, trans_type] : min_cut.GetOneNodeOps()) { + if (ori_node_[node_id].second != -1) { + MS_LOG(EXCEPTION) << "OneNodeOp should be the output edge. node_id:" << node_id + << " index:" << ori_node_[node_id].second; + } + auto trans_op = op_checker_.GenTransformOp(trans_type); + ori_node_[node_id].first->ReplaceWith(trans_op); + trans_op->SetInputs({ori_node_[node_id].first}); + } + + std::map trans_op_cache; + for (auto [edge, trans_type] : min_cut.GetTwoNodeOps()) { + auto node_id_from = edge.first; + auto node_id_to = edge.second; + if (ori_node_[node_id_from].second != -1) { + MS_LOG(EXCEPTION) << "node_from should be the output edge. node_id:" << node_id_from + << " index:" << ori_node_[node_id_from].second; + } + auto node_from = ori_node_[node_id_from].first; + auto node_to = ori_node_[node_id_to].first; + auto &trans_op = trans_op_cache[node_id_from]; + if (trans_op == nullptr) { + trans_op = op_checker_.GenTransformOp(trans_type); + trans_op->SetInputs({node_from}); + } + if (ori_node_[node_id_to].second >= 0) { + node_to->SetInput(ori_node_[node_id_to].second, trans_op); + } else { + for (size_t i = 0; i < node_to->inputs().size(); i++) { + if (node_to->input(i) == node_from) { + node_to->SetInput(i, trans_op); + } + } + } + } + } + + size_t GetId(const std::pair &node) { + // the nodes are indexed from 1 in the MinCut model. + auto &id = node_id_[node]; + if (id == 0) { + id = node_id_.size(); + ori_node_.push_back(node); + // set format_type for new id. + graph_vertex_.emplace_back(id, fmt_type[node]); + } + return id; + } + + TransformOp op_checker_; + NodePtr basenode_; + std::set flexible_ops_; + std::set trans_ops_; + std::set visited_; + + std::map, FormatType> fmt_type; + std::map, size_t> node_id_; + std::vector> ori_node_; + std::vector> graph_vertex_; + std::vector> graph_edges_; +}; + +bool TransformOpOptimizer::Process(const LiteGraphPtr &litegraph, const std::string &trans_op_name) { + ori_trans_op_num_ = 0; + auto &ops = litegraph->ops(); + std::set visited; + bool changed = true; + auto check_is_trans_op = [&trans_op_name](const NodePtr &node) { return node->As()->op() == trans_op_name; }; + auto ori_trans_op_num = std::count_if(ops.begin(), ops.end(), check_is_trans_op); + for (auto &op : ops) { + if (check_is_trans_op(op) && !op->inputs().empty() && op->input(0)->format != op->format) { + auto mutator = Mutator(op); + changed = mutator.Run() || changed; + } + } + if (!changed) return false; + auto &new_ops = litegraph->GetOrderedNodes(); + auto new_trans_op_num = std::count_if(new_ops.begin(), new_ops.end(), check_is_trans_op); + if (new_trans_op_num >= ori_trans_op_num) { + return false; + } + for (auto &op : new_ops) { + op->SetBaseInfo(op->As()->Infer(op->inputs(), op->attrs())); + } + return true; +} + +bool TransformOpOptimizer::Run(const FuncGraphPtr &kernel_graph) { + auto mng = kernel_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + auto todos = TopoSort(kernel_graph->get_return()); + bool changed = false; + for (auto node : todos) { + if (!AnfAlgo::IsGraphKernel(node)) continue; + try { + auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + auto litegraph = AnfGraph2LiteGraph(sub_func_graph); + if (Process(litegraph)) { + changed = true; + AnfNodePtrList outputs; + auto new_funcgraph = LiteGraph2AnfGraph(litegraph, &outputs); + new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); + auto cnode = node->cast(); + AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end()); + auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs, outputs); + SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs); + mng->Replace(node, new_node); + mng->AddFuncGraph(new_funcgraph); + } + } catch (const graphkernel::GKException &e) { + MS_LOG(WARNING) << e.what() << ", so we undo airthmetic simplify for this graph"; + } + } + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/transform_op_optimizer.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/transform_op_optimizer.h new file mode 100644 index 00000000000..f7bdd6ffdbe --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/transform_op_optimizer.h @@ -0,0 +1,62 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_TRANSFORM_OP_OPTIMIZER_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_TRANSFORM_OP_OPTIMIZER_H_ + +#include +#include +#include "backend/optimizer/common/pass.h" +#include "ir/func_graph.h" +#include "backend/optimizer/graph_kernel/model/lite_graph.h" + +namespace mindspore { +namespace opt { +/** + * @brief Eliminate the unnecessary transformation ops when the other operators + * are format flexible. + * @example + * %1 = Transpose(p0) // NCHW to NHWC + * %2 = Transpose(p1) // NCHW to NHWC + * %3 = Add(%1, %2) + * return %3 + * --> + * %1 = Add(p0, p1) + * %2 = Transpose(%1) // NCHW to NHWC + * return %2 + * @example + * %1 = Transpose(p0) // NCHW to NHWC + * %2 = Transpose(p1) // NCHW to NHWC + * %3 = Add(%1, %2) + * %4 = Transpose(%3) // NHWC to NCHW + * return %4 + * --> + * %1 = Add(p0, p1) + * return %1 + */ +class TransformOpOptimizer : public Pass { + public: + TransformOpOptimizer() : Pass("transform_op_optimizer") {} + ~TransformOpOptimizer() = default; + bool Run(const FuncGraphPtr &func_graph) override; + + private: + bool Process(const graphkernel::LiteGraphPtr &litegraph, const std::string &trans_op_name = "Transpose"); + bool IsFlexibleOp(const graphkernel::NodePtr &node); + size_t ori_trans_op_num_{0}; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_TRANSFORM_OP_OPTIMIZER_H_ diff --git a/mindspore/ccsrc/utils/context/graph_kernel_flags.cc b/mindspore/ccsrc/utils/context/graph_kernel_flags.cc index 8f226424ebd..96e24aaf800 100644 --- a/mindspore/ccsrc/utils/context/graph_kernel_flags.cc +++ b/mindspore/ccsrc/utils/context/graph_kernel_flags.cc @@ -165,7 +165,7 @@ void GraphKernelFlags::Refresh() { } } // Dump flags so that people can check the setting. - MS_LOG(INFO) << "GraphKernelFlags info: " << DumpAllFlags(); + MS_LOG(INFO) << "graph_kernel_flags = \"" << flags_cache_ << "\", all flags: " << DumpAllFlags(); } void GraphKernelFlags::RegisterFlags(std::map *flag_map) { @@ -188,10 +188,11 @@ void GraphKernelFlags::RegisterFlags(std::map *flag_ma reg.AddFlag("enable_recompute_fusion", &enable_recompute_fusion, opt_level >= OptLevel_2); reg.AddFlag("enable_parallel_fusion", &enable_parallel_fusion, opt_level == OptLevel_3); reg.AddFlag("enable_low_precision", &enable_low_precision); - reg.AddFlag("fusion_ops_level", &fusion_ops_level, is_gpu ? OpLevel_MAX : OpLevel_0); + reg.AddFlag("enable_trans_op_optimize", &enable_trans_op_optimize); // Integer flags reg.AddFlag("online_tuning", &online_tuning); + reg.AddFlag("fusion_ops_level", &fusion_ops_level, is_gpu ? OpLevel_MAX : OpLevel_0); // String flags reg.AddFlag("repository_path", &repository_path); @@ -217,6 +218,7 @@ std::string GraphKernelFlags::DumpAllFlags() const { json["enable_recompute_fusion"] = enable_recompute_fusion; json["enable_parallel_fusion"] = enable_parallel_fusion; json["enable_low_precision"] = enable_low_precision; + json["enable_trans_op_optimize"] = enable_trans_op_optimize; json["opt_level"] = opt_level; json["fusion_ops_level"] = fusion_ops_level; diff --git a/mindspore/ccsrc/utils/context/graph_kernel_flags.h b/mindspore/ccsrc/utils/context/graph_kernel_flags.h index e7d2338ed34..34c6613257c 100644 --- a/mindspore/ccsrc/utils/context/graph_kernel_flags.h +++ b/mindspore/ccsrc/utils/context/graph_kernel_flags.h @@ -95,6 +95,13 @@ class GraphKernelFlags { */ unsigned int fusion_ops_level{OpLevel_0}; + /** + * Enable optimization for transform operators (Transpose/TransData) + * + * Experimental feature, enabled by default when opt_level=3. + */ + bool enable_trans_op_optimize{false}; + /** * Optimization level, value from 0 to 3. * 0: Disable GraphKernel diff --git a/model_zoo/official/cv/yolov3_darknet53/train.py b/model_zoo/official/cv/yolov3_darknet53/train.py index 6b11a9e4f40..e97e27a919f 100644 --- a/model_zoo/official/cv/yolov3_darknet53/train.py +++ b/model_zoo/official/cv/yolov3_darknet53/train.py @@ -63,7 +63,7 @@ def set_graph_kernel_context(): if context.get_context("device_target") == "GPU": context.set_context(enable_graph_kernel=True) context.set_context(graph_kernel_flags="--enable_parallel_fusion " - "--disable_expand_ops=BatchNorm,BatchNormGrad " + "--enable_trans_op_optimize " "--disable_cluster_ops=ReduceMax,Reshape " "--enable_expand_ops=Conv2D")