forked from OSSInnovation/mindspore
!5586 Boost python pass compile and train performance
Merge pull request !5586 from BowenK/performance
This commit is contained in:
commit
1a4d3e351e
|
@ -59,6 +59,7 @@ class Pattern : public Base {
|
||||||
string unique_name() const { return unique_name_; }
|
string unique_name() const { return unique_name_; }
|
||||||
vector<PatternPtr> inputs() { return inputs_; }
|
vector<PatternPtr> inputs() { return inputs_; }
|
||||||
virtual void reset() {}
|
virtual void reset() {}
|
||||||
|
static void reset_gid() { g_id_ = 0; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
static int g_id_;
|
static int g_id_;
|
||||||
|
@ -213,7 +214,6 @@ class NewParameter : public Pattern {
|
||||||
explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel)
|
explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel)
|
||||||
: para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) {
|
: para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) {
|
||||||
unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name;
|
unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name;
|
||||||
// clone input tensor
|
|
||||||
default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get());
|
default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get());
|
||||||
built_ = false;
|
built_ = false;
|
||||||
}
|
}
|
||||||
|
@ -257,7 +257,7 @@ class MatchResult {
|
||||||
MatchResult() {}
|
MatchResult() {}
|
||||||
~MatchResult() = default;
|
~MatchResult() = default;
|
||||||
void add_entry(PatternPtr pattern, AnfNodePtr node) { match_result_[pattern] = node; }
|
void add_entry(PatternPtr pattern, AnfNodePtr node) { match_result_[pattern] = node; }
|
||||||
PatternNodeMap _result() { return match_result_; }
|
PatternNodeMap &_result() { return match_result_; }
|
||||||
AnfNodePtr get_node(const PatternPtr &pattern);
|
AnfNodePtr get_node(const PatternPtr &pattern);
|
||||||
void merge(const MatchResultPtr &other_result);
|
void merge(const MatchResultPtr &other_result);
|
||||||
void clear() { match_result_.clear(); }
|
void clear() { match_result_.clear(); }
|
||||||
|
|
|
@ -27,8 +27,6 @@
|
||||||
#include "pipeline/jit/resource.h"
|
#include "pipeline/jit/resource.h"
|
||||||
#include "frontend/optimizer/py_pass_manager.h"
|
#include "frontend/optimizer/py_pass_manager.h"
|
||||||
#include "utils/info.h"
|
#include "utils/info.h"
|
||||||
#include "debug/anf_ir_dump.h"
|
|
||||||
#include "debug/draw.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -42,29 +40,6 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph
|
||||||
void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input,
|
void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input,
|
||||||
bool requires_grad, bool layerwise_parallel);
|
bool requires_grad, bool layerwise_parallel);
|
||||||
|
|
||||||
std::string GetNodeRepr(AnfNodePtr node) {
|
|
||||||
if (node != nullptr) {
|
|
||||||
if (node->isa<CNode>()) {
|
|
||||||
std::string repr = "(";
|
|
||||||
auto const &inputs = node->cast<CNodePtr>()->inputs();
|
|
||||||
for (auto &input : inputs) {
|
|
||||||
repr += " ";
|
|
||||||
repr += GetNodeRepr(input);
|
|
||||||
repr += " ";
|
|
||||||
}
|
|
||||||
repr += ")";
|
|
||||||
return repr;
|
|
||||||
}
|
|
||||||
if (node->isa<Parameter>()) {
|
|
||||||
return "[Parameter]" + node->ToString();
|
|
||||||
} else if (node->isa<ValueNode>()) {
|
|
||||||
return "[Value]" + GetValueNode(node)->ToString();
|
|
||||||
}
|
|
||||||
return node->ToString();
|
|
||||||
}
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsTraversable(const AnfNodePtr &node) {
|
bool IsTraversable(const AnfNodePtr &node) {
|
||||||
if (node == nullptr) {
|
if (node == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -215,23 +190,6 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph
|
||||||
return new_node;
|
return new_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DrawNode(string name, AnfNodePtr node) {
|
|
||||||
auto context_ptr = MsContext::GetInstance();
|
|
||||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
|
||||||
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
|
||||||
if (save_graphs_path.empty()) {
|
|
||||||
save_graphs_path = ".";
|
|
||||||
}
|
|
||||||
auto new_func_graph = std::make_shared<FuncGraph>();
|
|
||||||
new_func_graph->set_output(node, true);
|
|
||||||
if (save_graphs) {
|
|
||||||
auto ir_dump_path = save_graphs_path + "/" + name + ".ir";
|
|
||||||
auto dot_dump_path = save_graphs_path + "/" + name + ".dot";
|
|
||||||
DumpIR(ir_dump_path, new_func_graph);
|
|
||||||
draw::Draw(dot_dump_path, new_func_graph);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input,
|
void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input,
|
||||||
bool requires_grad, bool layerwise_parallel) {
|
bool requires_grad, bool layerwise_parallel) {
|
||||||
// 1. Get current cell object
|
// 1. Get current cell object
|
||||||
|
@ -241,12 +199,15 @@ void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor
|
||||||
if (py::isinstance<py::none>(top_cell)) {
|
if (py::isinstance<py::none>(top_cell)) {
|
||||||
MS_LOG(EXCEPTION) << "Failed to get top cell from resource.";
|
MS_LOG(EXCEPTION) << "Failed to get top cell from resource.";
|
||||||
}
|
}
|
||||||
// 2. New a Parameter object with the above-specified args
|
// 2. Clone default_input tensor
|
||||||
|
auto default_tensor = std::make_shared<tensor::Tensor>(default_input->data_type(), default_input->shape_c(),
|
||||||
|
default_input->data_c(), (size_t)default_input->Size());
|
||||||
|
// 3. New a Parameter object with the above-specified args
|
||||||
py::object parameter_class = py::module::import(PARAMETER_MODULE).attr(PARAMETER_CLASS);
|
py::object parameter_class = py::module::import(PARAMETER_MODULE).attr(PARAMETER_CLASS);
|
||||||
py::object new_parameter = parameter_class(default_input, param_name, requires_grad, layerwise_parallel);
|
py::object new_parameter = parameter_class(default_tensor, param_name, requires_grad, layerwise_parallel);
|
||||||
// 3. Add the new python Parameter object to Cell's _params atttributes
|
// 4. Add the new python Parameter object to Cell's _params atttributes
|
||||||
top_cell.attr(SET_PARAM)(param_name, new_parameter);
|
top_cell.attr(SET_PARAM)(param_name, new_parameter);
|
||||||
// 4. Set default_param for param_node
|
// 5. Set default_param for param_node
|
||||||
ValuePtr param_value = nullptr;
|
ValuePtr param_value = nullptr;
|
||||||
bool converted = parse::ConvertData(new_parameter, ¶m_value, false);
|
bool converted = parse::ConvertData(new_parameter, ¶m_value, false);
|
||||||
if (!converted) {
|
if (!converted) {
|
||||||
|
@ -282,11 +243,9 @@ void Reset(PatternPtr pattern) {
|
||||||
AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res) {
|
AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res) {
|
||||||
auto match_res = src_pattern_->match(node);
|
auto match_res = src_pattern_->match(node);
|
||||||
if (match_res != nullptr) {
|
if (match_res != nullptr) {
|
||||||
MS_LOG(DEBUG) << "Matched pattern: " + src_pattern_->unique_name() + " node : " + internal::GetNodeRepr(node);
|
|
||||||
res->merge(match_res);
|
res->merge(match_res);
|
||||||
auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res);
|
auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res);
|
||||||
internal::Reset(dst_pattern());
|
internal::Reset(dst_pattern());
|
||||||
MS_LOG(WARNING) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n";
|
|
||||||
return new_node;
|
return new_node;
|
||||||
}
|
}
|
||||||
internal::Reset(src_pattern());
|
internal::Reset(src_pattern());
|
||||||
|
@ -303,7 +262,6 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res)
|
||||||
MS_LOG(EXCEPTION) << "Expect NewParameter pattern for target if src pattern is null.";
|
MS_LOG(EXCEPTION) << "Expect NewParameter pattern for target if src pattern is null.";
|
||||||
}
|
}
|
||||||
auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name();
|
auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name();
|
||||||
MS_LOG(DEBUG) << "Adding New parameter : " + para_name;
|
|
||||||
auto para_node = std::make_shared<Parameter>(func_graph);
|
auto para_node = std::make_shared<Parameter>(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(para_node);
|
MS_EXCEPTION_IF_NULL(para_node);
|
||||||
para_node->set_name(para_name);
|
para_node->set_name(para_name);
|
||||||
|
@ -321,7 +279,7 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res)
|
||||||
// Reflect back to Cell._params
|
// Reflect back to Cell._params
|
||||||
internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(),
|
internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(),
|
||||||
new_para_pattern->layerwise_parallel());
|
new_para_pattern->layerwise_parallel());
|
||||||
MS_LOG(WARNING) << "Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name();
|
MS_LOG(WARNING) << "[Gen]Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
FuncGraphManagerPtr manager = func_graph->manager();
|
FuncGraphManagerPtr manager = func_graph->manager();
|
||||||
|
@ -334,7 +292,6 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res)
|
||||||
for (auto &node : graph_nodes_sorted) {
|
for (auto &node : graph_nodes_sorted) {
|
||||||
AnfNodePtr new_node = Run(func_graph, node, res);
|
AnfNodePtr new_node = Run(func_graph, node, res);
|
||||||
if (new_node != nullptr && new_node != node) {
|
if (new_node != nullptr && new_node != node) {
|
||||||
internal::DrawNode(dst_pattern_->unique_name(), new_node);
|
|
||||||
(void)manager->Replace(node, new_node);
|
(void)manager->Replace(node, new_node);
|
||||||
changes = true;
|
changes = true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -98,7 +98,8 @@ REGISTER_PYBIND_DEFINE(
|
||||||
.def("registe", &PyPassManager::Registe, "Registe python pass")
|
.def("registe", &PyPassManager::Registe, "Registe python pass")
|
||||||
.def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass")
|
.def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass")
|
||||||
.def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter")
|
.def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter")
|
||||||
.def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph");
|
.def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph")
|
||||||
|
.def("set_reopt", &PyPassManager::SetReOpt, "Set whether or not to do optimization after modified graph");
|
||||||
}));
|
}));
|
||||||
} // namespace python_pass
|
} // namespace python_pass
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
|
|
|
@ -60,13 +60,19 @@ class PyPassManager {
|
||||||
MatchResultPtr GetMatchResult() { return res_; }
|
MatchResultPtr GetMatchResult() { return res_; }
|
||||||
void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; }
|
void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; }
|
||||||
bool ShouldRenorm() { return should_renorm_; }
|
bool ShouldRenorm() { return should_renorm_; }
|
||||||
|
void SetReOpt(bool should_reopt) { should_reopt_ = should_reopt; }
|
||||||
|
bool ShouldReOpt() { return should_reopt_; }
|
||||||
void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; }
|
void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; }
|
||||||
pipeline::ResourcePtr GetResource() { return resource_; }
|
pipeline::ResourcePtr GetResource() { return resource_; }
|
||||||
void ClearRes();
|
void ClearRes();
|
||||||
void ClearPipelineRes() { resource_ = nullptr; }
|
void ClearPipelineRes() {
|
||||||
|
resource_ = nullptr;
|
||||||
|
Pattern::reset_gid();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool should_renorm_ = true;
|
bool should_renorm_ = true;
|
||||||
|
bool should_reopt_ = true;
|
||||||
MatchResultPtr res_;
|
MatchResultPtr res_;
|
||||||
pipeline::ResourcePtr resource_;
|
pipeline::ResourcePtr resource_;
|
||||||
static std::unordered_map<Phase, PassGroupPtr> phase_to_group_;
|
static std::unordered_map<Phase, PassGroupPtr> phase_to_group_;
|
||||||
|
|
|
@ -451,15 +451,19 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
|
||||||
|
|
||||||
bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); }
|
bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); }
|
||||||
|
|
||||||
void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
|
bool ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
|
||||||
MS_EXCEPTION_IF_NULL(res->manager());
|
MS_EXCEPTION_IF_NULL(res->manager());
|
||||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||||
auto ppm = opt::python_pass::PyPassManager::GetInstance();
|
auto ppm = opt::python_pass::PyPassManager::GetInstance();
|
||||||
ppm->SetResource(res);
|
ppm->SetResource(res);
|
||||||
if (!ppm->GetPassGroup(phase)->Run(res->func_graph())) {
|
return ppm->GetPassGroup(phase)->Run(res->func_graph());
|
||||||
MS_LOG(DEBUG) << "No match.\n";
|
}
|
||||||
} else if (phase == opt::python_pass::Phase::OPT && opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
|
|
||||||
MS_LOG(DEBUG) << "Entered PyStub Renorm";
|
bool ResolveActionPyStub(const ResourcePtr &res) { return true || ActionPyStub(res, opt::python_pass::Phase::RESOLVE); }
|
||||||
|
|
||||||
|
bool OptActionVmPyStub(const ResourcePtr &res) {
|
||||||
|
if (ActionPyStub(res, opt::python_pass::Phase::OPT)) {
|
||||||
|
if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
|
||||||
// Renomalize
|
// Renomalize
|
||||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||||
FuncGraphPtr func_graph = res->func_graph();
|
FuncGraphPtr func_graph = res->func_graph();
|
||||||
|
@ -471,15 +475,31 @@ void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
|
||||||
res->set_func_graph(new_fg);
|
res->set_func_graph(new_fg);
|
||||||
res->set_args_spec(args_spec);
|
res->set_args_spec(args_spec);
|
||||||
}
|
}
|
||||||
}
|
if (opt::python_pass::PyPassManager::GetInstance()->ShouldReOpt()) {
|
||||||
|
return VmOptimizeAction(res);
|
||||||
bool ResolveActionPyStub(const ResourcePtr &res) {
|
}
|
||||||
ActionPyStub(res, opt::python_pass::Phase::RESOLVE);
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OptActionPyStub(const ResourcePtr &res) {
|
bool OptActionGePyStub(const ResourcePtr &res) {
|
||||||
ActionPyStub(res, opt::python_pass::Phase::OPT);
|
if (ActionPyStub(res, opt::python_pass::Phase::OPT)) {
|
||||||
|
if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
|
||||||
|
// Renomalize
|
||||||
|
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||||
|
FuncGraphPtr func_graph = res->func_graph();
|
||||||
|
abstract::AbstractBasePtrList args_spec;
|
||||||
|
auto parameters = func_graph->parameters();
|
||||||
|
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
|
||||||
|
[](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
|
||||||
|
FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
|
||||||
|
res->set_func_graph(new_fg);
|
||||||
|
res->set_args_spec(args_spec);
|
||||||
|
}
|
||||||
|
if (opt::python_pass::PyPassManager::GetInstance()->ShouldReOpt()) {
|
||||||
|
return GeOptimizeAction(res);
|
||||||
|
}
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -510,7 +530,7 @@ std::vector<ActionItem> GePipeline() {
|
||||||
// optimize
|
// optimize
|
||||||
actions.emplace_back(std::make_pair("optimize", GeOptimizeAction));
|
actions.emplace_back(std::make_pair("optimize", GeOptimizeAction));
|
||||||
// Add opt-stage python pass stub
|
// Add opt-stage python pass stub
|
||||||
actions.emplace_back(std::make_pair("py_opt", OptActionPyStub));
|
actions.emplace_back(std::make_pair("py_opt", OptActionGePyStub));
|
||||||
actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction));
|
actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction));
|
||||||
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||||
return actions;
|
return actions;
|
||||||
|
@ -523,7 +543,7 @@ std::vector<ActionItem> VmPipeline() {
|
||||||
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
||||||
|
|
||||||
// Add opt-stage python pass stub
|
// Add opt-stage python pass stub
|
||||||
actions.emplace_back(std::make_pair("py_opt", OptActionPyStub));
|
actions.emplace_back(std::make_pair("py_opt", OptActionVmPyStub));
|
||||||
|
|
||||||
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||||
|
|
|
@ -12,13 +12,15 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Top-level reference to python pass."""
|
"""Reference for python pass registration."""
|
||||||
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm
|
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\
|
||||||
|
set_reopt
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"registe_pass",
|
"registe_pass",
|
||||||
"unregiste_pass",
|
"unregiste_pass",
|
||||||
"gen_new_parameter",
|
"gen_new_parameter",
|
||||||
"cancel_new_parameter",
|
"cancel_new_parameter",
|
||||||
"set_renorm"
|
"set_renorm",
|
||||||
|
"set_reopt"
|
||||||
]
|
]
|
||||||
|
|
|
@ -23,7 +23,8 @@ __all__ = [
|
||||||
"unregiste_pass",
|
"unregiste_pass",
|
||||||
"gen_new_parameter",
|
"gen_new_parameter",
|
||||||
"cancel_new_parameter",
|
"cancel_new_parameter",
|
||||||
"set_renorm"
|
"set_renorm",
|
||||||
|
"set_reopt"
|
||||||
]
|
]
|
||||||
class PyPassManager(PyPassManager_):
|
class PyPassManager(PyPassManager_):
|
||||||
r"""
|
r"""
|
||||||
|
@ -75,6 +76,11 @@ class PyPassManager(PyPassManager_):
|
||||||
raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}")
|
raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}")
|
||||||
super().set_renorm(should_renorm)
|
super().set_renorm(should_renorm)
|
||||||
|
|
||||||
|
def set_reopt(self, do_reopt):
|
||||||
|
if not isinstance(do_reopt, bool):
|
||||||
|
raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}")
|
||||||
|
super().set_reopt(do_reopt)
|
||||||
|
|
||||||
def registe_pass(run_only_once=False):
|
def registe_pass(run_only_once=False):
|
||||||
"""
|
"""
|
||||||
Registe python pass to specified pipeline phase which would be used in compilation.
|
Registe python pass to specified pipeline phase which would be used in compilation.
|
||||||
|
@ -164,3 +170,17 @@ def set_renorm(should_renorm):
|
||||||
"""
|
"""
|
||||||
ppm = PyPassManager()
|
ppm = PyPassManager()
|
||||||
ppm.set_renorm(should_renorm)
|
ppm.set_renorm(should_renorm)
|
||||||
|
|
||||||
|
def set_reopt(do_reopt):
|
||||||
|
"""
|
||||||
|
Set whether or not to do optimization after modified graph in python pass(es).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_reopt(bool): whether or not to do optimization after modified graph in python pass(es).
|
||||||
|
|
||||||
|
NOTE:
|
||||||
|
This interface is mainly intended for testing modifying graph without worrying about its validity. Turn off
|
||||||
|
renormalization may BREAK the network.
|
||||||
|
"""
|
||||||
|
ppm = PyPassManager()
|
||||||
|
ppm.set_reopt(do_reopt)
|
||||||
|
|
|
@ -20,7 +20,7 @@ from mindspore import context
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
|
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
|
||||||
cancel_new_parameter
|
cancel_new_parameter, set_reopt
|
||||||
from mindspore.common.api import _generate_pip_args
|
from mindspore.common.api import _generate_pip_args
|
||||||
from mindspore._c_expression import generate_key, Executor_
|
from mindspore._c_expression import generate_key, Executor_
|
||||||
from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm
|
from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm
|
||||||
|
@ -50,8 +50,8 @@ def test_softmax_relu():
|
||||||
@registe_pass(run_only_once=True)
|
@registe_pass(run_only_once=True)
|
||||||
def softmax_relu_pass():
|
def softmax_relu_pass():
|
||||||
x = Any()
|
x = Any()
|
||||||
pattern = Call(P.Softmax(), inputs=[x])
|
pattern = Call(P.Softmax(), [x])
|
||||||
target = Call(P.ReLU(), inputs=[x])
|
target = Call(P.ReLU(), [x])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
|
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||||
|
@ -59,6 +59,23 @@ def test_softmax_relu():
|
||||||
assert "ReLU" in transformed_repr
|
assert "ReLU" in transformed_repr
|
||||||
assert "Softmax" not in transformed_repr
|
assert "Softmax" not in transformed_repr
|
||||||
|
|
||||||
|
def test_prim():
|
||||||
|
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||||
|
softmax_model = nn.Softmax()
|
||||||
|
|
||||||
|
@registe_pass(run_only_once=True)
|
||||||
|
def softmax_relu_pass():
|
||||||
|
x = Any()
|
||||||
|
sigmoid_softmax_pattern = Prim([P.Sigmoid(), P.Softmax()])
|
||||||
|
pattern = Call(sigmoid_softmax_pattern, [x])
|
||||||
|
target = Call(P.ReLU(), [x])
|
||||||
|
return pattern, target
|
||||||
|
|
||||||
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
|
||||||
|
unregiste_pass(softmax_relu_pass)
|
||||||
|
assert "ReLU" in transformed_repr
|
||||||
|
assert "Softmax" not in transformed_repr
|
||||||
|
|
||||||
def test_softmax_relu_sigmoid():
|
def test_softmax_relu_sigmoid():
|
||||||
"""
|
"""
|
||||||
Use python pass to transform from Softmax(x) to ReLU(Sigmoid(x)).
|
Use python pass to transform from Softmax(x) to ReLU(Sigmoid(x)).
|
||||||
|
@ -73,11 +90,11 @@ def test_softmax_relu_sigmoid():
|
||||||
def softmax_relu_pass():
|
def softmax_relu_pass():
|
||||||
x = Any()
|
x = Any()
|
||||||
softmax_pattern = Prim(P.Softmax())
|
softmax_pattern = Prim(P.Softmax())
|
||||||
pattern = Call(softmax_pattern, inputs=[x])
|
pattern = Call(softmax_pattern, [x])
|
||||||
sigmoid_pattern = Prim(P.Sigmoid())
|
sigmoid_pattern = Prim(P.Sigmoid())
|
||||||
call_sigmoid = Call(sigmoid_pattern, [x])
|
call_sigmoid = Call(sigmoid_pattern, [x])
|
||||||
relu_pattern = Prim(P.ReLU())
|
relu_pattern = Prim(P.ReLU())
|
||||||
target = Call(relu_pattern, inputs=[call_sigmoid])
|
target = Call(relu_pattern, [call_sigmoid])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
|
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
|
||||||
|
@ -98,13 +115,13 @@ def test_isin_pattern_0():
|
||||||
def softmax_relu_pass():
|
def softmax_relu_pass():
|
||||||
x = Any()
|
x = Any()
|
||||||
softmax_pattern = Prim(P.Softmax())
|
softmax_pattern = Prim(P.Softmax())
|
||||||
call_softmax = Call(softmax_pattern, inputs=[x])
|
call_softmax = Call(softmax_pattern, [x])
|
||||||
relu_pattern = Prim(P.ReLU())
|
relu_pattern = Prim(P.ReLU())
|
||||||
call_relu = Call(relu_pattern, inputs=[x])
|
call_relu = Call(relu_pattern, [x])
|
||||||
|
|
||||||
pattern = OneOf([call_softmax, call_relu])
|
pattern = OneOf([call_softmax, call_relu])
|
||||||
relu6_pattern = Prim(P.ReLU6())
|
relu6_pattern = Prim(P.ReLU6())
|
||||||
target = Call(relu6_pattern, inputs=[x])
|
target = Call(relu6_pattern, [x])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||||
unregiste_pass(softmax_relu_pass)
|
unregiste_pass(softmax_relu_pass)
|
||||||
|
@ -122,13 +139,13 @@ def test_isin_pattern_1():
|
||||||
def softmax_neg_pass():
|
def softmax_neg_pass():
|
||||||
x = Any()
|
x = Any()
|
||||||
softmax_pattern = Prim(P.Softmax())
|
softmax_pattern = Prim(P.Softmax())
|
||||||
call_softmax = Call(softmax_pattern, inputs=[x])
|
call_softmax = Call(softmax_pattern, [x])
|
||||||
relu_pattern = Prim(P.ReLU())
|
relu_pattern = Prim(P.ReLU())
|
||||||
call_relu = Call(relu_pattern, inputs=[x])
|
call_relu = Call(relu_pattern, [x])
|
||||||
|
|
||||||
pattern = OneOf([call_softmax, call_relu])
|
pattern = OneOf([call_softmax, call_relu])
|
||||||
neg_ops = Prim(P.Neg())
|
neg_ops = Prim(P.Neg())
|
||||||
target = Call(neg_ops, inputs=[pattern])
|
target = Call(neg_ops, [pattern])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
|
||||||
unregiste_pass(softmax_neg_pass)
|
unregiste_pass(softmax_neg_pass)
|
||||||
|
@ -141,6 +158,7 @@ def test_isnot_pattern_0():
|
||||||
Case: IsNot pass failed to match
|
Case: IsNot pass failed to match
|
||||||
"""
|
"""
|
||||||
set_renorm(False)
|
set_renorm(False)
|
||||||
|
set_reopt(False)
|
||||||
class ConvBN(nn.Cell):
|
class ConvBN(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(ConvBN, self).__init__()
|
super(ConvBN, self).__init__()
|
||||||
|
@ -166,8 +184,8 @@ def test_isnot_pattern_0():
|
||||||
conv2d_prim = Prim("Conv2D")
|
conv2d_prim = Prim("Conv2D")
|
||||||
conv2d = Call(conv2d_prim)
|
conv2d = Call(conv2d_prim)
|
||||||
pattern_0 = NoneOf(conv2d)
|
pattern_0 = NoneOf(conv2d)
|
||||||
pattern = Call(P.BatchNorm(), inputs=[pattern_0])
|
pattern = Call(P.BatchNorm(), [pattern_0])
|
||||||
target = Call(P.ReLU6(), inputs=[pattern_0])
|
target = Call(P.ReLU6(), [pattern_0])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
|
|
||||||
@registe_pass(run_only_once=True)
|
@registe_pass(run_only_once=True)
|
||||||
|
@ -202,9 +220,9 @@ def test_isnot_pattern_1():
|
||||||
matmul = Prim("MatMul")
|
matmul = Prim("MatMul")
|
||||||
pattern_0 = NoneOf(matmul)
|
pattern_0 = NoneOf(matmul)
|
||||||
softmax = P.Softmax()
|
softmax = P.Softmax()
|
||||||
pattern = Call(softmax, inputs=[pattern_0])
|
pattern = Call(softmax, [pattern_0])
|
||||||
relu6 = P.ReLU6()
|
relu6 = P.ReLU6()
|
||||||
target = Call(relu6, inputs=[pattern_0])
|
target = Call(relu6, [pattern_0])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
|
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||||
|
@ -217,17 +235,18 @@ def test_newtensor_pattern():
|
||||||
Test NewTensor pattern in the target
|
Test NewTensor pattern in the target
|
||||||
"""
|
"""
|
||||||
set_renorm(False)
|
set_renorm(False)
|
||||||
|
set_reopt(False)
|
||||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||||
softmax_model = nn.Softmax()
|
softmax_model = nn.Softmax()
|
||||||
|
|
||||||
@registe_pass(run_only_once=True)
|
@registe_pass(run_only_once=True)
|
||||||
def softmax_addn_pass():
|
def softmax_addn_pass():
|
||||||
x = Any()
|
x = Any()
|
||||||
pattern = Call(P.Softmax(), inputs=[x])
|
pattern = Call(P.Softmax(), [x])
|
||||||
|
|
||||||
weight_tensor = Tensor(np.zeros([42]), mindspore.float16)
|
weight_tensor = Tensor(np.zeros([42]), mindspore.float16)
|
||||||
new_weight = NewTensor(weight_tensor)
|
new_weight = NewTensor(weight_tensor)
|
||||||
target = Call(P.AddN(), inputs=[x, new_weight])
|
target = Call(P.AddN(), [x, new_weight])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
||||||
unregiste_pass(softmax_addn_pass)
|
unregiste_pass(softmax_addn_pass)
|
||||||
|
@ -242,17 +261,19 @@ def test_newparameter_pattern():
|
||||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||||
softmax_model = nn.Softmax()
|
softmax_model = nn.Softmax()
|
||||||
|
|
||||||
|
set_renorm(False)
|
||||||
|
set_reopt(False)
|
||||||
@registe_pass(run_only_once=True)
|
@registe_pass(run_only_once=True)
|
||||||
def softmax_addn_pass():
|
def softmax_addn_pass():
|
||||||
x = Any()
|
x = Any()
|
||||||
pattern = Call(P.Softmax(), inputs=[x])
|
pattern = Call(P.Softmax(), [x])
|
||||||
|
|
||||||
default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32)
|
default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32)
|
||||||
default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32)
|
default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32)
|
||||||
new_para_0 = NewParameter("Merlin", default_tensor0)
|
new_para_0 = NewParameter("Merlin", default_tensor0)
|
||||||
new_para_1 = NewParameter("Arthur", default_tensor1)
|
new_para_1 = NewParameter("Arthur", default_tensor1)
|
||||||
target_0 = Call(P.MatMul(), inputs=[new_para_0, new_para_1])
|
target_0 = Call(P.MatMul(), [new_para_0, new_para_1])
|
||||||
target = Call("make_tuple", inputs=[target_0])
|
target = Call("make_tuple", [target_0])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||||
unregiste_pass(softmax_addn_pass)
|
unregiste_pass(softmax_addn_pass)
|
||||||
|
@ -267,13 +288,15 @@ def test_imm_target():
|
||||||
inputs = Tensor(np.ones([42]), mindspore.float16)
|
inputs = Tensor(np.ones([42]), mindspore.float16)
|
||||||
softmax_model = nn.Softmax()
|
softmax_model = nn.Softmax()
|
||||||
|
|
||||||
|
set_renorm(False)
|
||||||
|
set_reopt(False)
|
||||||
@registe_pass(run_only_once=True)
|
@registe_pass(run_only_once=True)
|
||||||
def softmax_pass():
|
def softmax_pass():
|
||||||
x = Any()
|
x = Any()
|
||||||
pattern = Call(P.Softmax(), inputs=[x])
|
pattern = Call(P.Softmax(), [x])
|
||||||
imm = Imm(0)
|
imm = Imm(0)
|
||||||
target_0 = Call("make_tuple", inputs=[pattern])
|
target_0 = Call("make_tuple", [pattern])
|
||||||
target = Call("tuple_getitem", inputs=[target_0, imm])
|
target = Call("tuple_getitem", [target_0, imm])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||||
unregiste_pass(softmax_pass)
|
unregiste_pass(softmax_pass)
|
||||||
|
@ -290,14 +313,16 @@ def test_gen_new_parameter():
|
||||||
|
|
||||||
default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
|
default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
|
||||||
new_para = NewParameter("Merlin", default_tensor)
|
new_para = NewParameter("Merlin", default_tensor)
|
||||||
|
set_renorm(False)
|
||||||
|
set_reopt(False)
|
||||||
gen_new_parameter(new_para)
|
gen_new_parameter(new_para)
|
||||||
@registe_pass(run_only_once=True)
|
@registe_pass(run_only_once=True)
|
||||||
def softmax_make_tuple_pass():
|
def softmax_make_tuple_pass():
|
||||||
x = Any()
|
x = Any()
|
||||||
softmax = P.Softmax()
|
softmax = P.Softmax()
|
||||||
pattern = Call(softmax, inputs=[x])
|
pattern = Call(softmax, [x])
|
||||||
|
|
||||||
target = Call("make_tuple", inputs=[pattern, new_para])
|
target = Call("make_tuple", [pattern, new_para])
|
||||||
return pattern, target
|
return pattern, target
|
||||||
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
||||||
assert "Merlin" in transformed_repr
|
assert "Merlin" in transformed_repr
|
||||||
|
|
Loading…
Reference in New Issue