Update python pattern expression
This commit is contained in:
parent
c700fc5515
commit
6d4c07c886
|
@ -0,0 +1,158 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "frontend/optimizer/pattern.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "pybind_api/export_flags.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace python_pass {
|
||||
int Pattern::g_id_ = 0;
|
||||
|
||||
MatchResultPtr IsPrimTypeOf::match(const AnfNodePtr &node) {
|
||||
if (!IsValueNode<Primitive>(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
MatchResultPtr res = std::make_shared<MatchResult>();
|
||||
if (IsValueNode<Primitive>(node)) {
|
||||
// iterate over all primitives
|
||||
for (auto &iter : primitives_) {
|
||||
if (IsPrimitive(node, iter) || iter->name() == "*") {
|
||||
matched_prim_ = iter;
|
||||
res->add_entry(shared_from_base<IsPrimTypeOf>(), node);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MatchResultPtr CallWith::match(const AnfNodePtr &node) {
|
||||
if (!IsPrimitiveCNode(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
MatchResultPtr res = std::make_shared<MatchResult>();
|
||||
// IsPrimitiveCNode
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Check Primitive ValueNode
|
||||
if (prim_pattern_ != nullptr) {
|
||||
// Passed in prim_pattern
|
||||
auto prim_value_res = prim_pattern_->match(cnode->input(0));
|
||||
if (prim_value_res == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
res->merge(prim_value_res);
|
||||
} else if (prim_ != nullptr) {
|
||||
// Passed in primitive/primitive str
|
||||
if (!IsPrimitive(cnode->input(0), prim_)) {
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Uninitialized CallWith pattern.";
|
||||
}
|
||||
// Check inputs
|
||||
auto p_inputs_size = inputs_.size();
|
||||
auto node_inputs_size = cnode->size() - 1;
|
||||
if (p_inputs_size != 0 && p_inputs_size != node_inputs_size) {
|
||||
return nullptr;
|
||||
}
|
||||
// If inputs is not specified, add node without looking into its inputs
|
||||
if (p_inputs_size == 0) {
|
||||
res->add_entry(shared_from_base<CallWith>(), cnode->input(0));
|
||||
return res;
|
||||
}
|
||||
bool failed = false;
|
||||
for (std::size_t i = 0; i < node_inputs_size; i++) {
|
||||
auto pattern = inputs_[i];
|
||||
auto input = cnode->input(i + 1);
|
||||
auto input_match_result = pattern->match(input);
|
||||
if (input_match_result == nullptr) {
|
||||
failed = true;
|
||||
break;
|
||||
}
|
||||
res->merge(input_match_result);
|
||||
}
|
||||
if (!failed) {
|
||||
res->add_entry(shared_from_base<CallWith>(), cnode->input(0));
|
||||
return res;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MatchResultPtr IsIn::match(const AnfNodePtr &node) {
|
||||
for (auto &iter : patterns_) {
|
||||
auto res = iter->match(node);
|
||||
if (res != nullptr) {
|
||||
return res;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MatchResultPtr IsNot::match(const AnfNodePtr &node) {
|
||||
for (auto &iter : patterns_) {
|
||||
auto res = iter->match(node);
|
||||
if (res != nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
auto res = std::make_shared<MatchResult>();
|
||||
res->add_entry(shared_from_base<IsNot>(), node);
|
||||
return res;
|
||||
}
|
||||
|
||||
MatchResultPtr AnyPattern::match(const AnfNodePtr &node) {
|
||||
MatchResultPtr res = std::make_shared<MatchResult>();
|
||||
res->add_entry(shared_from_base<AnyPattern>(), node);
|
||||
return res;
|
||||
}
|
||||
|
||||
AnfNodePtr MatchResult::get_node(const PatternPtr &pattern) {
|
||||
auto entry = match_result_.find(pattern);
|
||||
if (entry == match_result_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return entry->second;
|
||||
}
|
||||
|
||||
void MatchResult::merge(const MatchResultPtr &other_result) {
|
||||
auto other_result_map = other_result->_result();
|
||||
// add/update entries in other_result
|
||||
for (auto &iter : other_result_map) {
|
||||
match_result_[iter.first] = iter.second;
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(
|
||||
Pattern, ([](const py::module *m) {
|
||||
(void)py::class_<Pattern, std::shared_ptr<Pattern>>(*m, "Pattern").def(py::init<>());
|
||||
(void)py::class_<IsIn, std::shared_ptr<IsIn>, Pattern>(*m, "IsIn_").def(py::init<vector<PatternPtr>>());
|
||||
(void)py::class_<IsPrimTypeOf, std::shared_ptr<IsPrimTypeOf>, Pattern>(*m, "IsPrimTypeOf_", py::dynamic_attr())
|
||||
.def(py::init<vector<PrimitivePyPtr>, string, bool>())
|
||||
.def(py::init<vector<string>, string, bool>());
|
||||
(void)py::class_<CallWith, std::shared_ptr<CallWith>, Pattern>(*m, "CallWith_")
|
||||
.def(py::init<PatternPtr, vector<PatternPtr>, bool>())
|
||||
.def(py::init<PrimitivePyPtr, vector<PatternPtr>, bool>())
|
||||
.def(py::init<string, vector<PatternPtr>, bool>());
|
||||
(void)py::class_<IsNot, std::shared_ptr<IsNot>, Pattern>(*m, "IsNot_").def(py::init<vector<PatternPtr>>());
|
||||
(void)py::class_<AnyPattern, std::shared_ptr<AnyPattern>, Pattern>(*m, "AnyPattern").def(py::init<>());
|
||||
(void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_")
|
||||
.def(py::init<tensor::TensorPtr>());
|
||||
}));
|
||||
} // namespace python_pass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,228 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "base/base.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "utils/primitive_py.h"
|
||||
#include "utils/tensor_py.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace python_pass {
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
||||
class MatchResult;
|
||||
using MatchResultPtr = std::shared_ptr<MatchResult>;
|
||||
class Pattern;
|
||||
using PatternPtr = std::shared_ptr<Pattern>;
|
||||
class IsPrimTypeOf;
|
||||
using IsPrimTypeOfPtr = std::shared_ptr<IsPrimTypeOf>;
|
||||
class CallWith;
|
||||
using CallWithPtr = std::shared_ptr<CallWith>;
|
||||
class NewTensor;
|
||||
using NewTensorPtr = std::shared_ptr<NewTensor>;
|
||||
struct PatternHasher;
|
||||
struct PatternEqual;
|
||||
using PatternNodeMap = std::unordered_map<PatternPtr, AnfNodePtr, PatternHasher, PatternEqual>;
|
||||
|
||||
class Pattern : public Base {
|
||||
public:
|
||||
Pattern() : unique_name_(std::to_string(g_id_++)) {}
|
||||
virtual MatchResultPtr match(const AnfNodePtr &node) { return nullptr; }
|
||||
virtual bool operator==(const Pattern &other) const { return unique_name_ == other.unique_name_; }
|
||||
string unique_name() const { return unique_name_; }
|
||||
vector<PatternPtr> inputs() { return inputs_; }
|
||||
bool should_replace() { return should_replace_; }
|
||||
virtual void reset() {}
|
||||
|
||||
protected:
|
||||
static int g_id_;
|
||||
// NOTE: To ensure uniqueness of the name, raise g_id_ by 1 every time a pattern got constructed
|
||||
string unique_name_;
|
||||
vector<PatternPtr> inputs_;
|
||||
bool should_replace_ = true;
|
||||
};
|
||||
|
||||
struct PatternEqual {
|
||||
bool operator()(PatternPtr const &p1, PatternPtr const &p2) const {
|
||||
MS_EXCEPTION_IF_NULL(p1);
|
||||
MS_EXCEPTION_IF_NULL(p2);
|
||||
return p1->unique_name() == p2->unique_name();
|
||||
}
|
||||
};
|
||||
|
||||
struct PatternHasher {
|
||||
std::size_t operator()(PatternPtr const &p) const {
|
||||
MS_EXCEPTION_IF_NULL(p);
|
||||
return std::hash<string>()(p->unique_name());
|
||||
}
|
||||
};
|
||||
|
||||
class IsPrimTypeOf : public Pattern {
|
||||
public:
|
||||
IsPrimTypeOf() { unique_name_ = std::to_string(g_id_++); }
|
||||
IsPrimTypeOf(vector<PrimitivePyPtr> prims, string name, bool should_replace)
|
||||
: primitives_(prims), name_(name), matched_prim_(nullptr) {
|
||||
unique_name_ = std::to_string(g_id_++) + "_" + name;
|
||||
should_replace_ = should_replace;
|
||||
if (!should_replace) {
|
||||
matched_prim_ = prims[0];
|
||||
}
|
||||
}
|
||||
IsPrimTypeOf(vector<string> types, string name, bool should_replace) : types_(types), name_(name) {
|
||||
unique_name_ = std::to_string(g_id_++) + "_" + name;
|
||||
// Make primitives_
|
||||
for (auto &iter : types) {
|
||||
primitives_.push_back(std::make_shared<PrimitivePy>(iter, py::cast(nullptr)));
|
||||
}
|
||||
should_replace_ = should_replace;
|
||||
if (!should_replace) {
|
||||
matched_prim_ = primitives_[0];
|
||||
}
|
||||
}
|
||||
MS_DECLARE_PARENT(IsPrimTypeOf, Pattern);
|
||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||
PrimitivePyPtr matched_primitive() { return matched_prim_; }
|
||||
void reset() override {
|
||||
if (should_replace_) {
|
||||
matched_prim_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
vector<string> types_;
|
||||
vector<PrimitivePyPtr> primitives_;
|
||||
string name_;
|
||||
PrimitivePyPtr matched_prim_;
|
||||
};
|
||||
|
||||
class CallWith : public Pattern {
|
||||
public:
|
||||
CallWith() { unique_name_ = std::to_string(g_id_++); }
|
||||
CallWith(PatternPtr prim_pattern, vector<PatternPtr> inputs, bool should_replace) {
|
||||
// NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting
|
||||
prim_pattern_ = prim_pattern;
|
||||
unique_name_ = std::to_string(g_id_++) + prim_pattern->unique_name();
|
||||
inputs_ = inputs;
|
||||
should_replace_ = should_replace;
|
||||
}
|
||||
CallWith(PrimitivePyPtr prim, vector<PatternPtr> inputs, bool should_replace) {
|
||||
prim_ = prim;
|
||||
unique_name_ = std::to_string(g_id_++) + prim_->ToString();
|
||||
inputs_ = inputs;
|
||||
should_replace_ = should_replace;
|
||||
}
|
||||
CallWith(string prim_str, vector<PatternPtr> inputs, bool should_replace) {
|
||||
prim_ = std::make_shared<PrimitivePy>(prim_str, py::cast(nullptr));
|
||||
unique_name_ = std::to_string(g_id_++) + prim_->ToString();
|
||||
inputs_ = inputs;
|
||||
should_replace_ = should_replace;
|
||||
}
|
||||
MS_DECLARE_PARENT(CallWith, Pattern);
|
||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||
PrimitivePtr prim_value() { return prim_; }
|
||||
PatternPtr prim_pattern() { return prim_pattern_; }
|
||||
|
||||
private:
|
||||
PatternPtr prim_pattern_ = nullptr;
|
||||
PrimitivePtr prim_ = nullptr;
|
||||
vector<string> types_;
|
||||
string name_;
|
||||
};
|
||||
|
||||
class IsIn : public Pattern {
|
||||
public:
|
||||
IsIn() { unique_name_ = std::to_string(g_id_++); }
|
||||
explicit IsIn(vector<PatternPtr> patterns) : patterns_(patterns) {
|
||||
unique_name_ = std::to_string(g_id_++);
|
||||
for (auto &iter : patterns) {
|
||||
unique_name_ = unique_name_ + "_" + iter->unique_name();
|
||||
}
|
||||
}
|
||||
MS_DECLARE_PARENT(IsIn, Pattern);
|
||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||
|
||||
private:
|
||||
vector<PatternPtr> patterns_;
|
||||
};
|
||||
|
||||
class IsNot : public Pattern {
|
||||
public:
|
||||
IsNot() { unique_name_ = std::to_string(g_id_++); }
|
||||
explicit IsNot(vector<PatternPtr> patterns) : patterns_(patterns) {
|
||||
unique_name_ = std::to_string(g_id_++);
|
||||
for (auto &iter : patterns) {
|
||||
unique_name_ = "IsNot_" + unique_name_ + "_" + iter->unique_name();
|
||||
}
|
||||
}
|
||||
MS_DECLARE_PARENT(IsNot, Pattern);
|
||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||
|
||||
private:
|
||||
vector<PatternPtr> patterns_;
|
||||
};
|
||||
|
||||
class AnyPattern : public Pattern {
|
||||
public:
|
||||
AnyPattern() { unique_name_ = std::to_string(g_id_++) + "_AnyPattern"; }
|
||||
MS_DECLARE_PARENT(AnyPattern, Pattern);
|
||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||
};
|
||||
|
||||
class NewTensor : public Pattern {
|
||||
public:
|
||||
NewTensor() { unique_name_ = std::to_string(g_id_++); }
|
||||
explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) { should_replace_ = false; }
|
||||
MS_DECLARE_PARENT(NewTensor, Pattern);
|
||||
MatchResultPtr match(const AnfNodePtr &node) override {
|
||||
MS_LOG(EXCEPTION) << "Find NewTensor in pattern, NewTensor should only appear in the target.\n";
|
||||
}
|
||||
tensor::TensorPtr input_tensor() { return input_tensor_; }
|
||||
|
||||
private:
|
||||
tensor::TensorPtr input_tensor_;
|
||||
};
|
||||
|
||||
class MatchResult {
|
||||
public:
|
||||
MatchResult() {}
|
||||
void add_entry(PatternPtr pattern, AnfNodePtr node) { match_result_[pattern] = node; }
|
||||
PatternNodeMap _result() { return match_result_; }
|
||||
AnfNodePtr get_node(const PatternPtr &pattern);
|
||||
void merge(const MatchResultPtr &other_result);
|
||||
void clear() { match_result_.clear(); }
|
||||
void dump() {
|
||||
MS_LOG(DEBUG) << "match_result_.size: " + std::to_string(match_result_.size()) + "\n";
|
||||
for (auto &iter : match_result_) {
|
||||
MS_LOG(DEBUG) << "Pattern : " + iter.first->unique_name() + " , node : " + iter.second->ToString() + "\n";
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
PatternNodeMap match_result_;
|
||||
};
|
||||
} // namespace python_pass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/manager.h"
|
||||
#include "utils/primitive_py.h"
|
||||
#include "pipeline/jit/parse/parse_base.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
|
||||
|
@ -29,6 +30,8 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
namespace python_pass {
|
||||
namespace internal {
|
||||
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res);
|
||||
|
||||
std::string GetNodeRepr(AnfNodePtr node) {
|
||||
if (node != nullptr) {
|
||||
if (node->isa<CNode>()) {
|
||||
|
@ -50,84 +53,7 @@ std::string GetNodeRepr(AnfNodePtr node) {
|
|||
return "";
|
||||
}
|
||||
|
||||
void ResolveFuncGraph_(const FuncGraphPtr &fg) {
|
||||
auto manager = Manage(fg, false);
|
||||
auto use_sig = parse::python_adapter::UseSignatureInResolve();
|
||||
parse::python_adapter::set_use_signature_in_resolve(false);
|
||||
parse::ResolveAll(manager);
|
||||
parse::python_adapter::set_use_signature_in_resolve(use_sig);
|
||||
}
|
||||
|
||||
bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) {
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(pattern);
|
||||
if (pattern->isa<ValueNode>()) {
|
||||
if (!node->isa<ValueNode>()) {
|
||||
return false;
|
||||
}
|
||||
if (GetNodeRepr(pattern) == GetNodeRepr(node)) {
|
||||
// add to equiv_ptr
|
||||
equiv_ptr->insert(std::make_pair(GetValueNode(pattern)->ToString(), node));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
} else if (pattern->isa<Parameter>()) {
|
||||
MS_LOG(DEBUG) << pattern->ToString() + "\n";
|
||||
// add to equiv_ptr
|
||||
equiv_ptr->insert(std::make_pair(pattern->ToString(), node));
|
||||
return true;
|
||||
} else if (pattern->isa<CNode>()) {
|
||||
// match every single sub ANode
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto pattern_inputs = pattern->cast<CNodePtr>()->inputs();
|
||||
auto node_inputs = node->cast<CNodePtr>()->inputs();
|
||||
if (pattern_inputs.size() != node_inputs.size()) {
|
||||
return false;
|
||||
}
|
||||
for (auto p_item = pattern_inputs.begin(), node_item = node_inputs.begin(); p_item != pattern_inputs.end();
|
||||
p_item++, node_item++) {
|
||||
auto res = Match(*p_item, *node_item, equiv_ptr);
|
||||
if (!res) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Unexpected condition, (" + pattern->ToString() + " , " + node->ToString() + ")\n";
|
||||
}
|
||||
|
||||
AnfNodePtr BuildTarget(const FuncGraphPtr &func_graph, const AnfNodePtr cur_raw_dst_node_,
|
||||
const NodeEquivPtr &equiv_ptr) {
|
||||
if (cur_raw_dst_node_->isa<Parameter>()) {
|
||||
auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->ToString());
|
||||
if (sub_pair != equiv_ptr->end()) {
|
||||
return sub_pair->second;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "cur_raw_dst_node_ : " + internal::GetNodeRepr(cur_raw_dst_node_) + "\n";
|
||||
} else if (cur_raw_dst_node_->isa<ValueNode>()) {
|
||||
// check primitive ValueNode
|
||||
auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->cast<ValueNodePtr>()->value()->ToString());
|
||||
if (sub_pair != equiv_ptr->end()) {
|
||||
return sub_pair->second;
|
||||
}
|
||||
return cur_raw_dst_node_;
|
||||
} else if (cur_raw_dst_node_->isa<CNode>()) {
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
auto inputs = cur_raw_dst_node_->cast<CNodePtr>()->inputs();
|
||||
for (auto sub_node = inputs.begin(); sub_node != inputs.end(); sub_node++) {
|
||||
auto subed = internal::BuildTarget(func_graph, *sub_node, equiv_ptr);
|
||||
new_inputs.push_back(subed);
|
||||
}
|
||||
return func_graph->NewCNode(new_inputs);
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Unexpected node type, got : " + internal::GetNodeRepr(cur_raw_dst_node_);
|
||||
}
|
||||
|
||||
bool isTraversable(const AnfNodePtr &node) {
|
||||
bool IsTraversable(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
@ -139,37 +65,92 @@ bool isTraversable(const AnfNodePtr &node) {
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
AnfNodePtr BuildPrimitive(const PatternPtr &pattern, const MatchResultPtr &res) {
|
||||
// Build up AnfNode from primitive
|
||||
auto prim_pattern = pattern->cast<IsPrimTypeOfPtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_pattern);
|
||||
PrimitivePyPtr prim = prim_pattern->matched_primitive();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
// Make value node out of primitives
|
||||
return std::make_shared<ValueNode>(prim);
|
||||
}
|
||||
|
||||
AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res) {
|
||||
// Build a ValueNode from TensorPtr
|
||||
auto new_tensor_pattern = pattern->cast<NewTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(new_tensor_pattern);
|
||||
auto input_tensor = new_tensor_pattern->input_tensor();
|
||||
MS_EXCEPTION_IF_NULL(input_tensor);
|
||||
return std::make_shared<ValueNode>(input_tensor);
|
||||
}
|
||||
|
||||
AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res) {
|
||||
auto call_with_pattern = pattern->cast<CallWithPtr>();
|
||||
MS_EXCEPTION_IF_NULL(call_with_pattern);
|
||||
auto prim = call_with_pattern->prim_value();
|
||||
if (prim != nullptr) {
|
||||
return std::make_shared<ValueNode>(prim);
|
||||
}
|
||||
auto prim_pattern = call_with_pattern->prim_pattern();
|
||||
MS_EXCEPTION_IF_NULL(prim_pattern);
|
||||
return ProcessSinglePattern(prim_pattern, res);
|
||||
}
|
||||
|
||||
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res) {
|
||||
if (pattern->should_replace()) {
|
||||
// Find replacement in the MatchResult
|
||||
auto target_node = res->get_node(pattern);
|
||||
if (target_node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find target node in pattern match result, pattern: " + pattern->unique_name() + "\n";
|
||||
}
|
||||
return target_node;
|
||||
}
|
||||
// Build up new node from pattern
|
||||
if (pattern->isa<IsPrimTypeOf>()) {
|
||||
return BuildPrimitive(pattern, res);
|
||||
} else if (pattern->isa<NewTensor>()) {
|
||||
return BuildNewTensor(pattern, res);
|
||||
} else if (pattern->isa<CallWith>()) {
|
||||
return BuildPrimitiveValueNode(pattern, res);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res) {
|
||||
auto target_inputs = pattern->inputs();
|
||||
if (target_inputs.size() == 0) {
|
||||
return ProcessSinglePattern(pattern, res);
|
||||
}
|
||||
// Build up the AnfNode in a recursive manner
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
auto prim_value_node = ProcessSinglePattern(pattern, res);
|
||||
MS_EXCEPTION_IF_NULL(prim_value_node);
|
||||
new_inputs.push_back(prim_value_node);
|
||||
for (auto &iter : target_inputs) {
|
||||
if (iter == pattern) {
|
||||
MS_LOG(EXCEPTION) << "Circle references: Pattern takes itself as input. Got pattern: " + pattern->unique_name() +
|
||||
"\n";
|
||||
}
|
||||
new_inputs.push_back(BuildTarget(iter, func_graph, res));
|
||||
}
|
||||
return func_graph->NewCNode(new_inputs);
|
||||
}
|
||||
} // namespace internal
|
||||
|
||||
void PythonPass::Build(const py::function &src, const py::function &dst) {
|
||||
// 1. get FuncGraph from py::function
|
||||
auto src_fg_ = parse::ParsePythonCode(src);
|
||||
auto dst_fg_ = parse::ParsePythonCode(dst);
|
||||
if (src_fg_ == nullptr || dst_fg_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to parse python code.\n";
|
||||
}
|
||||
// 2. Resolve
|
||||
internal::ResolveFuncGraph_(src_fg_);
|
||||
internal::ResolveFuncGraph_(dst_fg_);
|
||||
// 3. from FuncGraphPtr to ValueNode
|
||||
src_node_ = src_fg_->output();
|
||||
dst_node_ = dst_fg_->output();
|
||||
}
|
||||
|
||||
PythonPass::PythonPass(const std::string &name, const py::function &src, const py::function &dst, bool run_only_once,
|
||||
bool multigraph)
|
||||
: name_(name), run_only_once_(run_only_once), multigraph_(multigraph) {
|
||||
Build(src, dst);
|
||||
}
|
||||
|
||||
AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
auto equiv_ptr = std::make_shared<NodeEquiv>();
|
||||
bool is_a_match = internal::Match(src_node_, node, equiv_ptr);
|
||||
if (is_a_match) {
|
||||
auto new_node = internal::BuildTarget(func_graph, dst_node_, equiv_ptr);
|
||||
MS_EXCEPTION_IF_NULL(src_pattern_);
|
||||
MS_EXCEPTION_IF_NULL(dst_pattern_);
|
||||
auto res = src_pattern_->match(node);
|
||||
if (res != nullptr) {
|
||||
res->dump();
|
||||
MS_LOG(WARNING) << "Matched pattern: " + src_pattern_->unique_name();
|
||||
auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res);
|
||||
dst_pattern_->reset();
|
||||
MS_LOG(DEBUG) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n";
|
||||
return new_node;
|
||||
}
|
||||
src_pattern_->reset();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -188,14 +169,12 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph) {
|
|||
while (!todo.empty()) {
|
||||
AnfNodePtr node = todo.front();
|
||||
todo.pop_front();
|
||||
|
||||
// check whether this node has been matched.
|
||||
if (node == nullptr || node->seen_ == seen || !internal::isTraversable(node) || !all_nodes.contains(node)) {
|
||||
// Check whether this node has been matched.
|
||||
if (node == nullptr || node->seen_ == seen || !internal::IsTraversable(node) || !all_nodes.contains(node)) {
|
||||
continue;
|
||||
}
|
||||
node->seen_ = seen;
|
||||
|
||||
// select nodes that this transform can be applied.
|
||||
// Select nodes that this transform can be applied.
|
||||
AnfNodePtr new_node = Run(func_graph, node);
|
||||
bool change = (new_node != nullptr);
|
||||
if (new_node != nullptr && new_node != node) {
|
||||
|
@ -206,17 +185,14 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph) {
|
|||
if (run_only_once_) {
|
||||
return change;
|
||||
}
|
||||
|
||||
// find success, and add them to todo list
|
||||
// Find success, and add them to todo list
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
todo.push_back(GetValueNode<FuncGraphPtr>(node)->output());
|
||||
}
|
||||
|
||||
if (node->isa<CNode>()) {
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo));
|
||||
}
|
||||
|
||||
auto &node_users = manager->node_users();
|
||||
if (change && node_users.find(node) != node_users.end()) {
|
||||
for (auto &use : node_users[node]) {
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <unordered_map>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "frontend/optimizer/pattern.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "pybind_api/export_flags.h"
|
||||
|
||||
|
@ -33,17 +34,17 @@ using NodeEquivPtr = std::shared_ptr<NodeEquiv>;
|
|||
|
||||
class PythonPass {
|
||||
public:
|
||||
explicit PythonPass(const std::string &name, const py::function &src, const py::function &dst,
|
||||
bool run_only_once = false, bool multigraph = true);
|
||||
explicit PythonPass(const std::string &name, const PatternPtr &src, const PatternPtr &dst, bool run_only_once = false,
|
||||
bool multigraph = true)
|
||||
: src_pattern_(src), dst_pattern_(dst), name_(name), run_only_once_(run_only_once), multigraph_(multigraph) {}
|
||||
~PythonPass() = default;
|
||||
bool Run(const FuncGraphPtr &func_graph);
|
||||
std::string name() const { return name_; }
|
||||
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node);
|
||||
|
||||
private:
|
||||
void Build(const py::function &src, const py::function &dst);
|
||||
AnfNodePtr src_node_ = nullptr;
|
||||
AnfNodePtr dst_node_ = nullptr;
|
||||
PatternPtr src_pattern_;
|
||||
PatternPtr dst_pattern_;
|
||||
const std::string name_;
|
||||
bool run_only_once_;
|
||||
bool multigraph_ = true;
|
||||
|
|
|
@ -49,7 +49,7 @@ PyPassManager::PyPassManager() {
|
|||
phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>();
|
||||
}
|
||||
|
||||
void PyPassManager::Registe(const std::string &pass_name, const py::function &pattern, const py::function &target,
|
||||
void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
|
||||
Phase phase, bool run_only_once, bool multigraph) {
|
||||
auto cur_pm = GetPassGroup(phase);
|
||||
MS_EXCEPTION_IF_NULL(cur_pm);
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "common/utils.h"
|
||||
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
#include "frontend/optimizer/pattern.h"
|
||||
#include "frontend/optimizer/py_pass.h"
|
||||
#include "frontend/optimizer/pass_group.h"
|
||||
|
||||
|
@ -51,7 +52,7 @@ class PyPassManager {
|
|||
// Access the only global instance
|
||||
static PyPassManagerPtr GetInstance();
|
||||
virtual ~PyPassManager() = default;
|
||||
void Registe(const std::string &pass_name, const py::function &pattern, const py::function &target,
|
||||
void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
|
||||
Phase phase = Phase::RESOLVE, bool run_only_once = false, bool multigraph = true);
|
||||
void Unregiste(const std::string &pass_name, Phase phase);
|
||||
PassGroupPtr GetPassGroup(Phase phase);
|
||||
|
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Patterns for describing graphs"""
|
||||
from mindspore.ops import Primitive
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._c_expression import Pattern, IsIn_, IsPrimTypeOf_, CallWith_, IsNot_, AnyPattern, NewTensor_
|
||||
|
||||
__all__ = [
|
||||
"IsIn",
|
||||
"IsPrimTypeOf",
|
||||
"CallWith",
|
||||
"IsNot",
|
||||
"AnyPattern",
|
||||
"NewTensor",
|
||||
]
|
||||
|
||||
class IsIn(IsIn_):
|
||||
"""
|
||||
Express a pattern which allows a list of patterns.
|
||||
"""
|
||||
def __init__(self, patterns=None, should_replace=True):
|
||||
r"""
|
||||
Args:
|
||||
patterns(list/tuple): list of allowed patterns
|
||||
should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
|
||||
"""
|
||||
if not should_replace:
|
||||
raise ValueError("IsIn pattern does not have its own should_replace attribute. Set should_replace in \
|
||||
its sub-pattern instead.")
|
||||
self.patterns = patterns
|
||||
if patterns is None:
|
||||
IsIn_.__init__(self, ())
|
||||
elif isinstance(patterns, Pattern):
|
||||
IsIn_.__init__(self, [patterns])
|
||||
elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
|
||||
IsIn_.__init__(self, patterns)
|
||||
else:
|
||||
raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}")
|
||||
|
||||
class IsPrimTypeOf(IsPrimTypeOf_):
|
||||
r"""
|
||||
Express a pattern of certain primitive type(s).
|
||||
NOTE: This pattern will match and only match the primitive value node. If matching primitive CNode is needed,
|
||||
please refer to CallWith pattern.
|
||||
"""
|
||||
def __init__(self, types, name=None, should_replace=True):
|
||||
r"""
|
||||
Args:
|
||||
types (str/(list/tuple of Primitives)): Specify allowed types.
|
||||
If it is a string, the form could be
|
||||
1) a single primitive type, e.g. 'Conv2D'
|
||||
2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D'
|
||||
It can also be a list of Primitives, e.g. [ops.Conv2D(1, 6)]
|
||||
name (str): name of the pattern, optional
|
||||
should_replace
|
||||
"""
|
||||
if name is not None and not isinstance(name, str):
|
||||
raise TypeError(f"Expect string, got : {name}")
|
||||
self.name = name
|
||||
if isinstance(types, str):
|
||||
if self.name is None:
|
||||
self.name = types
|
||||
self.types = types.split('|')
|
||||
elif isinstance(types, Primitive):
|
||||
if self.name is None:
|
||||
self.name = types.name
|
||||
self.types = [types]
|
||||
elif isinstance(types, (tuple, list)) and all(isinstance(tp, Primitive) for tp in types):
|
||||
if self.name is None:
|
||||
self.name = ""
|
||||
for prim in types:
|
||||
self.name += prim.name
|
||||
self.types = types
|
||||
else:
|
||||
raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}")
|
||||
IsPrimTypeOf_.__init__(self, self.types, self.name, should_replace)
|
||||
|
||||
class CallWith(CallWith_):
|
||||
r"""
|
||||
Express a primitive CNode.
|
||||
"""
|
||||
def __init__(self, prim_pattern, inputs=None, should_replace=False):
|
||||
r"""
|
||||
Args:
|
||||
prim_pattern (Pattern/Primitive/str): Primitive ValueNode in the Primitive CNode.
|
||||
inputs (list/tuple): Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs;
|
||||
if specified, input patterns should be of right order.
|
||||
"""
|
||||
if not isinstance(prim_pattern, (Pattern, str, Primitive)):
|
||||
raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}")
|
||||
self.prim_pattern = prim_pattern
|
||||
self.inputs = []
|
||||
if inputs is None:
|
||||
pass
|
||||
elif isinstance(inputs, (tuple, list)) and all(isinstance(input, Pattern) for input in inputs):
|
||||
self.inputs = inputs
|
||||
else:
|
||||
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
|
||||
CallWith_.__init__(self, self.prim_pattern, self.inputs, should_replace)
|
||||
|
||||
|
||||
class IsNot(IsNot_):
|
||||
r"""
|
||||
Express a pattern which forbids a list of patterns.
|
||||
NOTE: IsNot pattern should not be the root pattern.
|
||||
"""
|
||||
def __init__(self, patterns=None, should_replace=True):
|
||||
r"""
|
||||
Args:
|
||||
patterns(list/tuple): list of forbiden patterns
|
||||
should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
|
||||
"""
|
||||
if not should_replace:
|
||||
raise ValueError("IsNot pattern does not have its own should_replace attribute. Set should_replace in \
|
||||
its sub-pattern instead.")
|
||||
self.patterns = patterns
|
||||
if patterns is None:
|
||||
IsNot_.__init__(self, ())
|
||||
elif isinstance(patterns, Pattern):
|
||||
IsNot_.__init__(self, [patterns])
|
||||
elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
|
||||
IsNot_.__init__(self, patterns)
|
||||
else:
|
||||
raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}")
|
||||
|
||||
class NewTensor(NewTensor_):
|
||||
r"""
|
||||
New Tensor to be used in the target.
|
||||
"""
|
||||
def __init__(self, input_tensor, should_replace=False):
|
||||
r"""
|
||||
Args:
|
||||
input_tensor(Tensor): new tensor to be used in the target
|
||||
should_replace(bool): added this for interface consistency. NewTensor should only appear in the target.
|
||||
"""
|
||||
if should_replace:
|
||||
raise ValueError("NewTensor should only appear in the target, thus should_replace can onlyu be False.")
|
||||
self.input_tensor = input_tensor
|
||||
if isinstance(input_tensor, Tensor):
|
||||
NewTensor_.__init__(self, input_tensor)
|
||||
else:
|
||||
raise TypeError(f"Expect input_tensor to be a Tensor, got : {input_tensor}")
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
"""Python pass register"""
|
||||
from inspect import isfunction
|
||||
from mindspore.common.graph_pattern import Pattern
|
||||
from mindspore._c_expression import PyPassManager_
|
||||
from mindspore._c_expression import phase
|
||||
|
||||
|
@ -46,10 +47,10 @@ class PyPassManager(PyPassManager_):
|
|||
raise TypeError(f"Expecting function pass, got : ({type(py_pass)}){py_pass}")
|
||||
pattern, target = py_pass()
|
||||
pass_name = py_pass.__name__
|
||||
if not isfunction(pattern):
|
||||
raise TypeError(f"Expecting function pattern, got : ({type(pattern)}){pattern}")
|
||||
if not isfunction(target):
|
||||
raise TypeError(f"Expecting function target, got : ({type(target)}){target}")
|
||||
if not isinstance(pattern, Pattern):
|
||||
raise TypeError(f"Expecting pattern of Pattern type, got : ({type(pattern)}){pattern}")
|
||||
if not isinstance(target, Pattern):
|
||||
raise TypeError(f"Expecting target of Pattern type, got : ({type(target)}){target}")
|
||||
super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_, self.multi_graph_)
|
||||
|
||||
def unregiste(self, py_pass, pipeline_phase=phase.opt):
|
||||
|
|
|
@ -22,10 +22,11 @@ from mindspore.ops import operations as P
|
|||
from mindspore.common.python_pass_register import registe_pass, PyPassManager
|
||||
from mindspore.common.api import _generate_pip_args
|
||||
from mindspore._c_expression import generate_key, Executor_
|
||||
from mindspore.common.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
def get_func_graph(obj, *args, phase="predict"):
|
||||
def get_func_graph(obj, *args, phase="validate"):
|
||||
args_names, args_list = _generate_pip_args(obj, *args)
|
||||
dic = dict(zip(args_names, args_list))
|
||||
key = generate_key(phase, dic)
|
||||
|
@ -47,14 +48,11 @@ def test_softmax_relu():
|
|||
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_relu_pass():
|
||||
softmax = P.Softmax()
|
||||
relu = P.ReLU()
|
||||
def pattern(x):
|
||||
x = softmax(x)
|
||||
return x
|
||||
def target(x):
|
||||
x = relu(x)
|
||||
return x
|
||||
x = AnyPattern()
|
||||
softmax_pattern = IsPrimTypeOf(P.Softmax())
|
||||
pattern = CallWith(softmax_pattern, inputs=[x])
|
||||
relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False)
|
||||
target = CallWith(relu_pattern, inputs=[x])
|
||||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||
|
@ -62,3 +60,128 @@ def test_softmax_relu():
|
|||
ppm.unregiste(softmax_relu_pass)
|
||||
assert "ReLU" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
|
||||
def test_isin_pattern():
|
||||
"""
|
||||
Test IsIn pattern which expresses the IsIn/OneOf semantics.
|
||||
"""
|
||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_relu_pass():
|
||||
x = AnyPattern()
|
||||
softmax_pattern = IsPrimTypeOf(P.Softmax())
|
||||
call_softmax = CallWith(softmax_pattern, inputs=[x])
|
||||
relu_pattern = IsPrimTypeOf(P.ReLU())
|
||||
call_relu = CallWith(relu_pattern, inputs=[x])
|
||||
|
||||
pattern = IsIn([call_softmax, call_relu])
|
||||
relu6_pattern = IsPrimTypeOf(P.ReLU6(), should_replace=False)
|
||||
target = CallWith(relu6_pattern, inputs=[x])
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||
ppm = PyPassManager()
|
||||
ppm.unregiste(softmax_relu_pass)
|
||||
assert "ReLU6" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
|
||||
def test_isnot_pattern_0():
|
||||
"""
|
||||
Test IsNot pattern which expresses the IsNot semantics.
|
||||
Case: IsNot pass failed to match
|
||||
"""
|
||||
class ConvBN(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ConvBN, self).__init__()
|
||||
self.conv = P.Conv2D(32, 3)
|
||||
self.conv_weight = Tensor(np.ones([32, 32, 3, 3]), mindspore.float32)
|
||||
self.scale = Tensor(np.ones([32]), mindspore.float32)
|
||||
self.bias = Tensor(np.ones([32]), mindspore.float32)
|
||||
self.mean = Tensor(np.ones([32]), mindspore.float32)
|
||||
self.variance = Tensor(np.ones([32]), mindspore.float32)
|
||||
self.bn = P.BatchNorm()
|
||||
def construct(self, x):
|
||||
x = self.conv(x, self.conv_weight)
|
||||
x = self.bn(x, self.scale, self.bias, self.mean, self.variance)
|
||||
return x
|
||||
inputs = Tensor(np.random.normal(0, 1, (10, 32, 32, 32)), mindspore.float32)
|
||||
conv_bn_model = ConvBN()
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
def single_bn_pass():
|
||||
"""
|
||||
Sub a BN which does NOT take Conv as inputs to ReLU6.
|
||||
"""
|
||||
conv2d_prim = IsPrimTypeOf("Conv2D")
|
||||
conv2d = CallWith(conv2d_prim)
|
||||
pattern_0 = IsNot(conv2d)
|
||||
pattern = CallWith(P.BatchNorm(), inputs=[pattern_0])
|
||||
target = CallWith(P.ReLU6(), inputs=[pattern_0])
|
||||
return pattern, target
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
def bn_pass():
|
||||
"""
|
||||
Sub a BN to Softmax.
|
||||
"""
|
||||
bn = P.BatchNorm()
|
||||
pattern = CallWith(bn)
|
||||
softmax = P.Softmax()
|
||||
target = CallWith(softmax, should_replace=False)
|
||||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5)
|
||||
ppm = PyPassManager()
|
||||
ppm.unregiste(single_bn_pass)
|
||||
ppm.unregiste(bn_pass)
|
||||
assert "ReLU6" not in transformed_repr
|
||||
assert "Softmax" in transformed_repr
|
||||
|
||||
def test_isnot_pattern_1():
|
||||
"""
|
||||
Test IsNot pattern which expresses the IsNot semantics.
|
||||
Case: IsNot pattern matches with the graph
|
||||
"""
|
||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
def single_bn_pass():
|
||||
"""
|
||||
Sub a BN which does NOT take MatMul as inputs to ReLU6.
|
||||
"""
|
||||
matmul = IsPrimTypeOf("MatMul")
|
||||
pattern_0 = IsNot(matmul)
|
||||
softmax = P.Softmax()
|
||||
pattern = CallWith(softmax, inputs=[pattern_0])
|
||||
relu6 = P.ReLU6()
|
||||
target = CallWith(relu6, inputs=[pattern_0], should_replace=False)
|
||||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||
ppm = PyPassManager()
|
||||
ppm.unregiste(single_bn_pass)
|
||||
assert "ReLU6" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
|
||||
def test_newtensor_pattern():
|
||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||
softmax_model = nn.Softmax()
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_addn_pass():
|
||||
x = AnyPattern()
|
||||
softmax = P.Softmax()
|
||||
pattern = CallWith(softmax, inputs=[x])
|
||||
|
||||
weight_tensor = Tensor(np.zeros([42]), mindspore.float16)
|
||||
new_weight = NewTensor(weight_tensor)
|
||||
addn_ops = P.AddN()
|
||||
target = CallWith(addn_ops, inputs=[x, new_weight], should_replace=False)
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||
ppm = PyPassManager()
|
||||
ppm.unregiste(softmax_addn_pass)
|
||||
assert "AddN" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
|
|
Loading…
Reference in New Issue