python pass pattern renaming and interface tweaking

This commit is contained in:
BowenK 2020-08-25 21:34:16 +08:00
parent 3a16925fa2
commit 641d12d6d9
9 changed files with 234 additions and 285 deletions

View File

@ -21,25 +21,23 @@ namespace opt {
namespace python_pass {
int Pattern::g_id_ = 0;
MatchResultPtr IsPrimTypeOf::match(const AnfNodePtr &node) {
MatchResultPtr Prim::match(const AnfNodePtr &node) {
if (!IsValueNode<Primitive>(node)) {
return nullptr;
}
MatchResultPtr res = std::make_shared<MatchResult>();
if (IsValueNode<Primitive>(node)) {
// iterate over all primitives
for (auto &iter : primitives_) {
if (IsPrimitive(node, iter) || iter->name() == "*") {
matched_prim_ = iter;
res->add_entry(shared_from_base<IsPrimTypeOf>(), node);
return res;
}
// iterate over all primitives
for (auto &iter : primitives_) {
if (IsPrimitive(node, iter) || iter->name() == "*") {
matched_prim_ = iter;
res->add_entry(shared_from_base<Prim>(), node);
return res;
}
}
return nullptr;
}
MatchResultPtr CallWith::match(const AnfNodePtr &node) {
MatchResultPtr Call::match(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node)) {
return nullptr;
}
@ -71,7 +69,7 @@ MatchResultPtr CallWith::match(const AnfNodePtr &node) {
}
// If inputs is not specified, add node without looking into its inputs
if (p_inputs_size == 0) {
res->add_entry(shared_from_base<CallWith>(), cnode->input(0));
res->add_entry(shared_from_base<Call>(), cnode->input(0));
return res;
}
bool failed = false;
@ -86,24 +84,24 @@ MatchResultPtr CallWith::match(const AnfNodePtr &node) {
res->merge(input_match_result);
}
if (!failed) {
res->add_entry(shared_from_base<CallWith>(), cnode->input(0));
res->add_entry(shared_from_base<Call>(), cnode->input(0));
return res;
}
return nullptr;
}
MatchResultPtr IsIn::match(const AnfNodePtr &node) {
MatchResultPtr OneOf::match(const AnfNodePtr &node) {
for (auto &iter : patterns_) {
auto res = iter->match(node);
if (res != nullptr) {
res->add_entry(shared_from_base<IsIn>(), node);
res->add_entry(shared_from_base<OneOf>(), node);
return res;
}
}
return nullptr;
}
MatchResultPtr IsNot::match(const AnfNodePtr &node) {
MatchResultPtr NoneOf::match(const AnfNodePtr &node) {
for (auto &iter : patterns_) {
auto res = iter->match(node);
if (res != nullptr) {
@ -111,16 +109,33 @@ MatchResultPtr IsNot::match(const AnfNodePtr &node) {
}
}
auto res = std::make_shared<MatchResult>();
res->add_entry(shared_from_base<IsNot>(), node);
res->add_entry(shared_from_base<NoneOf>(), node);
return res;
}
MatchResultPtr AnyPattern::match(const AnfNodePtr &node) {
MatchResultPtr Any::match(const AnfNodePtr &node) {
MatchResultPtr res = std::make_shared<MatchResult>();
res->add_entry(shared_from_base<AnyPattern>(), node);
res->add_entry(shared_from_base<Any>(), node);
return res;
}
MatchResultPtr Imm::match(const AnfNodePtr &node) {
if (!IsValueNode<Int32Imm>(node)) {
return nullptr;
}
// Check value
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value_ptr = value_node->value()->cast<Int32ImmPtr>();
MS_EXCEPTION_IF_NULL(value_ptr);
if ((int32_t)value_ptr->value() == value_) {
MatchResultPtr res = std::make_shared<MatchResult>();
res->add_entry(shared_from_base<Imm>(), node);
return res;
}
return nullptr;
}
AnfNodePtr MatchResult::get_node(const PatternPtr &pattern) {
auto entry = match_result_.find(pattern);
if (entry == match_result_.end()) {
@ -140,20 +155,20 @@ void MatchResult::merge(const MatchResultPtr &other_result) {
REGISTER_PYBIND_DEFINE(
Pattern, ([](const py::module *m) {
(void)py::class_<Pattern, std::shared_ptr<Pattern>>(*m, "Pattern").def(py::init<>());
(void)py::class_<IsIn, std::shared_ptr<IsIn>, Pattern>(*m, "IsIn_").def(py::init<vector<PatternPtr>>());
(void)py::class_<IsPrimTypeOf, std::shared_ptr<IsPrimTypeOf>, Pattern>(*m, "IsPrimTypeOf_", py::dynamic_attr())
.def(py::init<vector<PrimitivePyPtr>, string, bool>())
.def(py::init<vector<string>, string, bool>());
(void)py::class_<CallWith, std::shared_ptr<CallWith>, Pattern>(*m, "CallWith_")
.def(py::init<PatternPtr, vector<PatternPtr>, bool>())
.def(py::init<PrimitivePyPtr, vector<PatternPtr>, bool>())
.def(py::init<string, vector<PatternPtr>, bool>());
(void)py::class_<IsNot, std::shared_ptr<IsNot>, Pattern>(*m, "IsNot_").def(py::init<vector<PatternPtr>>());
(void)py::class_<AnyPattern, std::shared_ptr<AnyPattern>, Pattern>(*m, "AnyPattern").def(py::init<>());
(void)py::class_<OneOf, std::shared_ptr<OneOf>, Pattern>(*m, "OneOf_").def(py::init<vector<PatternPtr>>());
(void)py::class_<Prim, std::shared_ptr<Prim>, Pattern>(*m, "Prim_", py::dynamic_attr())
.def(py::init<vector<PrimitivePyPtr>, string>())
.def(py::init<vector<string>, string>());
(void)py::class_<Call, std::shared_ptr<Call>, Pattern>(*m, "Call_")
.def(py::init<PatternPtr, vector<PatternPtr>>())
.def(py::init<PrimitivePyPtr, vector<PatternPtr>>())
.def(py::init<string, vector<PatternPtr>>());
(void)py::class_<NoneOf, std::shared_ptr<NoneOf>, Pattern>(*m, "NoneOf_").def(py::init<vector<PatternPtr>>());
(void)py::class_<Any, std::shared_ptr<Any>, Pattern>(*m, "Any").def(py::init<>());
(void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_")
.def(py::init<tensor::TensorPtr>());
(void)py::class_<NewParameter, std::shared_ptr<NewParameter>, Pattern>(*m, "NewParameter_")
.def(py::init<string, tensor::TensorPtr, bool, bool, bool>());
.def(py::init<string, tensor::TensorPtr, bool, bool>());
(void)py::class_<Imm, std::shared_ptr<Imm>, Pattern>(*m, "Imm").def(py::init<int>());
}));
} // namespace python_pass

View File

@ -36,10 +36,10 @@ class MatchResult;
using MatchResultPtr = std::shared_ptr<MatchResult>;
class Pattern;
using PatternPtr = std::shared_ptr<Pattern>;
class IsPrimTypeOf;
using IsPrimTypeOfPtr = std::shared_ptr<IsPrimTypeOf>;
class CallWith;
using CallWithPtr = std::shared_ptr<CallWith>;
class Prim;
using PrimPtr = std::shared_ptr<Prim>;
class Call;
using CallPtr = std::shared_ptr<Call>;
class NewTensor;
using NewTensorPtr = std::shared_ptr<NewTensor>;
class NewParameter;
@ -58,8 +58,6 @@ class Pattern : public Base {
virtual bool operator==(const Pattern &other) const { return unique_name_ == other.unique_name_; }
string unique_name() const { return unique_name_; }
vector<PatternPtr> inputs() { return inputs_; }
bool should_replace() { return should_replace_; }
void set_should_replace(bool should_replace) { should_replace_ = should_replace; }
virtual void reset() {}
protected:
@ -67,7 +65,6 @@ class Pattern : public Base {
// NOTE: To ensure uniqueness of the name, raise g_id_ by 1 every time a pattern got constructed
string unique_name_;
vector<PatternPtr> inputs_;
bool should_replace_ = true;
};
struct PatternEqual {
@ -85,70 +82,61 @@ struct PatternHasher {
}
};
class IsPrimTypeOf : public Pattern {
class Prim : public Pattern {
public:
IsPrimTypeOf() { unique_name_ = std::to_string(g_id_++); }
~IsPrimTypeOf() = default;
IsPrimTypeOf(vector<PrimitivePyPtr> prims, string name, bool should_replace)
: primitives_(prims), name_(name), matched_prim_(nullptr) {
unique_name_ = std::to_string(g_id_++) + "IsPrimTypeOf_" + name;
should_replace_ = should_replace;
if (!should_replace) {
matched_prim_ = prims[0];
}
Prim() { unique_name_ = std::to_string(g_id_++); }
~Prim() = default;
Prim(vector<PrimitivePyPtr> prims, string name) : primitives_(prims), name_(name) {
unique_name_ = std::to_string(g_id_++) + "Prim_" + name;
// Default using the first prim to build target
matched_prim_ = primitives_[0];
}
IsPrimTypeOf(vector<string> types, string name, bool should_replace) : types_(types), name_(name) {
unique_name_ = std::to_string(g_id_++) + "IsPrimTypeOf_" + name;
Prim(vector<string> types, string name) : types_(types), name_(name) {
unique_name_ = std::to_string(g_id_++) + "Prim_" + name;
// Make primitives_
for (auto &iter : types) {
primitives_.push_back(std::make_shared<PrimitivePy>(iter, py::cast(nullptr)));
}
should_replace_ = should_replace;
if (!should_replace) {
matched_prim_ = primitives_[0];
}
// Default using the first prim to build target
matched_prim_ = primitives_[0];
}
MS_DECLARE_PARENT(IsPrimTypeOf, Pattern);
MS_DECLARE_PARENT(Prim, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override;
PrimitivePyPtr matched_primitive() { return matched_prim_; }
void reset() override {
if (should_replace_) {
matched_prim_ = nullptr;
}
// Init before reset
MS_EXCEPTION_IF_NULL(matched_prim_);
matched_prim_ = primitives_[0];
}
private:
vector<string> types_;
vector<PrimitivePyPtr> primitives_;
string name_;
PrimitivePyPtr matched_prim_;
PrimitivePyPtr matched_prim_{nullptr};
};
class CallWith : public Pattern {
class Call : public Pattern {
public:
CallWith() { unique_name_ = std::to_string(g_id_++); }
~CallWith() = default;
CallWith(PatternPtr prim_pattern, vector<PatternPtr> inputs, bool should_replace) {
Call() { unique_name_ = std::to_string(g_id_++); }
~Call() = default;
Call(PatternPtr prim_pattern, vector<PatternPtr> inputs) {
// NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting
prim_pattern_ = prim_pattern;
unique_name_ = std::to_string(g_id_++) + "CallWithPattern_" + prim_pattern->unique_name();
unique_name_ = std::to_string(g_id_++) + "Call_" + prim_pattern->unique_name();
inputs_ = inputs;
// NOTE: should_replace_ is overrided by it prim_pattern(if exists) silently.
should_replace_ = prim_pattern->should_replace();
}
CallWith(PrimitivePyPtr prim, vector<PatternPtr> inputs, bool should_replace) {
Call(PrimitivePyPtr prim, vector<PatternPtr> inputs) {
prim_ = prim;
unique_name_ = std::to_string(g_id_++) + "CallWithPrim_" + prim_->ToString();
unique_name_ = std::to_string(g_id_++) + "Call_" + prim_->ToString();
inputs_ = inputs;
should_replace_ = should_replace;
}
CallWith(string prim_str, vector<PatternPtr> inputs, bool should_replace) {
Call(string prim_str, vector<PatternPtr> inputs) {
prim_ = std::make_shared<PrimitivePy>(prim_str, py::cast(nullptr));
unique_name_ = std::to_string(g_id_++) + "CallWithStr_" + prim_->ToString();
unique_name_ = std::to_string(g_id_++) + "CallStr_" + prim_->ToString();
inputs_ = inputs;
should_replace_ = should_replace;
}
MS_DECLARE_PARENT(CallWith, Pattern);
MS_DECLARE_PARENT(Call, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override;
PrimitivePtr prim_value() { return prim_; }
PatternPtr prim_pattern() { return prim_pattern_; }
@ -160,45 +148,45 @@ class CallWith : public Pattern {
string name_;
};
class IsIn : public Pattern {
class OneOf : public Pattern {
public:
IsIn() { unique_name_ = std::to_string(g_id_++); }
~IsIn() = default;
explicit IsIn(vector<PatternPtr> patterns) : patterns_(patterns) {
unique_name_ = std::to_string(g_id_++) + "IsIn";
OneOf() { unique_name_ = std::to_string(g_id_++); }
~OneOf() = default;
explicit OneOf(vector<PatternPtr> patterns) : patterns_(patterns) {
unique_name_ = std::to_string(g_id_++) + "OneOf";
for (auto &iter : patterns) {
unique_name_ = unique_name_ + "_" + iter->unique_name();
}
}
MS_DECLARE_PARENT(IsIn, Pattern);
MS_DECLARE_PARENT(OneOf, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override;
private:
vector<PatternPtr> patterns_;
};
class IsNot : public Pattern {
class NoneOf : public Pattern {
public:
IsNot() { unique_name_ = std::to_string(g_id_++); }
~IsNot() = default;
explicit IsNot(vector<PatternPtr> patterns) : patterns_(patterns) {
unique_name_ = std::to_string(g_id_++) + "IsNot";
NoneOf() { unique_name_ = std::to_string(g_id_++); }
~NoneOf() = default;
explicit NoneOf(vector<PatternPtr> patterns) : patterns_(patterns) {
unique_name_ = std::to_string(g_id_++) + "NoneOf";
for (auto &iter : patterns) {
unique_name_ = unique_name_ + "_" + iter->unique_name();
}
}
MS_DECLARE_PARENT(IsNot, Pattern);
MS_DECLARE_PARENT(NoneOf, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override;
private:
vector<PatternPtr> patterns_;
};
class AnyPattern : public Pattern {
class Any : public Pattern {
public:
AnyPattern() { unique_name_ = std::to_string(g_id_++) + "_AnyPattern"; }
~AnyPattern() = default;
MS_DECLARE_PARENT(AnyPattern, Pattern);
Any() { unique_name_ = std::to_string(g_id_++) + "_Any"; }
~Any() = default;
MS_DECLARE_PARENT(Any, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override;
};
@ -207,7 +195,6 @@ class NewTensor : public Pattern {
NewTensor() { unique_name_ = std::to_string(g_id_++); }
~NewTensor() = default;
explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) {
should_replace_ = false;
unique_name_ = std::to_string(g_id_++) + "NewTensor";
}
MS_DECLARE_PARENT(NewTensor, Pattern);
@ -223,10 +210,8 @@ class NewTensor : public Pattern {
class NewParameter : public Pattern {
public:
NewParameter() { unique_name_ = std::to_string(g_id_++); }
explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel,
bool should_replace)
explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel)
: para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) {
should_replace_ = should_replace;
unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name;
// clone input tensor
default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get());
@ -243,11 +228,14 @@ class NewParameter : public Pattern {
bool built() { return built_; }
void set_built(bool built) { built_ = built; }
void reset() override { built_ = false; }
bool should_last() { return last_across_passes_; }
void set_last(bool last) { last_across_passes_ = last; }
private:
string para_name_;
bool requires_grad_;
bool layerwise_parallel_;
bool last_across_passes_{false};
bool built_;
tensor::TensorPtr default_tensor_;
};
@ -255,13 +243,9 @@ class NewParameter : public Pattern {
class Imm : public Pattern {
public:
Imm() { unique_name_ = std::to_string(g_id_++); }
explicit Imm(int value) : value_(value) {
should_replace_ = false;
unique_name_ = std::to_string(g_id_++) + "Imm_" + std::to_string(value);
}
explicit Imm(int value) : value_(value) { unique_name_ = std::to_string(g_id_++) + "Imm_" + std::to_string(value); }
MS_DECLARE_PARENT(Imm, Pattern);
// NOTE: Doesn't support Imm in src pattern currently.
MatchResultPtr match(const AnfNodePtr &node) override { return nullptr; }
MatchResultPtr match(const AnfNodePtr &node) override;
int value() { return value_; }
private:

View File

@ -80,7 +80,7 @@ bool IsTraversable(const AnfNodePtr &node) {
AnfNodePtr BuildPrimitive(const PatternPtr &pattern, const MatchResultPtr &res) {
// Build up AnfNode from primitive
auto prim_pattern = pattern->cast<IsPrimTypeOfPtr>();
auto prim_pattern = pattern->cast<PrimPtr>();
MS_EXCEPTION_IF_NULL(prim_pattern);
PrimitivePyPtr prim = prim_pattern->matched_primitive();
MS_EXCEPTION_IF_NULL(prim);
@ -98,13 +98,13 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res)
}
AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg) {
auto call_with_pattern = pattern->cast<CallWithPtr>();
MS_EXCEPTION_IF_NULL(call_with_pattern);
auto prim = call_with_pattern->prim_value();
auto call_pattern = pattern->cast<CallPtr>();
MS_EXCEPTION_IF_NULL(call_pattern);
auto prim = call_pattern->prim_value();
if (prim != nullptr) {
return std::make_shared<ValueNode>(prim);
}
auto prim_pattern = call_with_pattern->prim_pattern();
auto prim_pattern = call_pattern->prim_pattern();
MS_EXCEPTION_IF_NULL(prim_pattern);
return ProcessSinglePattern(prim_pattern, res, fg);
}
@ -152,45 +152,35 @@ AnfNodePtr BuildImmNode(const PatternPtr &pattern, const MatchResultPtr &res) {
}
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) {
if (pattern->should_replace()) {
// Find replacement in the MatchResult
auto target_node = res->get_node(pattern);
if (target_node == nullptr) {
// If it's base pattern(in contrast to complex pattern like CallWith/IsIn/IsNot), raise runtime exception.
if (pattern->isa<IsPrimTypeOf>() || pattern->isa<NewTensor>() || pattern->isa<NewParameter>()) {
MS_LOG(EXCEPTION) << "Cannot find target node, pattern: " + pattern->unique_name() + "\n";
return nullptr;
}
// Try to build this pattern and add to MatchResult, since this pattern is defined inside target
auto new_node = BuildTarget(pattern, func_graph, res);
if (new_node == nullptr) {
MS_LOG(EXCEPTION) << "Try to build pattern node but FAILED. pattern: " + pattern->unique_name() + "\n";
}
return new_node;
}
if (pattern->isa<NewParameter>()) {
auto target_node = res->get_node(pattern);
if (target_node != nullptr) {
// If pattern is NewParameter, check whether it shouldn't last and is not built
auto new_para = pattern->cast<NewParameterPtr>();
if (new_para == nullptr || new_para->should_last() || new_para->built()) {
return target_node;
}
return target_node;
}
// Build up new node from pattern
if (pattern->isa<IsPrimTypeOf>()) {
if (pattern->isa<Prim>()) {
return BuildPrimitive(pattern, res);
} else if (pattern->isa<NewTensor>()) {
return BuildNewTensor(pattern, res);
} else if (pattern->isa<CallWith>()) {
} else if (pattern->isa<Call>()) {
return BuildPrimitiveValueNode(pattern, res, func_graph);
} else if (pattern->isa<NewParameter>()) {
return BuildNewParameter(pattern, res, func_graph);
} else if (pattern->isa<Imm>()) {
return BuildImmNode(pattern, res);
} else {
MS_LOG(EXCEPTION) << "Cannot find or build target node, pattern: " + pattern->unique_name() + "\n";
return nullptr;
}
return nullptr;
}
AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res,
const FuncGraphPtr &func_graph) {
if (pattern->isa<CallWith>()) {
if (pattern->isa<Call>()) {
return BuildPrimitiveValueNode(pattern, res, func_graph);
}
return nullptr;
@ -269,16 +259,16 @@ void ReflectParamBackToPython(const AnfNodePtr &param, string param_name, tensor
}
void Reset(PatternPtr pattern) {
if (pattern->isa<IsPrimTypeOf>()) {
auto prim_pattern = pattern->cast<IsPrimTypeOfPtr>();
if (pattern->isa<Prim>()) {
auto prim_pattern = pattern->cast<PrimPtr>();
prim_pattern->reset();
return;
} else if (pattern->isa<NewParameter>()) {
auto new_param_pattern = pattern->cast<NewParameterPtr>();
new_param_pattern->reset();
return;
} else if (pattern->isa<CallWith>()) {
auto call_with_pattern = pattern->cast<CallWithPtr>();
} else if (pattern->isa<Call>()) {
auto call_with_pattern = pattern->cast<CallPtr>();
for (auto sub_pattern : call_with_pattern->inputs()) {
Reset(sub_pattern);
}

View File

@ -49,8 +49,9 @@ PyPassManager::PyPassManager() {
}
void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
Phase phase, bool run_only_once) {
auto cur_pg = GetPassGroup(phase);
bool run_only_once) {
// NOTE: remove phase option to avoid unnecessary confusion.
auto cur_pg = GetPassGroup(Phase::OPT);
MS_EXCEPTION_IF_NULL(cur_pg);
cur_pg->SetRunOnlyOnce(run_only_once);
MS_EXCEPTION_IF_NULL(pattern);
@ -60,8 +61,9 @@ void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &patt
cur_pg->AddPass(new_pass);
}
void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) {
auto cur_pm = GetPassGroup(phase);
void PyPassManager::Unregiste(const std::string &pass_name) {
// NOTE: remove phase option to avoid unnecessary confusion.
auto cur_pm = GetPassGroup(Phase::OPT);
MS_EXCEPTION_IF_NULL(cur_pm);
if (!cur_pm->DeletePass(pass_name)) {
MS_LOG(WARNING) << "No such pass : " + pass_name + "\n";
@ -70,7 +72,6 @@ void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) {
void PyPassManager::GenNewParameter(const PatternPtr &parameter) {
MS_EXCEPTION_IF_NULL(parameter);
// Add new parameter after resolve
// NOTE: Add NewParameter at early stage will cause CSE problems
auto cur_pg = GetPassGroup(Phase::OPT);
MS_EXCEPTION_IF_NULL(cur_pg);
@ -78,7 +79,7 @@ void PyPassManager::GenNewParameter(const PatternPtr &parameter) {
auto new_para_pattern = parameter->cast<NewParameterPtr>();
MS_EXCEPTION_IF_NULL(new_para_pattern);
auto pass_name = new_para_pattern->para_name();
parameter->set_should_replace(false);
new_para_pattern->set_last(true);
auto new_pass = std::make_shared<PythonPass>(pass_name, nullptr, parameter, true);
cur_pg->AddPass(new_pass);
}

View File

@ -53,16 +53,17 @@ class PyPassManager {
static PyPassManagerPtr GetInstance();
virtual ~PyPassManager() = default;
void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
Phase phase = Phase::RESOLVE, bool run_only_once = false);
void Unregiste(const std::string &pass_name, Phase phase);
bool run_only_once = false);
void Unregiste(const std::string &pass_name);
void GenNewParameter(const PatternPtr &parameter);
PassGroupPtr GetPassGroup(Phase phase);
void ClearRes();
MatchResultPtr GetMatchResult() { return res_; }
void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; }
bool ShouldRenorm() { return should_renorm_; }
void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; }
pipeline::ResourcePtr GetResource() { return resource_; }
void ClearRes();
void ClearPipelineRes() { resource_ = nullptr; }
private:
bool should_renorm_ = true;

View File

@ -477,6 +477,7 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
// save the run graph func to MsPipeLine
SaveCompiledGraph(phase_s);
opt::python_pass::PyPassManager::GetInstance()->ClearPipelineRes();
resource->Clean();
// Reclaim all resource used by optimizer;
ReclaimOptimizer();

View File

@ -15,50 +15,43 @@
"""Patterns for describing graphs"""
from mindspore.ops import Primitive
from mindspore.common.tensor import Tensor
from mindspore._c_expression import Pattern, IsIn_, IsPrimTypeOf_, CallWith_, IsNot_, AnyPattern, NewTensor_,\
NewParameter_, Imm
from mindspore._c_expression import Pattern, OneOf_, Prim_, Call_, NoneOf_, Any, NewTensor_, NewParameter_, Imm
__all__ = [
"IsIn",
"IsPrimTypeOf",
"CallWith",
"IsNot",
"AnyPattern",
"OneOf",
"Prim",
"Call",
"NoneOf",
"Any",
"NewTensor",
"NewParameter",
"Imm"
]
class IsIn(IsIn_):
class OneOf(OneOf_):
r"""
Express a pattern which allows a list of patterns.
"""
def __init__(self, patterns=None, should_replace=True):
def __init__(self, patterns=None):
r"""
Args:
patterns(Union[tuple[:class:`mindspore.graph_utils.graph_pattern`],
patterns(Union[:class:`mindspore.graph_utils.graph_pattern`,
tuple[:class:`mindspore.graph_utils.graph_pattern`],
list[:class:`mindspore.graph_utils.graph_pattern`]]): list of allowed patterns,
each element should be one of the exposed Pattern instance.
should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
Raises:
ValueError: raise if should_replace is False
TypeError: raise type error for invalid inputs.
"""
if not should_replace:
raise ValueError("IsIn pattern does not have its own should_replace attribute. Set should_replace in \
its sub-pattern instead.")
self.patterns = patterns
if patterns is None:
IsIn_.__init__(self, ())
elif isinstance(patterns, Pattern):
IsIn_.__init__(self, [patterns])
if isinstance(patterns, Pattern):
OneOf_.__init__(self, [patterns])
elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
IsIn_.__init__(self, patterns)
OneOf_.__init__(self, patterns)
else:
raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}")
class IsPrimTypeOf(IsPrimTypeOf_):
class Prim(Prim_):
r"""
Express a pattern of certain primitive type(s).
@ -66,7 +59,7 @@ class IsPrimTypeOf(IsPrimTypeOf_):
This pattern will match and only match the primitive value node. If matching primitive CNode is needed,
please refer to CallWith pattern.
"""
def __init__(self, types, name=None, should_replace=True):
def __init__(self, types, name=None):
r"""
Args:
types (Union[str, :class:`mindspore.ops.Primitive`, list[:class:`mindspore.ops.Primitive`],
@ -77,9 +70,6 @@ class IsPrimTypeOf(IsPrimTypeOf_):
2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D'
It can also be a Primitive or a list/tuple of Primitives, e.g. [ops.Conv2D(1, 6)]
name (str): name of the pattern, optional. Default: None.
should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is
used when building the replacement target node. Use captured node if True, build from scratch otherwise.
Default: True.
Raises:
TypeError: raise type error for invalid argument.
@ -103,13 +93,13 @@ class IsPrimTypeOf(IsPrimTypeOf_):
self.types = types
else:
raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}")
IsPrimTypeOf_.__init__(self, self.types, self.name, should_replace)
Prim_.__init__(self, self.types, self.name)
class CallWith(CallWith_):
class Call(Call_):
r"""
Express a primitive CNode.
"""
def __init__(self, prim_pattern, inputs=None, should_replace=True):
def __init__(self, prim_pattern, inputs=None):
r"""
Args:
prim_pattern (Union[str, :class:`mindspore.graph_utils.graph_pattern.IsPrimTypeOf`,
@ -118,9 +108,6 @@ class CallWith(CallWith_):
tuple[:class:`mindspore.graph_utils.graph_pattern`]]):
Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; if specified, input
patterns should be of right order and each element should be one of the exposed Pattern instance.
should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is
used when building the replacement target node. Use captured node if True, build from scratch otherwise.
Default: True.
Raises:
TypeError: raise type error for invalid argument.
@ -135,36 +122,31 @@ class CallWith(CallWith_):
self.inputs = inputs
else:
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
CallWith_.__init__(self, self.prim_pattern, self.inputs, should_replace)
Call_.__init__(self, self.prim_pattern, self.inputs)
class IsNot(IsNot_):
class NoneOf(NoneOf_):
r"""
Express a pattern which forbids a list of patterns.
NOTE:
IsNot pattern should not be the root pattern.
NoneOf pattern should not be the root pattern.
"""
def __init__(self, patterns=None, should_replace=True):
def __init__(self, patterns=None):
r"""
Args:
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbiden patterns, each element
should be one of the exposed Pattern instance.
should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
Raises:
ValueError: raise if should_replace is False.
TypeError: raise type error for invalid argument.
"""
if not should_replace:
raise ValueError("IsNot pattern does not have its own should_replace attribute. Set should_replace in \
its sub-pattern instead.")
self.patterns = patterns
if patterns is None:
IsNot_.__init__(self, ())
NoneOf_.__init__(self, ())
elif isinstance(patterns, Pattern):
IsNot_.__init__(self, [patterns])
NoneOf_.__init__(self, [patterns])
elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
IsNot_.__init__(self, patterns)
NoneOf_.__init__(self, patterns)
else:
raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}")
@ -172,18 +154,14 @@ class NewTensor(NewTensor_):
r"""
New Tensor to be used in the target.
"""
def __init__(self, input_tensor, should_replace=False):
def __init__(self, input_tensor):
r"""
Args:
input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target
should_replace(bool): added this for interface consistency. NewTensor should only appear in the target.
Raises:
ValueError: raise if should_replace is True
TypeError: raise type error for invalid argument.
"""
if should_replace:
raise ValueError("NewTensor should only appear in the target, thus should_replace can only be False.")
self.input_tensor = input_tensor
if isinstance(input_tensor, Tensor):
NewTensor_.__init__(self, input_tensor)
@ -194,15 +172,13 @@ class NewParameter(NewParameter_):
r"""
New Parameter to be used in the target.
"""
def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False, should_replace=False):
def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False):
r"""
Args:
para_name(str): name for the new Parameter
default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter
requires_grad(bool): True if the parameter requires gradient. Default: True
layerwise_parallel(bool): switch for layerwise parallel mode. Default: False
should_replace(bool): gen new parameter once and replace after if set to be true; otherwise build a new
parameter everytime a pass target got built. Default: False
Raises:
TypeError: raise type error for invalid argument.
@ -211,12 +187,11 @@ class NewParameter(NewParameter_):
self.default_tensor = default_tensor
self.requires_grad = requires_grad
self.layerwise_parallel = layerwise_parallel
self.should_replace = should_replace
if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\
isinstance(layerwise_parallel, bool) and isinstance(should_replace, bool):
isinstance(layerwise_parallel, bool):
NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad,
self.layerwise_parallel, self.should_replace)
self.layerwise_parallel)
else:
raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \
layerwise_parallel(bool) should_replace(bool) got : {para_name}, {default_tensor}, \
{requires_grad}, {layerwise_parallel}, {should_replace}")
layerwise_parallel(bool) got : {para_name}, {default_tensor}, \
{requires_grad}, {layerwise_parallel}")

View File

@ -15,7 +15,7 @@
"""Python pass register"""
from inspect import isfunction
from mindspore.graph_utils.graph_pattern import Pattern, NewParameter
from mindspore._c_expression import PyPassManager_, phase
from mindspore._c_expression import PyPassManager_
__all__ = [
@ -30,21 +30,16 @@ class PyPassManager(PyPassManager_):
Used to registe and unregiste python passes which can be used to alter graphs.
Args:
pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt.
run_only_once (bool): Specify whether or not to run pass only once. Default: False.
multigraph (bool): Whether or not the pattern exists across graphs. Default: True.
Raises:
TypeError: If argument has invalid type.
"""
def __init__(self, pipeline_phase=phase.opt, run_only_once=False):
if not isinstance(pipeline_phase, phase):
raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}")
def __init__(self, run_only_once=False):
if not isinstance(run_only_once, bool):
raise TypeError(f"Expect bool, got : ({type(run_only_once)}){run_only_once}")
PyPassManager_.__init__(self)
self.phase_ = pipeline_phase
self.run_only_once_ = run_only_once
PyPassManager_.__init__(self)
def registe(self, py_pass):
if not isfunction(py_pass):
@ -55,16 +50,14 @@ class PyPassManager(PyPassManager_):
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
if not isinstance(target, Pattern):
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_)
super().registe(pass_name, pattern, target, self.run_only_once_)
def unregiste(self, py_pass, pipeline_phase=phase.opt):
if not isinstance(pipeline_phase, phase):
raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}")
def unregiste(self, py_pass):
if isinstance(py_pass, str):
super().unregiste(py_pass, pipeline_phase)
super().unregiste(py_pass)
return
if isfunction(py_pass):
super().unregiste(py_pass.__name__, pipeline_phase)
super().unregiste(py_pass.__name__)
return
raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}")
@ -82,13 +75,11 @@ class PyPassManager(PyPassManager_):
raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}")
super().set_renorm(should_renorm)
def registe_pass(pipeline_phase=phase.opt, run_only_once=False):
def registe_pass(run_only_once=False):
"""
Registe python pass to specified pipeline phase which would be used in compilation.
Args:
pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is
registed. Support phase.resolve and phase.opt. Default: phase.opt.
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False.
Returns:
@ -102,19 +93,17 @@ def registe_pass(pipeline_phase=phase.opt, run_only_once=False):
>>> target = IsPrimTypeOf("ReLU6")
>>> return pattern, target
"""
return PyPassManager(pipeline_phase, run_only_once)
return PyPassManager(run_only_once)
def unregiste_pass(py_pass, pipeline_phase=phase.opt):
def unregiste_pass(py_pass):
"""
Unregiste python pass.
Args:
py_pass(Union(str, function)): target python pass to unregiste.
pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is
unregisted. Support phase.resolve and phase.opt. Default: phase.opt.
"""
ppm = PyPassManager()
ppm.unregiste(py_pass, pipeline_phase)
ppm.unregiste(py_pass)
def gen_new_parameter(pattern):
"""
@ -164,7 +153,14 @@ def cancel_new_parameter(pattern):
def set_renorm(should_renorm):
"""
Set whether or not to do renorm after modified graph in python pass(es).
Set whether or not to do renormalization after modified graph in python pass(es).
Args:
should_renorm(bool): whether or not to do renormalization after modified graph in python pass(es).
NOTE:
This interface is mainly intended for testing modifying graph without worrying about its validity. Turn off
renormalization may BREAK the network.
"""
ppm = PyPassManager()
ppm.set_renorm(should_renorm)

View File

@ -23,8 +23,7 @@ from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_
cancel_new_parameter
from mindspore.common.api import _generate_pip_args
from mindspore._c_expression import generate_key, Executor_
from mindspore.graph_utils.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor,\
NewParameter, Imm
from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm
context.set_context(mode=context.GRAPH_MODE)
@ -50,11 +49,9 @@ def test_softmax_relu():
@registe_pass(run_only_once=True)
def softmax_relu_pass():
x = AnyPattern()
softmax_pattern = IsPrimTypeOf(P.Softmax())
pattern = CallWith(softmax_pattern, inputs=[x])
relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False)
target = CallWith(relu_pattern, inputs=[x])
x = Any()
pattern = Call(P.Softmax(), inputs=[x])
target = Call(P.ReLU(), inputs=[x])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
@ -74,13 +71,13 @@ def test_softmax_relu_sigmoid():
@registe_pass(run_only_once=True)
def softmax_relu_pass():
x = AnyPattern()
softmax_pattern = IsPrimTypeOf(P.Softmax())
pattern = CallWith(softmax_pattern, inputs=[x])
sigmoid_pattern = IsPrimTypeOf(P.Sigmoid(), should_replace=False)
call_sigmoid = CallWith(sigmoid_pattern, [x])
relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False)
target = CallWith(relu_pattern, inputs=[call_sigmoid])
x = Any()
softmax_pattern = Prim(P.Softmax())
pattern = Call(softmax_pattern, inputs=[x])
sigmoid_pattern = Prim(P.Sigmoid())
call_sigmoid = Call(sigmoid_pattern, [x])
relu_pattern = Prim(P.ReLU())
target = Call(relu_pattern, inputs=[call_sigmoid])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
@ -99,15 +96,15 @@ def test_isin_pattern_0():
@registe_pass(run_only_once=True)
def softmax_relu_pass():
x = AnyPattern()
softmax_pattern = IsPrimTypeOf(P.Softmax())
call_softmax = CallWith(softmax_pattern, inputs=[x])
relu_pattern = IsPrimTypeOf(P.ReLU())
call_relu = CallWith(relu_pattern, inputs=[x])
x = Any()
softmax_pattern = Prim(P.Softmax())
call_softmax = Call(softmax_pattern, inputs=[x])
relu_pattern = Prim(P.ReLU())
call_relu = Call(relu_pattern, inputs=[x])
pattern = IsIn([call_softmax, call_relu])
relu6_pattern = IsPrimTypeOf(P.ReLU6(), should_replace=False)
target = CallWith(relu6_pattern, inputs=[x])
pattern = OneOf([call_softmax, call_relu])
relu6_pattern = Prim(P.ReLU6())
target = Call(relu6_pattern, inputs=[x])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
unregiste_pass(softmax_relu_pass)
@ -123,18 +120,17 @@ def test_isin_pattern_1():
@registe_pass(run_only_once=True)
def softmax_neg_pass():
x = AnyPattern()
softmax_pattern = IsPrimTypeOf(P.Softmax())
call_softmax = CallWith(softmax_pattern, inputs=[x])
relu_pattern = IsPrimTypeOf(P.ReLU())
call_relu = CallWith(relu_pattern, inputs=[x])
x = Any()
softmax_pattern = Prim(P.Softmax())
call_softmax = Call(softmax_pattern, inputs=[x])
relu_pattern = Prim(P.ReLU())
call_relu = Call(relu_pattern, inputs=[x])
pattern = IsIn([call_softmax, call_relu])
neg_ops = IsPrimTypeOf(P.Neg(), should_replace=False)
target = CallWith(neg_ops, inputs=[pattern])
pattern = OneOf([call_softmax, call_relu])
neg_ops = Prim(P.Neg())
target = Call(neg_ops, inputs=[pattern])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
print(transformed_repr)
unregiste_pass(softmax_neg_pass)
assert "Neg" in transformed_repr
assert "Softmax" in transformed_repr
@ -167,11 +163,11 @@ def test_isnot_pattern_0():
"""
Sub a BN which does NOT take Conv as inputs to ReLU6.
"""
conv2d_prim = IsPrimTypeOf("Conv2D")
conv2d = CallWith(conv2d_prim)
pattern_0 = IsNot(conv2d)
pattern = CallWith(P.BatchNorm(), inputs=[pattern_0])
target = CallWith(P.ReLU6(), inputs=[pattern_0])
conv2d_prim = Prim("Conv2D")
conv2d = Call(conv2d_prim)
pattern_0 = NoneOf(conv2d)
pattern = Call(P.BatchNorm(), inputs=[pattern_0])
target = Call(P.ReLU6(), inputs=[pattern_0])
return pattern, target
@registe_pass(run_only_once=True)
@ -179,10 +175,8 @@ def test_isnot_pattern_0():
"""
Sub a BN to Softmax.
"""
bn = P.BatchNorm()
pattern = CallWith(bn)
softmax = P.Softmax()
target = CallWith(softmax, should_replace=False)
pattern = Call(P.BatchNorm())
target = Call(P.Softmax())
return pattern, target
transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5)
@ -205,12 +199,12 @@ def test_isnot_pattern_1():
"""
Sub a BN which does NOT take MatMul as inputs to ReLU6.
"""
matmul = IsPrimTypeOf("MatMul")
pattern_0 = IsNot(matmul)
matmul = Prim("MatMul")
pattern_0 = NoneOf(matmul)
softmax = P.Softmax()
pattern = CallWith(softmax, inputs=[pattern_0])
pattern = Call(softmax, inputs=[pattern_0])
relu6 = P.ReLU6()
target = CallWith(relu6, inputs=[pattern_0], should_replace=False)
target = Call(relu6, inputs=[pattern_0])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
@ -228,14 +222,12 @@ def test_newtensor_pattern():
@registe_pass(run_only_once=True)
def softmax_addn_pass():
x = AnyPattern()
softmax = P.Softmax()
pattern = CallWith(softmax, inputs=[x])
x = Any()
pattern = Call(P.Softmax(), inputs=[x])
weight_tensor = Tensor(np.zeros([42]), mindspore.float16)
new_weight = NewTensor(weight_tensor)
addn_ops = P.AddN()
target = CallWith(addn_ops, inputs=[x, new_weight], should_replace=False)
target = Call(P.AddN(), inputs=[x, new_weight])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
unregiste_pass(softmax_addn_pass)
@ -252,25 +244,23 @@ def test_newparameter_pattern():
@registe_pass(run_only_once=True)
def softmax_addn_pass():
x = AnyPattern()
softmax = P.Softmax()
pattern = CallWith(softmax, inputs=[x])
x = Any()
pattern = Call(P.Softmax(), inputs=[x])
default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32)
default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32)
new_para_0 = NewParameter("Merlin", default_tensor0)
new_para_1 = NewParameter("Arthur", default_tensor1)
target_0 = CallWith(P.MatMul(), inputs=[new_para_0, new_para_1], should_replace=False)
target = CallWith("make_tuple", inputs=[target_0], should_replace=False)
target_0 = Call(P.MatMul(), inputs=[new_para_0, new_para_1])
target = Call("make_tuple", inputs=[target_0])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
print(transformed_repr)
unregiste_pass(softmax_addn_pass)
assert "MatMul" in transformed_repr
assert "make_tuple" in transformed_repr
assert "Softmax" not in transformed_repr
def test_imm_pattern():
def test_imm_target():
"""
Test NewParameter pattern in the target
"""
@ -278,17 +268,15 @@ def test_imm_pattern():
softmax_model = nn.Softmax()
@registe_pass(run_only_once=True)
def softmax_addn_pass():
x = AnyPattern()
softmax = P.Softmax()
pattern = CallWith(softmax, inputs=[x])
def softmax_pass():
x = Any()
pattern = Call(P.Softmax(), inputs=[x])
imm = Imm(0)
target_0 = CallWith("make_tuple", inputs=[pattern], should_replace=False)
target = CallWith("tuple_getitem", inputs=[target_0, imm], should_replace=False)
target_0 = Call("make_tuple", inputs=[pattern])
target = Call("tuple_getitem", inputs=[target_0, imm])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
print(transformed_repr)
unregiste_pass(softmax_addn_pass)
unregiste_pass(softmax_pass)
assert "make_tuple" in transformed_repr
assert "tuple_getitem" in transformed_repr
assert "Softmax" in transformed_repr
@ -301,21 +289,19 @@ def test_gen_new_parameter():
softmax_model = nn.Softmax()
default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
new_para = NewParameter("Merlin", default_tensor, should_replace=True)
new_para = NewParameter("Merlin", default_tensor)
gen_new_parameter(new_para)
@registe_pass(run_only_once=True)
def softmax_make_tuple_pass():
x = AnyPattern()
x = Any()
softmax = P.Softmax()
pattern = CallWith(softmax, inputs=[x])
pattern = Call(softmax, inputs=[x])
target = CallWith("make_tuple", inputs=[pattern, new_para], should_replace=False)
target = Call("make_tuple", inputs=[pattern, new_para])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
print(transformed_repr)
assert "Merlin" in transformed_repr
unregiste_pass(softmax_make_tuple_pass)
cancel_new_parameter(new_para)
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
print(transformed_repr)
assert "Merlin" not in transformed_repr