From e7c6b7e66add6ae1f7bf2373ebbe826f4fd1b599 Mon Sep 17 00:00:00 2001 From: BowenK Date: Wed, 19 Aug 2020 15:00:41 +0800 Subject: [PATCH] Add NewParameter and Imm patterns --- .../ccsrc/frontend/optimizer/pass_group.cc | 9 +- .../ccsrc/frontend/optimizer/pass_group.h | 6 +- mindspore/ccsrc/frontend/optimizer/pattern.cc | 4 + mindspore/ccsrc/frontend/optimizer/pattern.h | 77 ++++- mindspore/ccsrc/frontend/optimizer/py_pass.cc | 279 +++++++++++++----- mindspore/ccsrc/frontend/optimizer/py_pass.h | 12 +- .../frontend/optimizer/py_pass_manager.cc | 34 ++- .../frontend/optimizer/py_pass_manager.h | 13 +- mindspore/ccsrc/pipeline/jit/action.cc | 14 + mindspore/common/python_pass_register.py | 81 ----- mindspore/graph_utils/__init__.py | 15 + .../{common => graph_utils}/graph_pattern.py | 104 +++++-- mindspore/graph_utils/python_pass/__init__.py | 24 ++ .../python_pass/python_pass_register.py | 170 +++++++++++ mindspore/ops/primitive.py | 2 +- tests/ut/python/optimizer/test_python_pass.py | 162 +++++++++- 16 files changed, 796 insertions(+), 210 deletions(-) delete mode 100644 mindspore/common/python_pass_register.py create mode 100644 mindspore/graph_utils/__init__.py rename mindspore/{common => graph_utils}/graph_pattern.py (53%) create mode 100644 mindspore/graph_utils/python_pass/__init__.py create mode 100644 mindspore/graph_utils/python_pass/python_pass_register.py diff --git a/mindspore/ccsrc/frontend/optimizer/pass_group.cc b/mindspore/ccsrc/frontend/optimizer/pass_group.cc index 7f7544c41e..56ff81475e 100644 --- a/mindspore/ccsrc/frontend/optimizer/pass_group.cc +++ b/mindspore/ccsrc/frontend/optimizer/pass_group.cc @@ -14,6 +14,7 @@ * limitations under the License. */ #include "frontend/optimizer/pass_group.h" +#include "frontend/optimizer/py_pass_manager.h" namespace mindspore { namespace opt { @@ -35,14 +36,15 @@ bool PassGroup::DeletePass(const std::string &pass_name) { return false; } -bool PassGroup::Run(const FuncGraphPtr &func_graph, const std::vector &passes) const { +bool PassGroup::Run(const FuncGraphPtr &func_graph, const std::vector &passes, + const MatchResultPtr &res) const { if (func_graph == nullptr) { return false; } bool changed = false; for (const auto &pass : passes) { if (pass != nullptr) { - if (pass->Run(func_graph)) { + if (pass->Run(func_graph, res)) { changed = true; } } @@ -54,8 +56,9 @@ bool PassGroup::Run(const FuncGraphPtr &func_graph) const { bool changed = false; // run all passes bool change = true; + auto res = PyPassManager::GetInstance()->GetMatchResult(); while (change) { - change = Run(func_graph, passes_); + change = Run(func_graph, passes_, res); changed = change || changed; if (run_only_once_) { break; diff --git a/mindspore/ccsrc/frontend/optimizer/pass_group.h b/mindspore/ccsrc/frontend/optimizer/pass_group.h index 22b17b81b1..20cf66649f 100644 --- a/mindspore/ccsrc/frontend/optimizer/pass_group.h +++ b/mindspore/ccsrc/frontend/optimizer/pass_group.h @@ -41,12 +41,14 @@ class PassGroup { // @return false, graph not changed bool Run(const FuncGraphPtr &func_graph) const; // Run the given graph passes on the input graph - // @param [inout] graph The graph to be optimized + // @param [inout] func_graph The graph to be optimized // @param [in] passes The given graph passes + // @param [inout] res MatchResult used to collect all matched patterns and nodes // @return true, graph changed // @return false, graph not changed - bool Run(const FuncGraphPtr &func_graph, const std::vector &passes) const; + bool Run(const FuncGraphPtr &func_graph, const std::vector &passes, const MatchResultPtr &res) const; std::string name() const { return name_; } + void SetRunOnlyOnce(bool run_only_once) { run_only_once_ = run_only_once; } private: const std::string name_; diff --git a/mindspore/ccsrc/frontend/optimizer/pattern.cc b/mindspore/ccsrc/frontend/optimizer/pattern.cc index 9011cde7c8..9b323640ab 100644 --- a/mindspore/ccsrc/frontend/optimizer/pattern.cc +++ b/mindspore/ccsrc/frontend/optimizer/pattern.cc @@ -96,6 +96,7 @@ MatchResultPtr IsIn::match(const AnfNodePtr &node) { for (auto &iter : patterns_) { auto res = iter->match(node); if (res != nullptr) { + res->add_entry(shared_from_base(), node); return res; } } @@ -151,6 +152,9 @@ REGISTER_PYBIND_DEFINE( (void)py::class_, Pattern>(*m, "AnyPattern").def(py::init<>()); (void)py::class_, Pattern>(*m, "NewTensor_") .def(py::init()); + (void)py::class_, Pattern>(*m, "NewParameter_") + .def(py::init()); + (void)py::class_, Pattern>(*m, "Imm").def(py::init()); })); } // namespace python_pass } // namespace opt diff --git a/mindspore/ccsrc/frontend/optimizer/pattern.h b/mindspore/ccsrc/frontend/optimizer/pattern.h index 4dd94bfeba..fb15a626e5 100644 --- a/mindspore/ccsrc/frontend/optimizer/pattern.h +++ b/mindspore/ccsrc/frontend/optimizer/pattern.h @@ -42,6 +42,10 @@ class CallWith; using CallWithPtr = std::shared_ptr; class NewTensor; using NewTensorPtr = std::shared_ptr; +class NewParameter; +using NewParameterPtr = std::shared_ptr; +class Imm; +using ImmPtr = std::shared_ptr; struct PatternHasher; struct PatternEqual; using PatternNodeMap = std::unordered_map; @@ -55,6 +59,7 @@ class Pattern : public Base { string unique_name() const { return unique_name_; } vector inputs() { return inputs_; } bool should_replace() { return should_replace_; } + void set_should_replace(bool should_replace) { should_replace_ = should_replace; } virtual void reset() {} protected: @@ -86,14 +91,14 @@ class IsPrimTypeOf : public Pattern { ~IsPrimTypeOf() = default; IsPrimTypeOf(vector prims, string name, bool should_replace) : primitives_(prims), name_(name), matched_prim_(nullptr) { - unique_name_ = std::to_string(g_id_++) + "_" + name; + unique_name_ = std::to_string(g_id_++) + "IsPrimTypeOf_" + name; should_replace_ = should_replace; if (!should_replace) { matched_prim_ = prims[0]; } } IsPrimTypeOf(vector types, string name, bool should_replace) : types_(types), name_(name) { - unique_name_ = std::to_string(g_id_++) + "_" + name; + unique_name_ = std::to_string(g_id_++) + "IsPrimTypeOf_" + name; // Make primitives_ for (auto &iter : types) { primitives_.push_back(std::make_shared(iter, py::cast(nullptr))); @@ -126,19 +131,20 @@ class CallWith : public Pattern { CallWith(PatternPtr prim_pattern, vector inputs, bool should_replace) { // NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting prim_pattern_ = prim_pattern; - unique_name_ = std::to_string(g_id_++) + prim_pattern->unique_name(); + unique_name_ = std::to_string(g_id_++) + "CallWithPattern_" + prim_pattern->unique_name(); inputs_ = inputs; - should_replace_ = should_replace; + // NOTE: should_replace_ is overrided by it prim_pattern(if exists) silently. + should_replace_ = prim_pattern->should_replace(); } CallWith(PrimitivePyPtr prim, vector inputs, bool should_replace) { prim_ = prim; - unique_name_ = std::to_string(g_id_++) + prim_->ToString(); + unique_name_ = std::to_string(g_id_++) + "CallWithPrim_" + prim_->ToString(); inputs_ = inputs; should_replace_ = should_replace; } CallWith(string prim_str, vector inputs, bool should_replace) { prim_ = std::make_shared(prim_str, py::cast(nullptr)); - unique_name_ = std::to_string(g_id_++) + prim_->ToString(); + unique_name_ = std::to_string(g_id_++) + "CallWithStr_" + prim_->ToString(); inputs_ = inputs; should_replace_ = should_replace; } @@ -159,7 +165,7 @@ class IsIn : public Pattern { IsIn() { unique_name_ = std::to_string(g_id_++); } ~IsIn() = default; explicit IsIn(vector patterns) : patterns_(patterns) { - unique_name_ = std::to_string(g_id_++); + unique_name_ = std::to_string(g_id_++) + "IsIn"; for (auto &iter : patterns) { unique_name_ = unique_name_ + "_" + iter->unique_name(); } @@ -176,9 +182,9 @@ class IsNot : public Pattern { IsNot() { unique_name_ = std::to_string(g_id_++); } ~IsNot() = default; explicit IsNot(vector patterns) : patterns_(patterns) { - unique_name_ = std::to_string(g_id_++); + unique_name_ = std::to_string(g_id_++) + "IsNot"; for (auto &iter : patterns) { - unique_name_ = "IsNot_" + unique_name_ + "_" + iter->unique_name(); + unique_name_ = unique_name_ + "_" + iter->unique_name(); } } MS_DECLARE_PARENT(IsNot, Pattern); @@ -200,7 +206,10 @@ class NewTensor : public Pattern { public: NewTensor() { unique_name_ = std::to_string(g_id_++); } ~NewTensor() = default; - explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) { should_replace_ = false; } + explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) { + should_replace_ = false; + unique_name_ = std::to_string(g_id_++) + "NewTensor"; + } MS_DECLARE_PARENT(NewTensor, Pattern); MatchResultPtr match(const AnfNodePtr &node) override { MS_LOG(EXCEPTION) << "Find NewTensor in pattern, NewTensor should only appear in the target.\n"; @@ -211,6 +220,54 @@ class NewTensor : public Pattern { tensor::TensorPtr input_tensor_; }; +class NewParameter : public Pattern { + public: + NewParameter() { unique_name_ = std::to_string(g_id_++); } + explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel, + bool should_replace) + : para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) { + should_replace_ = should_replace; + unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name; + // clone input tensor + default_tensor_ = std::make_shared(*default_tensor.get()); + built_ = false; + } + MS_DECLARE_PARENT(NewParameter, Pattern); + MatchResultPtr match(const AnfNodePtr &node) override { + MS_LOG(EXCEPTION) << "Find NewParameter in pattern, NewParameter should only appear in the target.\n"; + } + string para_name() { return para_name_; } + tensor::TensorPtr default_tensor() { return default_tensor_; } + bool requires_grad() { return requires_grad_; } + bool layerwise_parallel() { return layerwise_parallel_; } + bool built() { return built_; } + void set_built(bool built) { built_ = built; } + void reset() override { built_ = false; } + + private: + string para_name_; + bool requires_grad_; + bool layerwise_parallel_; + bool built_; + tensor::TensorPtr default_tensor_; +}; + +class Imm : public Pattern { + public: + Imm() { unique_name_ = std::to_string(g_id_++); } + explicit Imm(int value) : value_(value) { + should_replace_ = false; + unique_name_ = std::to_string(g_id_++) + "Imm_" + std::to_string(value); + } + MS_DECLARE_PARENT(Imm, Pattern); + // NOTE: Doesn't support Imm in src pattern currently. + MatchResultPtr match(const AnfNodePtr &node) override { return nullptr; } + int value() { return value_; } + + private: + int value_; +}; + class MatchResult { public: MatchResult() {} diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass.cc b/mindspore/ccsrc/frontend/optimizer/py_pass.cc index dc51842bc7..af27a9069f 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass.cc +++ b/mindspore/ccsrc/frontend/optimizer/py_pass.cc @@ -21,13 +21,26 @@ #include "ir/func_graph.h" #include "ir/manager.h" #include "pybind_api/ir/primitive_py.h" +#include "ir/scalar.h" +#include "ir/graph_utils.h" +#include "pipeline/jit/parse/parse_base.h" #include "pipeline/jit/resource.h" +#include "frontend/optimizer/py_pass_manager.h" +#include "utils/info.h" +#include "debug/anf_ir_dump.h" +#include "debug/draw.h" namespace mindspore { namespace opt { namespace python_pass { namespace internal { -AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res); +const char PARAMETER_MODULE[] = "mindspore.common.parameter"; +const char PARAMETER_CLASS[] = "Parameter"; +const char SET_PARAM[] = "__setattr__"; +AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph); +AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res); +void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input, + bool requires_grad, bool layerwise_parallel); std::string GetNodeRepr(AnfNodePtr node) { if (node != nullptr) { @@ -42,8 +55,10 @@ std::string GetNodeRepr(AnfNodePtr node) { repr += ")"; return repr; } - if (node->isa()) { - return GetValueNode(node)->ToString(); + if (node->isa()) { + return "[Parameter]" + node->ToString(); + } else if (node->isa()) { + return "[Value]" + GetValueNode(node)->ToString(); } return node->ToString(); } @@ -82,7 +97,7 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res) return std::make_shared(input_tensor); } -AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res) { +AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg) { auto call_with_pattern = pattern->cast(); MS_EXCEPTION_IF_NULL(call_with_pattern); auto prim = call_with_pattern->prim_value(); @@ -91,15 +106,70 @@ AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultP } auto prim_pattern = call_with_pattern->prim_pattern(); MS_EXCEPTION_IF_NULL(prim_pattern); - return ProcessSinglePattern(prim_pattern, res); + return ProcessSinglePattern(prim_pattern, res, fg); } -AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res) { +AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) { + auto new_para_pattern = pattern->cast(); + MS_EXCEPTION_IF_NULL(new_para_pattern); + if (!new_para_pattern->built()) { + static int parameter_id = 0; + auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name() + std::to_string(parameter_id++); + auto para_node = std::make_shared(func_graph); + MS_EXCEPTION_IF_NULL(para_node); + para_node->set_name(para_name); + // Set function graph + para_node->set_func_graph(func_graph); + // Set Debug Info + auto debug_info = std::make_shared(para_name); + para_node->set_debug_info(debug_info); + // Set abstract + auto default_value = new_para_pattern->default_tensor(); + MS_EXCEPTION_IF_NULL(default_value); + para_node->set_abstract(default_value->ToAbstract()->Broaden()); + res->add_entry(pattern, para_node); + func_graph->add_parameter(para_node); + // Reflect back to Cell._params + internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(), + new_para_pattern->layerwise_parallel()); + MS_LOG(WARNING) << "Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name(); + new_para_pattern->set_built(true); + return para_node; + } else { + // Built, fetch the node + auto para_node = res->get_node(pattern); + MS_EXCEPTION_IF_NULL(para_node); + return para_node; + } +} + +AnfNodePtr BuildImmNode(const PatternPtr &pattern, const MatchResultPtr &res) { + auto imm_pattern = pattern->cast(); + MS_EXCEPTION_IF_NULL(imm_pattern); + auto value = imm_pattern->value(); + auto scalar_value_ptr = std::make_shared(value); + return std::make_shared(scalar_value_ptr); +} + +AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) { if (pattern->should_replace()) { // Find replacement in the MatchResult auto target_node = res->get_node(pattern); if (target_node == nullptr) { - MS_LOG(EXCEPTION) << "Cannot find target node in pattern match result, pattern: " + pattern->unique_name() + "\n"; + // If it's base pattern(in contrast to complex pattern like CallWith/IsIn/IsNot), raise runtime exception. + if (pattern->isa() || pattern->isa() || pattern->isa()) { + MS_LOG(EXCEPTION) << "Cannot find target node, pattern: " + pattern->unique_name() + "\n"; + return nullptr; + } + // Try to build this pattern and add to MatchResult, since this pattern is defined inside target + auto new_node = BuildTarget(pattern, func_graph, res); + if (new_node == nullptr) { + MS_LOG(EXCEPTION) << "Try to build pattern node but FAILED. pattern: " + pattern->unique_name() + "\n"; + } + return new_node; + } + if (pattern->isa()) { + return target_node; } return target_node; } @@ -109,7 +179,19 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr } else if (pattern->isa()) { return BuildNewTensor(pattern, res); } else if (pattern->isa()) { - return BuildPrimitiveValueNode(pattern, res); + return BuildPrimitiveValueNode(pattern, res, func_graph); + } else if (pattern->isa()) { + return BuildNewParameter(pattern, res, func_graph); + } else if (pattern->isa()) { + return BuildImmNode(pattern, res); + } + return nullptr; +} + +AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res, + const FuncGraphPtr &func_graph) { + if (pattern->isa()) { + return BuildPrimitiveValueNode(pattern, res, func_graph); } return nullptr; } @@ -117,91 +199,154 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res) { auto target_inputs = pattern->inputs(); if (target_inputs.size() == 0) { - return ProcessSinglePattern(pattern, res); + auto new_node = ProcessSinglePattern(pattern, res, func_graph); + if (new_node != nullptr) { + res->add_entry(pattern, new_node); + } + return new_node; } // Build up the AnfNode in a recursive manner std::vector new_inputs; - auto prim_value_node = ProcessSinglePattern(pattern, res); + auto prim_value_node = ProcessComplexPatternFirstInput(pattern, res, func_graph); MS_EXCEPTION_IF_NULL(prim_value_node); new_inputs.push_back(prim_value_node); for (auto &iter : target_inputs) { if (iter == pattern) { - MS_LOG(EXCEPTION) << "Circle references: Pattern takes itself as input. Got pattern: " + pattern->unique_name() + - "\n"; + MS_LOG(EXCEPTION) << "Circle references. Got pattern: " + pattern->unique_name() + "\n"; } - new_inputs.push_back(BuildTarget(iter, func_graph, res)); + auto input_node = BuildTarget(iter, func_graph, res); + if (input_node == nullptr) { + MS_LOG(EXCEPTION) << "Failed to build input node for pattern : " + iter->unique_name() + "\n"; + } + new_inputs.push_back(input_node); } - return func_graph->NewCNode(new_inputs); + auto new_node = func_graph->NewCNode(new_inputs); + res->add_entry(pattern, new_node); + return new_node; } + +void DrawNode(string name, AnfNodePtr node) { + auto context_ptr = MsContext::GetInstance(); + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + auto new_func_graph = std::make_shared(); + new_func_graph->set_output(node, true); + if (save_graphs) { + auto ir_dump_path = save_graphs_path + "/" + name + ".ir"; + auto dot_dump_path = save_graphs_path + "/" + name + ".dot"; + DumpIR(ir_dump_path, new_func_graph); + draw::Draw(dot_dump_path, new_func_graph); + } +} + +void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input, + bool requires_grad, bool layerwise_parallel) { + // 1. Get current cell object + auto ppm = opt::python_pass::PyPassManager::GetInstance(); + auto resource = ppm->GetResource(); + py::object top_cell = resource->input(); + if (py::isinstance(top_cell)) { + MS_LOG(EXCEPTION) << "Failed to get top cell from resource."; + } + // 2. New a Parameter object with the above-specified args + py::object parameter_class = py::module::import(PARAMETER_MODULE).attr(PARAMETER_CLASS); + py::object new_parameter = parameter_class(default_input, param_name, requires_grad, layerwise_parallel); + // 3. Add the new python Parameter object to Cell's _params atttributes + top_cell.attr(SET_PARAM)(param_name, new_parameter); + // 4. Set default_param for param_node + ValuePtr param_value = nullptr; + bool converted = parse::ConvertData(new_parameter, ¶m_value, false); + if (!converted) { + MS_LOG(EXCEPTION) << "Failed to convert new parameter to ValuePtr."; + } + MS_EXCEPTION_IF_NULL(param); + auto param_node = param->cast(); + MS_EXCEPTION_IF_NULL(param_node); + param_node->set_default_param(param_value); +} + +void Reset(PatternPtr pattern) { + if (pattern->isa()) { + auto prim_pattern = pattern->cast(); + prim_pattern->reset(); + return; + } else if (pattern->isa()) { + auto new_param_pattern = pattern->cast(); + new_param_pattern->reset(); + return; + } else if (pattern->isa()) { + auto call_with_pattern = pattern->cast(); + for (auto sub_pattern : call_with_pattern->inputs()) { + Reset(sub_pattern); + } + return; + } + return; +} + } // namespace internal -AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(src_pattern_); - MS_EXCEPTION_IF_NULL(dst_pattern_); - auto res = src_pattern_->match(node); - if (res != nullptr) { - res->dump(); - MS_LOG(WARNING) << "Matched pattern: " + src_pattern_->unique_name(); +AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res) { + auto match_res = src_pattern_->match(node); + if (match_res != nullptr) { + MS_LOG(DEBUG) << "Matched pattern: " + src_pattern_->unique_name() + " node : " + internal::GetNodeRepr(node); + res->merge(match_res); auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res); - dst_pattern_->reset(); - MS_LOG(DEBUG) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n"; + internal::Reset(dst_pattern()); + MS_LOG(WARNING) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n"; return new_node; } - src_pattern_->reset(); + internal::Reset(src_pattern()); return nullptr; } -bool PythonPass::Run(const FuncGraphPtr &func_graph) { +bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res) { MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(dst_pattern_); + if (src_pattern_ == nullptr) { + // Add NewParameter + auto new_para_pattern = dst_pattern_->cast(); + if (new_para_pattern == nullptr) { + MS_LOG(EXCEPTION) << "Expect NewParameter pattern for target if src pattern is null."; + } + auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name(); + MS_LOG(DEBUG) << "Adding New parameter : " + para_name; + auto para_node = std::make_shared(func_graph); + MS_EXCEPTION_IF_NULL(para_node); + para_node->set_name(para_name); + // Set function graph + para_node->set_func_graph(func_graph); + // Set Debug Info + auto debug_info = std::make_shared(para_name); + para_node->set_debug_info(debug_info); + // Set abstract + auto default_value = new_para_pattern->default_tensor(); + MS_EXCEPTION_IF_NULL(default_value); + para_node->set_abstract(default_value->ToAbstract()->Broaden()); + res->add_entry(dst_pattern_, para_node); + func_graph->add_parameter(para_node); + // Reflect back to Cell._params + internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(), + new_para_pattern->layerwise_parallel()); + MS_LOG(WARNING) << "Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name(); + return true; + } FuncGraphManagerPtr manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); manager->AddFuncGraph(func_graph); - auto seen = NewSeenGeneration(); - // 1024 is for the initial capacity of deque - std::deque todo(1024); - todo.push_back(func_graph->output()); + auto graph_nodes_sorted = TopoSort(func_graph->output()); bool changes = false; - auto &all_nodes = manager->all_nodes(); - while (!todo.empty()) { - AnfNodePtr node = todo.front(); - todo.pop_front(); - // Check whether this node has been matched. - if (node == nullptr || node->seen_ == seen || !internal::IsTraversable(node) || !all_nodes.contains(node)) { - continue; - } - node->seen_ = seen; - // Select nodes that this transform can be applied. - AnfNodePtr new_node = Run(func_graph, node); - bool change = (new_node != nullptr); + // Traverse once + for (auto &node : graph_nodes_sorted) { + AnfNodePtr new_node = Run(func_graph, node, res); if (new_node != nullptr && new_node != node) { + internal::DrawNode(dst_pattern_->unique_name(), new_node); (void)manager->Replace(node, new_node); - } else if (new_node == nullptr) { - new_node = node; - } - if (run_only_once_) { - return change; - } - // Find success, and add them to todo list - if (IsValueNode(node)) { - todo.push_back(GetValueNode(node)->output()); - } - if (node->isa()) { - auto &inputs = node->cast()->inputs(); - (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); - } - auto &node_users = manager->node_users(); - if (change && node_users.find(node) != node_users.end()) { - for (auto &use : node_users[node]) { - auto use_node = use.first; - if (use_node == nullptr) { - continue; - } - todo.push_back(use_node); - if (use_node->seen_ == seen) { - use_node->seen_--; - } - } + changes = true; } } return changes; diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass.h b/mindspore/ccsrc/frontend/optimizer/py_pass.h index 022c16a686..6e693c0e40 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass.h +++ b/mindspore/ccsrc/frontend/optimizer/py_pass.h @@ -34,20 +34,20 @@ using NodeEquivPtr = std::shared_ptr; class PythonPass { public: - explicit PythonPass(const std::string &name, const PatternPtr &src, const PatternPtr &dst, bool run_only_once = false, - bool multigraph = true) - : src_pattern_(src), dst_pattern_(dst), name_(name), run_only_once_(run_only_once), multigraph_(multigraph) {} + explicit PythonPass(const std::string &name, const PatternPtr &src, const PatternPtr &dst, bool run_only_once = false) + : src_pattern_(src), dst_pattern_(dst), name_(name), run_only_once_(run_only_once) {} ~PythonPass() = default; - bool Run(const FuncGraphPtr &func_graph); + bool Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res); std::string name() const { return name_; } - AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node); + AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res); + PatternPtr src_pattern() { return src_pattern_; } + PatternPtr dst_pattern() { return dst_pattern_; } private: PatternPtr src_pattern_; PatternPtr dst_pattern_; const std::string name_; bool run_only_once_; - bool multigraph_ = true; }; using PythonPassPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc index ac80136f7e..cb5c1dba93 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc +++ b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc @@ -45,14 +45,19 @@ PyPassManagerPtr PyPassManager::GetInstance() { PyPassManager::PyPassManager() { phase_to_group_[Phase::RESOLVE] = std::make_shared(); phase_to_group_[Phase::OPT] = std::make_shared(); + res_ = std::make_shared(); } void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, - Phase phase, bool run_only_once, bool multigraph) { - auto cur_pm = GetPassGroup(phase); - MS_EXCEPTION_IF_NULL(cur_pm); - PythonPassPtr new_pass = std::make_shared(pass_name, pattern, target, run_only_once, multigraph); - cur_pm->AddPass(new_pass); + Phase phase, bool run_only_once) { + auto cur_pg = GetPassGroup(phase); + MS_EXCEPTION_IF_NULL(cur_pg); + cur_pg->SetRunOnlyOnce(run_only_once); + MS_EXCEPTION_IF_NULL(pattern); + MS_EXCEPTION_IF_NULL(target); + MS_EXCEPTION_IF_NULL(cur_pg); + PythonPassPtr new_pass = std::make_shared(pass_name, pattern, target, run_only_once); + cur_pg->AddPass(new_pass); } void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) { @@ -63,6 +68,21 @@ void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) { } } +void PyPassManager::GenNewParameter(const PatternPtr ¶meter) { + MS_EXCEPTION_IF_NULL(parameter); + // Add new parameter after resolve + // NOTE: Add NewParameter at early stage will cause CSE problems + auto cur_pg = GetPassGroup(Phase::OPT); + MS_EXCEPTION_IF_NULL(cur_pg); + cur_pg->SetRunOnlyOnce(true); + auto new_para_pattern = parameter->cast(); + MS_EXCEPTION_IF_NULL(new_para_pattern); + auto pass_name = new_para_pattern->para_name(); + parameter->set_should_replace(false); + auto new_pass = std::make_shared(pass_name, nullptr, parameter, true); + cur_pg->AddPass(new_pass); +} + void PyPassManager::ClearRes() { MS_LOG(INFO) << "Clear PyPassManager resources!"; global_instance = nullptr; @@ -75,7 +95,9 @@ REGISTER_PYBIND_DEFINE( (void)py::class_>(*m, "PyPassManager_") .def(py::init([]() { return PyPassManager::GetInstance(); })) .def("registe", &PyPassManager::Registe, "Registe python pass") - .def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass"); + .def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass") + .def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter") + .def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph"); })); } // namespace python_pass } // namespace opt diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h index 8c78cee163..38159437f3 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h +++ b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h @@ -27,7 +27,7 @@ #include "ir/graph_utils.h" #include "utils/ms_utils.h" -#include "pipeline/jit/parse/resolve.h" +#include "pipeline/jit/resource.h" #include "frontend/optimizer/pattern.h" #include "frontend/optimizer/py_pass.h" #include "frontend/optimizer/pass_group.h" @@ -53,12 +53,21 @@ class PyPassManager { static PyPassManagerPtr GetInstance(); virtual ~PyPassManager() = default; void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, - Phase phase = Phase::RESOLVE, bool run_only_once = false, bool multigraph = true); + Phase phase = Phase::RESOLVE, bool run_only_once = false); void Unregiste(const std::string &pass_name, Phase phase); + void GenNewParameter(const PatternPtr ¶meter); PassGroupPtr GetPassGroup(Phase phase); void ClearRes(); + MatchResultPtr GetMatchResult() { return res_; } + void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; } + bool ShouldRenorm() { return should_renorm_; } + void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; } + pipeline::ResourcePtr GetResource() { return resource_; } private: + bool should_renorm_ = true; + MatchResultPtr res_; + pipeline::ResourcePtr resource_; static std::unordered_map phase_to_group_; }; } // namespace python_pass diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 1c43170f0f..15222a284c 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -448,8 +448,21 @@ void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) { MS_EXCEPTION_IF_NULL(res->manager()); MS_EXCEPTION_IF_NULL(res->func_graph()); auto ppm = opt::python_pass::PyPassManager::GetInstance(); + ppm->SetResource(res); if (!ppm->GetPassGroup(phase)->Run(res->func_graph())) { MS_LOG(DEBUG) << "No match.\n"; + } else if (phase == opt::python_pass::Phase::OPT && opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) { + MS_LOG(DEBUG) << "Entered PyStub Renorm"; + // Renomalize + MS_EXCEPTION_IF_NULL(res->func_graph()); + FuncGraphPtr func_graph = res->func_graph(); + abstract::AbstractBasePtrList args_spec; + auto parameters = func_graph->parameters(); + (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), + [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); + FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); + res->set_func_graph(new_fg); + res->set_args_spec(args_spec); } } @@ -477,6 +490,7 @@ static std::vector CommonPipeline() { } // Add resolve-stage python pass stub actions.emplace_back(std::make_pair("py_resolve", ResolveActionPyStub)); + actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); // Evaluate type and shape, and specialize actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); diff --git a/mindspore/common/python_pass_register.py b/mindspore/common/python_pass_register.py deleted file mode 100644 index ee4f0f0bc8..0000000000 --- a/mindspore/common/python_pass_register.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Python pass register""" -from inspect import isfunction -from mindspore.common.graph_pattern import Pattern -from mindspore._c_expression import PyPassManager_ -from mindspore._c_expression import phase - -class PyPassManager(PyPassManager_): - r""" - Used to registe and unregiste python passes which can be used to alter graphs. - - Args: - pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt. - run_only_once (bool): Specify whether or not to run pass only once. Default: False. - multigraph (bool): Whether or not the pattern exists across graphs. Default: True. - - Raises: - TypeError: If argument has invalid type. - """ - def __init__(self, pipeline_phase=phase.opt, run_only_once=False, multi_graph=True): - if not isinstance(pipeline_phase, phase): - raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}") - if not isinstance(run_only_once, bool): - raise TypeError(f"Expecting bool, got : ({type(run_only_once)}){run_only_once}") - if not isinstance(multi_graph, bool): - raise TypeError(f"Expecting bool, got : ({type(multi_graph)}){multi_graph}") - PyPassManager_.__init__(self) - self.phase_ = pipeline_phase - self.run_only_once_ = run_only_once - self.multi_graph_ = multi_graph - - def registe(self, py_pass): - if not isfunction(py_pass): - raise TypeError(f"Expecting function pass, got : ({type(py_pass)}){py_pass}") - pattern, target = py_pass() - pass_name = py_pass.__name__ - if not isinstance(pattern, Pattern): - raise TypeError(f"Expecting pattern of Pattern type, got : ({type(pattern)}){pattern}") - if not isinstance(target, Pattern): - raise TypeError(f"Expecting target of Pattern type, got : ({type(target)}){target}") - super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_, self.multi_graph_) - - def unregiste(self, py_pass, pipeline_phase=phase.opt): - if not isinstance(pipeline_phase, phase): - raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}") - if isinstance(py_pass, str): - super().unregiste(py_pass, pipeline_phase) - return - if isfunction(py_pass): - super().unregiste(py_pass.__name__, pipeline_phase) - return - raise TypeError(f"Expecting py_pass to be string or function, got ({type(py_pass)}){py_pass}") - - def __call__(self, py_pass): - self.registe(py_pass) - return py_pass - -def registe_pass(pipeline_phase=phase.opt, run_only_once=False, multi_graph=True): - """ - Examples: - >>> @registe_pass() - >>> def toy_pass(): - >>> def pattern(): - >>> pass - >>> def target(): - >>> pass - """ - return PyPassManager(pipeline_phase, run_only_once, multi_graph) diff --git a/mindspore/graph_utils/__init__.py b/mindspore/graph_utils/__init__.py new file mode 100644 index 0000000000..e38ac9e3d0 --- /dev/null +++ b/mindspore/graph_utils/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Top-level reference to python pass.""" diff --git a/mindspore/common/graph_pattern.py b/mindspore/graph_utils/graph_pattern.py similarity index 53% rename from mindspore/common/graph_pattern.py rename to mindspore/graph_utils/graph_pattern.py index 487db572f6..08ece93534 100644 --- a/mindspore/common/graph_pattern.py +++ b/mindspore/graph_utils/graph_pattern.py @@ -15,7 +15,8 @@ """Patterns for describing graphs""" from mindspore.ops import Primitive from mindspore.common.tensor import Tensor -from mindspore._c_expression import Pattern, IsIn_, IsPrimTypeOf_, CallWith_, IsNot_, AnyPattern, NewTensor_ +from mindspore._c_expression import Pattern, IsIn_, IsPrimTypeOf_, CallWith_, IsNot_, AnyPattern, NewTensor_,\ + NewParameter_, Imm __all__ = [ "IsIn", @@ -24,17 +25,25 @@ __all__ = [ "IsNot", "AnyPattern", "NewTensor", + "NewParameter", + "Imm" ] class IsIn(IsIn_): - """ + r""" Express a pattern which allows a list of patterns. """ def __init__(self, patterns=None, should_replace=True): r""" Args: - patterns(list/tuple): list of allowed patterns + patterns(Union[tuple[:class:`mindspore.graph_utils.graph_pattern`], + list[:class:`mindspore.graph_utils.graph_pattern`]]): list of allowed patterns, + each element should be one of the exposed Pattern instance. should_replace(bool): added this for interface consistency. Should only set this in sub-patterns. + + Raises: + ValueError: raise if should_replace is False + TypeError: raise type error for invalid inputs. """ if not should_replace: raise ValueError("IsIn pattern does not have its own should_replace attribute. Set should_replace in \ @@ -52,19 +61,28 @@ class IsIn(IsIn_): class IsPrimTypeOf(IsPrimTypeOf_): r""" Express a pattern of certain primitive type(s). - NOTE: This pattern will match and only match the primitive value node. If matching primitive CNode is needed, - please refer to CallWith pattern. + + NOTE: + This pattern will match and only match the primitive value node. If matching primitive CNode is needed, + please refer to CallWith pattern. """ def __init__(self, types, name=None, should_replace=True): r""" Args: - types (str/(list/tuple of Primitives)): Specify allowed types. + types (Union[str, :class:`mindspore.ops.Primitive`, list[:class:`mindspore.ops.Primitive`], + tuple[:class:`mindspore.ops.Primitive`]): + Specify allowed types. If it is a string, the form could be 1) a single primitive type, e.g. 'Conv2D' 2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D' - It can also be a list of Primitives, e.g. [ops.Conv2D(1, 6)] - name (str): name of the pattern, optional - should_replace + It can also be a Primitive or a list/tuple of Primitives, e.g. [ops.Conv2D(1, 6)] + name (str): name of the pattern, optional. Default: None. + should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is + used when building the replacement target node. Use captured node if True, build from scratch otherwise. + Default: True. + + Raises: + TypeError: raise type error for invalid argument. """ if name is not None and not isinstance(name, str): raise TypeError(f"Expect string, got : {name}") @@ -91,12 +109,21 @@ class CallWith(CallWith_): r""" Express a primitive CNode. """ - def __init__(self, prim_pattern, inputs=None, should_replace=False): + def __init__(self, prim_pattern, inputs=None, should_replace=True): r""" Args: - prim_pattern (Pattern/Primitive/str): Primitive ValueNode in the Primitive CNode. - inputs (list/tuple): Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; - if specified, input patterns should be of right order. + prim_pattern (Union[str, :class:`mindspore.graph_utils.graph_pattern.IsPrimTypeOf`, + :class:`mindspore.ops.Primitive`]): Primitive ValueNode in the Primitive CNode. + inputs (Union[list[:class:`mindspore.graph_utils.graph_pattern`], + tuple[:class:`mindspore.graph_utils.graph_pattern`]]): + Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; if specified, input + patterns should be of right order and each element should be one of the exposed Pattern instance. + should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is + used when building the replacement target node. Use captured node if True, build from scratch otherwise. + Default: True. + + Raises: + TypeError: raise type error for invalid argument. """ if not isinstance(prim_pattern, (Pattern, str, Primitive)): raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}") @@ -110,17 +137,23 @@ class CallWith(CallWith_): raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}") CallWith_.__init__(self, self.prim_pattern, self.inputs, should_replace) - class IsNot(IsNot_): r""" Express a pattern which forbids a list of patterns. - NOTE: IsNot pattern should not be the root pattern. + + NOTE: + IsNot pattern should not be the root pattern. """ def __init__(self, patterns=None, should_replace=True): r""" Args: - patterns(list/tuple): list of forbiden patterns + patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbiden patterns, each element + should be one of the exposed Pattern instance. should_replace(bool): added this for interface consistency. Should only set this in sub-patterns. + + Raises: + ValueError: raise if should_replace is False. + TypeError: raise type error for invalid argument. """ if not should_replace: raise ValueError("IsNot pattern does not have its own should_replace attribute. Set should_replace in \ @@ -142,13 +175,48 @@ class NewTensor(NewTensor_): def __init__(self, input_tensor, should_replace=False): r""" Args: - input_tensor(Tensor): new tensor to be used in the target + input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target should_replace(bool): added this for interface consistency. NewTensor should only appear in the target. + + Raises: + ValueError: raise if should_replace is True + TypeError: raise type error for invalid argument. """ if should_replace: - raise ValueError("NewTensor should only appear in the target, thus should_replace can onlyu be False.") + raise ValueError("NewTensor should only appear in the target, thus should_replace can only be False.") self.input_tensor = input_tensor if isinstance(input_tensor, Tensor): NewTensor_.__init__(self, input_tensor) else: raise TypeError(f"Expect input_tensor to be a Tensor, got : {input_tensor}") + +class NewParameter(NewParameter_): + r""" + New Parameter to be used in the target. + """ + def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False, should_replace=False): + r""" + Args: + para_name(str): name for the new Parameter + default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter + requires_grad(bool): True if the parameter requires gradient. Default: True + layerwise_parallel(bool): switch for layerwise parallel mode. Default: False + should_replace(bool): gen new parameter once and replace after if set to be true; otherwise build a new + parameter everytime a pass target got built. Default: False + + Raises: + TypeError: raise type error for invalid argument. + """ + self.para_name = para_name + self.default_tensor = default_tensor + self.requires_grad = requires_grad + self.layerwise_parallel = layerwise_parallel + self.should_replace = should_replace + if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\ + isinstance(layerwise_parallel, bool) and isinstance(should_replace, bool): + NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad, + self.layerwise_parallel, self.should_replace) + else: + raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \ + layerwise_parallel(bool) should_replace(bool), got : {para_name}, {default_tensor}, \ + {requires_grad}, {layerwise_parallel}, {should_replace}") diff --git a/mindspore/graph_utils/python_pass/__init__.py b/mindspore/graph_utils/python_pass/__init__.py new file mode 100644 index 0000000000..26a8216295 --- /dev/null +++ b/mindspore/graph_utils/python_pass/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Top-level reference to python pass.""" +from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm + +__all__ = [ + "registe_pass", + "unregiste_pass", + "gen_new_parameter", + "cancel_new_parameter", + "set_renorm" + ] diff --git a/mindspore/graph_utils/python_pass/python_pass_register.py b/mindspore/graph_utils/python_pass/python_pass_register.py new file mode 100644 index 0000000000..55d70b6581 --- /dev/null +++ b/mindspore/graph_utils/python_pass/python_pass_register.py @@ -0,0 +1,170 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Python pass register""" +from inspect import isfunction +from mindspore.graph_utils.graph_pattern import Pattern, NewParameter +from mindspore._c_expression import PyPassManager_, phase + + +__all__ = [ + "registe_pass", + "unregiste_pass", + "gen_new_parameter", + "cancel_new_parameter", + "set_renorm" +] +class PyPassManager(PyPassManager_): + r""" + Used to registe and unregiste python passes which can be used to alter graphs. + + Args: + pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt. + run_only_once (bool): Specify whether or not to run pass only once. Default: False. + multigraph (bool): Whether or not the pattern exists across graphs. Default: True. + + Raises: + TypeError: If argument has invalid type. + """ + def __init__(self, pipeline_phase=phase.opt, run_only_once=False): + if not isinstance(pipeline_phase, phase): + raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}") + if not isinstance(run_only_once, bool): + raise TypeError(f"Expect bool, got : ({type(run_only_once)}){run_only_once}") + PyPassManager_.__init__(self) + self.phase_ = pipeline_phase + self.run_only_once_ = run_only_once + + def registe(self, py_pass): + if not isfunction(py_pass): + raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}") + pattern, target = py_pass() + pass_name = py_pass.__name__ + if not isinstance(pattern, Pattern): + raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}") + if not isinstance(target, Pattern): + raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}") + super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_) + + def unregiste(self, py_pass, pipeline_phase=phase.opt): + if not isinstance(pipeline_phase, phase): + raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}") + if isinstance(py_pass, str): + super().unregiste(py_pass, pipeline_phase) + return + if isfunction(py_pass): + super().unregiste(py_pass.__name__, pipeline_phase) + return + raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}") + + def __call__(self, py_pass): + self.registe(py_pass) + return py_pass + + def gen_new_parameter(self, pattern): + if not isinstance(pattern, NewParameter): + raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}") + super().gen_new_parameter(pattern) + + def set_renorm(self, should_renorm): + if not isinstance(should_renorm, bool): + raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}") + super().set_renorm(should_renorm) + +def registe_pass(pipeline_phase=phase.opt, run_only_once=False): + """ + Registe python pass to specified pipeline phase which would be used in compilation. + + Args: + pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is + registed. Support phase.resolve and phase.opt. Default: phase.opt. + run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False. + + Returns: + This function should be used as a decorator, return the decoratorated pass function. + + Examples: + >>> from mindspore.graph_utils.graph_pattern import IsPrimTypeOf + >>> @registe_pass() + >>> def toy_pass(): + >>> pattern = IsPrimTypeOf("ReLU") + >>> target = IsPrimTypeOf("ReLU6") + >>> return pattern, target + """ + return PyPassManager(pipeline_phase, run_only_once) + +def unregiste_pass(py_pass, pipeline_phase=phase.opt): + """ + Unregiste python pass. + + Args: + py_pass(Union(str, function)): target python pass to unregiste. + pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is + unregisted. Support phase.resolve and phase.opt. Default: phase.opt. + """ + ppm = PyPassManager() + ppm.unregiste(py_pass, pipeline_phase) + +def gen_new_parameter(pattern): + """ + Generate specified parameter every time a network gets compiled. + + NOTE: + In this way, every pass uses this pattern would be using the same Parameter. If use NewParameter without + gen_new_parameter, every pass match would build a new Parameter. + This would registe a pass to add new parameter in the compilation pipeline, so later compilation would + ALSO add this parameter unless the pass is unregisted. To unregiste this pass, call + cancel_new_parameter(pattern) + + Args: + pattern (NewParameter): NewParameter type, could be used to build nested patterns across multiple passes + after gen_new_parameter. + + Raises: + TypeError: If argument has invalid type. + + Examples: + >>> from mindspore.graph_utils.graph_pattern import NewParameter + >>> abc = NewParameter("abc") + >>> gen_new_parameter(abc) + """ + ppm = PyPassManager() + ppm.gen_new_parameter(pattern) + +def cancel_new_parameter(pattern): + """ + Use with gen_new_parameter to unregiste gen_new_parameter pass. + + Args: + pattern (NewParameter): NewParameter type, cancel the pass which would add new parameter as this pattern + describes. + + Examples: + >>> from mindspore.graph_utils.graph_pattern import NewParameter + >>> abc = NewParameter("abc") + >>> gen_new_parameter(abs) + >>> # some compilations + >>> cancel_new_parameter(abc) + """ + if not isinstance(pattern, NewParameter): + raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}") + ppm = PyPassManager() + ppm.unregiste(pattern.para_name) + +def set_renorm(should_renorm): + """ + Set whether or not to do renorm after modified graph in python pass(es). + """ + ppm = PyPassManager() + ppm.set_renorm(should_renorm) diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 21924bb5a3..b371ccb0df 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -152,7 +152,7 @@ class Primitive(Primitive_): Check if certain inputs should go to the backend. Subclass in need should override this method. Args: - *args(Primitive args): Same as arguments of current Primitive. + args(Primitive args): Same as arguments of current Primitive. Returns: A tuple consisting of two elements. The first element indicates whether we should filter out current diff --git a/tests/ut/python/optimizer/test_python_pass.py b/tests/ut/python/optimizer/test_python_pass.py index 8728120ff1..9d91b928f2 100644 --- a/tests/ut/python/optimizer/test_python_pass.py +++ b/tests/ut/python/optimizer/test_python_pass.py @@ -19,10 +19,12 @@ import mindspore.nn as nn from mindspore import context from mindspore.common.tensor import Tensor from mindspore.ops import operations as P -from mindspore.common.python_pass_register import registe_pass, PyPassManager +from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\ + cancel_new_parameter from mindspore.common.api import _generate_pip_args from mindspore._c_expression import generate_key, Executor_ -from mindspore.common.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor +from mindspore.graph_utils.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor,\ + NewParameter, Imm context.set_context(mode=context.GRAPH_MODE) @@ -56,12 +58,39 @@ def test_softmax_relu(): return pattern, target transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) - ppm = PyPassManager() - ppm.unregiste(softmax_relu_pass) + unregiste_pass(softmax_relu_pass) assert "ReLU" in transformed_repr assert "Softmax" not in transformed_repr -def test_isin_pattern(): +def test_softmax_relu_sigmoid(): + """ + Use python pass to transform from Softmax(x) to ReLU(Sigmoid(x)). + + NOTE: + Sigmoid pattern only exists in the target. + """ + inputs = Tensor(np.ones([42]), mindspore.float16) + softmax_model = nn.Softmax() + + @registe_pass(run_only_once=True) + def softmax_relu_pass(): + x = AnyPattern() + softmax_pattern = IsPrimTypeOf(P.Softmax()) + pattern = CallWith(softmax_pattern, inputs=[x]) + sigmoid_pattern = IsPrimTypeOf(P.Sigmoid(), should_replace=False) + call_sigmoid = CallWith(sigmoid_pattern, [x]) + relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False) + target = CallWith(relu_pattern, inputs=[call_sigmoid]) + return pattern, target + + transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3) + unregiste_pass(softmax_relu_pass) + assert "ReLU" in transformed_repr + assert "Sigmoid" in transformed_repr + assert "Softmax" not in transformed_repr + + +def test_isin_pattern_0(): """ Test IsIn pattern which expresses the IsIn/OneOf semantics. """ @@ -81,16 +110,41 @@ def test_isin_pattern(): target = CallWith(relu6_pattern, inputs=[x]) return pattern, target transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) - ppm = PyPassManager() - ppm.unregiste(softmax_relu_pass) + unregiste_pass(softmax_relu_pass) assert "ReLU6" in transformed_repr assert "Softmax" not in transformed_repr +def test_isin_pattern_1(): + """ + Test IsIn. IsIn is used as nested inputs for the target in this case. + """ + inputs = Tensor(np.ones([42]), mindspore.float16) + softmax_model = nn.Softmax() + + @registe_pass(run_only_once=True) + def softmax_neg_pass(): + x = AnyPattern() + softmax_pattern = IsPrimTypeOf(P.Softmax()) + call_softmax = CallWith(softmax_pattern, inputs=[x]) + relu_pattern = IsPrimTypeOf(P.ReLU()) + call_relu = CallWith(relu_pattern, inputs=[x]) + + pattern = IsIn([call_softmax, call_relu]) + neg_ops = IsPrimTypeOf(P.Neg(), should_replace=False) + target = CallWith(neg_ops, inputs=[pattern]) + return pattern, target + transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4) + print(transformed_repr) + unregiste_pass(softmax_neg_pass) + assert "Neg" in transformed_repr + assert "Softmax" in transformed_repr + def test_isnot_pattern_0(): """ Test IsNot pattern which expresses the IsNot semantics. Case: IsNot pass failed to match """ + set_renorm(False) class ConvBN(nn.Cell): def __init__(self): super(ConvBN, self).__init__() @@ -132,11 +186,11 @@ def test_isnot_pattern_0(): return pattern, target transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5) - ppm = PyPassManager() - ppm.unregiste(single_bn_pass) - ppm.unregiste(bn_pass) + unregiste_pass(single_bn_pass) + unregiste_pass(bn_pass) assert "ReLU6" not in transformed_repr assert "Softmax" in transformed_repr + set_renorm(True) def test_isnot_pattern_1(): """ @@ -160,12 +214,15 @@ def test_isnot_pattern_1(): return pattern, target transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) - ppm = PyPassManager() - ppm.unregiste(single_bn_pass) + unregiste_pass(single_bn_pass) assert "ReLU6" in transformed_repr assert "Softmax" not in transformed_repr def test_newtensor_pattern(): + """ + Test NewTensor pattern in the target + """ + set_renorm(False) inputs = Tensor(np.ones([42]), mindspore.float16) softmax_model = nn.Softmax() @@ -181,7 +238,84 @@ def test_newtensor_pattern(): target = CallWith(addn_ops, inputs=[x, new_weight], should_replace=False) return pattern, target transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) - ppm = PyPassManager() - ppm.unregiste(softmax_addn_pass) + unregiste_pass(softmax_addn_pass) assert "AddN" in transformed_repr assert "Softmax" not in transformed_repr + set_renorm(True) + +def test_newparameter_pattern(): + """ + Test NewParameter pattern in the target + """ + inputs = Tensor(np.ones([42]), mindspore.float16) + softmax_model = nn.Softmax() + + @registe_pass(run_only_once=True) + def softmax_addn_pass(): + x = AnyPattern() + softmax = P.Softmax() + pattern = CallWith(softmax, inputs=[x]) + + default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32) + default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32) + new_para_0 = NewParameter("Merlin", default_tensor0) + new_para_1 = NewParameter("Arthur", default_tensor1) + target_0 = CallWith(P.MatMul(), inputs=[new_para_0, new_para_1], should_replace=False) + target = CallWith("make_tuple", inputs=[target_0], should_replace=False) + return pattern, target + transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) + print(transformed_repr) + unregiste_pass(softmax_addn_pass) + assert "MatMul" in transformed_repr + assert "make_tuple" in transformed_repr + assert "Softmax" not in transformed_repr + +def test_imm_pattern(): + """ + Test NewParameter pattern in the target + """ + inputs = Tensor(np.ones([42]), mindspore.float16) + softmax_model = nn.Softmax() + + @registe_pass(run_only_once=True) + def softmax_addn_pass(): + x = AnyPattern() + softmax = P.Softmax() + pattern = CallWith(softmax, inputs=[x]) + imm = Imm(0) + target_0 = CallWith("make_tuple", inputs=[pattern], should_replace=False) + target = CallWith("tuple_getitem", inputs=[target_0, imm], should_replace=False) + return pattern, target + transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) + print(transformed_repr) + unregiste_pass(softmax_addn_pass) + assert "make_tuple" in transformed_repr + assert "tuple_getitem" in transformed_repr + assert "Softmax" in transformed_repr + +def test_gen_new_parameter(): + """ + Test gen_new_parameter + """ + inputs = Tensor(np.ones([42]), mindspore.float16) + softmax_model = nn.Softmax() + + default_tensor = Tensor(np.ones((4, 4)), mindspore.float32) + new_para = NewParameter("Merlin", default_tensor, should_replace=True) + gen_new_parameter(new_para) + @registe_pass(run_only_once=True) + def softmax_make_tuple_pass(): + x = AnyPattern() + softmax = P.Softmax() + pattern = CallWith(softmax, inputs=[x]) + + target = CallWith("make_tuple", inputs=[pattern, new_para], should_replace=False) + return pattern, target + transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) + print(transformed_repr) + assert "Merlin" in transformed_repr + unregiste_pass(softmax_make_tuple_pass) + cancel_new_parameter(new_para) + transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) + print(transformed_repr) + assert "Merlin" not in transformed_repr