diff --git a/mindspore/lite/src/gllo/common/utils.cc b/mindspore/lite/src/gllo/common/gllo_utils.cc similarity index 75% rename from mindspore/lite/src/gllo/common/utils.cc rename to mindspore/lite/src/gllo/common/gllo_utils.cc index 17b781c1c68..e1d8beee5ab 100644 --- a/mindspore/lite/src/gllo/common/utils.cc +++ b/mindspore/lite/src/gllo/common/gllo_utils.cc @@ -13,15 +13,107 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "src/gllo/common/gllo_utils.h" #include -#include -#include "src/gllo/common/utils.h" #include "src/ir/primitive_t_value.h" #include "frontend/operator/ops.h" -using PrimitiveTValuePtr = std::shared_ptr; namespace mindspore { namespace opt { +namespace { +constexpr auto kAnfPrimitiveIndex = 0; +bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type); +} + +bool IsRealKernel(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + // parameter and value node is not a real kernel too + if (!node->isa()) { + return true; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().empty()) { + MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString(); + } + auto input = cnode->inputs()[0]; + bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) || + IsPrimitive(input, prim::kPrimTensorSummary) || + IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || + IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || + IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || + IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial); + return !is_virtual_node; +} + +ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + return nullptr; +} + +CNodePtr CreateCNodeWithGraph(const std::vector &input_nodes, const BaseRef &graph) { + if (utils::isa(graph)) { + return std::make_shared(input_nodes, utils::cast(graph)); + } + if (utils::isa(graph)) { + return std::make_shared(input_nodes, utils::cast(graph)); + } + return nullptr; +} + +VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { + if (utils::isa(graph)) { + MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); + return std::make_shared(utils::cast(sexp), nullptr); + } + if (utils::isa(graph)) { + MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString(); + return std::make_shared(utils::cast(sexp), utils::cast(graph)); + } + MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString(); + return nullptr; +} + +AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, + bool multigraph) { + MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); + std::vector input_nodes; + const auto &tuple = utils::cast(sexp); + if (multigraph && utils::isa(graph)) { + for (auto &x : tuple) { + AnfNodePtr node = SexpToNode(x, std::make_shared("G"), primitive_vars, true); + input_nodes.push_back(node); + } + VarPtr var_ptr = utils::cast(graph); + return std::make_shared(input_nodes, var_ptr); + } + + for (auto &x : tuple) { + AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph); + input_nodes.push_back(node); + } + return CreateCNodeWithGraph(input_nodes, graph); +} +} // namespace + bool AnfEqual(const BaseRef &a, const BaseRef &b) { if (utils::isa(a) && utils::isa(b)) { auto a_node = utils::cast(a); @@ -64,15 +156,15 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) { } if (utils::isa(a_value_ptr) && utils::isa(b_value_ptr)) { - auto a_obj = (lite::PrimitiveTValue *)(a_value_ptr.get()); - auto b_obj = (lite::PrimitiveTValue *)(b_value_ptr.get()); + auto a_obj = (lite::PrimitiveTValue *) (a_value_ptr.get()); + auto b_obj = (lite::PrimitiveTValue *) (b_value_ptr.get()); return (*a_obj) == (*b_obj); } else { return (*a_value_ptr) == (*b_value_ptr); } } } - if (a.m_ptr->isa()) { + if (a.m_ptr->isa() && b.m_ptr->isa()) { auto a_value_node_ptr = a.m_ptr->cast(); auto b_value_node_ptr = b.m_ptr->cast(); return a_value_node_ptr->GetPrimitiveT()->value.type == b_value_node_ptr->GetPrimitiveT()->value.type; @@ -89,70 +181,8 @@ bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { return a.type() == b.type(); } -namespace { -ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { - if (utils::isa(sexp)) { - return NewValueNode(utils::cast(sexp)); - } - if (utils::isa(sexp)) { - return NewValueNode(utils::cast(sexp)); - } - if (utils::isa(sexp)) { - return NewValueNode(utils::cast(sexp)); - } - if (utils::isa(sexp)) { - return NewValueNode(utils::cast(sexp)); - } - return nullptr; -} - -CNodePtr CreateCNodeWithGraph(const std::vector &input_nodes, const BaseRef &graph) { - if (utils::isa(graph)) { - return std::make_shared(input_nodes, utils::cast(graph)); - } - if (utils::isa(graph)) { - return std::make_shared(input_nodes, utils::cast(graph)); - } - return nullptr; -} - -VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { - if (utils::isa(graph)) { - // MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); - return std::make_shared(utils::cast(sexp), nullptr); - } - if (utils::isa(graph)) { - // MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString(); - return std::make_shared(utils::cast(sexp), utils::cast(graph)); - } - MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString(); - return nullptr; -} - -AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, - bool multigraph) { - // MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); - std::vector input_nodes; - const auto &tuple = utils::cast(sexp); - if (multigraph && utils::isa(graph)) { - for (auto &x : tuple) { - AnfNodePtr node = SexpToNode(x, std::make_shared("G"), primitive_vars, true); - input_nodes.push_back(node); - } - VarPtr var_ptr = utils::cast(graph); - return std::make_shared(input_nodes, var_ptr); - } - - for (auto &x : tuple) { - AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph); - input_nodes.push_back(node); - } - return CreateCNodeWithGraph(input_nodes, graph); -} -} // namespace - AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { - // MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); + MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); MS_EXCEPTION_IF_NULL(primitive_vars); if (utils::isa(sexp)) { return HandleSexpVector(sexp, graph, primitive_vars, multigraph); @@ -176,6 +206,38 @@ AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap return value_node; } +bool IsRealCNodeKernel(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + // parameter and value node is not a real cnode kernel + if (!node->isa()) { + return false; + } + // return considered as a real node + if (CheckPrimitiveType(node, prim::kPrimReturn)) { + return true; + } + return IsRealKernel(node); +} +bool IsGraphKernel(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + // graph kernel should be a real cnode kernel. + if (!IsRealCNodeKernel(node)) { + return false; + } + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto input = cnode->input(kAnfPrimitiveIndex); + // graph kernel should has func_graph as first input. + if (!IsValueNode(input)) { + return false; + } + + auto func_graph = GetValueNode(input); + MS_EXCEPTION_IF_NULL(func_graph); + return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); +} + void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph) { if (graph == nullptr) { MS_LOG(EXCEPTION) << "The graph is null."; diff --git a/mindspore/lite/src/gllo/common/utils.h b/mindspore/lite/src/gllo/common/gllo_utils.h similarity index 84% rename from mindspore/lite/src/gllo/common/utils.h rename to mindspore/lite/src/gllo/common/gllo_utils.h index 91b73c0d31f..d96190e7500 100644 --- a/mindspore/lite/src/gllo/common/utils.h +++ b/mindspore/lite/src/gllo/common/gllo_utils.h @@ -14,22 +14,21 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_ -#define MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_ +#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ +#define MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ -#include #include +#include "src/ir/primitive_t_value.h" #include "ir/anf.h" #include "ir/func_graph.h" #include "src/common/utils.h" -#include "src/gllo/common/pattern_engine.h" +#include "backend/optimizer/common/pattern_engine.h" #include "schema/inner/model_generated.h" #include "src/param_value_lite.h" using PrimitiveTValuePtr = std::shared_ptr; namespace mindspore { namespace opt { - bool AnfEqual(const BaseRef &a, const BaseRef &b); bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b); @@ -37,6 +36,10 @@ bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b); AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph = false); +bool IsRealCNodeKernel(const AnfNodePtr &node); + +bool IsGraphKernel(const AnfNodePtr &node); + void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph); void CheckIfAnfNodeIsNull(const AnfNodePtr &node); @@ -61,4 +64,4 @@ bool IsParamNode(const BaseRef &n); bool IsConvNode(const BaseRef &n); } // namespace opt } // namespace mindspore -#endif // MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_ +#endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ diff --git a/mindspore/lite/src/gllo/common/node_pass.cc b/mindspore/lite/src/gllo/common/node_pass.cc index badd0fb434e..6bc848cc237 100644 --- a/mindspore/lite/src/gllo/common/node_pass.cc +++ b/mindspore/lite/src/gllo/common/node_pass.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/gllo/common/node_pass.h" +#include "backend/optimizer/common/node_pass.h" #include #include @@ -22,6 +22,7 @@ #include "ir/anf.h" #include "ir/func_graph.h" #include "ir/manager.h" +#include "src/gllo/common/gllo_utils.h" namespace mindspore { namespace opt { @@ -54,6 +55,9 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(const_func_graph); todo.push_back(const_func_graph->output()); } else if (new_node && new_node->isa()) { + if (IsGraphKernel(new_node)) { + todo.push_back(new_node); + } auto cnode = new_node->cast(); MS_EXCEPTION_IF_NULL(cnode); auto inputs = cnode->inputs(); diff --git a/mindspore/lite/src/gllo/common/node_pass.h b/mindspore/lite/src/gllo/common/node_pass.h deleted file mode 100644 index 039c09bb8c2..00000000000 --- a/mindspore/lite/src/gllo/common/node_pass.h +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_NODE_PASS_H_ -#define MINDSPORE_LITE_SRC_PASS_COMMON_NODE_PASS_H_ -#include -#include - -#include "src/gllo/common/pass.h" - -namespace mindspore { -namespace opt { -// @brief ANF Node level optimization base pass -class NodePass : public Pass { - public: - explicit NodePass(const std::string &name) : Pass(name) {} - ~NodePass() override = default; - bool Run(const FuncGraphPtr &func_graph) final; - virtual AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0; -}; -using NodePassPtr = std::shared_ptr; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_LITE_SRC_PASS_COMMON_NODE_PASS_H_ diff --git a/mindspore/lite/src/gllo/common/optimizer.cc b/mindspore/lite/src/gllo/common/optimizer.cc index 925e02f847c..0a8cddfe893 100644 --- a/mindspore/lite/src/gllo/common/optimizer.cc +++ b/mindspore/lite/src/gllo/common/optimizer.cc @@ -23,8 +23,7 @@ #include #include -#include "src/gllo/common/pass_manager.h" -#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/pass_manager.h" #include "ir/manager.h" namespace mindspore { diff --git a/mindspore/lite/src/gllo/common/optimizer.h b/mindspore/lite/src/gllo/common/optimizer.h index c715f511b54..fa6fac69384 100644 --- a/mindspore/lite/src/gllo/common/optimizer.h +++ b/mindspore/lite/src/gllo/common/optimizer.h @@ -26,9 +26,9 @@ #include "ir/graph_utils.h" #include "src/common/utils.h" -#include "src/gllo/common/pass_manager.h" -#include "src/gllo/common/pattern_engine.h" -#include "src/gllo/common/utils.h" +#include "backend/optimizer/common/pass_manager.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "src/gllo/common/gllo_utils.h" namespace mindspore { namespace opt { diff --git a/mindspore/lite/src/gllo/common/pass_manager.cc b/mindspore/lite/src/gllo/common/pass_manager.cc index 763228e369b..5907700b6d9 100644 --- a/mindspore/lite/src/gllo/common/pass_manager.cc +++ b/mindspore/lite/src/gllo/common/pass_manager.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/gllo/common/pass_manager.h" +#include "backend/optimizer/common/pass_manager.h" #include #include diff --git a/mindspore/lite/src/gllo/common/pass_manager.h b/mindspore/lite/src/gllo/common/pass_manager.h deleted file mode 100644 index d9cbd3a5671..00000000000 --- a/mindspore/lite/src/gllo/common/pass_manager.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_PASS_MANAGER_H_ -#define MINDSPORE_LITE_SRC_PASS_COMMON_PASS_MANAGER_H_ - -#include -#include -#include -#include - -#include "src/gllo/common/pass.h" -#include "src/gllo/common/node_pass.h" - -namespace mindspore { -namespace opt { -// @brief For optimization passes management -class PassManager { - public: - explicit PassManager(const std::string &name = "pm", bool run_only_once = true) - : name_(name), passes_{}, run_only_once_(run_only_once) {} - virtual ~PassManager() = default; - // Get all the passes added by AddPass - const std::vector &Passes() const; - // Add graph pass, the pass object will be freed when pass manager freed. - void AddPass(const PassPtr &pass); - // Run passes added in pass manager on the input graph - // @param [inout] graph The graph to be optimized - // @return true, graph changed - // @return false, graph not changed - bool Run(const FuncGraphPtr &func_graph) const; - // Run the given graph passes on the input graph - // @param [inout] graph The graph to be optimized - // @param [in] passes The given graph passes - // @return true, graph changed - // @return false, graph not changed - bool Run(const FuncGraphPtr &func_graph, const std::vector &passes) const; - std::string name() const { return name_; } - - private: - const std::string name_; - std::vector passes_; - bool run_only_once_; -}; -using PassManagerPtr = std::shared_ptr; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_LITE_SRC_PASS_COMMON_PASS_MANAGER_H_ diff --git a/mindspore/lite/src/gllo/common/pattern_engine.cc b/mindspore/lite/src/gllo/common/pattern_engine.cc deleted file mode 100644 index 5615fd75df0..00000000000 --- a/mindspore/lite/src/gllo/common/pattern_engine.cc +++ /dev/null @@ -1,365 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/gllo/common/pattern_engine.h" - -#include -#include -#include -#include - -#include "ir/func_graph.h" -#include "mindspore/core/ir/primitive.h" -#include "utils/info.h" -#include "ir/anf.h" -#include "utils/convert_utils_base.h" -#include "utils/overload.h" - - -namespace mindspore { -static int GetNextTag() { - static int kID = 0; - return kID++; -} - -void Var::EnsureTag() { - if (tag_.length() == 0) { - std::ostringstream buffer; - buffer << "_" << GetNextTag(); - tag_ = buffer.str(); - } -} - -bool operator==(const VarPtr &lhs, const VarPtr &rhs) { - if (lhs->isa() && rhs->isa()) { - CondVarPtr v1 = dyn_cast(lhs); - CondVarPtr v2 = dyn_cast(rhs); - return *v1 == *v2; - } - - if (lhs->isa() && rhs->isa()) { - SVarPtr v1 = dyn_cast(lhs); - SVarPtr v2 = dyn_cast(rhs); - return *v1 == *v2; - } - return (*lhs == *rhs); -} - -std::string SeqVar::ToString() const { - std::ostringstream buffer; - buffer << "SeqVar(" << tag() << ", " << subvar_->ToString() << ")"; - return buffer.str(); -} - -std::ostream &operator<<(std::ostream &os, const VarPtr &var) { - if (var == nullptr) { - os << ""; - } else { - os << var->ToString(); - } - return os; -} - -template <> - -std::ostream &operator<<(std::ostream &os, const Equiv &equiv) { - os << "[Equiv]" - << "\n"; - for (auto &equiv_item : equiv) { - auto k = equiv_item.first; - os << k << ":"; - BaseRef x = equiv_item.second; - if (utils::isa(x)) { - auto node = utils::cast(x); - os << "TypeString[" << node->type_name() << "]"; - if (IsValueNode(node)) { - os << "IsValueNodeGraph "; - } - os << "type " << node->type_name(); - if (node->isa()) { - os << " value " << GetValueNode(node); - } - os << " addr: " << node; - } else if (utils::isa(x)) { - os << "Named " << x.ToString().c_str(); - } else if (utils::isa(x)) { - os << "TypeString[Var]"; - os << utils::cast(x); - } else if (utils::isa(x)) { - os << "TypeString[Graph]"; - } - os << "\n"; - } - return os; -} - - -static BaseRef GetVar(const BaseRef &x) { - // MS_LOG(DEBUG) << "getVar start :%s" + x.ToString(); - if (utils::isa(x)) { - auto node = utils::cast(x); - // MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]"; - if (node->isa()) { - // MS_LOG(DEBUG) << "IsVarNode " + node->cast()->var_->ToString(); - return node->cast()->var_; - } -// if (node->isa()) { -// MS_LOG(DEBUG) << "value " + GetValueNode(node)->ToString() + " addr: " + node->ToString(); -// } else { -// MS_LOG(DEBUG) << "type " + node->type_name(); -// } -// } else if (utils::isa(x)) { -// MS_LOG(DEBUG) << "Named " + x.ToString(); -// } else if (utils::isa(x)) { -// MS_LOG(DEBUG) << "VectorRef"; -// } else if (utils::isa(x)) { -// MS_LOG(DEBUG) << "TypeString[Var] " + x.ToString(); - } -// MS_LOG(DEBUG) << "GetVar end: " + x.ToString(); - return x; -} - -EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) { - MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString(); - MS_EXCEPTION_IF_NULL(equiv); - if (utils::isa(pattern)) { - VarPtr var = utils::cast(pattern); - if (var->matches(expr)) { - (*equiv)[var] = expr; - MS_LOG(DEBUG) << "pattern is var match: " + pattern.ToString() + ", " + expr.ToString(); - return equiv; - } - } - - return nullptr; -} - -bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, - VectorRef *const values_expr) const { - MS_EXCEPTION_IF_NULL(values_expr); - if (utils::isa(pattern_ref)) { - *values_pattern = pattern_ref; - *values_expr = expr_ref; - return true; - } - return false; -} - -bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern, - VectorRef *const values_expr) const { - MS_EXCEPTION_IF_NULL(values_expr); - // visitor to visite the list - auto appender_pattern = [](VectorRef &values) { - std::function fn = [&](const BaseRef &u) { - values.push_back(GetVar(u)); - return u; - }; - return fn; - }; - - visitor_->SetFn(appender_pattern(*values_pattern)); - // MS_LOG(DEBUG) << "visit pattern_ref"; - bool success = visitor_->Visit(pattern_ref, nullptr); - if (!success) { - return false; - } - - auto appender_expr = [](VectorRef &values) { - std::function fn = [&](const BaseRef &u) { - values.push_back(u); - return u; - }; - return fn; - }; - - visitor_->SetFn(appender_expr(*values_expr)); - // MS_LOG(DEBUG) << "visit expr_ref"; - return visitor_->Visit(expr_ref, nullptr); -} - -static int GetSVarStartIndex(const VectorRef &values) { - int index = -1; - int count = 0; - for (auto &value : values) { - if (utils::isa(value) && utils::cast(value)->isa()) { - if (index != -1) { - // MS_LOG(DEBUG) << "Multiple SVars in sequence"; - return kInvalidVarIndex; - } - index = count; - } - count++; - } - return index; -} - -void UpdateEquivMap(const VectorRef &values_pattern, const BaseRef &expr_ref, const PrimitiveVarMap &primitive_vars, - EquivPtr equiv) { - if (equiv == nullptr || values_pattern.empty() || !utils::isa(values_pattern[0]) || - !utils::isa(expr_ref)) { - return; - } - auto real_node = utils::cast(expr_ref); - MS_EXCEPTION_IF_NULL(real_node); - if (!real_node->isa()) { - return; - } - auto prim_node = utils::cast(values_pattern[0]); - MS_EXCEPTION_IF_NULL(prim_node); - if (!IsValueNode(prim_node)) { - return; - } - ValuePtr value = GetValueNode(prim_node); - MS_EXCEPTION_IF_NULL(value); - auto prim = value->cast(); - MS_EXCEPTION_IF_NULL(prim); - auto iter = primitive_vars.find(prim); - if (iter == primitive_vars.end()) { - return; - } - (*equiv)[iter->second] = real_node; -} - -EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr, - const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const { - int svar_index = GetSVarStartIndex(values_pattern); - if (svar_index == kInvalidVarIndex) { - return nullptr; - } - - size_t values_pattern_len = values_pattern.size(); - size_t values_expr_len = values_expr.size(); - - if (svar_index == -1) { - if (values_pattern_len != values_expr_len) { - // MS_LOG(DEBUG) << "Structures of differing size: pattern len " << values_pattern_len << ", - // expr len " << values_expr_len; - return nullptr; - } - } - if (values_expr_len < values_pattern_len - 1) { - MS_LOG(DEBUG) << "invalid size: pattern len " << values_pattern_len << ", expr len " << values_expr_len; - return nullptr; - } - size_t diff = values_expr_len - values_pattern_len + 1; - for (size_t i = 0; i < values_pattern_len; i++) { - size_t expr_i = i; - if (svar_index != -1 && i == IntToSize(svar_index)) { - auto seq = - std::vector(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff)); - equiv = Match(values_pattern[svar_index], seq, primitive_vars, equiv); - } else { - if (svar_index != -1 && i > IntToSize(svar_index)) { - expr_i = i + diff - 1; - } - equiv = Match(values_pattern[i], values_expr[expr_i], primitive_vars, equiv); - } - if (equiv == nullptr) { - return nullptr; - } - } - return equiv; -} - -EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars, - EquivPtr equiv) const { - MS_LOG(DEBUG) << "-----[in Match]"; - // MS_LOG(DEBUG) << "GetVar w"; - BaseRef pattern_ref = GetVar(pattern); - // MS_LOG(DEBUG) << "GetVar v"; - BaseRef expr_ref = expr; - - if (equiv == nullptr) { - MS_LOG(EXCEPTION) << "Equiv pointer is null"; - } - - MS_LOG(DEBUG) << "Pattern ref " + pattern_ref.ToString() + ", expr ref" + expr_ref.ToString(); - // 1. if pattern_ref is var and already in equiv, replace it. - if (utils::isa(pattern_ref)) { - VarPtr var = utils::cast(pattern_ref); - auto iter = equiv->find(var); - if (iter != equiv->end()) { - pattern_ref = iter->second; - } - } - - // 2. check equal - if (eq_(pattern_ref, expr_ref)) { - return equiv; - } - - // 3. match var - EquivPtr ret_equiv = MatchOnVar(pattern_ref, expr_ref, equiv); - if (ret_equiv) { - return ret_equiv; - } - - // 4. here the type can be std:vector, std:list, - // or cnode. - if (!type_eq_(pattern_ref, expr_ref)) { - MS_LOG(DEBUG) << "Type mismatch"; - return nullptr; - } - - // 5. transfer the Containers by visitor to std::vector - VectorRef values_pattern; - VectorRef values_expr; - if (!ToVector(pattern_ref, expr_ref, &values_pattern, &values_expr)) { - return nullptr; - } - - // 6. if any svar in both side, find the SeqVar index, - // try to pack the Var s in std::vector to a Seq and match elements one by one. - // check svar - equiv = AlignSVar(values_pattern, values_expr, primitive_vars, equiv); - UpdateEquivMap(values_pattern, expr_ref, primitive_vars, equiv); - return equiv; -} - -BaseRef PatternEngine::Replace(const BaseRef &pattern, const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(equiv); - MS_LOG(DEBUG) << "-----[in Replace]"; - BaseRef ref = GetVar(pattern); - BaseRef out; - bool is_match = false; - - // w is var - if (utils::isa(ref)) { - const VarPtr &var = utils::cast(ref); - auto iter = equiv->find(var); - if (iter != equiv->end()) { - out = iter->second; - is_match = true; - } - } - if (is_match) { - return out; - } - - // visitor to visit the list - std::function fn = [&, this, equiv](const BaseRef &u) { return Replace(u, equiv); }; - - visitor_->SetFn(fn); - BaseRef visit_out; - if (!visitor_->Visit(pattern, &visit_out)) { - return pattern; - } - return visit_out; -} -} // namespace mindspore - diff --git a/mindspore/lite/src/gllo/common/pattern_engine.h b/mindspore/lite/src/gllo/common/pattern_engine.h deleted file mode 100644 index ff1502db5f4..00000000000 --- a/mindspore/lite/src/gllo/common/pattern_engine.h +++ /dev/null @@ -1,203 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_PATTERN_ENGINE_H_ -#define MINDSPORE_LITE_SRC_PASS_COMMON_PATTERN_ENGINE_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "src/gllo/common/visit.h" -#include "mindspore/core/base/base.h" -#include "utils/log_adapter.h" -#include "base/base_ref.h" - -namespace mindspore { -class CondVar; -class SeqVar; -using CondVarPtr = std::shared_ptr; -using SVarPtr = std::shared_ptr; -const int kInvalidVarIndex = -2; - -using ConditionFunc = std::function; - -// Base wildcard variable which could match any anf node. -class Var : public Base { - friend class VarHasher; - - public: - explicit Var(std::string tag = "") : tag_(std::move(tag)), primitive_(nullptr) { EnsureTag(); } - explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) { - EnsureTag(); - } - Var(const Var &other) : Base(other), tag_(other.tag_) {} - virtual Var &operator=(const Var &other) { - if (&other == this) { - return *this; - } - this->tag_ = other.tag_; - return *this; - } - ~Var() override = default; - MS_DECLARE_PARENT(Var, Base); - - virtual bool matches(const BaseRef &) { return true; } - - virtual bool operator==(const Var &other) const { return tag_ == other.tag_; } - bool operator!=(const Var &other) const { return !(&other == this); } - - std::string tag() const { return tag_; } - PrimitivePtr primitive() const { return primitive_; } - std::string ToString() const override { - std::ostringstream buffer; - buffer << "Var(" << tag_ << ")"; - return buffer.str(); - } - std::size_t hash() const override { return std::hash()(tag_); } - - protected: - void EnsureTag(); - - std::string tag_; - PrimitivePtr primitive_; -}; - -// VarNode means variable node, a subclass of AnfNode -class VarNode : public AnfNode { - public: - VarNode(const VarPtr &value, const FuncGraphPtr &func_graph) : AnfNode(func_graph), var_(value) {} - ~VarNode() override = default; - MS_DECLARE_PARENT(VarNode, AnfNode); - - const VarPtr var_; -}; -using VarNodePtr = std::shared_ptr; - -class VarHasher { - public: - std::size_t operator()(const Var &var) const { return var.hash(); } -}; - -// Condition Var, match an anf node when condition function return true. -class CondVar : public Var { - public: - explicit CondVar(const ConditionFunc &cond) : cond_fn_(cond) {} - ~CondVar() override = default; - MS_DECLARE_PARENT(CondVar, Var); - bool matches(const BaseRef &value) override { - // MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString(); - if (utils::isa(value)) { - return false; - } - return cond_fn_(value); - } - ConditionFunc cond_fn_; -}; - -using Seq = VectorRef; -using SeqPtr = std::shared_ptr; - -// Sequence Var which could match multiple consecutive input nodes of a CNode. -class SeqVar : public Var { - public: - SeqVar() : subvar_(std::make_shared()) {} - ~SeqVar() override = default; - MS_DECLARE_PARENT(SeqVar, Var); - explicit SeqVar(const VarPtr subvar) : subvar_(subvar) {} - bool matches(const BaseRef &value) override { - // match Seq. - if (utils::isa(value)) { - const Seq &seq = utils::cast(value); - return std::all_of(seq.begin(), seq.end(), [this](const BaseRef &v) { - auto eq = subvar_->matches(v); - return eq; - }); - } - return false; - } - bool operator==(const SeqVar &other) const { return *subvar_ == *other.subvar_; } - std::string ToString() const override; - - private: - VarPtr subvar_; -}; - -bool operator==(const VarPtr &lhs, const VarPtr &rhs); - -inline bool operator!=(const VarPtr &lhs, const VarPtr &rhs) { return !(lhs == rhs); } - -std::ostream &operator<<(std::ostream &os, const VarPtr &var); - -using Equiv = std::map; -using EquivPtr = std::shared_ptr; -using PrimitiveVarMap = std::unordered_map; -using PrimitiveVarMapPtr = std::shared_ptr; - -inline bool DefaultTypeEq(const BaseRef &x, const BaseRef &y) { return x.type() == y.type(); } - -class PatternEngine { - public: - PatternEngine(const std::shared_ptr &visitor, - const std::function &eq, - const std::function &type_eq = DefaultTypeEq) - : visitor_(visitor), eq_(eq), type_eq_(type_eq) {} - ~PatternEngine() = default; - - EquivPtr Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars, - EquivPtr equiv) const; - // Replace pattern with equivalent - BaseRef Replace(const BaseRef &pattern, const EquivPtr &equiv) const; - - private: - EquivPtr AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr, - const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const; - bool ToVector(const BaseRef &pattern, const BaseRef &expr, VectorRef *const values_pattern, - VectorRef *const values_expr) const; - bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, - VectorRef *const values_expr) const; - std::shared_ptr visitor_; - std::function eq_; - std::function type_eq_; -}; -} // namespace mindspore -namespace std { -using mindspore::ERROR; -using mindspore::LogStream; -using mindspore::NoExceptionType; -template <> -struct hash { - std::size_t operator()(const mindspore::VarPtr var) const { - if (var == nullptr) { - MS_LOG(ERROR) << "Invalid var ptr"; - return 0; - } - return std::hash{}(var->tag()); - } -}; -} // namespace std -#endif // MINDSPORE_LITE_SRC_PASS_COMMON_PATTERN_ENGINE_H_ - diff --git a/mindspore/lite/src/gllo/common/visit.cc b/mindspore/lite/src/gllo/common/visit.cc deleted file mode 100644 index d00744e6563..00000000000 --- a/mindspore/lite/src/gllo/common/visit.cc +++ /dev/null @@ -1,165 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -#include -#include -#include -#include - -#include "src/gllo/common/visit.h" -#include "src/gllo/common/pattern_engine.h" -#include "utils/any.h" -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "utils/log_adapter.h" - - -namespace mindspore { -bool CheckIfNeedExpand(const std::vector &list) { - return std::any_of(list.begin(), list.end(), [](const BaseRef &any) { return utils::isa(any); }); -} - -std::shared_ptr ExpandList(const std::vector &list) { - std::shared_ptr new_list = std::make_shared(); - for (auto &item : list) { - if (utils::isa(item)) { - const Seq &seq = utils::cast(item); - new_list->insert(new_list->end(), seq.begin(), seq.end()); - } else { - new_list->push_back(item); - } - } - return new_list; -} - -bool DefaultVisitor::Visit(const VectorRef &v_any, BaseRef *const visit_out) const { - std::vector out; - (void)std::transform(v_any.begin(), v_any.end(), std::back_inserter(out), - [this](const BaseRef &item) { return fn_(item); }); - if (visit_out != nullptr) { - *visit_out = ExpandList(out); - } - return true; -} - -bool DefaultVisitor::Visit(const BaseRef &any, BaseRef *const visit_out) const { - if (utils::isa(any)) { - return Visit(utils::cast(any), visit_out); - } else if (utils::isa(any)) { - auto nodeptr = utils::cast(any); - AnfNodePtr output; - AnfNodePtr *p_output = &output; - if (visit_out == nullptr) { - p_output = nullptr; - } - Visit(nodeptr, fn_, p_output); - if (visit_out != nullptr) { - *visit_out = output; - } - return true; - } - MS_LOG(DEBUG) << "VisitError, not support type to Visit: " + any.ToString(); - return false; -} - -void DefaultVisitor::Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr *output) const { - if (node->isa()) { - Visit(node->cast(), fn, output); - return; - } - - if (node->isa()) { - Visit(node->cast(), fn, output); - return; - } - - if (output != nullptr) { - *output = node; - } -} - -void DefaultVisitor::Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr *output) const { - // if output is nullptr, it's not required to make the new CNode node. - if (output == nullptr) { - for (auto &inp : cnode->inputs()) { - (void)fn(inp); - } - - if (cnode->func_graph() != nullptr) { - (void)fn(cnode->func_graph()); - } else { - (void)fn(cnode->func_graph_as_var()); - } - return; - } - - std::vector new_inputs; - std::vector after_cnode_fn; - std::shared_ptr out; - (void)std::transform(cnode->inputs().begin(), cnode->inputs().end(), std::back_inserter(after_cnode_fn), fn); - if (CheckIfNeedExpand(after_cnode_fn)) { - out = ExpandList(after_cnode_fn); - } - - std::vector &outs = after_cnode_fn; - if (out != nullptr) { - outs = out->elements(); - } - - for (auto &any_item : outs) { - if (!utils::isa(any_item)) { - MS_LOG(EXCEPTION) << "VisitError, fn not return the same type AnfNodePtr"; - } - new_inputs.push_back(utils::cast(any_item)); - } - - BaseRef any_fg; - AnfNodePtr new_cnode = nullptr; - if (cnode->func_graph() != nullptr) { - any_fg = fn(cnode->func_graph()); - if (!utils::isa(any_fg)) { - MS_LOG(EXCEPTION) << "VisitError, fn not return the same type FuncGraphPtr"; - } - new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); - } else { - any_fg = fn(cnode->func_graph_as_var()); - if (utils::isa(any_fg)) { - new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); - } else if (utils::isa(any_fg)) { - new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); - } else { - MS_LOG(EXCEPTION) << "VisitError, fn not return VarPtr or FuncGraphPtr"; - } - } - new_cnode->set_abstract(cnode->abstract()); - *output = new_cnode; -} - -void DefaultVisitor::Visit(const ValueNodePtr &vnode, const VisitFn &fn, AnfNodePtr *output) const { - const BaseRef &value = utils::cast(fn(vnode->value())); - if (utils::isa(value)) { - if (output != nullptr) { - auto ct = NewValueNode(utils::cast(value)); - ct->set_abstract(vnode->abstract()); - *output = ct; - } - return; - } - MS_LOG(EXCEPTION) << "Visit result is not ValuePtr."; -} -} // namespace mindspore - diff --git a/mindspore/lite/src/gllo/common/visit.h b/mindspore/lite/src/gllo/common/visit.h deleted file mode 100644 index 548e5e033d8..00000000000 --- a/mindspore/lite/src/gllo/common/visit.h +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LIFT_SRC_PASS_COMMON_VISIT_H_ -#define MINDSPORE_LIFT_SRC_PASS_COMMON_VISIT_H_ - -#include -#include -#include -#include -#include -#include - -#include "mindspore/core/base/base.h" -#include "base/base_ref.h" - -namespace mindspore { -using VisitFn = std::function; - -class Visitor { - public: - virtual void SetFn(VisitFn fn) = 0; - virtual bool Visit(const BaseRef &e, BaseRef *out) const = 0; - virtual bool Visit(const VectorRef &e, BaseRef *out) const = 0; - virtual ~Visitor() = default; -}; - -class DefaultVisitor : public Visitor { - public: - DefaultVisitor() : fn_(nullptr) {} - ~DefaultVisitor() override = default; - void SetFn(VisitFn fn) override { fn_ = fn; }; - bool Visit(const VectorRef &e, BaseRef *out) const override; - bool Visit(const BaseRef &e, BaseRef *out) const override; - void Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr *output) const; - void Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr *output) const; - void Visit(const ValueNodePtr &vnode, const VisitFn &fn, AnfNodePtr *output) const; - - VisitFn fn_; -}; - -std::shared_ptr ExpandList(const std::vector &list); -bool CheckIfNeedExpand(const std::vector &list); -} // namespace mindspore -#endif // MINDSPORE_LIFT_SRC_PASS_COMMON_VISIT_H_ - diff --git a/mindspore/lite/src/gllo/fusion/conv_activation_fusion.cc b/mindspore/lite/src/gllo/fusion/conv_activation_fusion.cc index 784b607036c..2acf051b9eb 100644 --- a/mindspore/lite/src/gllo/fusion/conv_activation_fusion.cc +++ b/mindspore/lite/src/gllo/fusion/conv_activation_fusion.cc @@ -14,12 +14,12 @@ * limitations under the License. */ -#include "mindspore/lite/src/gllo/fusion/conv_activation_fusion.h" +#include "src/gllo/fusion/conv_activation_fusion.h" #include -#include "mindspore/lite/schema/inner/model_generated.h" -#include "mindspore/lite/src/ir/primitive_t_value.h" -#include "mindspore/ccsrc/utils/utils.h" -#include "mindspore/lite/src/gllo/common/utils.h" +#include "schema/inner/model_generated.h" +#include "src/ir/primitive_t_value.h" +#include "utils/utils.h" +#include "src/gllo/common/gllo_utils.h" namespace mindspore::opt { namespace { diff --git a/mindspore/lite/src/gllo/fusion/conv_activation_fusion.h b/mindspore/lite/src/gllo/fusion/conv_activation_fusion.h index 70aa2622433..760419fbe12 100644 --- a/mindspore/lite/src/gllo/fusion/conv_activation_fusion.h +++ b/mindspore/lite/src/gllo/fusion/conv_activation_fusion.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_ #include -#include "mindspore/lite/src/gllo/common/optimizer.h" +#include "src/gllo/common/optimizer.h" namespace mindspore { namespace opt { diff --git a/mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.cc b/mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.cc index dabcf95f708..16cb3b8b365 100644 --- a/mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.cc +++ b/mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.cc @@ -13,13 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.h" -#include +#include "src/gllo/fusion/conv_biasadd_fusion.h" #include -#include "mindspore/lite/schema/inner/model_generated.h" -#include "mindspore/lite/src/ir/primitive_t_value.h" -#include "mindspore/ccsrc/utils/utils.h" -#include "mindspore/lite/src/gllo/common/utils.h" +#include "src/param_value_lite.h" +#include "schema/inner/model_generated.h" +#include "src/ir/primitive_t_value.h" +#include "utils/utils.h" +#include "src/gllo/common/gllo_utils.h" #include "securec/include/securec.h" namespace mindspore::opt { @@ -142,7 +142,7 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons CheckIfCNodeIsNull(conv_node); GenConvNewBias(func_graph, conv_node, add_node); auto primitiveT_value = GetValueNode>(conv_node->input(0)); - MS_ASSERT(primitiveT_value); + MS_ASSERT(primitiveT_value != nullptr); auto type = primitiveT_value->GetPrimitiveT()->value.type; if (type == schema::PrimitiveType_Conv2D) { primitiveT_value->GetPrimitiveT()->value.AsConv2D()->hasBias = true; diff --git a/mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.h b/mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.h index 72d903feb41..df0f393ad51 100644 --- a/mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.h +++ b/mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_ #define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_ -#include "mindspore/lite/src/gllo/common/optimizer.h" +#include "src/gllo/common/optimizer.h" namespace mindspore { namespace opt { diff --git a/mindspore/lite/src/gllo/fusion/conv_bn_fusion.cc b/mindspore/lite/src/gllo/fusion/conv_bn_fusion.cc index 340cfc73b14..5f110a419da 100644 --- a/mindspore/lite/src/gllo/fusion/conv_bn_fusion.cc +++ b/mindspore/lite/src/gllo/fusion/conv_bn_fusion.cc @@ -14,13 +14,13 @@ * limitations under the License. */ -#include "mindspore/lite/src/gllo/fusion/conv_bn_fusion.h" -#include +#include "src/gllo/fusion/conv_bn_fusion.h" #include -#include "mindspore/lite/schema/inner/model_generated.h" -#include "mindspore/lite/src/ir/primitive_t_value.h" -#include "mindspore/ccsrc/utils/utils.h" -#include "mindspore/lite/src/gllo/common/utils.h" +#include "src/param_value_lite.h" +#include "schema/inner/model_generated.h" +#include "src/ir/primitive_t_value.h" +#include "utils/utils.h" +#include "src/gllo/common/gllo_utils.h" #include "securec/include/securec.h" namespace mindspore::opt { diff --git a/mindspore/lite/src/gllo/fusion/conv_bn_fusion.h b/mindspore/lite/src/gllo/fusion/conv_bn_fusion.h index c5a9dbc5ee5..56dc12a2f97 100644 --- a/mindspore/lite/src/gllo/fusion/conv_bn_fusion.h +++ b/mindspore/lite/src/gllo/fusion/conv_bn_fusion.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_ #define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_ -#include "mindspore/lite/src/gllo/fusion/conv_transform_fusion.h" +#include "src/gllo/fusion/conv_transform_fusion.h" namespace mindspore::opt { class ConvBatchNormFusion : public ConvTransformFusion { diff --git a/mindspore/lite/src/gllo/fusion/conv_scale_fusion.cc b/mindspore/lite/src/gllo/fusion/conv_scale_fusion.cc index accc310c2d6..d4d535c2ce1 100644 --- a/mindspore/lite/src/gllo/fusion/conv_scale_fusion.cc +++ b/mindspore/lite/src/gllo/fusion/conv_scale_fusion.cc @@ -14,13 +14,13 @@ * limitations under the License. */ -#include "mindspore/lite/src/gllo/fusion/conv_scale_fusion.h" +#include "src/gllo/fusion/conv_scale_fusion.h" #include -#include "mindspore/lite/src/param_value_lite.h" -#include "mindspore/lite/schema/inner/model_generated.h" -#include "mindspore/lite/src/ir/primitive_t_value.h" -#include "mindspore/ccsrc/utils/utils.h" -#include "mindspore/lite/src/gllo/common/utils.h" +#include "src/param_value_lite.h" +#include "schema/inner/model_generated.h" +#include "src/ir/primitive_t_value.h" +#include "utils/utils.h" +#include "src/gllo/common/gllo_utils.h" #include "include/errorcode.h" #include "securec/include/securec.h" diff --git a/mindspore/lite/src/gllo/fusion/conv_scale_fusion.h b/mindspore/lite/src/gllo/fusion/conv_scale_fusion.h index 35b5d794417..5472ebad464 100644 --- a/mindspore/lite/src/gllo/fusion/conv_scale_fusion.h +++ b/mindspore/lite/src/gllo/fusion/conv_scale_fusion.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_ #define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_ -#include "mindspore/lite/src/gllo/fusion/conv_transform_fusion.h" +#include "src/gllo/fusion/conv_transform_fusion.h" namespace mindspore::opt { class ConvScaleFusion : public ConvTransformFusion { diff --git a/mindspore/lite/src/gllo/fusion/conv_transform_fusion.cc b/mindspore/lite/src/gllo/fusion/conv_transform_fusion.cc index 82971c947eb..515d54ec645 100644 --- a/mindspore/lite/src/gllo/fusion/conv_transform_fusion.cc +++ b/mindspore/lite/src/gllo/fusion/conv_transform_fusion.cc @@ -14,13 +14,13 @@ * limitations under the License. */ -#include "mindspore/lite/src/gllo/fusion/conv_transform_fusion.h" +#include "src/gllo/fusion/conv_transform_fusion.h" #include -#include "mindspore/lite/src/param_value_lite.h" -#include "mindspore/lite/schema/inner/model_generated.h" -#include "mindspore/lite/src/ir/primitive_t_value.h" -#include "mindspore/ccsrc/utils/utils.h" -#include "mindspore/lite/src/gllo/common/utils.h" +#include "src/param_value_lite.h" +#include "schema/inner/model_generated.h" +#include "src/ir/primitive_t_value.h" +#include "utils/utils.h" +#include "src/gllo/common/gllo_utils.h" #include "include/errorcode.h" #include "securec/include/securec.h" @@ -78,6 +78,16 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co GenNewConvTensor(func_graph, conv_node, kernel_nums, trans_scale, trans_bias); delete[] trans_bias; delete[] trans_scale; + auto primitiveT_value = GetValueNode>(conv_node->input(0)); + MS_ASSERT(primitiveT_value != nullptr); + auto type = primitiveT_value->GetPrimitiveT()->value.type; + if (type == schema::PrimitiveType_Conv2D) { + primitiveT_value->GetPrimitiveT()->value.AsConv2D()->hasBias = true; + } else if (type == schema::PrimitiveType_DepthwiseConv2D) { + primitiveT_value->GetPrimitiveT()->value.AsDepthwiseConv2D()->hasBias = true; + } else { + MS_LOG(EXCEPTION) << "Unsupported opType, " << type; + } return pre_node; } diff --git a/mindspore/lite/src/gllo/fusion/conv_transform_fusion.h b/mindspore/lite/src/gllo/fusion/conv_transform_fusion.h index 98bb3618a7d..83f66dc7a11 100644 --- a/mindspore/lite/src/gllo/fusion/conv_transform_fusion.h +++ b/mindspore/lite/src/gllo/fusion/conv_transform_fusion.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_ #include -#include "mindspore/lite/src/gllo/common/optimizer.h" +#include "src/gllo/common/optimizer.h" namespace mindspore::opt { class ConvTransformFusion : public PatternProcessPass { diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index df077e8ca72..f972c81b3be 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -63,6 +63,8 @@ if(BUILD_CONVERTER) ${CCSRC_DIR}/pybind_api/export_flags.cc ${CCSRC_DIR}/utils/context/context_extends.cc ${CCSRC_DIR}/frontend/parallel/costmodel_context.cc + ${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc + ${CCSRC_DIR}/backend/optimizer/common/visit.cc ${CMAKE_CURRENT_SOURCE_DIR}/../src/common/graph_utils_extends.cc ) else() @@ -202,12 +204,14 @@ if(BUILD_CONVERTER) ${LITE_DIR}/tools/converter/converter.cc ${LITE_DIR}/tools/converter/parser/onnx/onnx.pb.cc ${LITE_DIR}/test/st/converter_test.cc + ${LITE_DIR}/test/ut/src/gllo/fusion/conv_activation_fusion_test.cc + ${LITE_DIR}/test/ut/src/gllo/fusion/conv_biasadd_fusion_test.cc + ${LITE_DIR}/test/ut/src/gllo/fusion/conv_bn_fusion_test.cc + ${LITE_DIR}/test/ut/src/gllo/fusion/conv_scale_fusion_test.cc ${LITE_DIR}/src/gllo/common/node_pass.cc ${LITE_DIR}/src/gllo/common/optimizer.cc ${LITE_DIR}/src/gllo/common/pass_manager.cc - ${LITE_DIR}/src/gllo/common/pattern_engine.cc - ${LITE_DIR}/src/gllo/common/visit.cc - ${LITE_DIR}/src/gllo/common/utils.cc + ${LITE_DIR}/src/gllo/common/gllo_utils.cc ${LITE_DIR}/src/gllo/fusion/conv_biasadd_fusion.cc ${LITE_DIR}/src/gllo/fusion/conv_activation_fusion.cc ${LITE_DIR}/src/gllo/fusion/conv_transform_fusion.cc diff --git a/mindspore/lite/test/ut/src/gllo/fusion/conv_activation_fusion_test.cc b/mindspore/lite/test/ut/src/gllo/fusion/conv_activation_fusion_test.cc new file mode 100644 index 00000000000..6fac6dedc4b --- /dev/null +++ b/mindspore/lite/test/ut/src/gllo/fusion/conv_activation_fusion_test.cc @@ -0,0 +1,184 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "schema/inner/model_generated.h" +#include "include/model.h" +#include "common/common_test.h" +#include "include/lite_session.h" +#include "include/context.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "tools/converter/model_parser.h" +#include "tools/converter/anf_transform.h" +#include "src/common/anf_exporter/anf_exporter.h" + +namespace mindspore { +class ConvActivationFusionTest : public mindspore::Common { + public: + ConvActivationFusionTest() = default; +}; +using MetaGraphTptr = std::shared_ptr; +using CNodeTptr = std::unique_ptr; + +namespace { +CNodeTptr BuildConv2D() { + auto convNode = std::make_unique(); + convNode->inputIndex = {0, 1}; + convNode->outputIndex = {2}; + convNode->primitive = std::make_unique(); + convNode->primitive->value.type = schema::PrimitiveType_Conv2D; + auto prim1 = new schema::Conv2DT; + prim1->padMode = schema::PadMode_SAME; + prim1->format = schema::Format_NHWC; + prim1->strideH = 1; + prim1->strideW = 1; + prim1->kernelH = 3; + prim1->kernelW = 3; + prim1->dilateH = 1; + prim1->dilateW = 1; + prim1->channelOut = 3; + convNode->primitive->value.value = prim1; + convNode->name = "Conv2D"; + return convNode; +} +CNodeTptr BuildDepthwiseConv2D() { + auto convNode = std::make_unique(); + convNode->inputIndex = {0, 1}; + convNode->outputIndex = {2}; + convNode->primitive = std::make_unique(); + convNode->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + auto prim1 = new schema::DepthwiseConv2DT; + prim1->padMode = schema::PadMode_SAME; + prim1->format = schema::Format_NHWC; + prim1->strideH = 1; + prim1->strideW = 1; + prim1->kernelH = 3; + prim1->kernelW = 3; + prim1->dilateH = 1; + prim1->dilateW = 1; + prim1->channelIn = 1; + prim1->channelMultiplier = 3; + convNode->primitive->value.value = prim1; + convNode->name = "Conv2D"; + return convNode; +} + +MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type, + schema::ActivationType activation_type) { + auto meta_graph = std::make_shared(); + meta_graph->name = "graph"; + // conv node + CNodeTptr convNode; + if (conv_type == schema::PrimitiveType_Conv2D) { + convNode = BuildConv2D(); + } else { + convNode = BuildDepthwiseConv2D(); + } + meta_graph->nodes.emplace_back(std::move(convNode)); + + // relu node + auto next_node = std::make_unique(); + next_node->inputIndex = {2}; + next_node->outputIndex = {3}; + next_node->primitive = std::make_unique(); + next_node->primitive->value.type = schema::PrimitiveType_Activation; + auto prim2 = new schema::ActivationT; + prim2->type = activation_type; + next_node->primitive->value.value = prim2; + next_node->name = "activation"; + meta_graph->nodes.emplace_back(std::move(next_node)); + + meta_graph->inputIndex = {0}; + meta_graph->outputIndex = {3}; + + // input 0: data + auto input0 = std::make_unique(); + input0->nodeType = schema::NodeType::NodeType_ValueNode; + input0->format = schema::Format_NHWC; + input0->dataType = TypeId::kNumberTypeFloat32; + input0->dims = {1, 5, 5, 3}; + input0->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input0)); + + // input 1: weight + auto input1 = std::make_unique(); + input1->nodeType = schema::NodeType::NodeType_ValueNode; + input1->format = schema::Format_KHWC; + input1->dataType = TypeId::kNumberTypeFloat32; + input1->dims = {8, 3, 3, 3}; + input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3); + meta_graph->allTensors.emplace_back(std::move(input1)); + + // conv output + auto conv_output = std::make_unique(); + conv_output->nodeType = schema::NodeType::NodeType_Parameter; + conv_output->format = schema::Format_NHWC; + conv_output->dataType = TypeId::kNumberTypeFloat32; + conv_output->dims = {1, 5, 5, 8}; + meta_graph->allTensors.emplace_back(std::move(conv_output)); + + // final output + auto output = std::make_unique(); + output->nodeType = schema::NodeType::NodeType_Parameter; + output->format = schema::Format_NHWC; + output->dataType = TypeId::kNumberTypeFloat32; + output->dims = {1, 5, 5, 8}; + meta_graph->allTensors.emplace_back(std::move(output)); + return meta_graph; +} +} // namespace +TEST_F(ConvActivationFusionTest, TestConvReluNode) { + auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2D, schema::ActivationType_RELU); + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto anf_transform = new lite::AnfTransform(); + auto new_graph = anf_transform->Transform(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 1); + for (auto &cnode : new_meta_graph->nodes) { + ASSERT_EQ(cnode->primitive->value.AsConv2D()->activationType, schema::ActivationType_RELU); + } +} + +TEST_F(ConvActivationFusionTest, TestConvRelu6Node) { + auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2D, schema::ActivationType_RELU6); + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto anf_transform = new lite::AnfTransform(); + auto new_graph = anf_transform->Transform(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 1); + for (auto &cnode : new_meta_graph->nodes) { + ASSERT_EQ(cnode->primitive->value.AsConv2D()->activationType, schema::ActivationType_RELU6); + } +} + +TEST_F(ConvActivationFusionTest, TestBadCase_ConvRelu) { + auto meta_graph = BuildGraph(schema::PrimitiveType_DepthwiseConv2D, schema::ActivationType_LEAKY_RELU); + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto anf_transform = new lite::AnfTransform(); + auto new_graph = anf_transform->Transform(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 2); + for (auto &cnode : new_meta_graph->nodes) { + if (cnode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { + ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->activationType, schema::ActivationType_NO_ACTIVATION); + } + } +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/gllo/fusion/conv_biasadd_fusion_test.cc b/mindspore/lite/test/ut/src/gllo/fusion/conv_biasadd_fusion_test.cc new file mode 100644 index 00000000000..ef9fd871150 --- /dev/null +++ b/mindspore/lite/test/ut/src/gllo/fusion/conv_biasadd_fusion_test.cc @@ -0,0 +1,194 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "schema/inner/model_generated.h" +#include "include/model.h" +#include "common/common_test.h" +#include "include/lite_session.h" +#include "include/context.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "tools/converter/model_parser.h" +#include "tools/converter/anf_transform.h" +#include "src/common/anf_exporter/anf_exporter.h" + +namespace mindspore { +class ConvBiasAddFusionTest : public mindspore::Common { + public: + ConvBiasAddFusionTest() = default; +}; +using MetaGraphTptr = std::shared_ptr; +using CNodeTptr = std::unique_ptr; + +namespace { +CNodeTptr BuildConv2D() { + auto convNode = std::make_unique(); + convNode->inputIndex = {0, 1}; + convNode->outputIndex = {2}; + convNode->primitive = std::make_unique(); + convNode->primitive->value.type = schema::PrimitiveType_Conv2D; + auto prim1 = new schema::Conv2DT; + prim1->padMode = schema::PadMode_SAME; + prim1->format = schema::Format_NHWC; + prim1->strideH = 1; + prim1->strideW = 1; + prim1->kernelH = 3; + prim1->kernelW = 3; + prim1->dilateH = 1; + prim1->dilateW = 1; + prim1->channelOut = 3; + convNode->primitive->value.value = prim1; + convNode->name = "Conv2D"; + return convNode; +} +CNodeTptr BuildDepthwiseConv2D() { + auto convNode = std::make_unique(); + convNode->inputIndex = {0, 1}; + convNode->outputIndex = {2}; + convNode->primitive = std::make_unique(); + convNode->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + auto prim1 = new schema::DepthwiseConv2DT; + prim1->padMode = schema::PadMode_SAME; + prim1->format = schema::Format_NHWC; + prim1->strideH = 1; + prim1->strideW = 1; + prim1->kernelH = 3; + prim1->kernelW = 3; + prim1->dilateH = 1; + prim1->dilateW = 1; + prim1->channelIn = 1; + prim1->channelMultiplier = 3; + convNode->primitive->value.value = prim1; + convNode->name = "Conv2D"; + return convNode; +} + +MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type, + schema::PrimitiveType add_type) { + auto meta_graph = std::make_shared(); + meta_graph->name = "graph"; + // conv node + CNodeTptr convNode; + if (conv_type == schema::PrimitiveType_Conv2D) { + convNode = BuildConv2D(); + } else { + convNode = BuildDepthwiseConv2D(); + } + + meta_graph->nodes.emplace_back(std::move(convNode)); + + // biasadd node + auto biasadd_node = std::make_unique(); + biasadd_node->inputIndex = {2, 3}; + biasadd_node->outputIndex = {4}; + biasadd_node->primitive = std::make_unique(); + biasadd_node->primitive->value.type = add_type; + auto prim2 = new schema::BiasAddT; + biasadd_node->primitive->value.value = prim2; + biasadd_node->name = "BiasAdd"; + meta_graph->nodes.emplace_back(std::move(biasadd_node)); + + meta_graph->inputIndex = {0}; + meta_graph->outputIndex = {4}; + + // input 0: data + auto input0 = std::make_unique(); + input0->nodeType = schema::NodeType::NodeType_ValueNode; + input0->format = schema::Format_NHWC; + input0->dataType = TypeId::kNumberTypeFloat32; + input0->dims = {1, 5, 5, 3}; + input0->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input0)); + + // input 1: weight + auto input1 = std::make_unique(); + input1->nodeType = schema::NodeType::NodeType_ValueNode; + input1->format = schema::Format_KHWC; + input1->dataType = TypeId::kNumberTypeFloat32; + input1->dims = {8, 3, 3, 3}; + input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3); + meta_graph->allTensors.emplace_back(std::move(input1)); + + // conv output + auto conv_output = std::make_unique(); + conv_output->nodeType = schema::NodeType::NodeType_Parameter; + conv_output->format = schema::Format_NHWC; + conv_output->dataType = TypeId::kNumberTypeFloat32; + conv_output->dims = {1, 5, 5, 8}; + meta_graph->allTensors.emplace_back(std::move(conv_output)); + + // input2: bias + auto input2 = std::make_unique(); + input2->nodeType = schema::NodeType::NodeType_ValueNode; + input2->format = schema::Format_NHWC; + input2->dataType = TypeId::kNumberTypeFloat32; + input2->dims = {1, 5, 5, 8}; + input2->data.resize(sizeof(float) * 8 * 5 * 5); + meta_graph->allTensors.emplace_back(std::move(input2)); + + // final output + auto output = std::make_unique(); + output->nodeType = schema::NodeType::NodeType_Parameter; + output->format = schema::Format_NHWC; + output->dataType = TypeId::kNumberTypeFloat32; + output->dims = {1, 5, 5, 8}; + meta_graph->allTensors.emplace_back(std::move(output)); + return meta_graph; +} +} // namespace +TEST_F(ConvBiasAddFusionTest, TestConvAddNode) { + auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2D, schema::PrimitiveType_BiasAdd); + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto anf_transform = new lite::AnfTransform(); + auto new_graph = anf_transform->Transform(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 1); + for (auto &cnode : new_meta_graph->nodes) { + ASSERT_EQ(cnode->primitive->value.AsConv2D()->hasBias, true); + } + MS_LOG(INFO) << "Passed"; +} + +TEST_F(ConvBiasAddFusionTest, TestDeptiwiseConvAddNode) { + auto meta_graph = BuildGraph(schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add); + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto anf_transform = new lite::AnfTransform(); + auto new_graph = anf_transform->Transform(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 1); + for (auto &cnode : new_meta_graph->nodes) { + ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->hasBias, true); + } +} + +TEST_F(ConvBiasAddFusionTest, TestBadCase_ConvAdd) { + auto meta_graph = BuildGraph(schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_MatMul); + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto anf_transform = new lite::AnfTransform(); + auto new_graph = anf_transform->Transform(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 2); + for (auto &cnode : new_meta_graph->nodes) { + if (cnode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { + ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->hasBias, false); + } + } +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/gllo/fusion/conv_bn_fusion_test.cc b/mindspore/lite/test/ut/src/gllo/fusion/conv_bn_fusion_test.cc new file mode 100644 index 00000000000..1906e71b149 --- /dev/null +++ b/mindspore/lite/test/ut/src/gllo/fusion/conv_bn_fusion_test.cc @@ -0,0 +1,296 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "schema/inner/model_generated.h" +#include "include/model.h" +#include "common/common_test.h" +#include "include/lite_session.h" +#include "include/context.h" +#include "include/errorcode.h" +#include "mindspore/core/utils/log_adapter.h" +#include "tools/converter/model_parser.h" +#include "tools/converter/anf_transform.h" +#include "src/common/anf_exporter/anf_exporter.h" + +namespace mindspore { +class ConvBNFusionTest : public mindspore::Common { + public: + ConvBNFusionTest() = default; +}; +using MetaGraphTptr = std::shared_ptr; +using CNodeTptr = std::unique_ptr; + +namespace { +CNodeTptr BuildConv2D() { + auto convNode = std::make_unique(); + convNode->inputIndex = {0, 1}; + convNode->outputIndex = {2}; + convNode->primitive = std::make_unique(); + convNode->primitive->value.type = schema::PrimitiveType_Conv2D; + auto prim1 = new schema::Conv2DT; + prim1->padMode = schema::PadMode_SAME; + prim1->format = schema::Format_NHWC; + prim1->strideH = 1; + prim1->strideW = 1; + prim1->kernelH = 3; + prim1->kernelW = 3; + prim1->dilateH = 1; + prim1->dilateW = 1; + prim1->channelOut = 3; + convNode->primitive->value.value = prim1; + convNode->name = "Conv2D"; + return convNode; +} +CNodeTptr BuildDepthwiseConv2D() { + auto convNode = std::make_unique(); + convNode->inputIndex = {0, 1, 2}; + convNode->outputIndex = {3}; + convNode->primitive = std::make_unique(); + convNode->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + auto prim1 = new schema::DepthwiseConv2DT; + prim1->padMode = schema::PadMode_SAME; + prim1->format = schema::Format_NHWC; + prim1->strideH = 1; + prim1->strideW = 1; + prim1->kernelH = 3; + prim1->kernelW = 3; + prim1->dilateH = 1; + prim1->dilateW = 1; + prim1->channelIn = 1; + prim1->channelMultiplier = 3; + + convNode->primitive->value.value = prim1; + convNode->name = "Conv2D"; + return convNode; +} +// caffe bn op has 3 inputs +MetaGraphTptr BuildCaffeGraph(schema::PrimitiveType conv_type) { + auto meta_graph = std::make_shared(); + meta_graph->name = "graph"; + // conv node + CNodeTptr convNode; + if (conv_type == schema::PrimitiveType_Conv2D) { + convNode = BuildConv2D(); + } else { + convNode = BuildDepthwiseConv2D(); + } + + meta_graph->nodes.emplace_back(std::move(convNode)); + + // bn_node + auto bn_node = std::make_unique(); + bn_node->inputIndex = {2, 3, 4}; + bn_node->outputIndex = {5}; + bn_node->primitive = std::make_unique(); + bn_node->primitive->value.type = schema::PrimitiveType_CaffeBatchNorm; + auto prim2 = new schema::CaffeBatchNormT; + bn_node->primitive->value.value = prim2; + bn_node->name = "bn"; + meta_graph->nodes.emplace_back(std::move(bn_node)); + + // input 0: data + auto input0 = std::make_unique(); + input0->nodeType = schema::NodeType::NodeType_ValueNode; + input0->format = schema::Format_NHWC; + input0->dataType = TypeId::kNumberTypeFloat32; + input0->dims = {1, 5, 5, 3}; + input0->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input0)); + + // input 1: weight + auto input1 = std::make_unique(); + input1->nodeType = schema::NodeType::NodeType_ValueNode; + input1->format = schema::Format_KHWC; + input1->dataType = TypeId::kNumberTypeFloat32; + input1->dims = {8, 3, 3, 3}; + input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3); + meta_graph->allTensors.emplace_back(std::move(input1)); + + // conv output + auto conv_output = std::make_unique(); + conv_output->nodeType = schema::NodeType::NodeType_Parameter; + conv_output->format = schema::Format_NHWC; + conv_output->dataType = TypeId::kNumberTypeFloat32; + conv_output->dims = {1, 5, 5, 8}; + meta_graph->allTensors.emplace_back(std::move(conv_output)); + + // caffe bn : mean + auto input2 = std::make_unique(); + input2->nodeType = schema::NodeType::NodeType_ValueNode; + input2->format = schema::Format_NHWC; + input2->dataType = TypeId::kNumberTypeFloat32; + input2->dims = {1, 5, 5, 8}; + input2->data.resize(sizeof(float) * 8 * 5 * 5); + meta_graph->allTensors.emplace_back(std::move(input2)); + + // caffe bn : var + auto input3 = std::make_unique(); + input3->nodeType = schema::NodeType::NodeType_ValueNode; + input3->format = schema::Format_NHWC; + input3->dataType = TypeId::kNumberTypeFloat32; + input3->dims = {1, 5, 5, 8}; + input3->data.resize(sizeof(float) * 8 * 5 * 5); + meta_graph->allTensors.emplace_back(std::move(input3)); + + + // final bn output + auto output = std::make_unique(); + output->nodeType = schema::NodeType::NodeType_Parameter; + output->format = schema::Format_NHWC; + output->dataType = TypeId::kNumberTypeFloat32; + output->dims = {1, 5, 5, 8}; + meta_graph->allTensors.emplace_back(std::move(output)); + meta_graph->inputIndex = {0}; + meta_graph->outputIndex = {5}; + return meta_graph; +} + +// tf bn op has 4 inputs +MetaGraphTptr BuildTFGraph(schema::PrimitiveType conv_type) { + auto meta_graph = std::make_shared(); + meta_graph->name = "graph"; + // conv node + CNodeTptr convNode; + if (conv_type == schema::PrimitiveType_Conv2D) { + convNode = BuildConv2D(); + } else { + convNode = BuildDepthwiseConv2D(); + } + + meta_graph->nodes.emplace_back(std::move(convNode)); + + // bn_node + auto bn_node = std::make_unique(); + bn_node->inputIndex = {3, 4, 5, 6, 7}; + bn_node->outputIndex = {8}; + bn_node->primitive = std::make_unique(); + bn_node->primitive->value.type = schema::PrimitiveType_FusedBatchNorm; + auto prim2 = new schema::FusedBatchNormT; + bn_node->primitive->value.value = prim2; + bn_node->name = "bn"; + meta_graph->nodes.emplace_back(std::move(bn_node)); + + // input 0: data + auto input0 = std::make_unique(); + input0->nodeType = schema::NodeType::NodeType_ValueNode; + input0->format = schema::Format_NHWC; + input0->dataType = TypeId::kNumberTypeFloat32; + input0->dims = {1, 5, 5, 3}; + input0->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input0)); + + + // input 1: conv_bias + auto input11 = std::make_unique(); + input11->nodeType = schema::NodeType::NodeType_ValueNode; + input11->format = schema::Format_KHWC; + input11->dataType = TypeId::kNumberTypeFloat32; + input11->dims = {8, 3, 3, 3}; + input11->data.resize(sizeof(float) * 8 * 3 * 3 * 3); + meta_graph->allTensors.emplace_back(std::move(input11)); + + // input 1: weight + auto input1 = std::make_unique(); + input1->nodeType = schema::NodeType::NodeType_ValueNode; + input1->format = schema::Format_KHWC; + input1->dataType = TypeId::kNumberTypeFloat32; + input1->dims = {8, 3, 3, 3}; + input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3); + meta_graph->allTensors.emplace_back(std::move(input1)); + + // conv output + auto conv_output = std::make_unique(); + conv_output->nodeType = schema::NodeType::NodeType_Parameter; + conv_output->format = schema::Format_NHWC; + conv_output->dataType = TypeId::kNumberTypeFloat32; + conv_output->dims = {1, 5, 5, 8}; + meta_graph->allTensors.emplace_back(std::move(conv_output)); + + // tflite bn : scale + auto input2 = std::make_unique(); + input2->nodeType = schema::NodeType::NodeType_ValueNode; + input2->format = schema::Format_NHWC; + input2->dataType = TypeId::kNumberTypeFloat32; + input2->dims = {1, 5, 5, 8}; + input2->data.resize(sizeof(float) * 8 * 5 * 5); + meta_graph->allTensors.emplace_back(std::move(input2)); + + // tflite bn : bias + auto input3 = std::make_unique(); + input3->nodeType = schema::NodeType::NodeType_ValueNode; + input3->format = schema::Format_NHWC; + input3->dataType = TypeId::kNumberTypeFloat32; + input3->dims = {1, 5, 5, 8}; + input3->data.resize(sizeof(float) * 8 * 5 * 5); + meta_graph->allTensors.emplace_back(std::move(input3)); + + // tflite bn : mean + auto input4 = std::make_unique(); + input4->nodeType = schema::NodeType::NodeType_ValueNode; + input4->format = schema::Format_NHWC; + input4->dataType = TypeId::kNumberTypeFloat32; + input4->dims = {1, 5, 5, 8}; + input4->data.resize(sizeof(float) * 8 * 5 * 5); + meta_graph->allTensors.emplace_back(std::move(input4)); + + // tflite bn : var + auto input5 = std::make_unique(); + input5->nodeType = schema::NodeType::NodeType_ValueNode; + input5->format = schema::Format_NHWC; + input5->dataType = TypeId::kNumberTypeFloat32; + input5->dims = {1, 5, 5, 8}; + input5->data.resize(sizeof(float) * 8 * 5 * 5); + meta_graph->allTensors.emplace_back(std::move(input5)); + + // final output + auto output = std::make_unique(); + output->nodeType = schema::NodeType::NodeType_Parameter; + output->format = schema::Format_NHWC; + output->dataType = TypeId::kNumberTypeFloat32; + output->dims = {1, 5, 5, 8}; + meta_graph->allTensors.emplace_back(std::move(output)); + meta_graph->inputIndex = {0}; + meta_graph->outputIndex = {8}; + return meta_graph; +} +} // namespace +TEST_F(ConvBNFusionTest, TestConvAddNode) { + auto meta_graph = BuildCaffeGraph(schema::PrimitiveType_Conv2D); + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto anf_transform = new lite::AnfTransform(); + auto new_graph = anf_transform->Transform(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 1); + for (auto &cnode : new_meta_graph->nodes) { + ASSERT_EQ(cnode->primitive->value.AsConv2D()->hasBias, true); + } +} + +TEST_F(ConvBNFusionTest, TestDeptiwiseConvAddNode) { + auto meta_graph = BuildTFGraph(schema::PrimitiveType_DepthwiseConv2D); + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto anf_transform = new lite::AnfTransform(); + auto new_graph = anf_transform->Transform(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 1); + for (auto &cnode : new_meta_graph->nodes) { + ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->hasBias, true); + } +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/gllo/fusion/conv_scale_fusion_test.cc b/mindspore/lite/test/ut/src/gllo/fusion/conv_scale_fusion_test.cc new file mode 100644 index 00000000000..1eca3a469cf --- /dev/null +++ b/mindspore/lite/test/ut/src/gllo/fusion/conv_scale_fusion_test.cc @@ -0,0 +1,221 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "schema/inner/model_generated.h" +#include "include/model.h" +#include "common/common_test.h" +#include "include/lite_session.h" +#include "include/context.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "tools/converter/model_parser.h" +#include "tools/converter/anf_transform.h" +#include "src/common/anf_exporter/anf_exporter.h" + +namespace mindspore { +class ConvScaleFusionTest : public mindspore::Common { + public: + ConvScaleFusionTest() = default; +}; +using MetaGraphTptr = std::shared_ptr; +using CNodeTptr = std::unique_ptr; + +namespace { +// conv has 2 inputs +CNodeTptr BuildConv2D(int with_bias_flag) { + auto convNode = std::make_unique(); + if (with_bias_flag) { + convNode->inputIndex = {0, 1, 2}; + convNode->outputIndex = {3}; + } else { + convNode->inputIndex = {0, 1}; + convNode->outputIndex = {2}; + } + convNode->primitive = std::make_unique(); + convNode->primitive->value.type = schema::PrimitiveType_Conv2D; + auto prim1 = new schema::Conv2DT; + prim1->padMode = schema::PadMode_SAME; + prim1->format = schema::Format_NHWC; + prim1->strideH = 1; + prim1->strideW = 1; + prim1->kernelH = 3; + prim1->kernelW = 3; + prim1->dilateH = 1; + prim1->dilateW = 1; + prim1->channelOut = 3; + convNode->primitive->value.value = prim1; + convNode->name = "Conv2D"; + return convNode; +} +// conv2d has 3 inputs +CNodeTptr BuildDepthwiseConv2D(int with_bias_flag) { + auto convNode = std::make_unique(); + if (with_bias_flag) { + convNode->inputIndex = {0, 1, 2}; + convNode->outputIndex = {3}; + } else { + convNode->inputIndex = {0, 1}; + convNode->outputIndex = {2}; + } + convNode->primitive = std::make_unique(); + convNode->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + auto prim1 = new schema::DepthwiseConv2DT; + prim1->padMode = schema::PadMode_SAME; + prim1->format = schema::Format_NHWC; + prim1->strideH = 1; + prim1->strideW = 1; + prim1->kernelH = 3; + prim1->kernelW = 3; + prim1->dilateH = 1; + prim1->dilateW = 1; + prim1->channelIn = 1; + prim1->channelMultiplier = 3; + + convNode->primitive->value.value = prim1; + convNode->name = "Conv2D"; + return convNode; +} + +MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type, bool conv_with_bias) { + auto meta_graph = std::make_shared(); + meta_graph->name = "graph"; + // conv node + CNodeTptr convNode; + if (conv_type == schema::PrimitiveType_Conv2D) { + convNode = BuildConv2D(conv_with_bias); + } else { + convNode = BuildDepthwiseConv2D(conv_with_bias); + } + + meta_graph->nodes.emplace_back(std::move(convNode)); + + // scale_node weight bias + auto scale_node = std::make_unique(); + if (conv_with_bias) { + scale_node->inputIndex = {3, 4, 5}; + scale_node->outputIndex = {6}; + } else { + scale_node->inputIndex = {2, 3, 4}; + scale_node->outputIndex = {5}; + } + + scale_node->primitive = std::make_unique(); + scale_node->primitive->value.type = schema::PrimitiveType_Scale; + auto prim2 = new schema::ScaleT; + scale_node->primitive->value.value = prim2; + scale_node->name = "scale"; + meta_graph->nodes.emplace_back(std::move(scale_node)); + + // input 0: data + auto input0 = std::make_unique(); + input0->nodeType = schema::NodeType::NodeType_ValueNode; + input0->format = schema::Format_NHWC; + input0->dataType = TypeId::kNumberTypeFloat32; + input0->dims = {1, 5, 5, 3}; + input0->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input0)); + + // input 1: weight + auto input1 = std::make_unique(); + input1->nodeType = schema::NodeType::NodeType_ValueNode; + input1->format = schema::Format_KHWC; + input1->dataType = TypeId::kNumberTypeFloat32; + input1->dims = {8, 3, 3, 3}; + input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3); + meta_graph->allTensors.emplace_back(std::move(input1)); + + if (conv_with_bias) { + // input 00: bias + auto input00 = std::make_unique(); + input00->nodeType = schema::NodeType::NodeType_ValueNode; + input00->format = schema::Format_NHWC; + input00->dataType = TypeId::kNumberTypeFloat32; + input00->dims = {1, 5, 5, 3}; + input00->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input00)); + } + + // conv output + auto conv_output = std::make_unique(); + conv_output->nodeType = schema::NodeType::NodeType_Parameter; + conv_output->format = schema::Format_NHWC; + conv_output->dataType = TypeId::kNumberTypeFloat32; + conv_output->dims = {1, 5, 5, 8}; + meta_graph->allTensors.emplace_back(std::move(conv_output)); + + // scale weight input + auto input2 = std::make_unique(); + input2->nodeType = schema::NodeType::NodeType_ValueNode; + input2->format = schema::Format_NHWC; + input2->dataType = TypeId::kNumberTypeFloat32; + input2->dims = {1, 5, 5, 8}; + input2->data.resize(sizeof(float) * 8 * 5 * 5); + meta_graph->allTensors.emplace_back(std::move(input2)); + + // scale bias input + auto input3 = std::make_unique(); + input3->nodeType = schema::NodeType::NodeType_ValueNode; + input3->format = schema::Format_NHWC; + input3->dataType = TypeId::kNumberTypeFloat32; + input3->dims = {1, 5, 5, 8}; + input3->data.resize(sizeof(float) * 8 * 5 * 5); + meta_graph->allTensors.emplace_back(std::move(input3)); + + // final scale output + auto output = std::make_unique(); + output->nodeType = schema::NodeType::NodeType_Parameter; + output->format = schema::Format_NHWC; + output->dataType = TypeId::kNumberTypeFloat32; + output->dims = {1, 5, 5, 8}; + meta_graph->allTensors.emplace_back(std::move(output)); + if (conv_with_bias) { + meta_graph->inputIndex = {0}; + meta_graph->outputIndex = {6}; + } else { + meta_graph->inputIndex = {0}; + meta_graph->outputIndex = {5}; + } + return meta_graph; +} +} // namespace +TEST_F(ConvScaleFusionTest, TestConvScaleNode) { + auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2D, true); + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto anf_transform = new lite::AnfTransform(); + auto new_graph = anf_transform->Transform(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 1); + for (auto &cnode : new_meta_graph->nodes) { + ASSERT_EQ(cnode->primitive->value.AsConv2D()->hasBias, true); + } +} + +TEST_F(ConvScaleFusionTest, TestDeptiwiseConvScaleNode) { + auto meta_graph = BuildGraph(schema::PrimitiveType_DepthwiseConv2D, false); + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto anf_transform = new lite::AnfTransform(); + auto new_graph = anf_transform->Transform(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 1); + for (auto &cnode : new_meta_graph->nodes) { + ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->hasBias, true); + ASSERT_EQ(cnode->inputIndex.size(), 3); + } +} +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 9dec33473a4..a40fcb1a18b 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -49,6 +49,8 @@ set(ANF_SRC ${CCSRC_DIR}/pybind_api/export_flags.cc ${CCSRC_DIR}/utils/context/context_extends.cc ${CCSRC_DIR}/frontend/parallel/costmodel_context.cc + ${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc + ${CCSRC_DIR}/backend/optimizer/common/visit.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/graph_utils_extends.cc ) @@ -75,9 +77,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/node_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/optimizer.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/pass_manager.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/pattern_engine.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/visit.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/gllo_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_biasadd_fusion.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_activation_fusion.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_transform_fusion.cc diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index c816e567987..87ab8a9910f 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -90,7 +90,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { return nullptr; } -// auto newGraph = anfTransform->Transform(graph); + graph = anfTransform->Transform(graph); CreateQuantizer(graph, flag); if (mQuantizer != nullptr) { diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index a4e445737fc..0ff3814ed37 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -100,20 +100,20 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { // } // fusion - { - Optimizer fusionOptimizer; - fusionOptimizer.AddPass(new (std::nothrow) ConvBiasAddFusionPass()); - fusionOptimizer.AddPass(new (std::nothrow) ConvBNFusionPass()); - fusionOptimizer.AddPass(new (std::nothrow) ConvScaleFusionPass()); - fusionOptimizer.AddPass(new (std::nothrow) ConvReluFusionPass()); - fusionOptimizer.AddPass(new (std::nothrow) ConvRelu6FusionPass()); - fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); - status = fusionOptimizer.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; - return status; - } - } + // { + // Optimizer fusionOptimizer; + // fusionOptimizer.AddPass(new (std::nothrow) ConvBiasAddFusionPass()); + // fusionOptimizer.AddPass(new (std::nothrow) ConvBNFusionPass()); + // fusionOptimizer.AddPass(new (std::nothrow) ConvScaleFusionPass()); + // fusionOptimizer.AddPass(new (std::nothrow) ConvReluFusionPass()); + // fusionOptimizer.AddPass(new (std::nothrow) ConvRelu6FusionPass()); + // fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + // status = fusionOptimizer.Run(graphDefT); + // if (status != RET_OK && status != RET_NO_CHANGE) { + // MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; + // return status; + // } + // } // weight format trans if (ctx.formatTrans) {