forked from mindspore-Ecosystem/mindspore
Add python pass support
This commit is contained in:
parent
1ea38eb60c
commit
b26f6b6b67
|
@ -346,10 +346,6 @@ class TensorAddByZero : public AnfVisitor {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Visit(const AnfNodePtr &node) override {
|
void Visit(const AnfNodePtr &node) override {
|
||||||
if (IsPrimitive(node, prim::kPrimZerosLike)) {
|
|
||||||
is_zero_ = true;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (node->isa<ValueNode>() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) {
|
if (node->isa<ValueNode>() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) {
|
||||||
is_zero_ = true;
|
is_zero_ = true;
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "optimizer/pass_group.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace python_pass {
|
||||||
|
void PassGroup::AddPass(const PythonPassPtr &pass) {
|
||||||
|
if (pass != nullptr) {
|
||||||
|
passes_.push_back(pass);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PassGroup::DeletePass(const std::string &pass_name) {
|
||||||
|
for (auto iter = passes_.begin(); iter != passes_.end(); iter++) {
|
||||||
|
if ((*iter)->name() == pass_name) {
|
||||||
|
*iter = nullptr;
|
||||||
|
passes_.erase(iter);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PassGroup::Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes) const {
|
||||||
|
if (func_graph == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool changed = false;
|
||||||
|
for (const auto &pass : passes) {
|
||||||
|
if (pass != nullptr) {
|
||||||
|
if (pass->Run(func_graph)) {
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PassGroup::Run(const FuncGraphPtr &func_graph) const {
|
||||||
|
bool changed = false;
|
||||||
|
// run all passes
|
||||||
|
bool change = true;
|
||||||
|
while (change) {
|
||||||
|
change = Run(func_graph, passes_);
|
||||||
|
changed = change || changed;
|
||||||
|
if (run_only_once_) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace python_pass
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_
|
||||||
|
#define MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "optimizer/py_pass.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace python_pass {
|
||||||
|
class PassGroup {
|
||||||
|
public:
|
||||||
|
explicit PassGroup(const std::string &name = "pass_group", bool run_only_once = false)
|
||||||
|
: name_(name), passes_{}, run_only_once_(run_only_once) {}
|
||||||
|
virtual ~PassGroup() = default;
|
||||||
|
// Add graph pass, the pass object will be freed when pass manager freed.
|
||||||
|
void AddPass(const PythonPassPtr &pass);
|
||||||
|
// Delete graph pass before the pass manager is freed.
|
||||||
|
bool DeletePass(const std::string &pass_name);
|
||||||
|
// Run passes added in pass manager on the input graph
|
||||||
|
// @param [inout] graph The graph to be optimized
|
||||||
|
// @return true, graph changed
|
||||||
|
// @return false, graph not changed
|
||||||
|
bool Run(const FuncGraphPtr &func_graph) const;
|
||||||
|
// Run the given graph passes on the input graph
|
||||||
|
// @param [inout] graph The graph to be optimized
|
||||||
|
// @param [in] passes The given graph passes
|
||||||
|
// @return true, graph changed
|
||||||
|
// @return false, graph not changed
|
||||||
|
bool Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes) const;
|
||||||
|
std::string name() const { return name_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::string name_;
|
||||||
|
std::vector<PythonPassPtr> passes_;
|
||||||
|
bool run_only_once_;
|
||||||
|
};
|
||||||
|
using PassGroupPtr = std::shared_ptr<PassGroup>;
|
||||||
|
} // namespace python_pass
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_
|
|
@ -0,0 +1,236 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "optimizer/py_pass.h"
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <deque>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ir/func_graph.h"
|
||||||
|
#include "ir/manager.h"
|
||||||
|
#include "pipeline/parse/parse_base.h"
|
||||||
|
#include "pipeline/resource.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace python_pass {
|
||||||
|
namespace internal {
|
||||||
|
std::string GetNodeRepr(AnfNodePtr node) {
|
||||||
|
if (node != nullptr) {
|
||||||
|
if (node->isa<CNode>()) {
|
||||||
|
std::string repr = "(";
|
||||||
|
auto const &inputs = node->cast<CNodePtr>()->inputs();
|
||||||
|
for (auto &input : inputs) {
|
||||||
|
repr += " ";
|
||||||
|
repr += GetNodeRepr(input);
|
||||||
|
repr += " ";
|
||||||
|
}
|
||||||
|
repr += ")";
|
||||||
|
return repr;
|
||||||
|
}
|
||||||
|
if (node->isa<ValueNode>()) {
|
||||||
|
return GetValueNode(node)->ToString();
|
||||||
|
}
|
||||||
|
return node->ToString();
|
||||||
|
}
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
void ResolveFuncGraph_(const FuncGraphPtr &fg) {
|
||||||
|
auto manager = Manage(fg, false);
|
||||||
|
parse::python_adapter::set_use_signature_in_resolve(false);
|
||||||
|
parse::ResolveAll(manager);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) {
|
||||||
|
if (node == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(pattern);
|
||||||
|
if (pattern->isa<ValueNode>()) {
|
||||||
|
if (!node->isa<ValueNode>()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (GetNodeRepr(pattern) == GetNodeRepr(node)) {
|
||||||
|
// add to equiv_ptr
|
||||||
|
equiv_ptr->insert(std::make_pair(GetValueNode(pattern)->ToString(), node));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
} else if (pattern->isa<Parameter>()) {
|
||||||
|
MS_LOG(DEBUG) << pattern->ToString() + "\n";
|
||||||
|
// add to equiv_ptr
|
||||||
|
equiv_ptr->insert(std::make_pair(pattern->ToString(), node));
|
||||||
|
return true;
|
||||||
|
} else if (pattern->isa<CNode>()) {
|
||||||
|
// match every single sub ANode
|
||||||
|
if (!node->isa<CNode>()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto pattern_inputs = pattern->cast<CNodePtr>()->inputs();
|
||||||
|
auto node_inputs = node->cast<CNodePtr>()->inputs();
|
||||||
|
if (pattern_inputs.size() != node_inputs.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (auto p_item = pattern_inputs.begin(), node_item = node_inputs.begin(); p_item != pattern_inputs.end();
|
||||||
|
p_item++, node_item++) {
|
||||||
|
auto res = Match(*p_item, *node_item, equiv_ptr);
|
||||||
|
if (!res) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
MS_LOG(EXCEPTION) << "Unexpected condition, (" + pattern->ToString() + " , " + node->ToString() + ")\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr BuildTarget(const FuncGraphPtr &func_graph, const AnfNodePtr cur_raw_dst_node_,
|
||||||
|
const NodeEquivPtr &equiv_ptr) {
|
||||||
|
if (cur_raw_dst_node_->isa<Parameter>()) {
|
||||||
|
auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->ToString());
|
||||||
|
if (sub_pair != equiv_ptr->end()) {
|
||||||
|
return sub_pair->second;
|
||||||
|
}
|
||||||
|
MS_LOG(EXCEPTION) << "cur_raw_dst_node_ : " + internal::GetNodeRepr(cur_raw_dst_node_) + "\n";
|
||||||
|
} else if (cur_raw_dst_node_->isa<ValueNode>()) {
|
||||||
|
// check primitive ValueNode
|
||||||
|
auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->cast<ValueNodePtr>()->value()->ToString());
|
||||||
|
if (sub_pair != equiv_ptr->end()) {
|
||||||
|
return sub_pair->second;
|
||||||
|
}
|
||||||
|
return cur_raw_dst_node_;
|
||||||
|
} else if (cur_raw_dst_node_->isa<CNode>()) {
|
||||||
|
std::vector<AnfNodePtr> new_inputs;
|
||||||
|
auto inputs = cur_raw_dst_node_->cast<CNodePtr>()->inputs();
|
||||||
|
for (auto sub_node = inputs.begin(); sub_node != inputs.end(); sub_node++) {
|
||||||
|
auto subed = internal::BuildTarget(func_graph, *sub_node, equiv_ptr);
|
||||||
|
new_inputs.push_back(subed);
|
||||||
|
}
|
||||||
|
return func_graph->NewCNode(new_inputs);
|
||||||
|
}
|
||||||
|
MS_LOG(EXCEPTION) << "Unexpected node type, got : " + internal::GetNodeRepr(cur_raw_dst_node_);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isTraversable(const AnfNodePtr &node) {
|
||||||
|
if (node == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (node->isa<CNode>() || node->isa<Parameter>()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (IsValueNode<FuncGraph>(node) || IsValueNode<RefKey>(node)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} // namespace internal
|
||||||
|
|
||||||
|
void PythonPass::Build(const py::function &src, const py::function &dst) {
|
||||||
|
// 1. get FuncGraph from py::function
|
||||||
|
auto src_fg_ = parse::ParsePythonCode(src);
|
||||||
|
auto dst_fg_ = parse::ParsePythonCode(dst);
|
||||||
|
if (src_fg_ == nullptr || dst_fg_ == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Failed to parse python code.\n";
|
||||||
|
}
|
||||||
|
// 2. Resolve
|
||||||
|
internal::ResolveFuncGraph_(src_fg_);
|
||||||
|
internal::ResolveFuncGraph_(dst_fg_);
|
||||||
|
// 3. from FuncGraphPtr to ValueNode
|
||||||
|
src_node_ = src_fg_->output();
|
||||||
|
dst_node_ = dst_fg_->output();
|
||||||
|
}
|
||||||
|
|
||||||
|
PythonPass::PythonPass(const std::string &name, const py::function &src, const py::function &dst, bool run_only_once,
|
||||||
|
bool multigraph)
|
||||||
|
: name_(name), run_only_once_(run_only_once), multigraph_(multigraph) {
|
||||||
|
Build(src, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||||
|
auto equiv_ptr = std::make_shared<NodeEquiv>();
|
||||||
|
bool is_a_match = internal::Match(src_node_, node, equiv_ptr);
|
||||||
|
if (is_a_match) {
|
||||||
|
auto new_node = internal::BuildTarget(func_graph, dst_node_, equiv_ptr);
|
||||||
|
MS_LOG(DEBUG) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n";
|
||||||
|
return new_node;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PythonPass::Run(const FuncGraphPtr &func_graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
FuncGraphManagerPtr manager = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
manager->AddFuncGraph(func_graph);
|
||||||
|
auto seen = NewSeenGeneration();
|
||||||
|
// 1024 is for the initial capacity of deque
|
||||||
|
std::deque<AnfNodePtr> todo(1024);
|
||||||
|
todo.push_back(func_graph->output());
|
||||||
|
bool changes = false;
|
||||||
|
|
||||||
|
auto &all_nodes = manager->all_nodes();
|
||||||
|
while (!todo.empty()) {
|
||||||
|
AnfNodePtr node = todo.front();
|
||||||
|
todo.pop_front();
|
||||||
|
|
||||||
|
// check whether this node has been matched.
|
||||||
|
if (node == nullptr || node->seen_ == seen || !internal::isTraversable(node) || !all_nodes.contains(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
node->seen_ = seen;
|
||||||
|
|
||||||
|
// select nodes that this transform can be applied.
|
||||||
|
AnfNodePtr new_node = Run(func_graph, node);
|
||||||
|
bool change = (new_node != nullptr);
|
||||||
|
if (new_node != nullptr && new_node != node) {
|
||||||
|
(void)manager->Replace(node, new_node);
|
||||||
|
} else if (new_node == nullptr) {
|
||||||
|
new_node = node;
|
||||||
|
}
|
||||||
|
if (run_only_once_) {
|
||||||
|
return change;
|
||||||
|
}
|
||||||
|
|
||||||
|
// find success, and add them to todo list
|
||||||
|
if (IsValueNode<FuncGraph>(node)) {
|
||||||
|
todo.push_back(GetValueNode<FuncGraphPtr>(node)->output());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->isa<CNode>()) {
|
||||||
|
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||||
|
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto &node_users = manager->node_users();
|
||||||
|
if (change && node_users.find(node) != node_users.end()) {
|
||||||
|
for (auto &use : node_users[node]) {
|
||||||
|
auto use_node = use.first;
|
||||||
|
if (use_node == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
todo.push_back(use_node);
|
||||||
|
if (use_node->seen_ == seen) {
|
||||||
|
use_node->seen_--;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return changes;
|
||||||
|
}
|
||||||
|
} // namespace python_pass
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_OPTIMIZER_PASS_H_
|
||||||
|
#define MINDSPORE_CCSRC_OPTIMIZER_PASS_H_
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "pybind_api/api_register.h"
|
||||||
|
#include "pybind_api/export_flags.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace python_pass {
|
||||||
|
class PythonPass;
|
||||||
|
using PythonPassPtr = std::shared_ptr<PythonPass>;
|
||||||
|
using NodeEquiv = std::unordered_map<std::string, AnfNodePtr>;
|
||||||
|
using NodeEquivPtr = std::shared_ptr<NodeEquiv>;
|
||||||
|
|
||||||
|
class PythonPass {
|
||||||
|
public:
|
||||||
|
explicit PythonPass(const std::string &name, const py::function &src, const py::function &dst,
|
||||||
|
bool run_only_once = false, bool multigraph = true);
|
||||||
|
~PythonPass() = default;
|
||||||
|
bool Run(const FuncGraphPtr &func_graph);
|
||||||
|
std::string name() const { return name_; }
|
||||||
|
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Build(const py::function &src, const py::function &dst);
|
||||||
|
AnfNodePtr src_node_ = nullptr;
|
||||||
|
AnfNodePtr dst_node_ = nullptr;
|
||||||
|
const std::string name_;
|
||||||
|
bool run_only_once_;
|
||||||
|
bool multigraph_ = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
using PythonPassPtr = std::shared_ptr<PythonPass>;
|
||||||
|
} // namespace python_pass
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_OPTIMIZER_PASS_H_
|
|
@ -0,0 +1,84 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "optimizer/py_pass_manager.h"
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
|
#include <initializer_list>
|
||||||
|
|
||||||
|
#include "ir/manager.h"
|
||||||
|
#include "optimizer/pass_group.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace python_pass {
|
||||||
|
PyPassManagerPtr PyPassManager::global_instance = nullptr;
|
||||||
|
std::unordered_map<Phase, PassGroupPtr> PyPassManager::phase_to_group_;
|
||||||
|
|
||||||
|
PassGroupPtr PyPassManager::GetPassGroup(Phase phase) {
|
||||||
|
auto pm = phase_to_group_.find(phase);
|
||||||
|
if (pm == phase_to_group_.end()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return pm->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyPassManagerPtr PyPassManager::GetInstance() {
|
||||||
|
if (global_instance == nullptr) {
|
||||||
|
global_instance = std::shared_ptr<PyPassManager>(new (std::nothrow) PyPassManager());
|
||||||
|
}
|
||||||
|
return global_instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyPassManager::PyPassManager() {
|
||||||
|
phase_to_group_[Phase::RESOLVE] = std::make_shared<PassGroup>();
|
||||||
|
phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void PyPassManager::Registe(const std::string &pass_name, const py::function &pattern, const py::function &target,
|
||||||
|
Phase phase, bool run_only_once, bool multigraph) {
|
||||||
|
auto cur_pm = GetPassGroup(phase);
|
||||||
|
MS_EXCEPTION_IF_NULL(cur_pm);
|
||||||
|
PythonPassPtr new_pass = std::make_shared<PythonPass>(pass_name, pattern, target, run_only_once, multigraph);
|
||||||
|
cur_pm->AddPass(new_pass);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) {
|
||||||
|
auto cur_pm = GetPassGroup(phase);
|
||||||
|
MS_EXCEPTION_IF_NULL(cur_pm);
|
||||||
|
if (!cur_pm->DeletePass(pass_name)) {
|
||||||
|
MS_LOG(WARNING) << "No such pass : " + pass_name + "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void PyPassManager::ClearRes() {
|
||||||
|
MS_LOG(INFO) << "Clear PyPassManager resources!";
|
||||||
|
global_instance = nullptr;
|
||||||
|
phase_to_group_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_PYBIND_DEFINE(
|
||||||
|
PyPassManager_, ([](const py::module *m) {
|
||||||
|
(void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("resolve", Phase::RESOLVE).value("opt", Phase::OPT);
|
||||||
|
(void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_")
|
||||||
|
.def(py::init([]() { return PyPassManager::GetInstance(); }))
|
||||||
|
.def("registe", &PyPassManager::Registe, "Registe python pass")
|
||||||
|
.def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass");
|
||||||
|
}));
|
||||||
|
} // namespace python_pass
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,66 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_
|
||||||
|
#define MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "ir/func_graph.h"
|
||||||
|
#include "ir/primitive.h"
|
||||||
|
#include "utils/graph_utils.h"
|
||||||
|
#include "common/utils.h"
|
||||||
|
|
||||||
|
#include "pipeline/parse/resolve.h"
|
||||||
|
#include "optimizer/py_pass.h"
|
||||||
|
#include "optimizer/pass_group.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace python_pass {
|
||||||
|
class PyPassManager;
|
||||||
|
using PyPassManagerPtr = std::shared_ptr<PyPassManager>;
|
||||||
|
|
||||||
|
enum Phase { RESOLVE, OPT };
|
||||||
|
|
||||||
|
class PyPassManager {
|
||||||
|
protected:
|
||||||
|
PyPassManager();
|
||||||
|
static PyPassManagerPtr global_instance;
|
||||||
|
|
||||||
|
public:
|
||||||
|
// Singletons should not be cloneable and assignable
|
||||||
|
PyPassManager(const PyPassManager &other) = delete;
|
||||||
|
void operator=(const PyPassManager &) = delete;
|
||||||
|
// Access the only global instance
|
||||||
|
static PyPassManagerPtr GetInstance();
|
||||||
|
virtual ~PyPassManager() = default;
|
||||||
|
void Registe(const std::string &pass_name, const py::function &pattern, const py::function &target,
|
||||||
|
Phase phase = Phase::RESOLVE, bool run_only_once = false, bool multigraph = true);
|
||||||
|
void Unregiste(const std::string &pass_name, Phase phase);
|
||||||
|
PassGroupPtr GetPassGroup(Phase phase);
|
||||||
|
void ClearRes();
|
||||||
|
|
||||||
|
private:
|
||||||
|
static std::unordered_map<Phase, PassGroupPtr> phase_to_group_;
|
||||||
|
};
|
||||||
|
} // namespace python_pass
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_
|
|
@ -39,6 +39,7 @@
|
||||||
#include "optimizer/optimizer.h"
|
#include "optimizer/optimizer.h"
|
||||||
#include "vm/transform.h"
|
#include "vm/transform.h"
|
||||||
#include "parse/python_adapter.h"
|
#include "parse/python_adapter.h"
|
||||||
|
#include "optimizer/py_pass_manager.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace pipeline {
|
namespace pipeline {
|
||||||
|
@ -420,6 +421,25 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
|
||||||
|
|
||||||
bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); }
|
bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); }
|
||||||
|
|
||||||
|
void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
|
||||||
|
MS_EXCEPTION_IF_NULL(res->manager());
|
||||||
|
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||||
|
auto ppm = opt::python_pass::PyPassManager::GetInstance();
|
||||||
|
if (!ppm->GetPassGroup(phase)->Run(res->func_graph())) {
|
||||||
|
MS_LOG(DEBUG) << "No match.\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ResolveActionPyStub(const ResourcePtr &res) {
|
||||||
|
ActionPyStub(res, opt::python_pass::Phase::RESOLVE);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OptActionPyStub(const ResourcePtr &res) {
|
||||||
|
ActionPyStub(res, opt::python_pass::Phase::RESOLVE);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
static std::vector<ActionItem> CommonPipeline() {
|
static std::vector<ActionItem> CommonPipeline() {
|
||||||
std::vector<ActionItem> actions;
|
std::vector<ActionItem> actions;
|
||||||
|
|
||||||
|
@ -432,6 +452,8 @@ static std::vector<ActionItem> CommonPipeline() {
|
||||||
if (!multi_graphs) {
|
if (!multi_graphs) {
|
||||||
actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
|
actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
|
||||||
}
|
}
|
||||||
|
// Add resolve-stage python pass stub
|
||||||
|
actions.emplace_back(std::make_pair("py_resolve", ResolveActionPyStub));
|
||||||
actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
|
actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
|
||||||
// Evaluate type and shape, and specialize
|
// Evaluate type and shape, and specialize
|
||||||
actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));
|
actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));
|
||||||
|
@ -443,6 +465,8 @@ std::vector<ActionItem> GePipeline() {
|
||||||
auto actions = CommonPipeline();
|
auto actions = CommonPipeline();
|
||||||
// optimize
|
// optimize
|
||||||
actions.emplace_back(std::make_pair("optimize", GeOptimizeAction));
|
actions.emplace_back(std::make_pair("optimize", GeOptimizeAction));
|
||||||
|
// Add opt-stage python pass stub
|
||||||
|
actions.emplace_back(std::make_pair("py_opt", OptActionPyStub));
|
||||||
actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction));
|
actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction));
|
||||||
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||||
return actions;
|
return actions;
|
||||||
|
@ -454,6 +478,9 @@ std::vector<ActionItem> VmPipeline() {
|
||||||
// optimize
|
// optimize
|
||||||
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
||||||
|
|
||||||
|
// Add opt-stage python pass stub
|
||||||
|
actions.emplace_back(std::make_pair("py_opt", OptActionPyStub));
|
||||||
|
|
||||||
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||||
|
|
||||||
// compile the ANF graph
|
// compile the ANF graph
|
||||||
|
|
|
@ -39,6 +39,7 @@
|
||||||
#include "device/kernel_runtime_manager.h"
|
#include "device/kernel_runtime_manager.h"
|
||||||
#include "debug/trace.h"
|
#include "debug/trace.h"
|
||||||
#include "pynative/pynative_execute.h"
|
#include "pynative/pynative_execute.h"
|
||||||
|
#include "optimizer/py_pass_manager.h"
|
||||||
|
|
||||||
#if (ENABLE_GE || ENABLE_D)
|
#if (ENABLE_GE || ENABLE_D)
|
||||||
#include "pipeline/pipeline_ge.h"
|
#include "pipeline/pipeline_ge.h"
|
||||||
|
@ -964,6 +965,7 @@ void ClearResAtexit() {
|
||||||
pipeline::ExecutorPy::ClearRes();
|
pipeline::ExecutorPy::ClearRes();
|
||||||
pipeline::ReclaimOptimizer();
|
pipeline::ReclaimOptimizer();
|
||||||
pynative::PynativeExecutor::GetInstance()->ClearRes();
|
pynative::PynativeExecutor::GetInstance()->ClearRes();
|
||||||
|
opt::python_pass::PyPassManager::GetInstance()->ClearRes();
|
||||||
#ifdef ENABLE_GE
|
#ifdef ENABLE_GE
|
||||||
transform::DfGraphManager::GetInstance().ClearGraph();
|
transform::DfGraphManager::GetInstance().ClearGraph();
|
||||||
transform::DfGraphConvertor::get_adpt_map().clear();
|
transform::DfGraphConvertor::get_adpt_map().clear();
|
||||||
|
|
|
@ -0,0 +1,80 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Python pass register"""
|
||||||
|
from inspect import isfunction
|
||||||
|
from mindspore._c_expression import PyPassManager_
|
||||||
|
from mindspore._c_expression import phase
|
||||||
|
|
||||||
|
class PyPassManager(PyPassManager_):
|
||||||
|
r"""
|
||||||
|
Used to registe and unregiste python passes which can be used to alter graphs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt.
|
||||||
|
run_only_once (bool): Specify whether or not to run pass only once. Default: False.
|
||||||
|
multigraph (bool): Whether or not the pattern exists across graphs. Default: True.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If argument has invalid type.
|
||||||
|
"""
|
||||||
|
def __init__(self, pipeline_phase=phase.opt, run_only_once=False, multi_graph=True):
|
||||||
|
if not isinstance(pipeline_phase, phase):
|
||||||
|
raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}")
|
||||||
|
if not isinstance(run_only_once, bool):
|
||||||
|
raise TypeError(f"Expecting bool, got : ({type(run_only_once)}){run_only_once}")
|
||||||
|
if not isinstance(multi_graph, bool):
|
||||||
|
raise TypeError(f"Expecting bool, got : ({type(multi_graph)}){multi_graph}")
|
||||||
|
PyPassManager_.__init__(self)
|
||||||
|
self.phase_ = pipeline_phase
|
||||||
|
self.run_only_once_ = run_only_once
|
||||||
|
self.multi_graph_ = multi_graph
|
||||||
|
|
||||||
|
def registe(self, py_pass):
|
||||||
|
if not isfunction(py_pass):
|
||||||
|
raise TypeError(f"Expecting function pass, got : ({type(py_pass)}){py_pass}")
|
||||||
|
pattern, target = py_pass()
|
||||||
|
pass_name = py_pass.__name__
|
||||||
|
if not isfunction(pattern):
|
||||||
|
raise TypeError(f"Expecting function pattern, got : ({type(pattern)}){pattern}")
|
||||||
|
if not isfunction(target):
|
||||||
|
raise TypeError(f"Expecting function target, got : ({type(target)}){target}")
|
||||||
|
super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_, self.multi_graph_)
|
||||||
|
|
||||||
|
def unregiste(self, py_pass, pipeline_phase=phase.opt):
|
||||||
|
if not isinstance(pipeline_phase, phase):
|
||||||
|
raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}")
|
||||||
|
if isinstance(py_pass, str):
|
||||||
|
super().unregiste(py_pass, pipeline_phase)
|
||||||
|
return
|
||||||
|
if isfunction(py_pass):
|
||||||
|
super().unregiste(py_pass.__name__, pipeline_phase)
|
||||||
|
return
|
||||||
|
raise TypeError(f"Expecting py_pass to be string or function, got ({type(py_pass)}){py_pass}")
|
||||||
|
|
||||||
|
def __call__(self, py_pass):
|
||||||
|
self.registe(py_pass)
|
||||||
|
return py_pass
|
||||||
|
|
||||||
|
def registe_pass(pipeline_phase=phase.opt, run_only_once=False, multi_graph=True):
|
||||||
|
"""
|
||||||
|
Examples:
|
||||||
|
>>> @registe_pass()
|
||||||
|
>>> def toy_pass():
|
||||||
|
>>> def pattern():
|
||||||
|
>>> pass
|
||||||
|
>>> def target():
|
||||||
|
>>> pass
|
||||||
|
"""
|
||||||
|
return PyPassManager(pipeline_phase, run_only_once, multi_graph)
|
|
@ -170,7 +170,8 @@ class Dense(Cell):
|
||||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||||
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
activation (str): activate function applied to the output of the fully connected layer, eg. 'relu'.
|
||||||
|
Default: None.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If weight_init or bias_init shape is incorrect.
|
ValueError: If weight_init or bias_init shape is incorrect.
|
||||||
|
|
Loading…
Reference in New Issue