python pass pattern renaming and interface tweaking
This commit is contained in:
parent
3a16925fa2
commit
641d12d6d9
|
@ -21,25 +21,23 @@ namespace opt {
|
|||
namespace python_pass {
|
||||
int Pattern::g_id_ = 0;
|
||||
|
||||
MatchResultPtr IsPrimTypeOf::match(const AnfNodePtr &node) {
|
||||
MatchResultPtr Prim::match(const AnfNodePtr &node) {
|
||||
if (!IsValueNode<Primitive>(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
MatchResultPtr res = std::make_shared<MatchResult>();
|
||||
if (IsValueNode<Primitive>(node)) {
|
||||
// iterate over all primitives
|
||||
for (auto &iter : primitives_) {
|
||||
if (IsPrimitive(node, iter) || iter->name() == "*") {
|
||||
matched_prim_ = iter;
|
||||
res->add_entry(shared_from_base<IsPrimTypeOf>(), node);
|
||||
res->add_entry(shared_from_base<Prim>(), node);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MatchResultPtr CallWith::match(const AnfNodePtr &node) {
|
||||
MatchResultPtr Call::match(const AnfNodePtr &node) {
|
||||
if (!IsPrimitiveCNode(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -71,7 +69,7 @@ MatchResultPtr CallWith::match(const AnfNodePtr &node) {
|
|||
}
|
||||
// If inputs is not specified, add node without looking into its inputs
|
||||
if (p_inputs_size == 0) {
|
||||
res->add_entry(shared_from_base<CallWith>(), cnode->input(0));
|
||||
res->add_entry(shared_from_base<Call>(), cnode->input(0));
|
||||
return res;
|
||||
}
|
||||
bool failed = false;
|
||||
|
@ -86,24 +84,24 @@ MatchResultPtr CallWith::match(const AnfNodePtr &node) {
|
|||
res->merge(input_match_result);
|
||||
}
|
||||
if (!failed) {
|
||||
res->add_entry(shared_from_base<CallWith>(), cnode->input(0));
|
||||
res->add_entry(shared_from_base<Call>(), cnode->input(0));
|
||||
return res;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MatchResultPtr IsIn::match(const AnfNodePtr &node) {
|
||||
MatchResultPtr OneOf::match(const AnfNodePtr &node) {
|
||||
for (auto &iter : patterns_) {
|
||||
auto res = iter->match(node);
|
||||
if (res != nullptr) {
|
||||
res->add_entry(shared_from_base<IsIn>(), node);
|
||||
res->add_entry(shared_from_base<OneOf>(), node);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MatchResultPtr IsNot::match(const AnfNodePtr &node) {
|
||||
MatchResultPtr NoneOf::match(const AnfNodePtr &node) {
|
||||
for (auto &iter : patterns_) {
|
||||
auto res = iter->match(node);
|
||||
if (res != nullptr) {
|
||||
|
@ -111,16 +109,33 @@ MatchResultPtr IsNot::match(const AnfNodePtr &node) {
|
|||
}
|
||||
}
|
||||
auto res = std::make_shared<MatchResult>();
|
||||
res->add_entry(shared_from_base<IsNot>(), node);
|
||||
res->add_entry(shared_from_base<NoneOf>(), node);
|
||||
return res;
|
||||
}
|
||||
|
||||
MatchResultPtr AnyPattern::match(const AnfNodePtr &node) {
|
||||
MatchResultPtr Any::match(const AnfNodePtr &node) {
|
||||
MatchResultPtr res = std::make_shared<MatchResult>();
|
||||
res->add_entry(shared_from_base<AnyPattern>(), node);
|
||||
res->add_entry(shared_from_base<Any>(), node);
|
||||
return res;
|
||||
}
|
||||
|
||||
MatchResultPtr Imm::match(const AnfNodePtr &node) {
|
||||
if (!IsValueNode<Int32Imm>(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
// Check value
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto value_ptr = value_node->value()->cast<Int32ImmPtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
if ((int32_t)value_ptr->value() == value_) {
|
||||
MatchResultPtr res = std::make_shared<MatchResult>();
|
||||
res->add_entry(shared_from_base<Imm>(), node);
|
||||
return res;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr MatchResult::get_node(const PatternPtr &pattern) {
|
||||
auto entry = match_result_.find(pattern);
|
||||
if (entry == match_result_.end()) {
|
||||
|
@ -140,20 +155,20 @@ void MatchResult::merge(const MatchResultPtr &other_result) {
|
|||
REGISTER_PYBIND_DEFINE(
|
||||
Pattern, ([](const py::module *m) {
|
||||
(void)py::class_<Pattern, std::shared_ptr<Pattern>>(*m, "Pattern").def(py::init<>());
|
||||
(void)py::class_<IsIn, std::shared_ptr<IsIn>, Pattern>(*m, "IsIn_").def(py::init<vector<PatternPtr>>());
|
||||
(void)py::class_<IsPrimTypeOf, std::shared_ptr<IsPrimTypeOf>, Pattern>(*m, "IsPrimTypeOf_", py::dynamic_attr())
|
||||
.def(py::init<vector<PrimitivePyPtr>, string, bool>())
|
||||
.def(py::init<vector<string>, string, bool>());
|
||||
(void)py::class_<CallWith, std::shared_ptr<CallWith>, Pattern>(*m, "CallWith_")
|
||||
.def(py::init<PatternPtr, vector<PatternPtr>, bool>())
|
||||
.def(py::init<PrimitivePyPtr, vector<PatternPtr>, bool>())
|
||||
.def(py::init<string, vector<PatternPtr>, bool>());
|
||||
(void)py::class_<IsNot, std::shared_ptr<IsNot>, Pattern>(*m, "IsNot_").def(py::init<vector<PatternPtr>>());
|
||||
(void)py::class_<AnyPattern, std::shared_ptr<AnyPattern>, Pattern>(*m, "AnyPattern").def(py::init<>());
|
||||
(void)py::class_<OneOf, std::shared_ptr<OneOf>, Pattern>(*m, "OneOf_").def(py::init<vector<PatternPtr>>());
|
||||
(void)py::class_<Prim, std::shared_ptr<Prim>, Pattern>(*m, "Prim_", py::dynamic_attr())
|
||||
.def(py::init<vector<PrimitivePyPtr>, string>())
|
||||
.def(py::init<vector<string>, string>());
|
||||
(void)py::class_<Call, std::shared_ptr<Call>, Pattern>(*m, "Call_")
|
||||
.def(py::init<PatternPtr, vector<PatternPtr>>())
|
||||
.def(py::init<PrimitivePyPtr, vector<PatternPtr>>())
|
||||
.def(py::init<string, vector<PatternPtr>>());
|
||||
(void)py::class_<NoneOf, std::shared_ptr<NoneOf>, Pattern>(*m, "NoneOf_").def(py::init<vector<PatternPtr>>());
|
||||
(void)py::class_<Any, std::shared_ptr<Any>, Pattern>(*m, "Any").def(py::init<>());
|
||||
(void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_")
|
||||
.def(py::init<tensor::TensorPtr>());
|
||||
(void)py::class_<NewParameter, std::shared_ptr<NewParameter>, Pattern>(*m, "NewParameter_")
|
||||
.def(py::init<string, tensor::TensorPtr, bool, bool, bool>());
|
||||
.def(py::init<string, tensor::TensorPtr, bool, bool>());
|
||||
(void)py::class_<Imm, std::shared_ptr<Imm>, Pattern>(*m, "Imm").def(py::init<int>());
|
||||
}));
|
||||
} // namespace python_pass
|
||||
|
|
|
@ -36,10 +36,10 @@ class MatchResult;
|
|||
using MatchResultPtr = std::shared_ptr<MatchResult>;
|
||||
class Pattern;
|
||||
using PatternPtr = std::shared_ptr<Pattern>;
|
||||
class IsPrimTypeOf;
|
||||
using IsPrimTypeOfPtr = std::shared_ptr<IsPrimTypeOf>;
|
||||
class CallWith;
|
||||
using CallWithPtr = std::shared_ptr<CallWith>;
|
||||
class Prim;
|
||||
using PrimPtr = std::shared_ptr<Prim>;
|
||||
class Call;
|
||||
using CallPtr = std::shared_ptr<Call>;
|
||||
class NewTensor;
|
||||
using NewTensorPtr = std::shared_ptr<NewTensor>;
|
||||
class NewParameter;
|
||||
|
@ -58,8 +58,6 @@ class Pattern : public Base {
|
|||
virtual bool operator==(const Pattern &other) const { return unique_name_ == other.unique_name_; }
|
||||
string unique_name() const { return unique_name_; }
|
||||
vector<PatternPtr> inputs() { return inputs_; }
|
||||
bool should_replace() { return should_replace_; }
|
||||
void set_should_replace(bool should_replace) { should_replace_ = should_replace; }
|
||||
virtual void reset() {}
|
||||
|
||||
protected:
|
||||
|
@ -67,7 +65,6 @@ class Pattern : public Base {
|
|||
// NOTE: To ensure uniqueness of the name, raise g_id_ by 1 every time a pattern got constructed
|
||||
string unique_name_;
|
||||
vector<PatternPtr> inputs_;
|
||||
bool should_replace_ = true;
|
||||
};
|
||||
|
||||
struct PatternEqual {
|
||||
|
@ -85,70 +82,61 @@ struct PatternHasher {
|
|||
}
|
||||
};
|
||||
|
||||
class IsPrimTypeOf : public Pattern {
|
||||
class Prim : public Pattern {
|
||||
public:
|
||||
IsPrimTypeOf() { unique_name_ = std::to_string(g_id_++); }
|
||||
~IsPrimTypeOf() = default;
|
||||
IsPrimTypeOf(vector<PrimitivePyPtr> prims, string name, bool should_replace)
|
||||
: primitives_(prims), name_(name), matched_prim_(nullptr) {
|
||||
unique_name_ = std::to_string(g_id_++) + "IsPrimTypeOf_" + name;
|
||||
should_replace_ = should_replace;
|
||||
if (!should_replace) {
|
||||
matched_prim_ = prims[0];
|
||||
Prim() { unique_name_ = std::to_string(g_id_++); }
|
||||
~Prim() = default;
|
||||
Prim(vector<PrimitivePyPtr> prims, string name) : primitives_(prims), name_(name) {
|
||||
unique_name_ = std::to_string(g_id_++) + "Prim_" + name;
|
||||
// Default using the first prim to build target
|
||||
matched_prim_ = primitives_[0];
|
||||
}
|
||||
}
|
||||
IsPrimTypeOf(vector<string> types, string name, bool should_replace) : types_(types), name_(name) {
|
||||
unique_name_ = std::to_string(g_id_++) + "IsPrimTypeOf_" + name;
|
||||
Prim(vector<string> types, string name) : types_(types), name_(name) {
|
||||
unique_name_ = std::to_string(g_id_++) + "Prim_" + name;
|
||||
// Make primitives_
|
||||
for (auto &iter : types) {
|
||||
primitives_.push_back(std::make_shared<PrimitivePy>(iter, py::cast(nullptr)));
|
||||
}
|
||||
should_replace_ = should_replace;
|
||||
if (!should_replace) {
|
||||
// Default using the first prim to build target
|
||||
matched_prim_ = primitives_[0];
|
||||
}
|
||||
}
|
||||
MS_DECLARE_PARENT(IsPrimTypeOf, Pattern);
|
||||
MS_DECLARE_PARENT(Prim, Pattern);
|
||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||
PrimitivePyPtr matched_primitive() { return matched_prim_; }
|
||||
void reset() override {
|
||||
if (should_replace_) {
|
||||
matched_prim_ = nullptr;
|
||||
}
|
||||
// Init before reset
|
||||
MS_EXCEPTION_IF_NULL(matched_prim_);
|
||||
matched_prim_ = primitives_[0];
|
||||
}
|
||||
|
||||
private:
|
||||
vector<string> types_;
|
||||
vector<PrimitivePyPtr> primitives_;
|
||||
string name_;
|
||||
PrimitivePyPtr matched_prim_;
|
||||
PrimitivePyPtr matched_prim_{nullptr};
|
||||
};
|
||||
|
||||
class CallWith : public Pattern {
|
||||
class Call : public Pattern {
|
||||
public:
|
||||
CallWith() { unique_name_ = std::to_string(g_id_++); }
|
||||
~CallWith() = default;
|
||||
CallWith(PatternPtr prim_pattern, vector<PatternPtr> inputs, bool should_replace) {
|
||||
Call() { unique_name_ = std::to_string(g_id_++); }
|
||||
~Call() = default;
|
||||
Call(PatternPtr prim_pattern, vector<PatternPtr> inputs) {
|
||||
// NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting
|
||||
prim_pattern_ = prim_pattern;
|
||||
unique_name_ = std::to_string(g_id_++) + "CallWithPattern_" + prim_pattern->unique_name();
|
||||
unique_name_ = std::to_string(g_id_++) + "Call_" + prim_pattern->unique_name();
|
||||
inputs_ = inputs;
|
||||
// NOTE: should_replace_ is overrided by it prim_pattern(if exists) silently.
|
||||
should_replace_ = prim_pattern->should_replace();
|
||||
}
|
||||
CallWith(PrimitivePyPtr prim, vector<PatternPtr> inputs, bool should_replace) {
|
||||
Call(PrimitivePyPtr prim, vector<PatternPtr> inputs) {
|
||||
prim_ = prim;
|
||||
unique_name_ = std::to_string(g_id_++) + "CallWithPrim_" + prim_->ToString();
|
||||
unique_name_ = std::to_string(g_id_++) + "Call_" + prim_->ToString();
|
||||
inputs_ = inputs;
|
||||
should_replace_ = should_replace;
|
||||
}
|
||||
CallWith(string prim_str, vector<PatternPtr> inputs, bool should_replace) {
|
||||
Call(string prim_str, vector<PatternPtr> inputs) {
|
||||
prim_ = std::make_shared<PrimitivePy>(prim_str, py::cast(nullptr));
|
||||
unique_name_ = std::to_string(g_id_++) + "CallWithStr_" + prim_->ToString();
|
||||
unique_name_ = std::to_string(g_id_++) + "CallStr_" + prim_->ToString();
|
||||
inputs_ = inputs;
|
||||
should_replace_ = should_replace;
|
||||
}
|
||||
MS_DECLARE_PARENT(CallWith, Pattern);
|
||||
MS_DECLARE_PARENT(Call, Pattern);
|
||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||
PrimitivePtr prim_value() { return prim_; }
|
||||
PatternPtr prim_pattern() { return prim_pattern_; }
|
||||
|
@ -160,45 +148,45 @@ class CallWith : public Pattern {
|
|||
string name_;
|
||||
};
|
||||
|
||||
class IsIn : public Pattern {
|
||||
class OneOf : public Pattern {
|
||||
public:
|
||||
IsIn() { unique_name_ = std::to_string(g_id_++); }
|
||||
~IsIn() = default;
|
||||
explicit IsIn(vector<PatternPtr> patterns) : patterns_(patterns) {
|
||||
unique_name_ = std::to_string(g_id_++) + "IsIn";
|
||||
OneOf() { unique_name_ = std::to_string(g_id_++); }
|
||||
~OneOf() = default;
|
||||
explicit OneOf(vector<PatternPtr> patterns) : patterns_(patterns) {
|
||||
unique_name_ = std::to_string(g_id_++) + "OneOf";
|
||||
for (auto &iter : patterns) {
|
||||
unique_name_ = unique_name_ + "_" + iter->unique_name();
|
||||
}
|
||||
}
|
||||
MS_DECLARE_PARENT(IsIn, Pattern);
|
||||
MS_DECLARE_PARENT(OneOf, Pattern);
|
||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||
|
||||
private:
|
||||
vector<PatternPtr> patterns_;
|
||||
};
|
||||
|
||||
class IsNot : public Pattern {
|
||||
class NoneOf : public Pattern {
|
||||
public:
|
||||
IsNot() { unique_name_ = std::to_string(g_id_++); }
|
||||
~IsNot() = default;
|
||||
explicit IsNot(vector<PatternPtr> patterns) : patterns_(patterns) {
|
||||
unique_name_ = std::to_string(g_id_++) + "IsNot";
|
||||
NoneOf() { unique_name_ = std::to_string(g_id_++); }
|
||||
~NoneOf() = default;
|
||||
explicit NoneOf(vector<PatternPtr> patterns) : patterns_(patterns) {
|
||||
unique_name_ = std::to_string(g_id_++) + "NoneOf";
|
||||
for (auto &iter : patterns) {
|
||||
unique_name_ = unique_name_ + "_" + iter->unique_name();
|
||||
}
|
||||
}
|
||||
MS_DECLARE_PARENT(IsNot, Pattern);
|
||||
MS_DECLARE_PARENT(NoneOf, Pattern);
|
||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||
|
||||
private:
|
||||
vector<PatternPtr> patterns_;
|
||||
};
|
||||
|
||||
class AnyPattern : public Pattern {
|
||||
class Any : public Pattern {
|
||||
public:
|
||||
AnyPattern() { unique_name_ = std::to_string(g_id_++) + "_AnyPattern"; }
|
||||
~AnyPattern() = default;
|
||||
MS_DECLARE_PARENT(AnyPattern, Pattern);
|
||||
Any() { unique_name_ = std::to_string(g_id_++) + "_Any"; }
|
||||
~Any() = default;
|
||||
MS_DECLARE_PARENT(Any, Pattern);
|
||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||
};
|
||||
|
||||
|
@ -207,7 +195,6 @@ class NewTensor : public Pattern {
|
|||
NewTensor() { unique_name_ = std::to_string(g_id_++); }
|
||||
~NewTensor() = default;
|
||||
explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) {
|
||||
should_replace_ = false;
|
||||
unique_name_ = std::to_string(g_id_++) + "NewTensor";
|
||||
}
|
||||
MS_DECLARE_PARENT(NewTensor, Pattern);
|
||||
|
@ -223,10 +210,8 @@ class NewTensor : public Pattern {
|
|||
class NewParameter : public Pattern {
|
||||
public:
|
||||
NewParameter() { unique_name_ = std::to_string(g_id_++); }
|
||||
explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel,
|
||||
bool should_replace)
|
||||
explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel)
|
||||
: para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) {
|
||||
should_replace_ = should_replace;
|
||||
unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name;
|
||||
// clone input tensor
|
||||
default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get());
|
||||
|
@ -243,11 +228,14 @@ class NewParameter : public Pattern {
|
|||
bool built() { return built_; }
|
||||
void set_built(bool built) { built_ = built; }
|
||||
void reset() override { built_ = false; }
|
||||
bool should_last() { return last_across_passes_; }
|
||||
void set_last(bool last) { last_across_passes_ = last; }
|
||||
|
||||
private:
|
||||
string para_name_;
|
||||
bool requires_grad_;
|
||||
bool layerwise_parallel_;
|
||||
bool last_across_passes_{false};
|
||||
bool built_;
|
||||
tensor::TensorPtr default_tensor_;
|
||||
};
|
||||
|
@ -255,13 +243,9 @@ class NewParameter : public Pattern {
|
|||
class Imm : public Pattern {
|
||||
public:
|
||||
Imm() { unique_name_ = std::to_string(g_id_++); }
|
||||
explicit Imm(int value) : value_(value) {
|
||||
should_replace_ = false;
|
||||
unique_name_ = std::to_string(g_id_++) + "Imm_" + std::to_string(value);
|
||||
}
|
||||
explicit Imm(int value) : value_(value) { unique_name_ = std::to_string(g_id_++) + "Imm_" + std::to_string(value); }
|
||||
MS_DECLARE_PARENT(Imm, Pattern);
|
||||
// NOTE: Doesn't support Imm in src pattern currently.
|
||||
MatchResultPtr match(const AnfNodePtr &node) override { return nullptr; }
|
||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||
int value() { return value_; }
|
||||
|
||||
private:
|
||||
|
|
|
@ -80,7 +80,7 @@ bool IsTraversable(const AnfNodePtr &node) {
|
|||
|
||||
AnfNodePtr BuildPrimitive(const PatternPtr &pattern, const MatchResultPtr &res) {
|
||||
// Build up AnfNode from primitive
|
||||
auto prim_pattern = pattern->cast<IsPrimTypeOfPtr>();
|
||||
auto prim_pattern = pattern->cast<PrimPtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_pattern);
|
||||
PrimitivePyPtr prim = prim_pattern->matched_primitive();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
@ -98,13 +98,13 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res)
|
|||
}
|
||||
|
||||
AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg) {
|
||||
auto call_with_pattern = pattern->cast<CallWithPtr>();
|
||||
MS_EXCEPTION_IF_NULL(call_with_pattern);
|
||||
auto prim = call_with_pattern->prim_value();
|
||||
auto call_pattern = pattern->cast<CallPtr>();
|
||||
MS_EXCEPTION_IF_NULL(call_pattern);
|
||||
auto prim = call_pattern->prim_value();
|
||||
if (prim != nullptr) {
|
||||
return std::make_shared<ValueNode>(prim);
|
||||
}
|
||||
auto prim_pattern = call_with_pattern->prim_pattern();
|
||||
auto prim_pattern = call_pattern->prim_pattern();
|
||||
MS_EXCEPTION_IF_NULL(prim_pattern);
|
||||
return ProcessSinglePattern(prim_pattern, res, fg);
|
||||
}
|
||||
|
@ -152,45 +152,35 @@ AnfNodePtr BuildImmNode(const PatternPtr &pattern, const MatchResultPtr &res) {
|
|||
}
|
||||
|
||||
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) {
|
||||
if (pattern->should_replace()) {
|
||||
// Find replacement in the MatchResult
|
||||
auto target_node = res->get_node(pattern);
|
||||
if (target_node == nullptr) {
|
||||
// If it's base pattern(in contrast to complex pattern like CallWith/IsIn/IsNot), raise runtime exception.
|
||||
if (pattern->isa<IsPrimTypeOf>() || pattern->isa<NewTensor>() || pattern->isa<NewParameter>()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find target node, pattern: " + pattern->unique_name() + "\n";
|
||||
return nullptr;
|
||||
}
|
||||
// Try to build this pattern and add to MatchResult, since this pattern is defined inside target
|
||||
auto new_node = BuildTarget(pattern, func_graph, res);
|
||||
if (new_node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Try to build pattern node but FAILED. pattern: " + pattern->unique_name() + "\n";
|
||||
}
|
||||
return new_node;
|
||||
}
|
||||
if (pattern->isa<NewParameter>()) {
|
||||
if (target_node != nullptr) {
|
||||
// If pattern is NewParameter, check whether it shouldn't last and is not built
|
||||
auto new_para = pattern->cast<NewParameterPtr>();
|
||||
if (new_para == nullptr || new_para->should_last() || new_para->built()) {
|
||||
return target_node;
|
||||
}
|
||||
return target_node;
|
||||
}
|
||||
// Build up new node from pattern
|
||||
if (pattern->isa<IsPrimTypeOf>()) {
|
||||
if (pattern->isa<Prim>()) {
|
||||
return BuildPrimitive(pattern, res);
|
||||
} else if (pattern->isa<NewTensor>()) {
|
||||
return BuildNewTensor(pattern, res);
|
||||
} else if (pattern->isa<CallWith>()) {
|
||||
} else if (pattern->isa<Call>()) {
|
||||
return BuildPrimitiveValueNode(pattern, res, func_graph);
|
||||
} else if (pattern->isa<NewParameter>()) {
|
||||
return BuildNewParameter(pattern, res, func_graph);
|
||||
} else if (pattern->isa<Imm>()) {
|
||||
return BuildImmNode(pattern, res);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Cannot find or build target node, pattern: " + pattern->unique_name() + "\n";
|
||||
return nullptr;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res,
|
||||
const FuncGraphPtr &func_graph) {
|
||||
if (pattern->isa<CallWith>()) {
|
||||
if (pattern->isa<Call>()) {
|
||||
return BuildPrimitiveValueNode(pattern, res, func_graph);
|
||||
}
|
||||
return nullptr;
|
||||
|
@ -269,16 +259,16 @@ void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor
|
|||
}
|
||||
|
||||
void Reset(PatternPtr pattern) {
|
||||
if (pattern->isa<IsPrimTypeOf>()) {
|
||||
auto prim_pattern = pattern->cast<IsPrimTypeOfPtr>();
|
||||
if (pattern->isa<Prim>()) {
|
||||
auto prim_pattern = pattern->cast<PrimPtr>();
|
||||
prim_pattern->reset();
|
||||
return;
|
||||
} else if (pattern->isa<NewParameter>()) {
|
||||
auto new_param_pattern = pattern->cast<NewParameterPtr>();
|
||||
new_param_pattern->reset();
|
||||
return;
|
||||
} else if (pattern->isa<CallWith>()) {
|
||||
auto call_with_pattern = pattern->cast<CallWithPtr>();
|
||||
} else if (pattern->isa<Call>()) {
|
||||
auto call_with_pattern = pattern->cast<CallPtr>();
|
||||
for (auto sub_pattern : call_with_pattern->inputs()) {
|
||||
Reset(sub_pattern);
|
||||
}
|
||||
|
|
|
@ -49,8 +49,9 @@ PyPassManager::PyPassManager() {
|
|||
}
|
||||
|
||||
void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
|
||||
Phase phase, bool run_only_once) {
|
||||
auto cur_pg = GetPassGroup(phase);
|
||||
bool run_only_once) {
|
||||
// NOTE: remove phase option to avoid unnecessary confusion.
|
||||
auto cur_pg = GetPassGroup(Phase::OPT);
|
||||
MS_EXCEPTION_IF_NULL(cur_pg);
|
||||
cur_pg->SetRunOnlyOnce(run_only_once);
|
||||
MS_EXCEPTION_IF_NULL(pattern);
|
||||
|
@ -60,8 +61,9 @@ void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &patt
|
|||
cur_pg->AddPass(new_pass);
|
||||
}
|
||||
|
||||
void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) {
|
||||
auto cur_pm = GetPassGroup(phase);
|
||||
void PyPassManager::Unregiste(const std::string &pass_name) {
|
||||
// NOTE: remove phase option to avoid unnecessary confusion.
|
||||
auto cur_pm = GetPassGroup(Phase::OPT);
|
||||
MS_EXCEPTION_IF_NULL(cur_pm);
|
||||
if (!cur_pm->DeletePass(pass_name)) {
|
||||
MS_LOG(WARNING) << "No such pass : " + pass_name + "\n";
|
||||
|
@ -70,7 +72,6 @@ void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) {
|
|||
|
||||
void PyPassManager::GenNewParameter(const PatternPtr ¶meter) {
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
// Add new parameter after resolve
|
||||
// NOTE: Add NewParameter at early stage will cause CSE problems
|
||||
auto cur_pg = GetPassGroup(Phase::OPT);
|
||||
MS_EXCEPTION_IF_NULL(cur_pg);
|
||||
|
@ -78,7 +79,7 @@ void PyPassManager::GenNewParameter(const PatternPtr ¶meter) {
|
|||
auto new_para_pattern = parameter->cast<NewParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(new_para_pattern);
|
||||
auto pass_name = new_para_pattern->para_name();
|
||||
parameter->set_should_replace(false);
|
||||
new_para_pattern->set_last(true);
|
||||
auto new_pass = std::make_shared<PythonPass>(pass_name, nullptr, parameter, true);
|
||||
cur_pg->AddPass(new_pass);
|
||||
}
|
||||
|
|
|
@ -53,16 +53,17 @@ class PyPassManager {
|
|||
static PyPassManagerPtr GetInstance();
|
||||
virtual ~PyPassManager() = default;
|
||||
void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
|
||||
Phase phase = Phase::RESOLVE, bool run_only_once = false);
|
||||
void Unregiste(const std::string &pass_name, Phase phase);
|
||||
bool run_only_once = false);
|
||||
void Unregiste(const std::string &pass_name);
|
||||
void GenNewParameter(const PatternPtr ¶meter);
|
||||
PassGroupPtr GetPassGroup(Phase phase);
|
||||
void ClearRes();
|
||||
MatchResultPtr GetMatchResult() { return res_; }
|
||||
void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; }
|
||||
bool ShouldRenorm() { return should_renorm_; }
|
||||
void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; }
|
||||
pipeline::ResourcePtr GetResource() { return resource_; }
|
||||
void ClearRes();
|
||||
void ClearPipelineRes() { resource_ = nullptr; }
|
||||
|
||||
private:
|
||||
bool should_renorm_ = true;
|
||||
|
|
|
@ -477,6 +477,7 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
|
|||
// save the run graph func to MsPipeLine
|
||||
SaveCompiledGraph(phase_s);
|
||||
|
||||
opt::python_pass::PyPassManager::GetInstance()->ClearPipelineRes();
|
||||
resource->Clean();
|
||||
// Reclaim all resource used by optimizer;
|
||||
ReclaimOptimizer();
|
||||
|
|
|
@ -15,50 +15,43 @@
|
|||
"""Patterns for describing graphs"""
|
||||
from mindspore.ops import Primitive
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._c_expression import Pattern, IsIn_, IsPrimTypeOf_, CallWith_, IsNot_, AnyPattern, NewTensor_,\
|
||||
NewParameter_, Imm
|
||||
from mindspore._c_expression import Pattern, OneOf_, Prim_, Call_, NoneOf_, Any, NewTensor_, NewParameter_, Imm
|
||||
|
||||
__all__ = [
|
||||
"IsIn",
|
||||
"IsPrimTypeOf",
|
||||
"CallWith",
|
||||
"IsNot",
|
||||
"AnyPattern",
|
||||
"OneOf",
|
||||
"Prim",
|
||||
"Call",
|
||||
"NoneOf",
|
||||
"Any",
|
||||
"NewTensor",
|
||||
"NewParameter",
|
||||
"Imm"
|
||||
]
|
||||
|
||||
class IsIn(IsIn_):
|
||||
class OneOf(OneOf_):
|
||||
r"""
|
||||
Express a pattern which allows a list of patterns.
|
||||
"""
|
||||
def __init__(self, patterns=None, should_replace=True):
|
||||
def __init__(self, patterns=None):
|
||||
r"""
|
||||
Args:
|
||||
patterns(Union[tuple[:class:`mindspore.graph_utils.graph_pattern`],
|
||||
patterns(Union[:class:`mindspore.graph_utils.graph_pattern`,
|
||||
tuple[:class:`mindspore.graph_utils.graph_pattern`],
|
||||
list[:class:`mindspore.graph_utils.graph_pattern`]]): list of allowed patterns,
|
||||
each element should be one of the exposed Pattern instance.
|
||||
should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
|
||||
|
||||
Raises:
|
||||
ValueError: raise if should_replace is False
|
||||
TypeError: raise type error for invalid inputs.
|
||||
"""
|
||||
if not should_replace:
|
||||
raise ValueError("IsIn pattern does not have its own should_replace attribute. Set should_replace in \
|
||||
its sub-pattern instead.")
|
||||
self.patterns = patterns
|
||||
if patterns is None:
|
||||
IsIn_.__init__(self, ())
|
||||
elif isinstance(patterns, Pattern):
|
||||
IsIn_.__init__(self, [patterns])
|
||||
if isinstance(patterns, Pattern):
|
||||
OneOf_.__init__(self, [patterns])
|
||||
elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
|
||||
IsIn_.__init__(self, patterns)
|
||||
OneOf_.__init__(self, patterns)
|
||||
else:
|
||||
raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}")
|
||||
|
||||
class IsPrimTypeOf(IsPrimTypeOf_):
|
||||
class Prim(Prim_):
|
||||
r"""
|
||||
Express a pattern of certain primitive type(s).
|
||||
|
||||
|
@ -66,7 +59,7 @@ class IsPrimTypeOf(IsPrimTypeOf_):
|
|||
This pattern will match and only match the primitive value node. If matching primitive CNode is needed,
|
||||
please refer to CallWith pattern.
|
||||
"""
|
||||
def __init__(self, types, name=None, should_replace=True):
|
||||
def __init__(self, types, name=None):
|
||||
r"""
|
||||
Args:
|
||||
types (Union[str, :class:`mindspore.ops.Primitive`, list[:class:`mindspore.ops.Primitive`],
|
||||
|
@ -77,9 +70,6 @@ class IsPrimTypeOf(IsPrimTypeOf_):
|
|||
2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D'
|
||||
It can also be a Primitive or a list/tuple of Primitives, e.g. [ops.Conv2D(1, 6)]
|
||||
name (str): name of the pattern, optional. Default: None.
|
||||
should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is
|
||||
used when building the replacement target node. Use captured node if True, build from scratch otherwise.
|
||||
Default: True.
|
||||
|
||||
Raises:
|
||||
TypeError: raise type error for invalid argument.
|
||||
|
@ -103,13 +93,13 @@ class IsPrimTypeOf(IsPrimTypeOf_):
|
|||
self.types = types
|
||||
else:
|
||||
raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}")
|
||||
IsPrimTypeOf_.__init__(self, self.types, self.name, should_replace)
|
||||
Prim_.__init__(self, self.types, self.name)
|
||||
|
||||
class CallWith(CallWith_):
|
||||
class Call(Call_):
|
||||
r"""
|
||||
Express a primitive CNode.
|
||||
"""
|
||||
def __init__(self, prim_pattern, inputs=None, should_replace=True):
|
||||
def __init__(self, prim_pattern, inputs=None):
|
||||
r"""
|
||||
Args:
|
||||
prim_pattern (Union[str, :class:`mindspore.graph_utils.graph_pattern.IsPrimTypeOf`,
|
||||
|
@ -118,9 +108,6 @@ class CallWith(CallWith_):
|
|||
tuple[:class:`mindspore.graph_utils.graph_pattern`]]):
|
||||
Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; if specified, input
|
||||
patterns should be of right order and each element should be one of the exposed Pattern instance.
|
||||
should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is
|
||||
used when building the replacement target node. Use captured node if True, build from scratch otherwise.
|
||||
Default: True.
|
||||
|
||||
Raises:
|
||||
TypeError: raise type error for invalid argument.
|
||||
|
@ -135,36 +122,31 @@ class CallWith(CallWith_):
|
|||
self.inputs = inputs
|
||||
else:
|
||||
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
|
||||
CallWith_.__init__(self, self.prim_pattern, self.inputs, should_replace)
|
||||
Call_.__init__(self, self.prim_pattern, self.inputs)
|
||||
|
||||
class IsNot(IsNot_):
|
||||
class NoneOf(NoneOf_):
|
||||
r"""
|
||||
Express a pattern which forbids a list of patterns.
|
||||
|
||||
NOTE:
|
||||
IsNot pattern should not be the root pattern.
|
||||
NoneOf pattern should not be the root pattern.
|
||||
"""
|
||||
def __init__(self, patterns=None, should_replace=True):
|
||||
def __init__(self, patterns=None):
|
||||
r"""
|
||||
Args:
|
||||
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbiden patterns, each element
|
||||
should be one of the exposed Pattern instance.
|
||||
should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
|
||||
|
||||
Raises:
|
||||
ValueError: raise if should_replace is False.
|
||||
TypeError: raise type error for invalid argument.
|
||||
"""
|
||||
if not should_replace:
|
||||
raise ValueError("IsNot pattern does not have its own should_replace attribute. Set should_replace in \
|
||||
its sub-pattern instead.")
|
||||
self.patterns = patterns
|
||||
if patterns is None:
|
||||
IsNot_.__init__(self, ())
|
||||
NoneOf_.__init__(self, ())
|
||||
elif isinstance(patterns, Pattern):
|
||||
IsNot_.__init__(self, [patterns])
|
||||
NoneOf_.__init__(self, [patterns])
|
||||
elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
|
||||
IsNot_.__init__(self, patterns)
|
||||
NoneOf_.__init__(self, patterns)
|
||||
else:
|
||||
raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}")
|
||||
|
||||
|
@ -172,18 +154,14 @@ class NewTensor(NewTensor_):
|
|||
r"""
|
||||
New Tensor to be used in the target.
|
||||
"""
|
||||
def __init__(self, input_tensor, should_replace=False):
|
||||
def __init__(self, input_tensor):
|
||||
r"""
|
||||
Args:
|
||||
input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target
|
||||
should_replace(bool): added this for interface consistency. NewTensor should only appear in the target.
|
||||
|
||||
Raises:
|
||||
ValueError: raise if should_replace is True
|
||||
TypeError: raise type error for invalid argument.
|
||||
"""
|
||||
if should_replace:
|
||||
raise ValueError("NewTensor should only appear in the target, thus should_replace can only be False.")
|
||||
self.input_tensor = input_tensor
|
||||
if isinstance(input_tensor, Tensor):
|
||||
NewTensor_.__init__(self, input_tensor)
|
||||
|
@ -194,15 +172,13 @@ class NewParameter(NewParameter_):
|
|||
r"""
|
||||
New Parameter to be used in the target.
|
||||
"""
|
||||
def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False, should_replace=False):
|
||||
def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False):
|
||||
r"""
|
||||
Args:
|
||||
para_name(str): name for the new Parameter
|
||||
default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter
|
||||
requires_grad(bool): True if the parameter requires gradient. Default: True
|
||||
layerwise_parallel(bool): switch for layerwise parallel mode. Default: False
|
||||
should_replace(bool): gen new parameter once and replace after if set to be true; otherwise build a new
|
||||
parameter everytime a pass target got built. Default: False
|
||||
|
||||
Raises:
|
||||
TypeError: raise type error for invalid argument.
|
||||
|
@ -211,12 +187,11 @@ class NewParameter(NewParameter_):
|
|||
self.default_tensor = default_tensor
|
||||
self.requires_grad = requires_grad
|
||||
self.layerwise_parallel = layerwise_parallel
|
||||
self.should_replace = should_replace
|
||||
if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\
|
||||
isinstance(layerwise_parallel, bool) and isinstance(should_replace, bool):
|
||||
isinstance(layerwise_parallel, bool):
|
||||
NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad,
|
||||
self.layerwise_parallel, self.should_replace)
|
||||
self.layerwise_parallel)
|
||||
else:
|
||||
raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \
|
||||
layerwise_parallel(bool) should_replace(bool), got : {para_name}, {default_tensor}, \
|
||||
{requires_grad}, {layerwise_parallel}, {should_replace}")
|
||||
layerwise_parallel(bool), got : {para_name}, {default_tensor}, \
|
||||
{requires_grad}, {layerwise_parallel}")
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
"""Python pass register"""
|
||||
from inspect import isfunction
|
||||
from mindspore.graph_utils.graph_pattern import Pattern, NewParameter
|
||||
from mindspore._c_expression import PyPassManager_, phase
|
||||
from mindspore._c_expression import PyPassManager_
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -30,21 +30,16 @@ class PyPassManager(PyPassManager_):
|
|||
Used to registe and unregiste python passes which can be used to alter graphs.
|
||||
|
||||
Args:
|
||||
pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt.
|
||||
run_only_once (bool): Specify whether or not to run pass only once. Default: False.
|
||||
multigraph (bool): Whether or not the pattern exists across graphs. Default: True.
|
||||
|
||||
Raises:
|
||||
TypeError: If argument has invalid type.
|
||||
"""
|
||||
def __init__(self, pipeline_phase=phase.opt, run_only_once=False):
|
||||
if not isinstance(pipeline_phase, phase):
|
||||
raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}")
|
||||
def __init__(self, run_only_once=False):
|
||||
if not isinstance(run_only_once, bool):
|
||||
raise TypeError(f"Expect bool, got : ({type(run_only_once)}){run_only_once}")
|
||||
PyPassManager_.__init__(self)
|
||||
self.phase_ = pipeline_phase
|
||||
self.run_only_once_ = run_only_once
|
||||
PyPassManager_.__init__(self)
|
||||
|
||||
def registe(self, py_pass):
|
||||
if not isfunction(py_pass):
|
||||
|
@ -55,16 +50,14 @@ class PyPassManager(PyPassManager_):
|
|||
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
|
||||
if not isinstance(target, Pattern):
|
||||
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
|
||||
super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_)
|
||||
super().registe(pass_name, pattern, target, self.run_only_once_)
|
||||
|
||||
def unregiste(self, py_pass, pipeline_phase=phase.opt):
|
||||
if not isinstance(pipeline_phase, phase):
|
||||
raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}")
|
||||
def unregiste(self, py_pass):
|
||||
if isinstance(py_pass, str):
|
||||
super().unregiste(py_pass, pipeline_phase)
|
||||
super().unregiste(py_pass)
|
||||
return
|
||||
if isfunction(py_pass):
|
||||
super().unregiste(py_pass.__name__, pipeline_phase)
|
||||
super().unregiste(py_pass.__name__)
|
||||
return
|
||||
raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}")
|
||||
|
||||
|
@ -82,13 +75,11 @@ class PyPassManager(PyPassManager_):
|
|||
raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}")
|
||||
super().set_renorm(should_renorm)
|
||||
|
||||
def registe_pass(pipeline_phase=phase.opt, run_only_once=False):
|
||||
def registe_pass(run_only_once=False):
|
||||
"""
|
||||
Registe python pass to specified pipeline phase which would be used in compilation.
|
||||
|
||||
Args:
|
||||
pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is
|
||||
registed. Support phase.resolve and phase.opt. Default: phase.opt.
|
||||
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False.
|
||||
|
||||
Returns:
|
||||
|
@ -102,19 +93,17 @@ def registe_pass(pipeline_phase=phase.opt, run_only_once=False):
|
|||
>>> target = IsPrimTypeOf("ReLU6")
|
||||
>>> return pattern, target
|
||||
"""
|
||||
return PyPassManager(pipeline_phase, run_only_once)
|
||||
return PyPassManager(run_only_once)
|
||||
|
||||
def unregiste_pass(py_pass, pipeline_phase=phase.opt):
|
||||
def unregiste_pass(py_pass):
|
||||
"""
|
||||
Unregiste python pass.
|
||||
|
||||
Args:
|
||||
py_pass(Union(str, function)): target python pass to unregiste.
|
||||
pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is
|
||||
unregisted. Support phase.resolve and phase.opt. Default: phase.opt.
|
||||
"""
|
||||
ppm = PyPassManager()
|
||||
ppm.unregiste(py_pass, pipeline_phase)
|
||||
ppm.unregiste(py_pass)
|
||||
|
||||
def gen_new_parameter(pattern):
|
||||
"""
|
||||
|
@ -164,7 +153,14 @@ def cancel_new_parameter(pattern):
|
|||
|
||||
def set_renorm(should_renorm):
|
||||
"""
|
||||
Set whether or not to do renorm after modified graph in python pass(es).
|
||||
Set whether or not to do renormalization after modified graph in python pass(es).
|
||||
|
||||
Args:
|
||||
should_renorm(bool): whether or not to do renormalization after modified graph in python pass(es).
|
||||
|
||||
NOTE:
|
||||
This interface is mainly intended for testing modifying graph without worrying about its validity. Turn off
|
||||
renormalization may BREAK the network.
|
||||
"""
|
||||
ppm = PyPassManager()
|
||||
ppm.set_renorm(should_renorm)
|
||||
|
|
|
@ -23,8 +23,7 @@ from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_
|
|||
cancel_new_parameter
|
||||
from mindspore.common.api import _generate_pip_args
|
||||
from mindspore._c_expression import generate_key, Executor_
|
||||
from mindspore.graph_utils.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor,\
|
||||
NewParameter, Imm
|
||||
from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -50,11 +49,9 @@ def test_softmax_relu():
|
|||
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_relu_pass():
|
||||
x = AnyPattern()
|
||||
softmax_pattern = IsPrimTypeOf(P.Softmax())
|
||||
pattern = CallWith(softmax_pattern, inputs=[x])
|
||||
relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False)
|
||||
target = CallWith(relu_pattern, inputs=[x])
|
||||
x = Any()
|
||||
pattern = Call(P.Softmax(), inputs=[x])
|
||||
target = Call(P.ReLU(), inputs=[x])
|
||||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||
|
@ -74,13 +71,13 @@ def test_softmax_relu_sigmoid():
|
|||
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_relu_pass():
|
||||
x = AnyPattern()
|
||||
softmax_pattern = IsPrimTypeOf(P.Softmax())
|
||||
pattern = CallWith(softmax_pattern, inputs=[x])
|
||||
sigmoid_pattern = IsPrimTypeOf(P.Sigmoid(), should_replace=False)
|
||||
call_sigmoid = CallWith(sigmoid_pattern, [x])
|
||||
relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False)
|
||||
target = CallWith(relu_pattern, inputs=[call_sigmoid])
|
||||
x = Any()
|
||||
softmax_pattern = Prim(P.Softmax())
|
||||
pattern = Call(softmax_pattern, inputs=[x])
|
||||
sigmoid_pattern = Prim(P.Sigmoid())
|
||||
call_sigmoid = Call(sigmoid_pattern, [x])
|
||||
relu_pattern = Prim(P.ReLU())
|
||||
target = Call(relu_pattern, inputs=[call_sigmoid])
|
||||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
|
||||
|
@ -99,15 +96,15 @@ def test_isin_pattern_0():
|
|||
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_relu_pass():
|
||||
x = AnyPattern()
|
||||
softmax_pattern = IsPrimTypeOf(P.Softmax())
|
||||
call_softmax = CallWith(softmax_pattern, inputs=[x])
|
||||
relu_pattern = IsPrimTypeOf(P.ReLU())
|
||||
call_relu = CallWith(relu_pattern, inputs=[x])
|
||||
x = Any()
|
||||
softmax_pattern = Prim(P.Softmax())
|
||||
call_softmax = Call(softmax_pattern, inputs=[x])
|
||||
relu_pattern = Prim(P.ReLU())
|
||||
call_relu = Call(relu_pattern, inputs=[x])
|
||||
|
||||
pattern = IsIn([call_softmax, call_relu])
|
||||
relu6_pattern = IsPrimTypeOf(P.ReLU6(), should_replace=False)
|
||||
target = CallWith(relu6_pattern, inputs=[x])
|
||||
pattern = OneOf([call_softmax, call_relu])
|
||||
relu6_pattern = Prim(P.ReLU6())
|
||||
target = Call(relu6_pattern, inputs=[x])
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||
unregiste_pass(softmax_relu_pass)
|
||||
|
@ -123,18 +120,17 @@ def test_isin_pattern_1():
|
|||
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_neg_pass():
|
||||
x = AnyPattern()
|
||||
softmax_pattern = IsPrimTypeOf(P.Softmax())
|
||||
call_softmax = CallWith(softmax_pattern, inputs=[x])
|
||||
relu_pattern = IsPrimTypeOf(P.ReLU())
|
||||
call_relu = CallWith(relu_pattern, inputs=[x])
|
||||
x = Any()
|
||||
softmax_pattern = Prim(P.Softmax())
|
||||
call_softmax = Call(softmax_pattern, inputs=[x])
|
||||
relu_pattern = Prim(P.ReLU())
|
||||
call_relu = Call(relu_pattern, inputs=[x])
|
||||
|
||||
pattern = IsIn([call_softmax, call_relu])
|
||||
neg_ops = IsPrimTypeOf(P.Neg(), should_replace=False)
|
||||
target = CallWith(neg_ops, inputs=[pattern])
|
||||
pattern = OneOf([call_softmax, call_relu])
|
||||
neg_ops = Prim(P.Neg())
|
||||
target = Call(neg_ops, inputs=[pattern])
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
|
||||
print(transformed_repr)
|
||||
unregiste_pass(softmax_neg_pass)
|
||||
assert "Neg" in transformed_repr
|
||||
assert "Softmax" in transformed_repr
|
||||
|
@ -167,11 +163,11 @@ def test_isnot_pattern_0():
|
|||
"""
|
||||
Sub a BN which does NOT take Conv as inputs to ReLU6.
|
||||
"""
|
||||
conv2d_prim = IsPrimTypeOf("Conv2D")
|
||||
conv2d = CallWith(conv2d_prim)
|
||||
pattern_0 = IsNot(conv2d)
|
||||
pattern = CallWith(P.BatchNorm(), inputs=[pattern_0])
|
||||
target = CallWith(P.ReLU6(), inputs=[pattern_0])
|
||||
conv2d_prim = Prim("Conv2D")
|
||||
conv2d = Call(conv2d_prim)
|
||||
pattern_0 = NoneOf(conv2d)
|
||||
pattern = Call(P.BatchNorm(), inputs=[pattern_0])
|
||||
target = Call(P.ReLU6(), inputs=[pattern_0])
|
||||
return pattern, target
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
|
@ -179,10 +175,8 @@ def test_isnot_pattern_0():
|
|||
"""
|
||||
Sub a BN to Softmax.
|
||||
"""
|
||||
bn = P.BatchNorm()
|
||||
pattern = CallWith(bn)
|
||||
softmax = P.Softmax()
|
||||
target = CallWith(softmax, should_replace=False)
|
||||
pattern = Call(P.BatchNorm())
|
||||
target = Call(P.Softmax())
|
||||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5)
|
||||
|
@ -205,12 +199,12 @@ def test_isnot_pattern_1():
|
|||
"""
|
||||
Sub a BN which does NOT take MatMul as inputs to ReLU6.
|
||||
"""
|
||||
matmul = IsPrimTypeOf("MatMul")
|
||||
pattern_0 = IsNot(matmul)
|
||||
matmul = Prim("MatMul")
|
||||
pattern_0 = NoneOf(matmul)
|
||||
softmax = P.Softmax()
|
||||
pattern = CallWith(softmax, inputs=[pattern_0])
|
||||
pattern = Call(softmax, inputs=[pattern_0])
|
||||
relu6 = P.ReLU6()
|
||||
target = CallWith(relu6, inputs=[pattern_0], should_replace=False)
|
||||
target = Call(relu6, inputs=[pattern_0])
|
||||
return pattern, target
|
||||
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||
|
@ -228,14 +222,12 @@ def test_newtensor_pattern():
|
|||
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_addn_pass():
|
||||
x = AnyPattern()
|
||||
softmax = P.Softmax()
|
||||
pattern = CallWith(softmax, inputs=[x])
|
||||
x = Any()
|
||||
pattern = Call(P.Softmax(), inputs=[x])
|
||||
|
||||
weight_tensor = Tensor(np.zeros([42]), mindspore.float16)
|
||||
new_weight = NewTensor(weight_tensor)
|
||||
addn_ops = P.AddN()
|
||||
target = CallWith(addn_ops, inputs=[x, new_weight], should_replace=False)
|
||||
target = Call(P.AddN(), inputs=[x, new_weight])
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||
unregiste_pass(softmax_addn_pass)
|
||||
|
@ -252,25 +244,23 @@ def test_newparameter_pattern():
|
|||
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_addn_pass():
|
||||
x = AnyPattern()
|
||||
softmax = P.Softmax()
|
||||
pattern = CallWith(softmax, inputs=[x])
|
||||
x = Any()
|
||||
pattern = Call(P.Softmax(), inputs=[x])
|
||||
|
||||
default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32)
|
||||
default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32)
|
||||
new_para_0 = NewParameter("Merlin", default_tensor0)
|
||||
new_para_1 = NewParameter("Arthur", default_tensor1)
|
||||
target_0 = CallWith(P.MatMul(), inputs=[new_para_0, new_para_1], should_replace=False)
|
||||
target = CallWith("make_tuple", inputs=[target_0], should_replace=False)
|
||||
target_0 = Call(P.MatMul(), inputs=[new_para_0, new_para_1])
|
||||
target = Call("make_tuple", inputs=[target_0])
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||
print(transformed_repr)
|
||||
unregiste_pass(softmax_addn_pass)
|
||||
assert "MatMul" in transformed_repr
|
||||
assert "make_tuple" in transformed_repr
|
||||
assert "Softmax" not in transformed_repr
|
||||
|
||||
def test_imm_pattern():
|
||||
def test_imm_target():
|
||||
"""
|
||||
Test NewParameter pattern in the target
|
||||
"""
|
||||
|
@ -278,17 +268,15 @@ def test_imm_pattern():
|
|||
softmax_model = nn.Softmax()
|
||||
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_addn_pass():
|
||||
x = AnyPattern()
|
||||
softmax = P.Softmax()
|
||||
pattern = CallWith(softmax, inputs=[x])
|
||||
def softmax_pass():
|
||||
x = Any()
|
||||
pattern = Call(P.Softmax(), inputs=[x])
|
||||
imm = Imm(0)
|
||||
target_0 = CallWith("make_tuple", inputs=[pattern], should_replace=False)
|
||||
target = CallWith("tuple_getitem", inputs=[target_0, imm], should_replace=False)
|
||||
target_0 = Call("make_tuple", inputs=[pattern])
|
||||
target = Call("tuple_getitem", inputs=[target_0, imm])
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||
print(transformed_repr)
|
||||
unregiste_pass(softmax_addn_pass)
|
||||
unregiste_pass(softmax_pass)
|
||||
assert "make_tuple" in transformed_repr
|
||||
assert "tuple_getitem" in transformed_repr
|
||||
assert "Softmax" in transformed_repr
|
||||
|
@ -301,21 +289,19 @@ def test_gen_new_parameter():
|
|||
softmax_model = nn.Softmax()
|
||||
|
||||
default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
|
||||
new_para = NewParameter("Merlin", default_tensor, should_replace=True)
|
||||
new_para = NewParameter("Merlin", default_tensor)
|
||||
gen_new_parameter(new_para)
|
||||
@registe_pass(run_only_once=True)
|
||||
def softmax_make_tuple_pass():
|
||||
x = AnyPattern()
|
||||
x = Any()
|
||||
softmax = P.Softmax()
|
||||
pattern = CallWith(softmax, inputs=[x])
|
||||
pattern = Call(softmax, inputs=[x])
|
||||
|
||||
target = CallWith("make_tuple", inputs=[pattern, new_para], should_replace=False)
|
||||
target = Call("make_tuple", inputs=[pattern, new_para])
|
||||
return pattern, target
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||
print(transformed_repr)
|
||||
assert "Merlin" in transformed_repr
|
||||
unregiste_pass(softmax_make_tuple_pass)
|
||||
cancel_new_parameter(new_para)
|
||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||
print(transformed_repr)
|
||||
assert "Merlin" not in transformed_repr
|
||||
|
|
Loading…
Reference in New Issue