python pass pattern renaming and interface tweaking

This commit is contained in:
BowenK 2020-08-25 21:34:16 +08:00
parent 3a16925fa2
commit 641d12d6d9
9 changed files with 234 additions and 285 deletions

View File

@ -21,25 +21,23 @@ namespace opt {
namespace python_pass { namespace python_pass {
int Pattern::g_id_ = 0; int Pattern::g_id_ = 0;
MatchResultPtr IsPrimTypeOf::match(const AnfNodePtr &node) { MatchResultPtr Prim::match(const AnfNodePtr &node) {
if (!IsValueNode<Primitive>(node)) { if (!IsValueNode<Primitive>(node)) {
return nullptr; return nullptr;
} }
MatchResultPtr res = std::make_shared<MatchResult>(); MatchResultPtr res = std::make_shared<MatchResult>();
if (IsValueNode<Primitive>(node)) {
// iterate over all primitives // iterate over all primitives
for (auto &iter : primitives_) { for (auto &iter : primitives_) {
if (IsPrimitive(node, iter) || iter->name() == "*") { if (IsPrimitive(node, iter) || iter->name() == "*") {
matched_prim_ = iter; matched_prim_ = iter;
res->add_entry(shared_from_base<IsPrimTypeOf>(), node); res->add_entry(shared_from_base<Prim>(), node);
return res; return res;
} }
} }
}
return nullptr; return nullptr;
} }
MatchResultPtr CallWith::match(const AnfNodePtr &node) { MatchResultPtr Call::match(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node)) { if (!IsPrimitiveCNode(node)) {
return nullptr; return nullptr;
} }
@ -71,7 +69,7 @@ MatchResultPtr CallWith::match(const AnfNodePtr &node) {
} }
// If inputs is not specified, add node without looking into its inputs // If inputs is not specified, add node without looking into its inputs
if (p_inputs_size == 0) { if (p_inputs_size == 0) {
res->add_entry(shared_from_base<CallWith>(), cnode->input(0)); res->add_entry(shared_from_base<Call>(), cnode->input(0));
return res; return res;
} }
bool failed = false; bool failed = false;
@ -86,24 +84,24 @@ MatchResultPtr CallWith::match(const AnfNodePtr &node) {
res->merge(input_match_result); res->merge(input_match_result);
} }
if (!failed) { if (!failed) {
res->add_entry(shared_from_base<CallWith>(), cnode->input(0)); res->add_entry(shared_from_base<Call>(), cnode->input(0));
return res; return res;
} }
return nullptr; return nullptr;
} }
MatchResultPtr IsIn::match(const AnfNodePtr &node) { MatchResultPtr OneOf::match(const AnfNodePtr &node) {
for (auto &iter : patterns_) { for (auto &iter : patterns_) {
auto res = iter->match(node); auto res = iter->match(node);
if (res != nullptr) { if (res != nullptr) {
res->add_entry(shared_from_base<IsIn>(), node); res->add_entry(shared_from_base<OneOf>(), node);
return res; return res;
} }
} }
return nullptr; return nullptr;
} }
MatchResultPtr IsNot::match(const AnfNodePtr &node) { MatchResultPtr NoneOf::match(const AnfNodePtr &node) {
for (auto &iter : patterns_) { for (auto &iter : patterns_) {
auto res = iter->match(node); auto res = iter->match(node);
if (res != nullptr) { if (res != nullptr) {
@ -111,16 +109,33 @@ MatchResultPtr IsNot::match(const AnfNodePtr &node) {
} }
} }
auto res = std::make_shared<MatchResult>(); auto res = std::make_shared<MatchResult>();
res->add_entry(shared_from_base<IsNot>(), node); res->add_entry(shared_from_base<NoneOf>(), node);
return res; return res;
} }
MatchResultPtr AnyPattern::match(const AnfNodePtr &node) { MatchResultPtr Any::match(const AnfNodePtr &node) {
MatchResultPtr res = std::make_shared<MatchResult>(); MatchResultPtr res = std::make_shared<MatchResult>();
res->add_entry(shared_from_base<AnyPattern>(), node); res->add_entry(shared_from_base<Any>(), node);
return res; return res;
} }
MatchResultPtr Imm::match(const AnfNodePtr &node) {
if (!IsValueNode<Int32Imm>(node)) {
return nullptr;
}
// Check value
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value_ptr = value_node->value()->cast<Int32ImmPtr>();
MS_EXCEPTION_IF_NULL(value_ptr);
if ((int32_t)value_ptr->value() == value_) {
MatchResultPtr res = std::make_shared<MatchResult>();
res->add_entry(shared_from_base<Imm>(), node);
return res;
}
return nullptr;
}
AnfNodePtr MatchResult::get_node(const PatternPtr &pattern) { AnfNodePtr MatchResult::get_node(const PatternPtr &pattern) {
auto entry = match_result_.find(pattern); auto entry = match_result_.find(pattern);
if (entry == match_result_.end()) { if (entry == match_result_.end()) {
@ -140,20 +155,20 @@ void MatchResult::merge(const MatchResultPtr &other_result) {
REGISTER_PYBIND_DEFINE( REGISTER_PYBIND_DEFINE(
Pattern, ([](const py::module *m) { Pattern, ([](const py::module *m) {
(void)py::class_<Pattern, std::shared_ptr<Pattern>>(*m, "Pattern").def(py::init<>()); (void)py::class_<Pattern, std::shared_ptr<Pattern>>(*m, "Pattern").def(py::init<>());
(void)py::class_<IsIn, std::shared_ptr<IsIn>, Pattern>(*m, "IsIn_").def(py::init<vector<PatternPtr>>()); (void)py::class_<OneOf, std::shared_ptr<OneOf>, Pattern>(*m, "OneOf_").def(py::init<vector<PatternPtr>>());
(void)py::class_<IsPrimTypeOf, std::shared_ptr<IsPrimTypeOf>, Pattern>(*m, "IsPrimTypeOf_", py::dynamic_attr()) (void)py::class_<Prim, std::shared_ptr<Prim>, Pattern>(*m, "Prim_", py::dynamic_attr())
.def(py::init<vector<PrimitivePyPtr>, string, bool>()) .def(py::init<vector<PrimitivePyPtr>, string>())
.def(py::init<vector<string>, string, bool>()); .def(py::init<vector<string>, string>());
(void)py::class_<CallWith, std::shared_ptr<CallWith>, Pattern>(*m, "CallWith_") (void)py::class_<Call, std::shared_ptr<Call>, Pattern>(*m, "Call_")
.def(py::init<PatternPtr, vector<PatternPtr>, bool>()) .def(py::init<PatternPtr, vector<PatternPtr>>())
.def(py::init<PrimitivePyPtr, vector<PatternPtr>, bool>()) .def(py::init<PrimitivePyPtr, vector<PatternPtr>>())
.def(py::init<string, vector<PatternPtr>, bool>()); .def(py::init<string, vector<PatternPtr>>());
(void)py::class_<IsNot, std::shared_ptr<IsNot>, Pattern>(*m, "IsNot_").def(py::init<vector<PatternPtr>>()); (void)py::class_<NoneOf, std::shared_ptr<NoneOf>, Pattern>(*m, "NoneOf_").def(py::init<vector<PatternPtr>>());
(void)py::class_<AnyPattern, std::shared_ptr<AnyPattern>, Pattern>(*m, "AnyPattern").def(py::init<>()); (void)py::class_<Any, std::shared_ptr<Any>, Pattern>(*m, "Any").def(py::init<>());
(void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_") (void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_")
.def(py::init<tensor::TensorPtr>()); .def(py::init<tensor::TensorPtr>());
(void)py::class_<NewParameter, std::shared_ptr<NewParameter>, Pattern>(*m, "NewParameter_") (void)py::class_<NewParameter, std::shared_ptr<NewParameter>, Pattern>(*m, "NewParameter_")
.def(py::init<string, tensor::TensorPtr, bool, bool, bool>()); .def(py::init<string, tensor::TensorPtr, bool, bool>());
(void)py::class_<Imm, std::shared_ptr<Imm>, Pattern>(*m, "Imm").def(py::init<int>()); (void)py::class_<Imm, std::shared_ptr<Imm>, Pattern>(*m, "Imm").def(py::init<int>());
})); }));
} // namespace python_pass } // namespace python_pass

View File

@ -36,10 +36,10 @@ class MatchResult;
using MatchResultPtr = std::shared_ptr<MatchResult>; using MatchResultPtr = std::shared_ptr<MatchResult>;
class Pattern; class Pattern;
using PatternPtr = std::shared_ptr<Pattern>; using PatternPtr = std::shared_ptr<Pattern>;
class IsPrimTypeOf; class Prim;
using IsPrimTypeOfPtr = std::shared_ptr<IsPrimTypeOf>; using PrimPtr = std::shared_ptr<Prim>;
class CallWith; class Call;
using CallWithPtr = std::shared_ptr<CallWith>; using CallPtr = std::shared_ptr<Call>;
class NewTensor; class NewTensor;
using NewTensorPtr = std::shared_ptr<NewTensor>; using NewTensorPtr = std::shared_ptr<NewTensor>;
class NewParameter; class NewParameter;
@ -58,8 +58,6 @@ class Pattern : public Base {
virtual bool operator==(const Pattern &other) const { return unique_name_ == other.unique_name_; } virtual bool operator==(const Pattern &other) const { return unique_name_ == other.unique_name_; }
string unique_name() const { return unique_name_; } string unique_name() const { return unique_name_; }
vector<PatternPtr> inputs() { return inputs_; } vector<PatternPtr> inputs() { return inputs_; }
bool should_replace() { return should_replace_; }
void set_should_replace(bool should_replace) { should_replace_ = should_replace; }
virtual void reset() {} virtual void reset() {}
protected: protected:
@ -67,7 +65,6 @@ class Pattern : public Base {
// NOTE: To ensure uniqueness of the name, raise g_id_ by 1 every time a pattern got constructed // NOTE: To ensure uniqueness of the name, raise g_id_ by 1 every time a pattern got constructed
string unique_name_; string unique_name_;
vector<PatternPtr> inputs_; vector<PatternPtr> inputs_;
bool should_replace_ = true;
}; };
struct PatternEqual { struct PatternEqual {
@ -85,70 +82,61 @@ struct PatternHasher {
} }
}; };
class IsPrimTypeOf : public Pattern { class Prim : public Pattern {
public: public:
IsPrimTypeOf() { unique_name_ = std::to_string(g_id_++); } Prim() { unique_name_ = std::to_string(g_id_++); }
~IsPrimTypeOf() = default; ~Prim() = default;
IsPrimTypeOf(vector<PrimitivePyPtr> prims, string name, bool should_replace) Prim(vector<PrimitivePyPtr> prims, string name) : primitives_(prims), name_(name) {
: primitives_(prims), name_(name), matched_prim_(nullptr) { unique_name_ = std::to_string(g_id_++) + "Prim_" + name;
unique_name_ = std::to_string(g_id_++) + "IsPrimTypeOf_" + name; // Default using the first prim to build target
should_replace_ = should_replace; matched_prim_ = primitives_[0];
if (!should_replace) {
matched_prim_ = prims[0];
} }
} Prim(vector<string> types, string name) : types_(types), name_(name) {
IsPrimTypeOf(vector<string> types, string name, bool should_replace) : types_(types), name_(name) { unique_name_ = std::to_string(g_id_++) + "Prim_" + name;
unique_name_ = std::to_string(g_id_++) + "IsPrimTypeOf_" + name;
// Make primitives_ // Make primitives_
for (auto &iter : types) { for (auto &iter : types) {
primitives_.push_back(std::make_shared<PrimitivePy>(iter, py::cast(nullptr))); primitives_.push_back(std::make_shared<PrimitivePy>(iter, py::cast(nullptr)));
} }
should_replace_ = should_replace; // Default using the first prim to build target
if (!should_replace) {
matched_prim_ = primitives_[0]; matched_prim_ = primitives_[0];
} }
} MS_DECLARE_PARENT(Prim, Pattern);
MS_DECLARE_PARENT(IsPrimTypeOf, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override; MatchResultPtr match(const AnfNodePtr &node) override;
PrimitivePyPtr matched_primitive() { return matched_prim_; } PrimitivePyPtr matched_primitive() { return matched_prim_; }
void reset() override { void reset() override {
if (should_replace_) { // Init before reset
matched_prim_ = nullptr; MS_EXCEPTION_IF_NULL(matched_prim_);
} matched_prim_ = primitives_[0];
} }
private: private:
vector<string> types_; vector<string> types_;
vector<PrimitivePyPtr> primitives_; vector<PrimitivePyPtr> primitives_;
string name_; string name_;
PrimitivePyPtr matched_prim_; PrimitivePyPtr matched_prim_{nullptr};
}; };
class CallWith : public Pattern { class Call : public Pattern {
public: public:
CallWith() { unique_name_ = std::to_string(g_id_++); } Call() { unique_name_ = std::to_string(g_id_++); }
~CallWith() = default; ~Call() = default;
CallWith(PatternPtr prim_pattern, vector<PatternPtr> inputs, bool should_replace) { Call(PatternPtr prim_pattern, vector<PatternPtr> inputs) {
// NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting // NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting
prim_pattern_ = prim_pattern; prim_pattern_ = prim_pattern;
unique_name_ = std::to_string(g_id_++) + "CallWithPattern_" + prim_pattern->unique_name(); unique_name_ = std::to_string(g_id_++) + "Call_" + prim_pattern->unique_name();
inputs_ = inputs; inputs_ = inputs;
// NOTE: should_replace_ is overrided by it prim_pattern(if exists) silently.
should_replace_ = prim_pattern->should_replace();
} }
CallWith(PrimitivePyPtr prim, vector<PatternPtr> inputs, bool should_replace) { Call(PrimitivePyPtr prim, vector<PatternPtr> inputs) {
prim_ = prim; prim_ = prim;
unique_name_ = std::to_string(g_id_++) + "CallWithPrim_" + prim_->ToString(); unique_name_ = std::to_string(g_id_++) + "Call_" + prim_->ToString();
inputs_ = inputs; inputs_ = inputs;
should_replace_ = should_replace;
} }
CallWith(string prim_str, vector<PatternPtr> inputs, bool should_replace) { Call(string prim_str, vector<PatternPtr> inputs) {
prim_ = std::make_shared<PrimitivePy>(prim_str, py::cast(nullptr)); prim_ = std::make_shared<PrimitivePy>(prim_str, py::cast(nullptr));
unique_name_ = std::to_string(g_id_++) + "CallWithStr_" + prim_->ToString(); unique_name_ = std::to_string(g_id_++) + "CallStr_" + prim_->ToString();
inputs_ = inputs; inputs_ = inputs;
should_replace_ = should_replace;
} }
MS_DECLARE_PARENT(CallWith, Pattern); MS_DECLARE_PARENT(Call, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override; MatchResultPtr match(const AnfNodePtr &node) override;
PrimitivePtr prim_value() { return prim_; } PrimitivePtr prim_value() { return prim_; }
PatternPtr prim_pattern() { return prim_pattern_; } PatternPtr prim_pattern() { return prim_pattern_; }
@ -160,45 +148,45 @@ class CallWith : public Pattern {
string name_; string name_;
}; };
class IsIn : public Pattern { class OneOf : public Pattern {
public: public:
IsIn() { unique_name_ = std::to_string(g_id_++); } OneOf() { unique_name_ = std::to_string(g_id_++); }
~IsIn() = default; ~OneOf() = default;
explicit IsIn(vector<PatternPtr> patterns) : patterns_(patterns) { explicit OneOf(vector<PatternPtr> patterns) : patterns_(patterns) {
unique_name_ = std::to_string(g_id_++) + "IsIn"; unique_name_ = std::to_string(g_id_++) + "OneOf";
for (auto &iter : patterns) { for (auto &iter : patterns) {
unique_name_ = unique_name_ + "_" + iter->unique_name(); unique_name_ = unique_name_ + "_" + iter->unique_name();
} }
} }
MS_DECLARE_PARENT(IsIn, Pattern); MS_DECLARE_PARENT(OneOf, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override; MatchResultPtr match(const AnfNodePtr &node) override;
private: private:
vector<PatternPtr> patterns_; vector<PatternPtr> patterns_;
}; };
class IsNot : public Pattern { class NoneOf : public Pattern {
public: public:
IsNot() { unique_name_ = std::to_string(g_id_++); } NoneOf() { unique_name_ = std::to_string(g_id_++); }
~IsNot() = default; ~NoneOf() = default;
explicit IsNot(vector<PatternPtr> patterns) : patterns_(patterns) { explicit NoneOf(vector<PatternPtr> patterns) : patterns_(patterns) {
unique_name_ = std::to_string(g_id_++) + "IsNot"; unique_name_ = std::to_string(g_id_++) + "NoneOf";
for (auto &iter : patterns) { for (auto &iter : patterns) {
unique_name_ = unique_name_ + "_" + iter->unique_name(); unique_name_ = unique_name_ + "_" + iter->unique_name();
} }
} }
MS_DECLARE_PARENT(IsNot, Pattern); MS_DECLARE_PARENT(NoneOf, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override; MatchResultPtr match(const AnfNodePtr &node) override;
private: private:
vector<PatternPtr> patterns_; vector<PatternPtr> patterns_;
}; };
class AnyPattern : public Pattern { class Any : public Pattern {
public: public:
AnyPattern() { unique_name_ = std::to_string(g_id_++) + "_AnyPattern"; } Any() { unique_name_ = std::to_string(g_id_++) + "_Any"; }
~AnyPattern() = default; ~Any() = default;
MS_DECLARE_PARENT(AnyPattern, Pattern); MS_DECLARE_PARENT(Any, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override; MatchResultPtr match(const AnfNodePtr &node) override;
}; };
@ -207,7 +195,6 @@ class NewTensor : public Pattern {
NewTensor() { unique_name_ = std::to_string(g_id_++); } NewTensor() { unique_name_ = std::to_string(g_id_++); }
~NewTensor() = default; ~NewTensor() = default;
explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) { explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) {
should_replace_ = false;
unique_name_ = std::to_string(g_id_++) + "NewTensor"; unique_name_ = std::to_string(g_id_++) + "NewTensor";
} }
MS_DECLARE_PARENT(NewTensor, Pattern); MS_DECLARE_PARENT(NewTensor, Pattern);
@ -223,10 +210,8 @@ class NewTensor : public Pattern {
class NewParameter : public Pattern { class NewParameter : public Pattern {
public: public:
NewParameter() { unique_name_ = std::to_string(g_id_++); } NewParameter() { unique_name_ = std::to_string(g_id_++); }
explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel, explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel)
bool should_replace)
: para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) { : para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) {
should_replace_ = should_replace;
unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name; unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name;
// clone input tensor // clone input tensor
default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get()); default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get());
@ -243,11 +228,14 @@ class NewParameter : public Pattern {
bool built() { return built_; } bool built() { return built_; }
void set_built(bool built) { built_ = built; } void set_built(bool built) { built_ = built; }
void reset() override { built_ = false; } void reset() override { built_ = false; }
bool should_last() { return last_across_passes_; }
void set_last(bool last) { last_across_passes_ = last; }
private: private:
string para_name_; string para_name_;
bool requires_grad_; bool requires_grad_;
bool layerwise_parallel_; bool layerwise_parallel_;
bool last_across_passes_{false};
bool built_; bool built_;
tensor::TensorPtr default_tensor_; tensor::TensorPtr default_tensor_;
}; };
@ -255,13 +243,9 @@ class NewParameter : public Pattern {
class Imm : public Pattern { class Imm : public Pattern {
public: public:
Imm() { unique_name_ = std::to_string(g_id_++); } Imm() { unique_name_ = std::to_string(g_id_++); }
explicit Imm(int value) : value_(value) { explicit Imm(int value) : value_(value) { unique_name_ = std::to_string(g_id_++) + "Imm_" + std::to_string(value); }
should_replace_ = false;
unique_name_ = std::to_string(g_id_++) + "Imm_" + std::to_string(value);
}
MS_DECLARE_PARENT(Imm, Pattern); MS_DECLARE_PARENT(Imm, Pattern);
// NOTE: Doesn't support Imm in src pattern currently. MatchResultPtr match(const AnfNodePtr &node) override;
MatchResultPtr match(const AnfNodePtr &node) override { return nullptr; }
int value() { return value_; } int value() { return value_; }
private: private:

View File

@ -80,7 +80,7 @@ bool IsTraversable(const AnfNodePtr &node) {
AnfNodePtr BuildPrimitive(const PatternPtr &pattern, const MatchResultPtr &res) { AnfNodePtr BuildPrimitive(const PatternPtr &pattern, const MatchResultPtr &res) {
// Build up AnfNode from primitive // Build up AnfNode from primitive
auto prim_pattern = pattern->cast<IsPrimTypeOfPtr>(); auto prim_pattern = pattern->cast<PrimPtr>();
MS_EXCEPTION_IF_NULL(prim_pattern); MS_EXCEPTION_IF_NULL(prim_pattern);
PrimitivePyPtr prim = prim_pattern->matched_primitive(); PrimitivePyPtr prim = prim_pattern->matched_primitive();
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
@ -98,13 +98,13 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res)
} }
AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg) { AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg) {
auto call_with_pattern = pattern->cast<CallWithPtr>(); auto call_pattern = pattern->cast<CallPtr>();
MS_EXCEPTION_IF_NULL(call_with_pattern); MS_EXCEPTION_IF_NULL(call_pattern);
auto prim = call_with_pattern->prim_value(); auto prim = call_pattern->prim_value();
if (prim != nullptr) { if (prim != nullptr) {
return std::make_shared<ValueNode>(prim); return std::make_shared<ValueNode>(prim);
} }
auto prim_pattern = call_with_pattern->prim_pattern(); auto prim_pattern = call_pattern->prim_pattern();
MS_EXCEPTION_IF_NULL(prim_pattern); MS_EXCEPTION_IF_NULL(prim_pattern);
return ProcessSinglePattern(prim_pattern, res, fg); return ProcessSinglePattern(prim_pattern, res, fg);
} }
@ -152,45 +152,35 @@ AnfNodePtr BuildImmNode(const PatternPtr &pattern, const MatchResultPtr &res) {
} }
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) { AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) {
if (pattern->should_replace()) {
// Find replacement in the MatchResult
auto target_node = res->get_node(pattern); auto target_node = res->get_node(pattern);
if (target_node == nullptr) { if (target_node != nullptr) {
// If it's base pattern(in contrast to complex pattern like CallWith/IsIn/IsNot), raise runtime exception. // If pattern is NewParameter, check whether it shouldn't last and is not built
if (pattern->isa<IsPrimTypeOf>() || pattern->isa<NewTensor>() || pattern->isa<NewParameter>()) { auto new_para = pattern->cast<NewParameterPtr>();
MS_LOG(EXCEPTION) << "Cannot find target node, pattern: " + pattern->unique_name() + "\n"; if (new_para == nullptr || new_para->should_last() || new_para->built()) {
return nullptr;
}
// Try to build this pattern and add to MatchResult, since this pattern is defined inside target
auto new_node = BuildTarget(pattern, func_graph, res);
if (new_node == nullptr) {
MS_LOG(EXCEPTION) << "Try to build pattern node but FAILED. pattern: " + pattern->unique_name() + "\n";
}
return new_node;
}
if (pattern->isa<NewParameter>()) {
return target_node; return target_node;
} }
return target_node;
} }
// Build up new node from pattern // Build up new node from pattern
if (pattern->isa<IsPrimTypeOf>()) { if (pattern->isa<Prim>()) {
return BuildPrimitive(pattern, res); return BuildPrimitive(pattern, res);
} else if (pattern->isa<NewTensor>()) { } else if (pattern->isa<NewTensor>()) {
return BuildNewTensor(pattern, res); return BuildNewTensor(pattern, res);
} else if (pattern->isa<CallWith>()) { } else if (pattern->isa<Call>()) {
return BuildPrimitiveValueNode(pattern, res, func_graph); return BuildPrimitiveValueNode(pattern, res, func_graph);
} else if (pattern->isa<NewParameter>()) { } else if (pattern->isa<NewParameter>()) {
return BuildNewParameter(pattern, res, func_graph); return BuildNewParameter(pattern, res, func_graph);
} else if (pattern->isa<Imm>()) { } else if (pattern->isa<Imm>()) {
return BuildImmNode(pattern, res); return BuildImmNode(pattern, res);
} else {
MS_LOG(EXCEPTION) << "Cannot find or build target node, pattern: " + pattern->unique_name() + "\n";
return nullptr;
} }
return nullptr; return nullptr;
} }
AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res, AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res,
const FuncGraphPtr &func_graph) { const FuncGraphPtr &func_graph) {
if (pattern->isa<CallWith>()) { if (pattern->isa<Call>()) {
return BuildPrimitiveValueNode(pattern, res, func_graph); return BuildPrimitiveValueNode(pattern, res, func_graph);
} }
return nullptr; return nullptr;
@ -269,16 +259,16 @@ void ReflectParamBackToPython(const AnfNodePtr &param, string param_name, tensor
} }
void Reset(PatternPtr pattern) { void Reset(PatternPtr pattern) {
if (pattern->isa<IsPrimTypeOf>()) { if (pattern->isa<Prim>()) {
auto prim_pattern = pattern->cast<IsPrimTypeOfPtr>(); auto prim_pattern = pattern->cast<PrimPtr>();
prim_pattern->reset(); prim_pattern->reset();
return; return;
} else if (pattern->isa<NewParameter>()) { } else if (pattern->isa<NewParameter>()) {
auto new_param_pattern = pattern->cast<NewParameterPtr>(); auto new_param_pattern = pattern->cast<NewParameterPtr>();
new_param_pattern->reset(); new_param_pattern->reset();
return; return;
} else if (pattern->isa<CallWith>()) { } else if (pattern->isa<Call>()) {
auto call_with_pattern = pattern->cast<CallWithPtr>(); auto call_with_pattern = pattern->cast<CallPtr>();
for (auto sub_pattern : call_with_pattern->inputs()) { for (auto sub_pattern : call_with_pattern->inputs()) {
Reset(sub_pattern); Reset(sub_pattern);
} }

View File

@ -49,8 +49,9 @@ PyPassManager::PyPassManager() {
} }
void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
Phase phase, bool run_only_once) { bool run_only_once) {
auto cur_pg = GetPassGroup(phase); // NOTE: remove phase option to avoid unnecessary confusion.
auto cur_pg = GetPassGroup(Phase::OPT);
MS_EXCEPTION_IF_NULL(cur_pg); MS_EXCEPTION_IF_NULL(cur_pg);
cur_pg->SetRunOnlyOnce(run_only_once); cur_pg->SetRunOnlyOnce(run_only_once);
MS_EXCEPTION_IF_NULL(pattern); MS_EXCEPTION_IF_NULL(pattern);
@ -60,8 +61,9 @@ void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &patt
cur_pg->AddPass(new_pass); cur_pg->AddPass(new_pass);
} }
void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) { void PyPassManager::Unregiste(const std::string &pass_name) {
auto cur_pm = GetPassGroup(phase); // NOTE: remove phase option to avoid unnecessary confusion.
auto cur_pm = GetPassGroup(Phase::OPT);
MS_EXCEPTION_IF_NULL(cur_pm); MS_EXCEPTION_IF_NULL(cur_pm);
if (!cur_pm->DeletePass(pass_name)) { if (!cur_pm->DeletePass(pass_name)) {
MS_LOG(WARNING) << "No such pass : " + pass_name + "\n"; MS_LOG(WARNING) << "No such pass : " + pass_name + "\n";
@ -70,7 +72,6 @@ void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) {
void PyPassManager::GenNewParameter(const PatternPtr &parameter) { void PyPassManager::GenNewParameter(const PatternPtr &parameter) {
MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(parameter);
// Add new parameter after resolve
// NOTE: Add NewParameter at early stage will cause CSE problems // NOTE: Add NewParameter at early stage will cause CSE problems
auto cur_pg = GetPassGroup(Phase::OPT); auto cur_pg = GetPassGroup(Phase::OPT);
MS_EXCEPTION_IF_NULL(cur_pg); MS_EXCEPTION_IF_NULL(cur_pg);
@ -78,7 +79,7 @@ void PyPassManager::GenNewParameter(const PatternPtr &parameter) {
auto new_para_pattern = parameter->cast<NewParameterPtr>(); auto new_para_pattern = parameter->cast<NewParameterPtr>();
MS_EXCEPTION_IF_NULL(new_para_pattern); MS_EXCEPTION_IF_NULL(new_para_pattern);
auto pass_name = new_para_pattern->para_name(); auto pass_name = new_para_pattern->para_name();
parameter->set_should_replace(false); new_para_pattern->set_last(true);
auto new_pass = std::make_shared<PythonPass>(pass_name, nullptr, parameter, true); auto new_pass = std::make_shared<PythonPass>(pass_name, nullptr, parameter, true);
cur_pg->AddPass(new_pass); cur_pg->AddPass(new_pass);
} }

View File

@ -53,16 +53,17 @@ class PyPassManager {
static PyPassManagerPtr GetInstance(); static PyPassManagerPtr GetInstance();
virtual ~PyPassManager() = default; virtual ~PyPassManager() = default;
void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
Phase phase = Phase::RESOLVE, bool run_only_once = false); bool run_only_once = false);
void Unregiste(const std::string &pass_name, Phase phase); void Unregiste(const std::string &pass_name);
void GenNewParameter(const PatternPtr &parameter); void GenNewParameter(const PatternPtr &parameter);
PassGroupPtr GetPassGroup(Phase phase); PassGroupPtr GetPassGroup(Phase phase);
void ClearRes();
MatchResultPtr GetMatchResult() { return res_; } MatchResultPtr GetMatchResult() { return res_; }
void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; } void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; }
bool ShouldRenorm() { return should_renorm_; } bool ShouldRenorm() { return should_renorm_; }
void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; } void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; }
pipeline::ResourcePtr GetResource() { return resource_; } pipeline::ResourcePtr GetResource() { return resource_; }
void ClearRes();
void ClearPipelineRes() { resource_ = nullptr; }
private: private:
bool should_renorm_ = true; bool should_renorm_ = true;

View File

@ -477,6 +477,7 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
// save the run graph func to MsPipeLine // save the run graph func to MsPipeLine
SaveCompiledGraph(phase_s); SaveCompiledGraph(phase_s);
opt::python_pass::PyPassManager::GetInstance()->ClearPipelineRes();
resource->Clean(); resource->Clean();
// Reclaim all resource used by optimizer; // Reclaim all resource used by optimizer;
ReclaimOptimizer(); ReclaimOptimizer();

View File

@ -15,50 +15,43 @@
"""Patterns for describing graphs""" """Patterns for describing graphs"""
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._c_expression import Pattern, IsIn_, IsPrimTypeOf_, CallWith_, IsNot_, AnyPattern, NewTensor_,\ from mindspore._c_expression import Pattern, OneOf_, Prim_, Call_, NoneOf_, Any, NewTensor_, NewParameter_, Imm
NewParameter_, Imm
__all__ = [ __all__ = [
"IsIn", "OneOf",
"IsPrimTypeOf", "Prim",
"CallWith", "Call",
"IsNot", "NoneOf",
"AnyPattern", "Any",
"NewTensor", "NewTensor",
"NewParameter", "NewParameter",
"Imm" "Imm"
] ]
class IsIn(IsIn_): class OneOf(OneOf_):
r""" r"""
Express a pattern which allows a list of patterns. Express a pattern which allows a list of patterns.
""" """
def __init__(self, patterns=None, should_replace=True): def __init__(self, patterns=None):
r""" r"""
Args: Args:
patterns(Union[tuple[:class:`mindspore.graph_utils.graph_pattern`], patterns(Union[:class:`mindspore.graph_utils.graph_pattern`,
tuple[:class:`mindspore.graph_utils.graph_pattern`],
list[:class:`mindspore.graph_utils.graph_pattern`]]): list of allowed patterns, list[:class:`mindspore.graph_utils.graph_pattern`]]): list of allowed patterns,
each element should be one of the exposed Pattern instance. each element should be one of the exposed Pattern instance.
should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
Raises: Raises:
ValueError: raise if should_replace is False
TypeError: raise type error for invalid inputs. TypeError: raise type error for invalid inputs.
""" """
if not should_replace:
raise ValueError("IsIn pattern does not have its own should_replace attribute. Set should_replace in \
its sub-pattern instead.")
self.patterns = patterns self.patterns = patterns
if patterns is None: if isinstance(patterns, Pattern):
IsIn_.__init__(self, ()) OneOf_.__init__(self, [patterns])
elif isinstance(patterns, Pattern):
IsIn_.__init__(self, [patterns])
elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns): elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
IsIn_.__init__(self, patterns) OneOf_.__init__(self, patterns)
else: else:
raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}") raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}")
class IsPrimTypeOf(IsPrimTypeOf_): class Prim(Prim_):
r""" r"""
Express a pattern of certain primitive type(s). Express a pattern of certain primitive type(s).
@ -66,7 +59,7 @@ class IsPrimTypeOf(IsPrimTypeOf_):
This pattern will match and only match the primitive value node. If matching primitive CNode is needed, This pattern will match and only match the primitive value node. If matching primitive CNode is needed,
please refer to CallWith pattern. please refer to CallWith pattern.
""" """
def __init__(self, types, name=None, should_replace=True): def __init__(self, types, name=None):
r""" r"""
Args: Args:
types (Union[str, :class:`mindspore.ops.Primitive`, list[:class:`mindspore.ops.Primitive`], types (Union[str, :class:`mindspore.ops.Primitive`, list[:class:`mindspore.ops.Primitive`],
@ -77,9 +70,6 @@ class IsPrimTypeOf(IsPrimTypeOf_):
2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D' 2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D'
It can also be a Primitive or a list/tuple of Primitives, e.g. [ops.Conv2D(1, 6)] It can also be a Primitive or a list/tuple of Primitives, e.g. [ops.Conv2D(1, 6)]
name (str): name of the pattern, optional. Default: None. name (str): name of the pattern, optional. Default: None.
should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is
used when building the replacement target node. Use captured node if True, build from scratch otherwise.
Default: True.
Raises: Raises:
TypeError: raise type error for invalid argument. TypeError: raise type error for invalid argument.
@ -103,13 +93,13 @@ class IsPrimTypeOf(IsPrimTypeOf_):
self.types = types self.types = types
else: else:
raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}") raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}")
IsPrimTypeOf_.__init__(self, self.types, self.name, should_replace) Prim_.__init__(self, self.types, self.name)
class CallWith(CallWith_): class Call(Call_):
r""" r"""
Express a primitive CNode. Express a primitive CNode.
""" """
def __init__(self, prim_pattern, inputs=None, should_replace=True): def __init__(self, prim_pattern, inputs=None):
r""" r"""
Args: Args:
prim_pattern (Union[str, :class:`mindspore.graph_utils.graph_pattern.IsPrimTypeOf`, prim_pattern (Union[str, :class:`mindspore.graph_utils.graph_pattern.IsPrimTypeOf`,
@ -118,9 +108,6 @@ class CallWith(CallWith_):
tuple[:class:`mindspore.graph_utils.graph_pattern`]]): tuple[:class:`mindspore.graph_utils.graph_pattern`]]):
Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; if specified, input Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; if specified, input
patterns should be of right order and each element should be one of the exposed Pattern instance. patterns should be of right order and each element should be one of the exposed Pattern instance.
should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is
used when building the replacement target node. Use captured node if True, build from scratch otherwise.
Default: True.
Raises: Raises:
TypeError: raise type error for invalid argument. TypeError: raise type error for invalid argument.
@ -135,36 +122,31 @@ class CallWith(CallWith_):
self.inputs = inputs self.inputs = inputs
else: else:
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}") raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
CallWith_.__init__(self, self.prim_pattern, self.inputs, should_replace) Call_.__init__(self, self.prim_pattern, self.inputs)
class IsNot(IsNot_): class NoneOf(NoneOf_):
r""" r"""
Express a pattern which forbids a list of patterns. Express a pattern which forbids a list of patterns.
NOTE: NOTE:
IsNot pattern should not be the root pattern. NoneOf pattern should not be the root pattern.
""" """
def __init__(self, patterns=None, should_replace=True): def __init__(self, patterns=None):
r""" r"""
Args: Args:
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbiden patterns, each element patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbiden patterns, each element
should be one of the exposed Pattern instance. should be one of the exposed Pattern instance.
should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
Raises: Raises:
ValueError: raise if should_replace is False.
TypeError: raise type error for invalid argument. TypeError: raise type error for invalid argument.
""" """
if not should_replace:
raise ValueError("IsNot pattern does not have its own should_replace attribute. Set should_replace in \
its sub-pattern instead.")
self.patterns = patterns self.patterns = patterns
if patterns is None: if patterns is None:
IsNot_.__init__(self, ()) NoneOf_.__init__(self, ())
elif isinstance(patterns, Pattern): elif isinstance(patterns, Pattern):
IsNot_.__init__(self, [patterns]) NoneOf_.__init__(self, [patterns])
elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns): elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
IsNot_.__init__(self, patterns) NoneOf_.__init__(self, patterns)
else: else:
raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}") raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}")
@ -172,18 +154,14 @@ class NewTensor(NewTensor_):
r""" r"""
New Tensor to be used in the target. New Tensor to be used in the target.
""" """
def __init__(self, input_tensor, should_replace=False): def __init__(self, input_tensor):
r""" r"""
Args: Args:
input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target
should_replace(bool): added this for interface consistency. NewTensor should only appear in the target.
Raises: Raises:
ValueError: raise if should_replace is True
TypeError: raise type error for invalid argument. TypeError: raise type error for invalid argument.
""" """
if should_replace:
raise ValueError("NewTensor should only appear in the target, thus should_replace can only be False.")
self.input_tensor = input_tensor self.input_tensor = input_tensor
if isinstance(input_tensor, Tensor): if isinstance(input_tensor, Tensor):
NewTensor_.__init__(self, input_tensor) NewTensor_.__init__(self, input_tensor)
@ -194,15 +172,13 @@ class NewParameter(NewParameter_):
r""" r"""
New Parameter to be used in the target. New Parameter to be used in the target.
""" """
def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False, should_replace=False): def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False):
r""" r"""
Args: Args:
para_name(str): name for the new Parameter para_name(str): name for the new Parameter
default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter
requires_grad(bool): True if the parameter requires gradient. Default: True requires_grad(bool): True if the parameter requires gradient. Default: True
layerwise_parallel(bool): switch for layerwise parallel mode. Default: False layerwise_parallel(bool): switch for layerwise parallel mode. Default: False
should_replace(bool): gen new parameter once and replace after if set to be true; otherwise build a new
parameter everytime a pass target got built. Default: False
Raises: Raises:
TypeError: raise type error for invalid argument. TypeError: raise type error for invalid argument.
@ -211,12 +187,11 @@ class NewParameter(NewParameter_):
self.default_tensor = default_tensor self.default_tensor = default_tensor
self.requires_grad = requires_grad self.requires_grad = requires_grad
self.layerwise_parallel = layerwise_parallel self.layerwise_parallel = layerwise_parallel
self.should_replace = should_replace
if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\ if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\
isinstance(layerwise_parallel, bool) and isinstance(should_replace, bool): isinstance(layerwise_parallel, bool):
NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad, NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad,
self.layerwise_parallel, self.should_replace) self.layerwise_parallel)
else: else:
raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \ raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \
layerwise_parallel(bool) should_replace(bool) got : {para_name}, {default_tensor}, \ layerwise_parallel(bool) got : {para_name}, {default_tensor}, \
{requires_grad}, {layerwise_parallel}, {should_replace}") {requires_grad}, {layerwise_parallel}")

View File

@ -15,7 +15,7 @@
"""Python pass register""" """Python pass register"""
from inspect import isfunction from inspect import isfunction
from mindspore.graph_utils.graph_pattern import Pattern, NewParameter from mindspore.graph_utils.graph_pattern import Pattern, NewParameter
from mindspore._c_expression import PyPassManager_, phase from mindspore._c_expression import PyPassManager_
__all__ = [ __all__ = [
@ -30,21 +30,16 @@ class PyPassManager(PyPassManager_):
Used to registe and unregiste python passes which can be used to alter graphs. Used to registe and unregiste python passes which can be used to alter graphs.
Args: Args:
pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt.
run_only_once (bool): Specify whether or not to run pass only once. Default: False. run_only_once (bool): Specify whether or not to run pass only once. Default: False.
multigraph (bool): Whether or not the pattern exists across graphs. Default: True.
Raises: Raises:
TypeError: If argument has invalid type. TypeError: If argument has invalid type.
""" """
def __init__(self, pipeline_phase=phase.opt, run_only_once=False): def __init__(self, run_only_once=False):
if not isinstance(pipeline_phase, phase):
raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}")
if not isinstance(run_only_once, bool): if not isinstance(run_only_once, bool):
raise TypeError(f"Expect bool, got : ({type(run_only_once)}){run_only_once}") raise TypeError(f"Expect bool, got : ({type(run_only_once)}){run_only_once}")
PyPassManager_.__init__(self)
self.phase_ = pipeline_phase
self.run_only_once_ = run_only_once self.run_only_once_ = run_only_once
PyPassManager_.__init__(self)
def registe(self, py_pass): def registe(self, py_pass):
if not isfunction(py_pass): if not isfunction(py_pass):
@ -55,16 +50,14 @@ class PyPassManager(PyPassManager_):
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}") raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
if not isinstance(target, Pattern): if not isinstance(target, Pattern):
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}") raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_) super().registe(pass_name, pattern, target, self.run_only_once_)
def unregiste(self, py_pass, pipeline_phase=phase.opt): def unregiste(self, py_pass):
if not isinstance(pipeline_phase, phase):
raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}")
if isinstance(py_pass, str): if isinstance(py_pass, str):
super().unregiste(py_pass, pipeline_phase) super().unregiste(py_pass)
return return
if isfunction(py_pass): if isfunction(py_pass):
super().unregiste(py_pass.__name__, pipeline_phase) super().unregiste(py_pass.__name__)
return return
raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}") raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}")
@ -82,13 +75,11 @@ class PyPassManager(PyPassManager_):
raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}") raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}")
super().set_renorm(should_renorm) super().set_renorm(should_renorm)
def registe_pass(pipeline_phase=phase.opt, run_only_once=False): def registe_pass(run_only_once=False):
""" """
Registe python pass to specified pipeline phase which would be used in compilation. Registe python pass to specified pipeline phase which would be used in compilation.
Args: Args:
pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is
registed. Support phase.resolve and phase.opt. Default: phase.opt.
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False. run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False.
Returns: Returns:
@ -102,19 +93,17 @@ def registe_pass(pipeline_phase=phase.opt, run_only_once=False):
>>> target = IsPrimTypeOf("ReLU6") >>> target = IsPrimTypeOf("ReLU6")
>>> return pattern, target >>> return pattern, target
""" """
return PyPassManager(pipeline_phase, run_only_once) return PyPassManager(run_only_once)
def unregiste_pass(py_pass, pipeline_phase=phase.opt): def unregiste_pass(py_pass):
""" """
Unregiste python pass. Unregiste python pass.
Args: Args:
py_pass(Union(str, function)): target python pass to unregiste. py_pass(Union(str, function)): target python pass to unregiste.
pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is
unregisted. Support phase.resolve and phase.opt. Default: phase.opt.
""" """
ppm = PyPassManager() ppm = PyPassManager()
ppm.unregiste(py_pass, pipeline_phase) ppm.unregiste(py_pass)
def gen_new_parameter(pattern): def gen_new_parameter(pattern):
""" """
@ -164,7 +153,14 @@ def cancel_new_parameter(pattern):
def set_renorm(should_renorm): def set_renorm(should_renorm):
""" """
Set whether or not to do renorm after modified graph in python pass(es). Set whether or not to do renormalization after modified graph in python pass(es).
Args:
should_renorm(bool): whether or not to do renormalization after modified graph in python pass(es).
NOTE:
This interface is mainly intended for testing modifying graph without worrying about its validity. Turn off
renormalization may BREAK the network.
""" """
ppm = PyPassManager() ppm = PyPassManager()
ppm.set_renorm(should_renorm) ppm.set_renorm(should_renorm)

View File

@ -23,8 +23,7 @@ from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_
cancel_new_parameter cancel_new_parameter
from mindspore.common.api import _generate_pip_args from mindspore.common.api import _generate_pip_args
from mindspore._c_expression import generate_key, Executor_ from mindspore._c_expression import generate_key, Executor_
from mindspore.graph_utils.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor,\ from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm
NewParameter, Imm
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
@ -50,11 +49,9 @@ def test_softmax_relu():
@registe_pass(run_only_once=True) @registe_pass(run_only_once=True)
def softmax_relu_pass(): def softmax_relu_pass():
x = AnyPattern() x = Any()
softmax_pattern = IsPrimTypeOf(P.Softmax()) pattern = Call(P.Softmax(), inputs=[x])
pattern = CallWith(softmax_pattern, inputs=[x]) target = Call(P.ReLU(), inputs=[x])
relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False)
target = CallWith(relu_pattern, inputs=[x])
return pattern, target return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
@ -74,13 +71,13 @@ def test_softmax_relu_sigmoid():
@registe_pass(run_only_once=True) @registe_pass(run_only_once=True)
def softmax_relu_pass(): def softmax_relu_pass():
x = AnyPattern() x = Any()
softmax_pattern = IsPrimTypeOf(P.Softmax()) softmax_pattern = Prim(P.Softmax())
pattern = CallWith(softmax_pattern, inputs=[x]) pattern = Call(softmax_pattern, inputs=[x])
sigmoid_pattern = IsPrimTypeOf(P.Sigmoid(), should_replace=False) sigmoid_pattern = Prim(P.Sigmoid())
call_sigmoid = CallWith(sigmoid_pattern, [x]) call_sigmoid = Call(sigmoid_pattern, [x])
relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False) relu_pattern = Prim(P.ReLU())
target = CallWith(relu_pattern, inputs=[call_sigmoid]) target = Call(relu_pattern, inputs=[call_sigmoid])
return pattern, target return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
@ -99,15 +96,15 @@ def test_isin_pattern_0():
@registe_pass(run_only_once=True) @registe_pass(run_only_once=True)
def softmax_relu_pass(): def softmax_relu_pass():
x = AnyPattern() x = Any()
softmax_pattern = IsPrimTypeOf(P.Softmax()) softmax_pattern = Prim(P.Softmax())
call_softmax = CallWith(softmax_pattern, inputs=[x]) call_softmax = Call(softmax_pattern, inputs=[x])
relu_pattern = IsPrimTypeOf(P.ReLU()) relu_pattern = Prim(P.ReLU())
call_relu = CallWith(relu_pattern, inputs=[x]) call_relu = Call(relu_pattern, inputs=[x])
pattern = IsIn([call_softmax, call_relu]) pattern = OneOf([call_softmax, call_relu])
relu6_pattern = IsPrimTypeOf(P.ReLU6(), should_replace=False) relu6_pattern = Prim(P.ReLU6())
target = CallWith(relu6_pattern, inputs=[x]) target = Call(relu6_pattern, inputs=[x])
return pattern, target return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
unregiste_pass(softmax_relu_pass) unregiste_pass(softmax_relu_pass)
@ -123,18 +120,17 @@ def test_isin_pattern_1():
@registe_pass(run_only_once=True) @registe_pass(run_only_once=True)
def softmax_neg_pass(): def softmax_neg_pass():
x = AnyPattern() x = Any()
softmax_pattern = IsPrimTypeOf(P.Softmax()) softmax_pattern = Prim(P.Softmax())
call_softmax = CallWith(softmax_pattern, inputs=[x]) call_softmax = Call(softmax_pattern, inputs=[x])
relu_pattern = IsPrimTypeOf(P.ReLU()) relu_pattern = Prim(P.ReLU())
call_relu = CallWith(relu_pattern, inputs=[x]) call_relu = Call(relu_pattern, inputs=[x])
pattern = IsIn([call_softmax, call_relu]) pattern = OneOf([call_softmax, call_relu])
neg_ops = IsPrimTypeOf(P.Neg(), should_replace=False) neg_ops = Prim(P.Neg())
target = CallWith(neg_ops, inputs=[pattern]) target = Call(neg_ops, inputs=[pattern])
return pattern, target return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
print(transformed_repr)
unregiste_pass(softmax_neg_pass) unregiste_pass(softmax_neg_pass)
assert "Neg" in transformed_repr assert "Neg" in transformed_repr
assert "Softmax" in transformed_repr assert "Softmax" in transformed_repr
@ -167,11 +163,11 @@ def test_isnot_pattern_0():
""" """
Sub a BN which does NOT take Conv as inputs to ReLU6. Sub a BN which does NOT take Conv as inputs to ReLU6.
""" """
conv2d_prim = IsPrimTypeOf("Conv2D") conv2d_prim = Prim("Conv2D")
conv2d = CallWith(conv2d_prim) conv2d = Call(conv2d_prim)
pattern_0 = IsNot(conv2d) pattern_0 = NoneOf(conv2d)
pattern = CallWith(P.BatchNorm(), inputs=[pattern_0]) pattern = Call(P.BatchNorm(), inputs=[pattern_0])
target = CallWith(P.ReLU6(), inputs=[pattern_0]) target = Call(P.ReLU6(), inputs=[pattern_0])
return pattern, target return pattern, target
@registe_pass(run_only_once=True) @registe_pass(run_only_once=True)
@ -179,10 +175,8 @@ def test_isnot_pattern_0():
""" """
Sub a BN to Softmax. Sub a BN to Softmax.
""" """
bn = P.BatchNorm() pattern = Call(P.BatchNorm())
pattern = CallWith(bn) target = Call(P.Softmax())
softmax = P.Softmax()
target = CallWith(softmax, should_replace=False)
return pattern, target return pattern, target
transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5) transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5)
@ -205,12 +199,12 @@ def test_isnot_pattern_1():
""" """
Sub a BN which does NOT take MatMul as inputs to ReLU6. Sub a BN which does NOT take MatMul as inputs to ReLU6.
""" """
matmul = IsPrimTypeOf("MatMul") matmul = Prim("MatMul")
pattern_0 = IsNot(matmul) pattern_0 = NoneOf(matmul)
softmax = P.Softmax() softmax = P.Softmax()
pattern = CallWith(softmax, inputs=[pattern_0]) pattern = Call(softmax, inputs=[pattern_0])
relu6 = P.ReLU6() relu6 = P.ReLU6()
target = CallWith(relu6, inputs=[pattern_0], should_replace=False) target = Call(relu6, inputs=[pattern_0])
return pattern, target return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
@ -228,14 +222,12 @@ def test_newtensor_pattern():
@registe_pass(run_only_once=True) @registe_pass(run_only_once=True)
def softmax_addn_pass(): def softmax_addn_pass():
x = AnyPattern() x = Any()
softmax = P.Softmax() pattern = Call(P.Softmax(), inputs=[x])
pattern = CallWith(softmax, inputs=[x])
weight_tensor = Tensor(np.zeros([42]), mindspore.float16) weight_tensor = Tensor(np.zeros([42]), mindspore.float16)
new_weight = NewTensor(weight_tensor) new_weight = NewTensor(weight_tensor)
addn_ops = P.AddN() target = Call(P.AddN(), inputs=[x, new_weight])
target = CallWith(addn_ops, inputs=[x, new_weight], should_replace=False)
return pattern, target return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
unregiste_pass(softmax_addn_pass) unregiste_pass(softmax_addn_pass)
@ -252,25 +244,23 @@ def test_newparameter_pattern():
@registe_pass(run_only_once=True) @registe_pass(run_only_once=True)
def softmax_addn_pass(): def softmax_addn_pass():
x = AnyPattern() x = Any()
softmax = P.Softmax() pattern = Call(P.Softmax(), inputs=[x])
pattern = CallWith(softmax, inputs=[x])
default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32) default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32)
default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32) default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32)
new_para_0 = NewParameter("Merlin", default_tensor0) new_para_0 = NewParameter("Merlin", default_tensor0)
new_para_1 = NewParameter("Arthur", default_tensor1) new_para_1 = NewParameter("Arthur", default_tensor1)
target_0 = CallWith(P.MatMul(), inputs=[new_para_0, new_para_1], should_replace=False) target_0 = Call(P.MatMul(), inputs=[new_para_0, new_para_1])
target = CallWith("make_tuple", inputs=[target_0], should_replace=False) target = Call("make_tuple", inputs=[target_0])
return pattern, target return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
print(transformed_repr)
unregiste_pass(softmax_addn_pass) unregiste_pass(softmax_addn_pass)
assert "MatMul" in transformed_repr assert "MatMul" in transformed_repr
assert "make_tuple" in transformed_repr assert "make_tuple" in transformed_repr
assert "Softmax" not in transformed_repr assert "Softmax" not in transformed_repr
def test_imm_pattern(): def test_imm_target():
""" """
Test NewParameter pattern in the target Test NewParameter pattern in the target
""" """
@ -278,17 +268,15 @@ def test_imm_pattern():
softmax_model = nn.Softmax() softmax_model = nn.Softmax()
@registe_pass(run_only_once=True) @registe_pass(run_only_once=True)
def softmax_addn_pass(): def softmax_pass():
x = AnyPattern() x = Any()
softmax = P.Softmax() pattern = Call(P.Softmax(), inputs=[x])
pattern = CallWith(softmax, inputs=[x])
imm = Imm(0) imm = Imm(0)
target_0 = CallWith("make_tuple", inputs=[pattern], should_replace=False) target_0 = Call("make_tuple", inputs=[pattern])
target = CallWith("tuple_getitem", inputs=[target_0, imm], should_replace=False) target = Call("tuple_getitem", inputs=[target_0, imm])
return pattern, target return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
print(transformed_repr) unregiste_pass(softmax_pass)
unregiste_pass(softmax_addn_pass)
assert "make_tuple" in transformed_repr assert "make_tuple" in transformed_repr
assert "tuple_getitem" in transformed_repr assert "tuple_getitem" in transformed_repr
assert "Softmax" in transformed_repr assert "Softmax" in transformed_repr
@ -301,21 +289,19 @@ def test_gen_new_parameter():
softmax_model = nn.Softmax() softmax_model = nn.Softmax()
default_tensor = Tensor(np.ones((4, 4)), mindspore.float32) default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
new_para = NewParameter("Merlin", default_tensor, should_replace=True) new_para = NewParameter("Merlin", default_tensor)
gen_new_parameter(new_para) gen_new_parameter(new_para)
@registe_pass(run_only_once=True) @registe_pass(run_only_once=True)
def softmax_make_tuple_pass(): def softmax_make_tuple_pass():
x = AnyPattern() x = Any()
softmax = P.Softmax() softmax = P.Softmax()
pattern = CallWith(softmax, inputs=[x]) pattern = Call(softmax, inputs=[x])
target = CallWith("make_tuple", inputs=[pattern, new_para], should_replace=False) target = Call("make_tuple", inputs=[pattern, new_para])
return pattern, target return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
print(transformed_repr)
assert "Merlin" in transformed_repr assert "Merlin" in transformed_repr
unregiste_pass(softmax_make_tuple_pass) unregiste_pass(softmax_make_tuple_pass)
cancel_new_parameter(new_para) cancel_new_parameter(new_para)
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
print(transformed_repr)
assert "Merlin" not in transformed_repr assert "Merlin" not in transformed_repr