python pass pattern renaming and interface tweaking
This commit is contained in:
parent
3a16925fa2
commit
641d12d6d9
|
@ -21,25 +21,23 @@ namespace opt {
|
||||||
namespace python_pass {
|
namespace python_pass {
|
||||||
int Pattern::g_id_ = 0;
|
int Pattern::g_id_ = 0;
|
||||||
|
|
||||||
MatchResultPtr IsPrimTypeOf::match(const AnfNodePtr &node) {
|
MatchResultPtr Prim::match(const AnfNodePtr &node) {
|
||||||
if (!IsValueNode<Primitive>(node)) {
|
if (!IsValueNode<Primitive>(node)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
MatchResultPtr res = std::make_shared<MatchResult>();
|
MatchResultPtr res = std::make_shared<MatchResult>();
|
||||||
if (IsValueNode<Primitive>(node)) {
|
|
||||||
// iterate over all primitives
|
// iterate over all primitives
|
||||||
for (auto &iter : primitives_) {
|
for (auto &iter : primitives_) {
|
||||||
if (IsPrimitive(node, iter) || iter->name() == "*") {
|
if (IsPrimitive(node, iter) || iter->name() == "*") {
|
||||||
matched_prim_ = iter;
|
matched_prim_ = iter;
|
||||||
res->add_entry(shared_from_base<IsPrimTypeOf>(), node);
|
res->add_entry(shared_from_base<Prim>(), node);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
MatchResultPtr CallWith::match(const AnfNodePtr &node) {
|
MatchResultPtr Call::match(const AnfNodePtr &node) {
|
||||||
if (!IsPrimitiveCNode(node)) {
|
if (!IsPrimitiveCNode(node)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -71,7 +69,7 @@ MatchResultPtr CallWith::match(const AnfNodePtr &node) {
|
||||||
}
|
}
|
||||||
// If inputs is not specified, add node without looking into its inputs
|
// If inputs is not specified, add node without looking into its inputs
|
||||||
if (p_inputs_size == 0) {
|
if (p_inputs_size == 0) {
|
||||||
res->add_entry(shared_from_base<CallWith>(), cnode->input(0));
|
res->add_entry(shared_from_base<Call>(), cnode->input(0));
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
bool failed = false;
|
bool failed = false;
|
||||||
|
@ -86,24 +84,24 @@ MatchResultPtr CallWith::match(const AnfNodePtr &node) {
|
||||||
res->merge(input_match_result);
|
res->merge(input_match_result);
|
||||||
}
|
}
|
||||||
if (!failed) {
|
if (!failed) {
|
||||||
res->add_entry(shared_from_base<CallWith>(), cnode->input(0));
|
res->add_entry(shared_from_base<Call>(), cnode->input(0));
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
MatchResultPtr IsIn::match(const AnfNodePtr &node) {
|
MatchResultPtr OneOf::match(const AnfNodePtr &node) {
|
||||||
for (auto &iter : patterns_) {
|
for (auto &iter : patterns_) {
|
||||||
auto res = iter->match(node);
|
auto res = iter->match(node);
|
||||||
if (res != nullptr) {
|
if (res != nullptr) {
|
||||||
res->add_entry(shared_from_base<IsIn>(), node);
|
res->add_entry(shared_from_base<OneOf>(), node);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
MatchResultPtr IsNot::match(const AnfNodePtr &node) {
|
MatchResultPtr NoneOf::match(const AnfNodePtr &node) {
|
||||||
for (auto &iter : patterns_) {
|
for (auto &iter : patterns_) {
|
||||||
auto res = iter->match(node);
|
auto res = iter->match(node);
|
||||||
if (res != nullptr) {
|
if (res != nullptr) {
|
||||||
|
@ -111,16 +109,33 @@ MatchResultPtr IsNot::match(const AnfNodePtr &node) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto res = std::make_shared<MatchResult>();
|
auto res = std::make_shared<MatchResult>();
|
||||||
res->add_entry(shared_from_base<IsNot>(), node);
|
res->add_entry(shared_from_base<NoneOf>(), node);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
MatchResultPtr AnyPattern::match(const AnfNodePtr &node) {
|
MatchResultPtr Any::match(const AnfNodePtr &node) {
|
||||||
MatchResultPtr res = std::make_shared<MatchResult>();
|
MatchResultPtr res = std::make_shared<MatchResult>();
|
||||||
res->add_entry(shared_from_base<AnyPattern>(), node);
|
res->add_entry(shared_from_base<Any>(), node);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MatchResultPtr Imm::match(const AnfNodePtr &node) {
|
||||||
|
if (!IsValueNode<Int32Imm>(node)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
// Check value
|
||||||
|
auto value_node = node->cast<ValueNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
|
auto value_ptr = value_node->value()->cast<Int32ImmPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
|
if ((int32_t)value_ptr->value() == value_) {
|
||||||
|
MatchResultPtr res = std::make_shared<MatchResult>();
|
||||||
|
res->add_entry(shared_from_base<Imm>(), node);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
AnfNodePtr MatchResult::get_node(const PatternPtr &pattern) {
|
AnfNodePtr MatchResult::get_node(const PatternPtr &pattern) {
|
||||||
auto entry = match_result_.find(pattern);
|
auto entry = match_result_.find(pattern);
|
||||||
if (entry == match_result_.end()) {
|
if (entry == match_result_.end()) {
|
||||||
|
@ -140,20 +155,20 @@ void MatchResult::merge(const MatchResultPtr &other_result) {
|
||||||
REGISTER_PYBIND_DEFINE(
|
REGISTER_PYBIND_DEFINE(
|
||||||
Pattern, ([](const py::module *m) {
|
Pattern, ([](const py::module *m) {
|
||||||
(void)py::class_<Pattern, std::shared_ptr<Pattern>>(*m, "Pattern").def(py::init<>());
|
(void)py::class_<Pattern, std::shared_ptr<Pattern>>(*m, "Pattern").def(py::init<>());
|
||||||
(void)py::class_<IsIn, std::shared_ptr<IsIn>, Pattern>(*m, "IsIn_").def(py::init<vector<PatternPtr>>());
|
(void)py::class_<OneOf, std::shared_ptr<OneOf>, Pattern>(*m, "OneOf_").def(py::init<vector<PatternPtr>>());
|
||||||
(void)py::class_<IsPrimTypeOf, std::shared_ptr<IsPrimTypeOf>, Pattern>(*m, "IsPrimTypeOf_", py::dynamic_attr())
|
(void)py::class_<Prim, std::shared_ptr<Prim>, Pattern>(*m, "Prim_", py::dynamic_attr())
|
||||||
.def(py::init<vector<PrimitivePyPtr>, string, bool>())
|
.def(py::init<vector<PrimitivePyPtr>, string>())
|
||||||
.def(py::init<vector<string>, string, bool>());
|
.def(py::init<vector<string>, string>());
|
||||||
(void)py::class_<CallWith, std::shared_ptr<CallWith>, Pattern>(*m, "CallWith_")
|
(void)py::class_<Call, std::shared_ptr<Call>, Pattern>(*m, "Call_")
|
||||||
.def(py::init<PatternPtr, vector<PatternPtr>, bool>())
|
.def(py::init<PatternPtr, vector<PatternPtr>>())
|
||||||
.def(py::init<PrimitivePyPtr, vector<PatternPtr>, bool>())
|
.def(py::init<PrimitivePyPtr, vector<PatternPtr>>())
|
||||||
.def(py::init<string, vector<PatternPtr>, bool>());
|
.def(py::init<string, vector<PatternPtr>>());
|
||||||
(void)py::class_<IsNot, std::shared_ptr<IsNot>, Pattern>(*m, "IsNot_").def(py::init<vector<PatternPtr>>());
|
(void)py::class_<NoneOf, std::shared_ptr<NoneOf>, Pattern>(*m, "NoneOf_").def(py::init<vector<PatternPtr>>());
|
||||||
(void)py::class_<AnyPattern, std::shared_ptr<AnyPattern>, Pattern>(*m, "AnyPattern").def(py::init<>());
|
(void)py::class_<Any, std::shared_ptr<Any>, Pattern>(*m, "Any").def(py::init<>());
|
||||||
(void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_")
|
(void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_")
|
||||||
.def(py::init<tensor::TensorPtr>());
|
.def(py::init<tensor::TensorPtr>());
|
||||||
(void)py::class_<NewParameter, std::shared_ptr<NewParameter>, Pattern>(*m, "NewParameter_")
|
(void)py::class_<NewParameter, std::shared_ptr<NewParameter>, Pattern>(*m, "NewParameter_")
|
||||||
.def(py::init<string, tensor::TensorPtr, bool, bool, bool>());
|
.def(py::init<string, tensor::TensorPtr, bool, bool>());
|
||||||
(void)py::class_<Imm, std::shared_ptr<Imm>, Pattern>(*m, "Imm").def(py::init<int>());
|
(void)py::class_<Imm, std::shared_ptr<Imm>, Pattern>(*m, "Imm").def(py::init<int>());
|
||||||
}));
|
}));
|
||||||
} // namespace python_pass
|
} // namespace python_pass
|
||||||
|
|
|
@ -36,10 +36,10 @@ class MatchResult;
|
||||||
using MatchResultPtr = std::shared_ptr<MatchResult>;
|
using MatchResultPtr = std::shared_ptr<MatchResult>;
|
||||||
class Pattern;
|
class Pattern;
|
||||||
using PatternPtr = std::shared_ptr<Pattern>;
|
using PatternPtr = std::shared_ptr<Pattern>;
|
||||||
class IsPrimTypeOf;
|
class Prim;
|
||||||
using IsPrimTypeOfPtr = std::shared_ptr<IsPrimTypeOf>;
|
using PrimPtr = std::shared_ptr<Prim>;
|
||||||
class CallWith;
|
class Call;
|
||||||
using CallWithPtr = std::shared_ptr<CallWith>;
|
using CallPtr = std::shared_ptr<Call>;
|
||||||
class NewTensor;
|
class NewTensor;
|
||||||
using NewTensorPtr = std::shared_ptr<NewTensor>;
|
using NewTensorPtr = std::shared_ptr<NewTensor>;
|
||||||
class NewParameter;
|
class NewParameter;
|
||||||
|
@ -58,8 +58,6 @@ class Pattern : public Base {
|
||||||
virtual bool operator==(const Pattern &other) const { return unique_name_ == other.unique_name_; }
|
virtual bool operator==(const Pattern &other) const { return unique_name_ == other.unique_name_; }
|
||||||
string unique_name() const { return unique_name_; }
|
string unique_name() const { return unique_name_; }
|
||||||
vector<PatternPtr> inputs() { return inputs_; }
|
vector<PatternPtr> inputs() { return inputs_; }
|
||||||
bool should_replace() { return should_replace_; }
|
|
||||||
void set_should_replace(bool should_replace) { should_replace_ = should_replace; }
|
|
||||||
virtual void reset() {}
|
virtual void reset() {}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -67,7 +65,6 @@ class Pattern : public Base {
|
||||||
// NOTE: To ensure uniqueness of the name, raise g_id_ by 1 every time a pattern got constructed
|
// NOTE: To ensure uniqueness of the name, raise g_id_ by 1 every time a pattern got constructed
|
||||||
string unique_name_;
|
string unique_name_;
|
||||||
vector<PatternPtr> inputs_;
|
vector<PatternPtr> inputs_;
|
||||||
bool should_replace_ = true;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct PatternEqual {
|
struct PatternEqual {
|
||||||
|
@ -85,70 +82,61 @@ struct PatternHasher {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class IsPrimTypeOf : public Pattern {
|
class Prim : public Pattern {
|
||||||
public:
|
public:
|
||||||
IsPrimTypeOf() { unique_name_ = std::to_string(g_id_++); }
|
Prim() { unique_name_ = std::to_string(g_id_++); }
|
||||||
~IsPrimTypeOf() = default;
|
~Prim() = default;
|
||||||
IsPrimTypeOf(vector<PrimitivePyPtr> prims, string name, bool should_replace)
|
Prim(vector<PrimitivePyPtr> prims, string name) : primitives_(prims), name_(name) {
|
||||||
: primitives_(prims), name_(name), matched_prim_(nullptr) {
|
unique_name_ = std::to_string(g_id_++) + "Prim_" + name;
|
||||||
unique_name_ = std::to_string(g_id_++) + "IsPrimTypeOf_" + name;
|
// Default using the first prim to build target
|
||||||
should_replace_ = should_replace;
|
matched_prim_ = primitives_[0];
|
||||||
if (!should_replace) {
|
|
||||||
matched_prim_ = prims[0];
|
|
||||||
}
|
}
|
||||||
}
|
Prim(vector<string> types, string name) : types_(types), name_(name) {
|
||||||
IsPrimTypeOf(vector<string> types, string name, bool should_replace) : types_(types), name_(name) {
|
unique_name_ = std::to_string(g_id_++) + "Prim_" + name;
|
||||||
unique_name_ = std::to_string(g_id_++) + "IsPrimTypeOf_" + name;
|
|
||||||
// Make primitives_
|
// Make primitives_
|
||||||
for (auto &iter : types) {
|
for (auto &iter : types) {
|
||||||
primitives_.push_back(std::make_shared<PrimitivePy>(iter, py::cast(nullptr)));
|
primitives_.push_back(std::make_shared<PrimitivePy>(iter, py::cast(nullptr)));
|
||||||
}
|
}
|
||||||
should_replace_ = should_replace;
|
// Default using the first prim to build target
|
||||||
if (!should_replace) {
|
|
||||||
matched_prim_ = primitives_[0];
|
matched_prim_ = primitives_[0];
|
||||||
}
|
}
|
||||||
}
|
MS_DECLARE_PARENT(Prim, Pattern);
|
||||||
MS_DECLARE_PARENT(IsPrimTypeOf, Pattern);
|
|
||||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||||
PrimitivePyPtr matched_primitive() { return matched_prim_; }
|
PrimitivePyPtr matched_primitive() { return matched_prim_; }
|
||||||
void reset() override {
|
void reset() override {
|
||||||
if (should_replace_) {
|
// Init before reset
|
||||||
matched_prim_ = nullptr;
|
MS_EXCEPTION_IF_NULL(matched_prim_);
|
||||||
}
|
matched_prim_ = primitives_[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
vector<string> types_;
|
vector<string> types_;
|
||||||
vector<PrimitivePyPtr> primitives_;
|
vector<PrimitivePyPtr> primitives_;
|
||||||
string name_;
|
string name_;
|
||||||
PrimitivePyPtr matched_prim_;
|
PrimitivePyPtr matched_prim_{nullptr};
|
||||||
};
|
};
|
||||||
|
|
||||||
class CallWith : public Pattern {
|
class Call : public Pattern {
|
||||||
public:
|
public:
|
||||||
CallWith() { unique_name_ = std::to_string(g_id_++); }
|
Call() { unique_name_ = std::to_string(g_id_++); }
|
||||||
~CallWith() = default;
|
~Call() = default;
|
||||||
CallWith(PatternPtr prim_pattern, vector<PatternPtr> inputs, bool should_replace) {
|
Call(PatternPtr prim_pattern, vector<PatternPtr> inputs) {
|
||||||
// NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting
|
// NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting
|
||||||
prim_pattern_ = prim_pattern;
|
prim_pattern_ = prim_pattern;
|
||||||
unique_name_ = std::to_string(g_id_++) + "CallWithPattern_" + prim_pattern->unique_name();
|
unique_name_ = std::to_string(g_id_++) + "Call_" + prim_pattern->unique_name();
|
||||||
inputs_ = inputs;
|
inputs_ = inputs;
|
||||||
// NOTE: should_replace_ is overrided by it prim_pattern(if exists) silently.
|
|
||||||
should_replace_ = prim_pattern->should_replace();
|
|
||||||
}
|
}
|
||||||
CallWith(PrimitivePyPtr prim, vector<PatternPtr> inputs, bool should_replace) {
|
Call(PrimitivePyPtr prim, vector<PatternPtr> inputs) {
|
||||||
prim_ = prim;
|
prim_ = prim;
|
||||||
unique_name_ = std::to_string(g_id_++) + "CallWithPrim_" + prim_->ToString();
|
unique_name_ = std::to_string(g_id_++) + "Call_" + prim_->ToString();
|
||||||
inputs_ = inputs;
|
inputs_ = inputs;
|
||||||
should_replace_ = should_replace;
|
|
||||||
}
|
}
|
||||||
CallWith(string prim_str, vector<PatternPtr> inputs, bool should_replace) {
|
Call(string prim_str, vector<PatternPtr> inputs) {
|
||||||
prim_ = std::make_shared<PrimitivePy>(prim_str, py::cast(nullptr));
|
prim_ = std::make_shared<PrimitivePy>(prim_str, py::cast(nullptr));
|
||||||
unique_name_ = std::to_string(g_id_++) + "CallWithStr_" + prim_->ToString();
|
unique_name_ = std::to_string(g_id_++) + "CallStr_" + prim_->ToString();
|
||||||
inputs_ = inputs;
|
inputs_ = inputs;
|
||||||
should_replace_ = should_replace;
|
|
||||||
}
|
}
|
||||||
MS_DECLARE_PARENT(CallWith, Pattern);
|
MS_DECLARE_PARENT(Call, Pattern);
|
||||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||||
PrimitivePtr prim_value() { return prim_; }
|
PrimitivePtr prim_value() { return prim_; }
|
||||||
PatternPtr prim_pattern() { return prim_pattern_; }
|
PatternPtr prim_pattern() { return prim_pattern_; }
|
||||||
|
@ -160,45 +148,45 @@ class CallWith : public Pattern {
|
||||||
string name_;
|
string name_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class IsIn : public Pattern {
|
class OneOf : public Pattern {
|
||||||
public:
|
public:
|
||||||
IsIn() { unique_name_ = std::to_string(g_id_++); }
|
OneOf() { unique_name_ = std::to_string(g_id_++); }
|
||||||
~IsIn() = default;
|
~OneOf() = default;
|
||||||
explicit IsIn(vector<PatternPtr> patterns) : patterns_(patterns) {
|
explicit OneOf(vector<PatternPtr> patterns) : patterns_(patterns) {
|
||||||
unique_name_ = std::to_string(g_id_++) + "IsIn";
|
unique_name_ = std::to_string(g_id_++) + "OneOf";
|
||||||
for (auto &iter : patterns) {
|
for (auto &iter : patterns) {
|
||||||
unique_name_ = unique_name_ + "_" + iter->unique_name();
|
unique_name_ = unique_name_ + "_" + iter->unique_name();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MS_DECLARE_PARENT(IsIn, Pattern);
|
MS_DECLARE_PARENT(OneOf, Pattern);
|
||||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
vector<PatternPtr> patterns_;
|
vector<PatternPtr> patterns_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class IsNot : public Pattern {
|
class NoneOf : public Pattern {
|
||||||
public:
|
public:
|
||||||
IsNot() { unique_name_ = std::to_string(g_id_++); }
|
NoneOf() { unique_name_ = std::to_string(g_id_++); }
|
||||||
~IsNot() = default;
|
~NoneOf() = default;
|
||||||
explicit IsNot(vector<PatternPtr> patterns) : patterns_(patterns) {
|
explicit NoneOf(vector<PatternPtr> patterns) : patterns_(patterns) {
|
||||||
unique_name_ = std::to_string(g_id_++) + "IsNot";
|
unique_name_ = std::to_string(g_id_++) + "NoneOf";
|
||||||
for (auto &iter : patterns) {
|
for (auto &iter : patterns) {
|
||||||
unique_name_ = unique_name_ + "_" + iter->unique_name();
|
unique_name_ = unique_name_ + "_" + iter->unique_name();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MS_DECLARE_PARENT(IsNot, Pattern);
|
MS_DECLARE_PARENT(NoneOf, Pattern);
|
||||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
vector<PatternPtr> patterns_;
|
vector<PatternPtr> patterns_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class AnyPattern : public Pattern {
|
class Any : public Pattern {
|
||||||
public:
|
public:
|
||||||
AnyPattern() { unique_name_ = std::to_string(g_id_++) + "_AnyPattern"; }
|
Any() { unique_name_ = std::to_string(g_id_++) + "_Any"; }
|
||||||
~AnyPattern() = default;
|
~Any() = default;
|
||||||
MS_DECLARE_PARENT(AnyPattern, Pattern);
|
MS_DECLARE_PARENT(Any, Pattern);
|
||||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -207,7 +195,6 @@ class NewTensor : public Pattern {
|
||||||
NewTensor() { unique_name_ = std::to_string(g_id_++); }
|
NewTensor() { unique_name_ = std::to_string(g_id_++); }
|
||||||
~NewTensor() = default;
|
~NewTensor() = default;
|
||||||
explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) {
|
explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) {
|
||||||
should_replace_ = false;
|
|
||||||
unique_name_ = std::to_string(g_id_++) + "NewTensor";
|
unique_name_ = std::to_string(g_id_++) + "NewTensor";
|
||||||
}
|
}
|
||||||
MS_DECLARE_PARENT(NewTensor, Pattern);
|
MS_DECLARE_PARENT(NewTensor, Pattern);
|
||||||
|
@ -223,10 +210,8 @@ class NewTensor : public Pattern {
|
||||||
class NewParameter : public Pattern {
|
class NewParameter : public Pattern {
|
||||||
public:
|
public:
|
||||||
NewParameter() { unique_name_ = std::to_string(g_id_++); }
|
NewParameter() { unique_name_ = std::to_string(g_id_++); }
|
||||||
explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel,
|
explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel)
|
||||||
bool should_replace)
|
|
||||||
: para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) {
|
: para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) {
|
||||||
should_replace_ = should_replace;
|
|
||||||
unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name;
|
unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name;
|
||||||
// clone input tensor
|
// clone input tensor
|
||||||
default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get());
|
default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get());
|
||||||
|
@ -243,11 +228,14 @@ class NewParameter : public Pattern {
|
||||||
bool built() { return built_; }
|
bool built() { return built_; }
|
||||||
void set_built(bool built) { built_ = built; }
|
void set_built(bool built) { built_ = built; }
|
||||||
void reset() override { built_ = false; }
|
void reset() override { built_ = false; }
|
||||||
|
bool should_last() { return last_across_passes_; }
|
||||||
|
void set_last(bool last) { last_across_passes_ = last; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
string para_name_;
|
string para_name_;
|
||||||
bool requires_grad_;
|
bool requires_grad_;
|
||||||
bool layerwise_parallel_;
|
bool layerwise_parallel_;
|
||||||
|
bool last_across_passes_{false};
|
||||||
bool built_;
|
bool built_;
|
||||||
tensor::TensorPtr default_tensor_;
|
tensor::TensorPtr default_tensor_;
|
||||||
};
|
};
|
||||||
|
@ -255,13 +243,9 @@ class NewParameter : public Pattern {
|
||||||
class Imm : public Pattern {
|
class Imm : public Pattern {
|
||||||
public:
|
public:
|
||||||
Imm() { unique_name_ = std::to_string(g_id_++); }
|
Imm() { unique_name_ = std::to_string(g_id_++); }
|
||||||
explicit Imm(int value) : value_(value) {
|
explicit Imm(int value) : value_(value) { unique_name_ = std::to_string(g_id_++) + "Imm_" + std::to_string(value); }
|
||||||
should_replace_ = false;
|
|
||||||
unique_name_ = std::to_string(g_id_++) + "Imm_" + std::to_string(value);
|
|
||||||
}
|
|
||||||
MS_DECLARE_PARENT(Imm, Pattern);
|
MS_DECLARE_PARENT(Imm, Pattern);
|
||||||
// NOTE: Doesn't support Imm in src pattern currently.
|
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||||
MatchResultPtr match(const AnfNodePtr &node) override { return nullptr; }
|
|
||||||
int value() { return value_; }
|
int value() { return value_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -80,7 +80,7 @@ bool IsTraversable(const AnfNodePtr &node) {
|
||||||
|
|
||||||
AnfNodePtr BuildPrimitive(const PatternPtr &pattern, const MatchResultPtr &res) {
|
AnfNodePtr BuildPrimitive(const PatternPtr &pattern, const MatchResultPtr &res) {
|
||||||
// Build up AnfNode from primitive
|
// Build up AnfNode from primitive
|
||||||
auto prim_pattern = pattern->cast<IsPrimTypeOfPtr>();
|
auto prim_pattern = pattern->cast<PrimPtr>();
|
||||||
MS_EXCEPTION_IF_NULL(prim_pattern);
|
MS_EXCEPTION_IF_NULL(prim_pattern);
|
||||||
PrimitivePyPtr prim = prim_pattern->matched_primitive();
|
PrimitivePyPtr prim = prim_pattern->matched_primitive();
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
@ -98,13 +98,13 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res)
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg) {
|
AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg) {
|
||||||
auto call_with_pattern = pattern->cast<CallWithPtr>();
|
auto call_pattern = pattern->cast<CallPtr>();
|
||||||
MS_EXCEPTION_IF_NULL(call_with_pattern);
|
MS_EXCEPTION_IF_NULL(call_pattern);
|
||||||
auto prim = call_with_pattern->prim_value();
|
auto prim = call_pattern->prim_value();
|
||||||
if (prim != nullptr) {
|
if (prim != nullptr) {
|
||||||
return std::make_shared<ValueNode>(prim);
|
return std::make_shared<ValueNode>(prim);
|
||||||
}
|
}
|
||||||
auto prim_pattern = call_with_pattern->prim_pattern();
|
auto prim_pattern = call_pattern->prim_pattern();
|
||||||
MS_EXCEPTION_IF_NULL(prim_pattern);
|
MS_EXCEPTION_IF_NULL(prim_pattern);
|
||||||
return ProcessSinglePattern(prim_pattern, res, fg);
|
return ProcessSinglePattern(prim_pattern, res, fg);
|
||||||
}
|
}
|
||||||
|
@ -152,45 +152,35 @@ AnfNodePtr BuildImmNode(const PatternPtr &pattern, const MatchResultPtr &res) {
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) {
|
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) {
|
||||||
if (pattern->should_replace()) {
|
|
||||||
// Find replacement in the MatchResult
|
|
||||||
auto target_node = res->get_node(pattern);
|
auto target_node = res->get_node(pattern);
|
||||||
if (target_node == nullptr) {
|
if (target_node != nullptr) {
|
||||||
// If it's base pattern(in contrast to complex pattern like CallWith/IsIn/IsNot), raise runtime exception.
|
// If pattern is NewParameter, check whether it shouldn't last and is not built
|
||||||
if (pattern->isa<IsPrimTypeOf>() || pattern->isa<NewTensor>() || pattern->isa<NewParameter>()) {
|
auto new_para = pattern->cast<NewParameterPtr>();
|
||||||
MS_LOG(EXCEPTION) << "Cannot find target node, pattern: " + pattern->unique_name() + "\n";
|
if (new_para == nullptr || new_para->should_last() || new_para->built()) {
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
// Try to build this pattern and add to MatchResult, since this pattern is defined inside target
|
|
||||||
auto new_node = BuildTarget(pattern, func_graph, res);
|
|
||||||
if (new_node == nullptr) {
|
|
||||||
MS_LOG(EXCEPTION) << "Try to build pattern node but FAILED. pattern: " + pattern->unique_name() + "\n";
|
|
||||||
}
|
|
||||||
return new_node;
|
|
||||||
}
|
|
||||||
if (pattern->isa<NewParameter>()) {
|
|
||||||
return target_node;
|
return target_node;
|
||||||
}
|
}
|
||||||
return target_node;
|
|
||||||
}
|
}
|
||||||
// Build up new node from pattern
|
// Build up new node from pattern
|
||||||
if (pattern->isa<IsPrimTypeOf>()) {
|
if (pattern->isa<Prim>()) {
|
||||||
return BuildPrimitive(pattern, res);
|
return BuildPrimitive(pattern, res);
|
||||||
} else if (pattern->isa<NewTensor>()) {
|
} else if (pattern->isa<NewTensor>()) {
|
||||||
return BuildNewTensor(pattern, res);
|
return BuildNewTensor(pattern, res);
|
||||||
} else if (pattern->isa<CallWith>()) {
|
} else if (pattern->isa<Call>()) {
|
||||||
return BuildPrimitiveValueNode(pattern, res, func_graph);
|
return BuildPrimitiveValueNode(pattern, res, func_graph);
|
||||||
} else if (pattern->isa<NewParameter>()) {
|
} else if (pattern->isa<NewParameter>()) {
|
||||||
return BuildNewParameter(pattern, res, func_graph);
|
return BuildNewParameter(pattern, res, func_graph);
|
||||||
} else if (pattern->isa<Imm>()) {
|
} else if (pattern->isa<Imm>()) {
|
||||||
return BuildImmNode(pattern, res);
|
return BuildImmNode(pattern, res);
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "Cannot find or build target node, pattern: " + pattern->unique_name() + "\n";
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res,
|
AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res,
|
||||||
const FuncGraphPtr &func_graph) {
|
const FuncGraphPtr &func_graph) {
|
||||||
if (pattern->isa<CallWith>()) {
|
if (pattern->isa<Call>()) {
|
||||||
return BuildPrimitiveValueNode(pattern, res, func_graph);
|
return BuildPrimitiveValueNode(pattern, res, func_graph);
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -269,16 +259,16 @@ void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
void Reset(PatternPtr pattern) {
|
void Reset(PatternPtr pattern) {
|
||||||
if (pattern->isa<IsPrimTypeOf>()) {
|
if (pattern->isa<Prim>()) {
|
||||||
auto prim_pattern = pattern->cast<IsPrimTypeOfPtr>();
|
auto prim_pattern = pattern->cast<PrimPtr>();
|
||||||
prim_pattern->reset();
|
prim_pattern->reset();
|
||||||
return;
|
return;
|
||||||
} else if (pattern->isa<NewParameter>()) {
|
} else if (pattern->isa<NewParameter>()) {
|
||||||
auto new_param_pattern = pattern->cast<NewParameterPtr>();
|
auto new_param_pattern = pattern->cast<NewParameterPtr>();
|
||||||
new_param_pattern->reset();
|
new_param_pattern->reset();
|
||||||
return;
|
return;
|
||||||
} else if (pattern->isa<CallWith>()) {
|
} else if (pattern->isa<Call>()) {
|
||||||
auto call_with_pattern = pattern->cast<CallWithPtr>();
|
auto call_with_pattern = pattern->cast<CallPtr>();
|
||||||
for (auto sub_pattern : call_with_pattern->inputs()) {
|
for (auto sub_pattern : call_with_pattern->inputs()) {
|
||||||
Reset(sub_pattern);
|
Reset(sub_pattern);
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,8 +49,9 @@ PyPassManager::PyPassManager() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
|
void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
|
||||||
Phase phase, bool run_only_once) {
|
bool run_only_once) {
|
||||||
auto cur_pg = GetPassGroup(phase);
|
// NOTE: remove phase option to avoid unnecessary confusion.
|
||||||
|
auto cur_pg = GetPassGroup(Phase::OPT);
|
||||||
MS_EXCEPTION_IF_NULL(cur_pg);
|
MS_EXCEPTION_IF_NULL(cur_pg);
|
||||||
cur_pg->SetRunOnlyOnce(run_only_once);
|
cur_pg->SetRunOnlyOnce(run_only_once);
|
||||||
MS_EXCEPTION_IF_NULL(pattern);
|
MS_EXCEPTION_IF_NULL(pattern);
|
||||||
|
@ -60,8 +61,9 @@ void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &patt
|
||||||
cur_pg->AddPass(new_pass);
|
cur_pg->AddPass(new_pass);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) {
|
void PyPassManager::Unregiste(const std::string &pass_name) {
|
||||||
auto cur_pm = GetPassGroup(phase);
|
// NOTE: remove phase option to avoid unnecessary confusion.
|
||||||
|
auto cur_pm = GetPassGroup(Phase::OPT);
|
||||||
MS_EXCEPTION_IF_NULL(cur_pm);
|
MS_EXCEPTION_IF_NULL(cur_pm);
|
||||||
if (!cur_pm->DeletePass(pass_name)) {
|
if (!cur_pm->DeletePass(pass_name)) {
|
||||||
MS_LOG(WARNING) << "No such pass : " + pass_name + "\n";
|
MS_LOG(WARNING) << "No such pass : " + pass_name + "\n";
|
||||||
|
@ -70,7 +72,6 @@ void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) {
|
||||||
|
|
||||||
void PyPassManager::GenNewParameter(const PatternPtr ¶meter) {
|
void PyPassManager::GenNewParameter(const PatternPtr ¶meter) {
|
||||||
MS_EXCEPTION_IF_NULL(parameter);
|
MS_EXCEPTION_IF_NULL(parameter);
|
||||||
// Add new parameter after resolve
|
|
||||||
// NOTE: Add NewParameter at early stage will cause CSE problems
|
// NOTE: Add NewParameter at early stage will cause CSE problems
|
||||||
auto cur_pg = GetPassGroup(Phase::OPT);
|
auto cur_pg = GetPassGroup(Phase::OPT);
|
||||||
MS_EXCEPTION_IF_NULL(cur_pg);
|
MS_EXCEPTION_IF_NULL(cur_pg);
|
||||||
|
@ -78,7 +79,7 @@ void PyPassManager::GenNewParameter(const PatternPtr ¶meter) {
|
||||||
auto new_para_pattern = parameter->cast<NewParameterPtr>();
|
auto new_para_pattern = parameter->cast<NewParameterPtr>();
|
||||||
MS_EXCEPTION_IF_NULL(new_para_pattern);
|
MS_EXCEPTION_IF_NULL(new_para_pattern);
|
||||||
auto pass_name = new_para_pattern->para_name();
|
auto pass_name = new_para_pattern->para_name();
|
||||||
parameter->set_should_replace(false);
|
new_para_pattern->set_last(true);
|
||||||
auto new_pass = std::make_shared<PythonPass>(pass_name, nullptr, parameter, true);
|
auto new_pass = std::make_shared<PythonPass>(pass_name, nullptr, parameter, true);
|
||||||
cur_pg->AddPass(new_pass);
|
cur_pg->AddPass(new_pass);
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,16 +53,17 @@ class PyPassManager {
|
||||||
static PyPassManagerPtr GetInstance();
|
static PyPassManagerPtr GetInstance();
|
||||||
virtual ~PyPassManager() = default;
|
virtual ~PyPassManager() = default;
|
||||||
void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
|
void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
|
||||||
Phase phase = Phase::RESOLVE, bool run_only_once = false);
|
bool run_only_once = false);
|
||||||
void Unregiste(const std::string &pass_name, Phase phase);
|
void Unregiste(const std::string &pass_name);
|
||||||
void GenNewParameter(const PatternPtr ¶meter);
|
void GenNewParameter(const PatternPtr ¶meter);
|
||||||
PassGroupPtr GetPassGroup(Phase phase);
|
PassGroupPtr GetPassGroup(Phase phase);
|
||||||
void ClearRes();
|
|
||||||
MatchResultPtr GetMatchResult() { return res_; }
|
MatchResultPtr GetMatchResult() { return res_; }
|
||||||
void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; }
|
void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; }
|
||||||
bool ShouldRenorm() { return should_renorm_; }
|
bool ShouldRenorm() { return should_renorm_; }
|
||||||
void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; }
|
void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; }
|
||||||
pipeline::ResourcePtr GetResource() { return resource_; }
|
pipeline::ResourcePtr GetResource() { return resource_; }
|
||||||
|
void ClearRes();
|
||||||
|
void ClearPipelineRes() { resource_ = nullptr; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool should_renorm_ = true;
|
bool should_renorm_ = true;
|
||||||
|
|
|
@ -477,6 +477,7 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
|
||||||
// save the run graph func to MsPipeLine
|
// save the run graph func to MsPipeLine
|
||||||
SaveCompiledGraph(phase_s);
|
SaveCompiledGraph(phase_s);
|
||||||
|
|
||||||
|
opt::python_pass::PyPassManager::GetInstance()->ClearPipelineRes();
|
||||||
resource->Clean();
|
resource->Clean();
|
||||||
// Reclaim all resource used by optimizer;
|
// Reclaim all resource used by optimizer;
|
||||||
ReclaimOptimizer();
|
ReclaimOptimizer();
|
||||||
|
|
|
@ -15,50 +15,43 @@
|
||||||
"""Patterns for describing graphs"""
|
"""Patterns for describing graphs"""
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore._c_expression import Pattern, IsIn_, IsPrimTypeOf_, CallWith_, IsNot_, AnyPattern, NewTensor_,\
|
from mindspore._c_expression import Pattern, OneOf_, Prim_, Call_, NoneOf_, Any, NewTensor_, NewParameter_, Imm
|
||||||
NewParameter_, Imm
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"IsIn",
|
"OneOf",
|
||||||
"IsPrimTypeOf",
|
"Prim",
|
||||||
"CallWith",
|
"Call",
|
||||||
"IsNot",
|
"NoneOf",
|
||||||
"AnyPattern",
|
"Any",
|
||||||
"NewTensor",
|
"NewTensor",
|
||||||
"NewParameter",
|
"NewParameter",
|
||||||
"Imm"
|
"Imm"
|
||||||
]
|
]
|
||||||
|
|
||||||
class IsIn(IsIn_):
|
class OneOf(OneOf_):
|
||||||
r"""
|
r"""
|
||||||
Express a pattern which allows a list of patterns.
|
Express a pattern which allows a list of patterns.
|
||||||
"""
|
"""
|
||||||
def __init__(self, patterns=None, should_replace=True):
|
def __init__(self, patterns=None):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
patterns(Union[tuple[:class:`mindspore.graph_utils.graph_pattern`],
|
patterns(Union[:class:`mindspore.graph_utils.graph_pattern`,
|
||||||
|
tuple[:class:`mindspore.graph_utils.graph_pattern`],
|
||||||
list[:class:`mindspore.graph_utils.graph_pattern`]]): list of allowed patterns,
|
list[:class:`mindspore.graph_utils.graph_pattern`]]): list of allowed patterns,
|
||||||
each element should be one of the exposed Pattern instance.
|
each element should be one of the exposed Pattern instance.
|
||||||
should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: raise if should_replace is False
|
|
||||||
TypeError: raise type error for invalid inputs.
|
TypeError: raise type error for invalid inputs.
|
||||||
"""
|
"""
|
||||||
if not should_replace:
|
|
||||||
raise ValueError("IsIn pattern does not have its own should_replace attribute. Set should_replace in \
|
|
||||||
its sub-pattern instead.")
|
|
||||||
self.patterns = patterns
|
self.patterns = patterns
|
||||||
if patterns is None:
|
if isinstance(patterns, Pattern):
|
||||||
IsIn_.__init__(self, ())
|
OneOf_.__init__(self, [patterns])
|
||||||
elif isinstance(patterns, Pattern):
|
|
||||||
IsIn_.__init__(self, [patterns])
|
|
||||||
elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
|
elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
|
||||||
IsIn_.__init__(self, patterns)
|
OneOf_.__init__(self, patterns)
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}")
|
raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}")
|
||||||
|
|
||||||
class IsPrimTypeOf(IsPrimTypeOf_):
|
class Prim(Prim_):
|
||||||
r"""
|
r"""
|
||||||
Express a pattern of certain primitive type(s).
|
Express a pattern of certain primitive type(s).
|
||||||
|
|
||||||
|
@ -66,7 +59,7 @@ class IsPrimTypeOf(IsPrimTypeOf_):
|
||||||
This pattern will match and only match the primitive value node. If matching primitive CNode is needed,
|
This pattern will match and only match the primitive value node. If matching primitive CNode is needed,
|
||||||
please refer to CallWith pattern.
|
please refer to CallWith pattern.
|
||||||
"""
|
"""
|
||||||
def __init__(self, types, name=None, should_replace=True):
|
def __init__(self, types, name=None):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
types (Union[str, :class:`mindspore.ops.Primitive`, list[:class:`mindspore.ops.Primitive`],
|
types (Union[str, :class:`mindspore.ops.Primitive`, list[:class:`mindspore.ops.Primitive`],
|
||||||
|
@ -77,9 +70,6 @@ class IsPrimTypeOf(IsPrimTypeOf_):
|
||||||
2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D'
|
2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D'
|
||||||
It can also be a Primitive or a list/tuple of Primitives, e.g. [ops.Conv2D(1, 6)]
|
It can also be a Primitive or a list/tuple of Primitives, e.g. [ops.Conv2D(1, 6)]
|
||||||
name (str): name of the pattern, optional. Default: None.
|
name (str): name of the pattern, optional. Default: None.
|
||||||
should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is
|
|
||||||
used when building the replacement target node. Use captured node if True, build from scratch otherwise.
|
|
||||||
Default: True.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: raise type error for invalid argument.
|
TypeError: raise type error for invalid argument.
|
||||||
|
@ -103,13 +93,13 @@ class IsPrimTypeOf(IsPrimTypeOf_):
|
||||||
self.types = types
|
self.types = types
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}")
|
raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}")
|
||||||
IsPrimTypeOf_.__init__(self, self.types, self.name, should_replace)
|
Prim_.__init__(self, self.types, self.name)
|
||||||
|
|
||||||
class CallWith(CallWith_):
|
class Call(Call_):
|
||||||
r"""
|
r"""
|
||||||
Express a primitive CNode.
|
Express a primitive CNode.
|
||||||
"""
|
"""
|
||||||
def __init__(self, prim_pattern, inputs=None, should_replace=True):
|
def __init__(self, prim_pattern, inputs=None):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
prim_pattern (Union[str, :class:`mindspore.graph_utils.graph_pattern.IsPrimTypeOf`,
|
prim_pattern (Union[str, :class:`mindspore.graph_utils.graph_pattern.IsPrimTypeOf`,
|
||||||
|
@ -118,9 +108,6 @@ class CallWith(CallWith_):
|
||||||
tuple[:class:`mindspore.graph_utils.graph_pattern`]]):
|
tuple[:class:`mindspore.graph_utils.graph_pattern`]]):
|
||||||
Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; if specified, input
|
Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; if specified, input
|
||||||
patterns should be of right order and each element should be one of the exposed Pattern instance.
|
patterns should be of right order and each element should be one of the exposed Pattern instance.
|
||||||
should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is
|
|
||||||
used when building the replacement target node. Use captured node if True, build from scratch otherwise.
|
|
||||||
Default: True.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: raise type error for invalid argument.
|
TypeError: raise type error for invalid argument.
|
||||||
|
@ -135,36 +122,31 @@ class CallWith(CallWith_):
|
||||||
self.inputs = inputs
|
self.inputs = inputs
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
|
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
|
||||||
CallWith_.__init__(self, self.prim_pattern, self.inputs, should_replace)
|
Call_.__init__(self, self.prim_pattern, self.inputs)
|
||||||
|
|
||||||
class IsNot(IsNot_):
|
class NoneOf(NoneOf_):
|
||||||
r"""
|
r"""
|
||||||
Express a pattern which forbids a list of patterns.
|
Express a pattern which forbids a list of patterns.
|
||||||
|
|
||||||
NOTE:
|
NOTE:
|
||||||
IsNot pattern should not be the root pattern.
|
NoneOf pattern should not be the root pattern.
|
||||||
"""
|
"""
|
||||||
def __init__(self, patterns=None, should_replace=True):
|
def __init__(self, patterns=None):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbiden patterns, each element
|
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbiden patterns, each element
|
||||||
should be one of the exposed Pattern instance.
|
should be one of the exposed Pattern instance.
|
||||||
should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: raise if should_replace is False.
|
|
||||||
TypeError: raise type error for invalid argument.
|
TypeError: raise type error for invalid argument.
|
||||||
"""
|
"""
|
||||||
if not should_replace:
|
|
||||||
raise ValueError("IsNot pattern does not have its own should_replace attribute. Set should_replace in \
|
|
||||||
its sub-pattern instead.")
|
|
||||||
self.patterns = patterns
|
self.patterns = patterns
|
||||||
if patterns is None:
|
if patterns is None:
|
||||||
IsNot_.__init__(self, ())
|
NoneOf_.__init__(self, ())
|
||||||
elif isinstance(patterns, Pattern):
|
elif isinstance(patterns, Pattern):
|
||||||
IsNot_.__init__(self, [patterns])
|
NoneOf_.__init__(self, [patterns])
|
||||||
elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
|
elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
|
||||||
IsNot_.__init__(self, patterns)
|
NoneOf_.__init__(self, patterns)
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}")
|
raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}")
|
||||||
|
|
||||||
|
@ -172,18 +154,14 @@ class NewTensor(NewTensor_):
|
||||||
r"""
|
r"""
|
||||||
New Tensor to be used in the target.
|
New Tensor to be used in the target.
|
||||||
"""
|
"""
|
||||||
def __init__(self, input_tensor, should_replace=False):
|
def __init__(self, input_tensor):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target
|
input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target
|
||||||
should_replace(bool): added this for interface consistency. NewTensor should only appear in the target.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: raise if should_replace is True
|
|
||||||
TypeError: raise type error for invalid argument.
|
TypeError: raise type error for invalid argument.
|
||||||
"""
|
"""
|
||||||
if should_replace:
|
|
||||||
raise ValueError("NewTensor should only appear in the target, thus should_replace can only be False.")
|
|
||||||
self.input_tensor = input_tensor
|
self.input_tensor = input_tensor
|
||||||
if isinstance(input_tensor, Tensor):
|
if isinstance(input_tensor, Tensor):
|
||||||
NewTensor_.__init__(self, input_tensor)
|
NewTensor_.__init__(self, input_tensor)
|
||||||
|
@ -194,15 +172,13 @@ class NewParameter(NewParameter_):
|
||||||
r"""
|
r"""
|
||||||
New Parameter to be used in the target.
|
New Parameter to be used in the target.
|
||||||
"""
|
"""
|
||||||
def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False, should_replace=False):
|
def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
para_name(str): name for the new Parameter
|
para_name(str): name for the new Parameter
|
||||||
default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter
|
default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter
|
||||||
requires_grad(bool): True if the parameter requires gradient. Default: True
|
requires_grad(bool): True if the parameter requires gradient. Default: True
|
||||||
layerwise_parallel(bool): switch for layerwise parallel mode. Default: False
|
layerwise_parallel(bool): switch for layerwise parallel mode. Default: False
|
||||||
should_replace(bool): gen new parameter once and replace after if set to be true; otherwise build a new
|
|
||||||
parameter everytime a pass target got built. Default: False
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: raise type error for invalid argument.
|
TypeError: raise type error for invalid argument.
|
||||||
|
@ -211,12 +187,11 @@ class NewParameter(NewParameter_):
|
||||||
self.default_tensor = default_tensor
|
self.default_tensor = default_tensor
|
||||||
self.requires_grad = requires_grad
|
self.requires_grad = requires_grad
|
||||||
self.layerwise_parallel = layerwise_parallel
|
self.layerwise_parallel = layerwise_parallel
|
||||||
self.should_replace = should_replace
|
|
||||||
if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\
|
if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\
|
||||||
isinstance(layerwise_parallel, bool) and isinstance(should_replace, bool):
|
isinstance(layerwise_parallel, bool):
|
||||||
NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad,
|
NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad,
|
||||||
self.layerwise_parallel, self.should_replace)
|
self.layerwise_parallel)
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \
|
raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \
|
||||||
layerwise_parallel(bool) should_replace(bool), got : {para_name}, {default_tensor}, \
|
layerwise_parallel(bool), got : {para_name}, {default_tensor}, \
|
||||||
{requires_grad}, {layerwise_parallel}, {should_replace}")
|
{requires_grad}, {layerwise_parallel}")
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
"""Python pass register"""
|
"""Python pass register"""
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
from mindspore.graph_utils.graph_pattern import Pattern, NewParameter
|
from mindspore.graph_utils.graph_pattern import Pattern, NewParameter
|
||||||
from mindspore._c_expression import PyPassManager_, phase
|
from mindspore._c_expression import PyPassManager_
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -30,21 +30,16 @@ class PyPassManager(PyPassManager_):
|
||||||
Used to registe and unregiste python passes which can be used to alter graphs.
|
Used to registe and unregiste python passes which can be used to alter graphs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt.
|
|
||||||
run_only_once (bool): Specify whether or not to run pass only once. Default: False.
|
run_only_once (bool): Specify whether or not to run pass only once. Default: False.
|
||||||
multigraph (bool): Whether or not the pattern exists across graphs. Default: True.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If argument has invalid type.
|
TypeError: If argument has invalid type.
|
||||||
"""
|
"""
|
||||||
def __init__(self, pipeline_phase=phase.opt, run_only_once=False):
|
def __init__(self, run_only_once=False):
|
||||||
if not isinstance(pipeline_phase, phase):
|
|
||||||
raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}")
|
|
||||||
if not isinstance(run_only_once, bool):
|
if not isinstance(run_only_once, bool):
|
||||||
raise TypeError(f"Expect bool, got : ({type(run_only_once)}){run_only_once}")
|
raise TypeError(f"Expect bool, got : ({type(run_only_once)}){run_only_once}")
|
||||||
PyPassManager_.__init__(self)
|
|
||||||
self.phase_ = pipeline_phase
|
|
||||||
self.run_only_once_ = run_only_once
|
self.run_only_once_ = run_only_once
|
||||||
|
PyPassManager_.__init__(self)
|
||||||
|
|
||||||
def registe(self, py_pass):
|
def registe(self, py_pass):
|
||||||
if not isfunction(py_pass):
|
if not isfunction(py_pass):
|
||||||
|
@ -55,16 +50,14 @@ class PyPassManager(PyPassManager_):
|
||||||
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
|
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
|
||||||
if not isinstance(target, Pattern):
|
if not isinstance(target, Pattern):
|
||||||
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
|
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
|
||||||
super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_)
|
super().registe(pass_name, pattern, target, self.run_only_once_)
|
||||||
|
|
||||||
def unregiste(self, py_pass, pipeline_phase=phase.opt):
|
def unregiste(self, py_pass):
|
||||||
if not isinstance(pipeline_phase, phase):
|
|
||||||
raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}")
|
|
||||||
if isinstance(py_pass, str):
|
if isinstance(py_pass, str):
|
||||||
super().unregiste(py_pass, pipeline_phase)
|
super().unregiste(py_pass)
|
||||||
return
|
return
|
||||||
if isfunction(py_pass):
|
if isfunction(py_pass):
|
||||||
super().unregiste(py_pass.__name__, pipeline_phase)
|
super().unregiste(py_pass.__name__)
|
||||||
return
|
return
|
||||||
raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}")
|
raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}")
|
||||||
|
|
||||||
|
@ -82,13 +75,11 @@ class PyPassManager(PyPassManager_):
|
||||||
raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}")
|
raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}")
|
||||||
super().set_renorm(should_renorm)
|
super().set_renorm(should_renorm)
|
||||||
|
|
||||||
def registe_pass(pipeline_phase=phase.opt, run_only_once=False):
|
def registe_pass(run_only_once=False):
|
||||||
"""
|
"""
|
||||||
Registe python pass to specified pipeline phase which would be used in compilation.
|
Registe python pass to specified pipeline phase which would be used in compilation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is
|
|
||||||
registed. Support phase.resolve and phase.opt. Default: phase.opt.
|
|
||||||
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False.
|
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -102,19 +93,17 @@ def registe_pass(pipeline_phase=phase.opt, run_only_once=False):
|
||||||
>>> target = IsPrimTypeOf("ReLU6")
|
>>> target = IsPrimTypeOf("ReLU6")
|
||||||
>>> return pattern, target
|
>>> return pattern, target
|
||||||
"""
|
"""
|
||||||
return PyPassManager(pipeline_phase, run_only_once)
|
return PyPassManager(run_only_once)
|
||||||
|
|
||||||
def unregiste_pass(py_pass, pipeline_phase=phase.opt):
|
def unregiste_pass(py_pass):
|
||||||
"""
|
"""
|
||||||
Unregiste python pass.
|
Unregiste python pass.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
py_pass(Union(str, function)): target python pass to unregiste.
|
py_pass(Union(str, function)): target python pass to unregiste.
|
||||||
pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is
|
|
||||||
unregisted. Support phase.resolve and phase.opt. Default: phase.opt.
|
|
||||||
"""
|
"""
|
||||||
ppm = PyPassManager()
|
ppm = PyPassManager()
|
||||||
ppm.unregiste(py_pass, pipeline_phase)
|
ppm.unregiste(py_pass)
|
||||||
|
|
||||||
def gen_new_parameter(pattern):
|
def gen_new_parameter(pattern):
|
||||||
"""
|
"""
|
||||||
|
@ -164,7 +153,14 @@ def cancel_new_parameter(pattern):
|
||||||
|
|
||||||
def set_renorm(should_renorm):
|
def set_renorm(should_renorm):
|
||||||
"""
|
"""
|
||||||
Set whether or not to do renorm after modified graph in python pass(es).
|
Set whether or not to do renormalization after modified graph in python pass(es).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
should_renorm(bool): whether or not to do renormalization after modified graph in python pass(es).
|
||||||
|
|
||||||
|
NOTE:
|
||||||
|
This interface is mainly intended for testing modifying graph without worrying about its validity. Turn off
|
||||||
|
renormalization may BREAK the network.
|
||||||
"""
|
"""
|
||||||
ppm = PyPassManager()
|
ppm = PyPassManager()
|
||||||
ppm.set_renorm(should_renorm)
|
ppm.set_renorm(should_renorm)
|
||||||
|
|
|
@ -23,8 +23,7 @@ from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_
|
||||||
cancel_new_parameter
|
cancel_new_parameter
|
||||||
from mindspore.common.api import _generate_pip_args
|
from mindspore.common.api import _generate_pip_args
|
||||||
from mindspore._c_expression import generate_key, Executor_
|
from mindspore._c_expression import generate_key, Executor_
|
||||||
from mindspore.graph_utils.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor,\
|
from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm
|
||||||
NewParameter, Imm
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
@ -50,11 +49,9 @@ def test_softmax_relu():
|
||||||
|
|
||||||
@registe_pass(run_only_once=True)
|
@registe_pass(run_only_once=True)
|
||||||
def softmax_relu_pass():
|
def softmax_relu_pass():
|
||||||
x = AnyPattern()
|
x = Any()
|
||||||
softmax_pattern = IsPrimTypeOf(P.Softmax())
|
pattern = Call(P.Softmax(), inputs=[x])
|
||||||
pattern = CallWith(softmax_pattern, inputs=[x])
|
target = Call(P.ReLU(), inputs=[x])
|
||||||
relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False)
|
|
||||||
target = CallWith(relu_pattern, inputs=[x])
|
|
||||||
return pattern, target
|
return pattern, target
|
||||||
|
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||||
|
@ -74,13 +71,13 @@ def test_softmax_relu_sigmoid():
|
||||||
|
|
||||||
@registe_pass(run_only_once=True)
|
@registe_pass(run_only_once=True)
|
||||||
def softmax_relu_pass():
|
def softmax_relu_pass():
|
||||||
x = AnyPattern()
|
x = Any()
|
||||||
softmax_pattern = IsPrimTypeOf(P.Softmax())
|
softmax_pattern = Prim(P.Softmax())
|
||||||
pattern = CallWith(softmax_pattern, inputs=[x])
|
pattern = Call(softmax_pattern, inputs=[x])
|
||||||
sigmoid_pattern = IsPrimTypeOf(P.Sigmoid(), should_replace=False)
|
sigmoid_pattern = Prim(P.Sigmoid())
|
||||||
call_sigmoid = CallWith(sigmoid_pattern, [x])
|
call_sigmoid = Call(sigmoid_pattern, [x])
|
||||||
relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False)
|
relu_pattern = Prim(P.ReLU())
|
||||||
target = CallWith(relu_pattern, inputs=[call_sigmoid])
|
target = Call(relu_pattern, inputs=[call_sigmoid])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
|
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
|
||||||
|
@ -99,15 +96,15 @@ def test_isin_pattern_0():
|
||||||
|
|
||||||
@registe_pass(run_only_once=True)
|
@registe_pass(run_only_once=True)
|
||||||
def softmax_relu_pass():
|
def softmax_relu_pass():
|
||||||
x = AnyPattern()
|
x = Any()
|
||||||
softmax_pattern = IsPrimTypeOf(P.Softmax())
|
softmax_pattern = Prim(P.Softmax())
|
||||||
call_softmax = CallWith(softmax_pattern, inputs=[x])
|
call_softmax = Call(softmax_pattern, inputs=[x])
|
||||||
relu_pattern = IsPrimTypeOf(P.ReLU())
|
relu_pattern = Prim(P.ReLU())
|
||||||
call_relu = CallWith(relu_pattern, inputs=[x])
|
call_relu = Call(relu_pattern, inputs=[x])
|
||||||
|
|
||||||
pattern = IsIn([call_softmax, call_relu])
|
pattern = OneOf([call_softmax, call_relu])
|
||||||
relu6_pattern = IsPrimTypeOf(P.ReLU6(), should_replace=False)
|
relu6_pattern = Prim(P.ReLU6())
|
||||||
target = CallWith(relu6_pattern, inputs=[x])
|
target = Call(relu6_pattern, inputs=[x])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||||
unregiste_pass(softmax_relu_pass)
|
unregiste_pass(softmax_relu_pass)
|
||||||
|
@ -123,18 +120,17 @@ def test_isin_pattern_1():
|
||||||
|
|
||||||
@registe_pass(run_only_once=True)
|
@registe_pass(run_only_once=True)
|
||||||
def softmax_neg_pass():
|
def softmax_neg_pass():
|
||||||
x = AnyPattern()
|
x = Any()
|
||||||
softmax_pattern = IsPrimTypeOf(P.Softmax())
|
softmax_pattern = Prim(P.Softmax())
|
||||||
call_softmax = CallWith(softmax_pattern, inputs=[x])
|
call_softmax = Call(softmax_pattern, inputs=[x])
|
||||||
relu_pattern = IsPrimTypeOf(P.ReLU())
|
relu_pattern = Prim(P.ReLU())
|
||||||
call_relu = CallWith(relu_pattern, inputs=[x])
|
call_relu = Call(relu_pattern, inputs=[x])
|
||||||
|
|
||||||
pattern = IsIn([call_softmax, call_relu])
|
pattern = OneOf([call_softmax, call_relu])
|
||||||
neg_ops = IsPrimTypeOf(P.Neg(), should_replace=False)
|
neg_ops = Prim(P.Neg())
|
||||||
target = CallWith(neg_ops, inputs=[pattern])
|
target = Call(neg_ops, inputs=[pattern])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
|
||||||
print(transformed_repr)
|
|
||||||
unregiste_pass(softmax_neg_pass)
|
unregiste_pass(softmax_neg_pass)
|
||||||
assert "Neg" in transformed_repr
|
assert "Neg" in transformed_repr
|
||||||
assert "Softmax" in transformed_repr
|
assert "Softmax" in transformed_repr
|
||||||
|
@ -167,11 +163,11 @@ def test_isnot_pattern_0():
|
||||||
"""
|
"""
|
||||||
Sub a BN which does NOT take Conv as inputs to ReLU6.
|
Sub a BN which does NOT take Conv as inputs to ReLU6.
|
||||||
"""
|
"""
|
||||||
conv2d_prim = IsPrimTypeOf("Conv2D")
|
conv2d_prim = Prim("Conv2D")
|
||||||
conv2d = CallWith(conv2d_prim)
|
conv2d = Call(conv2d_prim)
|
||||||
pattern_0 = IsNot(conv2d)
|
pattern_0 = NoneOf(conv2d)
|
||||||
pattern = CallWith(P.BatchNorm(), inputs=[pattern_0])
|
pattern = Call(P.BatchNorm(), inputs=[pattern_0])
|
||||||
target = CallWith(P.ReLU6(), inputs=[pattern_0])
|
target = Call(P.ReLU6(), inputs=[pattern_0])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
|
|
||||||
@registe_pass(run_only_once=True)
|
@registe_pass(run_only_once=True)
|
||||||
|
@ -179,10 +175,8 @@ def test_isnot_pattern_0():
|
||||||
"""
|
"""
|
||||||
Sub a BN to Softmax.
|
Sub a BN to Softmax.
|
||||||
"""
|
"""
|
||||||
bn = P.BatchNorm()
|
pattern = Call(P.BatchNorm())
|
||||||
pattern = CallWith(bn)
|
target = Call(P.Softmax())
|
||||||
softmax = P.Softmax()
|
|
||||||
target = CallWith(softmax, should_replace=False)
|
|
||||||
return pattern, target
|
return pattern, target
|
||||||
|
|
||||||
transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5)
|
transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5)
|
||||||
|
@ -205,12 +199,12 @@ def test_isnot_pattern_1():
|
||||||
"""
|
"""
|
||||||
Sub a BN which does NOT take MatMul as inputs to ReLU6.
|
Sub a BN which does NOT take MatMul as inputs to ReLU6.
|
||||||
"""
|
"""
|
||||||
matmul = IsPrimTypeOf("MatMul")
|
matmul = Prim("MatMul")
|
||||||
pattern_0 = IsNot(matmul)
|
pattern_0 = NoneOf(matmul)
|
||||||
softmax = P.Softmax()
|
softmax = P.Softmax()
|
||||||
pattern = CallWith(softmax, inputs=[pattern_0])
|
pattern = Call(softmax, inputs=[pattern_0])
|
||||||
relu6 = P.ReLU6()
|
relu6 = P.ReLU6()
|
||||||
target = CallWith(relu6, inputs=[pattern_0], should_replace=False)
|
target = Call(relu6, inputs=[pattern_0])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
|
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||||
|
@ -228,14 +222,12 @@ def test_newtensor_pattern():
|
||||||
|
|
||||||
@registe_pass(run_only_once=True)
|
@registe_pass(run_only_once=True)
|
||||||
def softmax_addn_pass():
|
def softmax_addn_pass():
|
||||||
x = AnyPattern()
|
x = Any()
|
||||||
softmax = P.Softmax()
|
pattern = Call(P.Softmax(), inputs=[x])
|
||||||
pattern = CallWith(softmax, inputs=[x])
|
|
||||||
|
|
||||||
weight_tensor = Tensor(np.zeros([42]), mindspore.float16)
|
weight_tensor = Tensor(np.zeros([42]), mindspore.float16)
|
||||||
new_weight = NewTensor(weight_tensor)
|
new_weight = NewTensor(weight_tensor)
|
||||||
addn_ops = P.AddN()
|
target = Call(P.AddN(), inputs=[x, new_weight])
|
||||||
target = CallWith(addn_ops, inputs=[x, new_weight], should_replace=False)
|
|
||||||
return pattern, target
|
return pattern, target
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||||
unregiste_pass(softmax_addn_pass)
|
unregiste_pass(softmax_addn_pass)
|
||||||
|
@ -252,25 +244,23 @@ def test_newparameter_pattern():
|
||||||
|
|
||||||
@registe_pass(run_only_once=True)
|
@registe_pass(run_only_once=True)
|
||||||
def softmax_addn_pass():
|
def softmax_addn_pass():
|
||||||
x = AnyPattern()
|
x = Any()
|
||||||
softmax = P.Softmax()
|
pattern = Call(P.Softmax(), inputs=[x])
|
||||||
pattern = CallWith(softmax, inputs=[x])
|
|
||||||
|
|
||||||
default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32)
|
default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32)
|
||||||
default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32)
|
default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32)
|
||||||
new_para_0 = NewParameter("Merlin", default_tensor0)
|
new_para_0 = NewParameter("Merlin", default_tensor0)
|
||||||
new_para_1 = NewParameter("Arthur", default_tensor1)
|
new_para_1 = NewParameter("Arthur", default_tensor1)
|
||||||
target_0 = CallWith(P.MatMul(), inputs=[new_para_0, new_para_1], should_replace=False)
|
target_0 = Call(P.MatMul(), inputs=[new_para_0, new_para_1])
|
||||||
target = CallWith("make_tuple", inputs=[target_0], should_replace=False)
|
target = Call("make_tuple", inputs=[target_0])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||||
print(transformed_repr)
|
|
||||||
unregiste_pass(softmax_addn_pass)
|
unregiste_pass(softmax_addn_pass)
|
||||||
assert "MatMul" in transformed_repr
|
assert "MatMul" in transformed_repr
|
||||||
assert "make_tuple" in transformed_repr
|
assert "make_tuple" in transformed_repr
|
||||||
assert "Softmax" not in transformed_repr
|
assert "Softmax" not in transformed_repr
|
||||||
|
|
||||||
def test_imm_pattern():
|
def test_imm_target():
|
||||||
"""
|
"""
|
||||||
Test NewParameter pattern in the target
|
Test NewParameter pattern in the target
|
||||||
"""
|
"""
|
||||||
|
@ -278,17 +268,15 @@ def test_imm_pattern():
|
||||||
softmax_model = nn.Softmax()
|
softmax_model = nn.Softmax()
|
||||||
|
|
||||||
@registe_pass(run_only_once=True)
|
@registe_pass(run_only_once=True)
|
||||||
def softmax_addn_pass():
|
def softmax_pass():
|
||||||
x = AnyPattern()
|
x = Any()
|
||||||
softmax = P.Softmax()
|
pattern = Call(P.Softmax(), inputs=[x])
|
||||||
pattern = CallWith(softmax, inputs=[x])
|
|
||||||
imm = Imm(0)
|
imm = Imm(0)
|
||||||
target_0 = CallWith("make_tuple", inputs=[pattern], should_replace=False)
|
target_0 = Call("make_tuple", inputs=[pattern])
|
||||||
target = CallWith("tuple_getitem", inputs=[target_0, imm], should_replace=False)
|
target = Call("tuple_getitem", inputs=[target_0, imm])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||||
print(transformed_repr)
|
unregiste_pass(softmax_pass)
|
||||||
unregiste_pass(softmax_addn_pass)
|
|
||||||
assert "make_tuple" in transformed_repr
|
assert "make_tuple" in transformed_repr
|
||||||
assert "tuple_getitem" in transformed_repr
|
assert "tuple_getitem" in transformed_repr
|
||||||
assert "Softmax" in transformed_repr
|
assert "Softmax" in transformed_repr
|
||||||
|
@ -301,21 +289,19 @@ def test_gen_new_parameter():
|
||||||
softmax_model = nn.Softmax()
|
softmax_model = nn.Softmax()
|
||||||
|
|
||||||
default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
|
default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
|
||||||
new_para = NewParameter("Merlin", default_tensor, should_replace=True)
|
new_para = NewParameter("Merlin", default_tensor)
|
||||||
gen_new_parameter(new_para)
|
gen_new_parameter(new_para)
|
||||||
@registe_pass(run_only_once=True)
|
@registe_pass(run_only_once=True)
|
||||||
def softmax_make_tuple_pass():
|
def softmax_make_tuple_pass():
|
||||||
x = AnyPattern()
|
x = Any()
|
||||||
softmax = P.Softmax()
|
softmax = P.Softmax()
|
||||||
pattern = CallWith(softmax, inputs=[x])
|
pattern = Call(softmax, inputs=[x])
|
||||||
|
|
||||||
target = CallWith("make_tuple", inputs=[pattern, new_para], should_replace=False)
|
target = Call("make_tuple", inputs=[pattern, new_para])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||||
print(transformed_repr)
|
|
||||||
assert "Merlin" in transformed_repr
|
assert "Merlin" in transformed_repr
|
||||||
unregiste_pass(softmax_make_tuple_pass)
|
unregiste_pass(softmax_make_tuple_pass)
|
||||||
cancel_new_parameter(new_para)
|
cancel_new_parameter(new_para)
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||||
print(transformed_repr)
|
|
||||||
assert "Merlin" not in transformed_repr
|
assert "Merlin" not in transformed_repr
|
||||||
|
|
Loading…
Reference in New Issue