!5586 Boost python pass compile and train performance

Merge pull request !5586 from BowenK/performance
This commit is contained in:
mindspore-ci-bot 2020-09-01 09:19:02 +08:00 committed by Gitee
commit 1a4d3e351e
8 changed files with 137 additions and 106 deletions

View File

@ -59,6 +59,7 @@ class Pattern : public Base {
string unique_name() const { return unique_name_; }
vector<PatternPtr> inputs() { return inputs_; }
virtual void reset() {}
static void reset_gid() { g_id_ = 0; }
protected:
static int g_id_;
@ -213,7 +214,6 @@ class NewParameter : public Pattern {
explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel)
: para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) {
unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name;
// clone input tensor
default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get());
built_ = false;
}
@ -257,7 +257,7 @@ class MatchResult {
MatchResult() {}
~MatchResult() = default;
void add_entry(PatternPtr pattern, AnfNodePtr node) { match_result_[pattern] = node; }
PatternNodeMap _result() { return match_result_; }
PatternNodeMap &_result() { return match_result_; }
AnfNodePtr get_node(const PatternPtr &pattern);
void merge(const MatchResultPtr &other_result);
void clear() { match_result_.clear(); }

View File

@ -27,8 +27,6 @@
#include "pipeline/jit/resource.h"
#include "frontend/optimizer/py_pass_manager.h"
#include "utils/info.h"
#include "debug/anf_ir_dump.h"
#include "debug/draw.h"
namespace mindspore {
namespace opt {
@ -42,29 +40,6 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph
void ReflectParamBackToPython(const AnfNodePtr &param, string param_name, tensor::TensorPtr default_input,
bool requires_grad, bool layerwise_parallel);
std::string GetNodeRepr(AnfNodePtr node) {
if (node != nullptr) {
if (node->isa<CNode>()) {
std::string repr = "(";
auto const &inputs = node->cast<CNodePtr>()->inputs();
for (auto &input : inputs) {
repr += " ";
repr += GetNodeRepr(input);
repr += " ";
}
repr += ")";
return repr;
}
if (node->isa<Parameter>()) {
return "[Parameter]" + node->ToString();
} else if (node->isa<ValueNode>()) {
return "[Value]" + GetValueNode(node)->ToString();
}
return node->ToString();
}
return "";
}
bool IsTraversable(const AnfNodePtr &node) {
if (node == nullptr) {
return false;
@ -215,23 +190,6 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph
return new_node;
}
void DrawNode(string name, AnfNodePtr node) {
auto context_ptr = MsContext::GetInstance();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
auto new_func_graph = std::make_shared<FuncGraph>();
new_func_graph->set_output(node, true);
if (save_graphs) {
auto ir_dump_path = save_graphs_path + "/" + name + ".ir";
auto dot_dump_path = save_graphs_path + "/" + name + ".dot";
DumpIR(ir_dump_path, new_func_graph);
draw::Draw(dot_dump_path, new_func_graph);
}
}
void ReflectParamBackToPython(const AnfNodePtr &param, string param_name, tensor::TensorPtr default_input,
bool requires_grad, bool layerwise_parallel) {
// 1. Get current cell object
@ -241,12 +199,15 @@ void ReflectParamBackToPython(const AnfNodePtr &param, string param_name, tensor
if (py::isinstance<py::none>(top_cell)) {
MS_LOG(EXCEPTION) << "Failed to get top cell from resource.";
}
// 2. New a Parameter object with the above-specified args
// 2. Clone default_input tensor
auto default_tensor = std::make_shared<tensor::Tensor>(default_input->data_type(), default_input->shape_c(),
default_input->data_c(), (size_t)default_input->Size());
// 3. New a Parameter object with the above-specified args
py::object parameter_class = py::module::import(PARAMETER_MODULE).attr(PARAMETER_CLASS);
py::object new_parameter = parameter_class(default_input, param_name, requires_grad, layerwise_parallel);
// 3. Add the new python Parameter object to Cell's _params atttributes
py::object new_parameter = parameter_class(default_tensor, param_name, requires_grad, layerwise_parallel);
// 4. Add the new python Parameter object to Cell's _params atttributes
top_cell.attr(SET_PARAM)(param_name, new_parameter);
// 4. Set default_param for param_node
// 5. Set default_param for param_node
ValuePtr param_value = nullptr;
bool converted = parse::ConvertData(new_parameter, &param_value, false);
if (!converted) {
@ -282,11 +243,9 @@ void Reset(PatternPtr pattern) {
AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res) {
auto match_res = src_pattern_->match(node);
if (match_res != nullptr) {
MS_LOG(DEBUG) << "Matched pattern: " + src_pattern_->unique_name() + " node : " + internal::GetNodeRepr(node);
res->merge(match_res);
auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res);
internal::Reset(dst_pattern());
MS_LOG(WARNING) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n";
return new_node;
}
internal::Reset(src_pattern());
@ -303,7 +262,6 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res)
MS_LOG(EXCEPTION) << "Expect NewParameter pattern for target if src pattern is null.";
}
auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name();
MS_LOG(DEBUG) << "Adding New parameter : " + para_name;
auto para_node = std::make_shared<Parameter>(func_graph);
MS_EXCEPTION_IF_NULL(para_node);
para_node->set_name(para_name);
@ -321,7 +279,7 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res)
// Reflect back to Cell._params
internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(),
new_para_pattern->layerwise_parallel());
MS_LOG(WARNING) << "Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name();
MS_LOG(WARNING) << "[Gen]Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name();
return true;
}
FuncGraphManagerPtr manager = func_graph->manager();
@ -334,7 +292,6 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res)
for (auto &node : graph_nodes_sorted) {
AnfNodePtr new_node = Run(func_graph, node, res);
if (new_node != nullptr && new_node != node) {
internal::DrawNode(dst_pattern_->unique_name(), new_node);
(void)manager->Replace(node, new_node);
changes = true;
}

View File

@ -98,7 +98,8 @@ REGISTER_PYBIND_DEFINE(
.def("registe", &PyPassManager::Registe, "Registe python pass")
.def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass")
.def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter")
.def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph");
.def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph")
.def("set_reopt", &PyPassManager::SetReOpt, "Set whether or not to do optimization after modified graph");
}));
} // namespace python_pass
} // namespace opt

View File

@ -60,13 +60,19 @@ class PyPassManager {
MatchResultPtr GetMatchResult() { return res_; }
void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; }
bool ShouldRenorm() { return should_renorm_; }
void SetReOpt(bool should_reopt) { should_reopt_ = should_reopt; }
bool ShouldReOpt() { return should_reopt_; }
void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; }
pipeline::ResourcePtr GetResource() { return resource_; }
void ClearRes();
void ClearPipelineRes() { resource_ = nullptr; }
void ClearPipelineRes() {
resource_ = nullptr;
Pattern::reset_gid();
}
private:
bool should_renorm_ = true;
bool should_reopt_ = true;
MatchResultPtr res_;
pipeline::ResourcePtr resource_;
static std::unordered_map<Phase, PassGroupPtr> phase_to_group_;

View File

@ -451,35 +451,55 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); }
void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
bool ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
MS_EXCEPTION_IF_NULL(res->manager());
MS_EXCEPTION_IF_NULL(res->func_graph());
auto ppm = opt::python_pass::PyPassManager::GetInstance();
ppm->SetResource(res);
if (!ppm->GetPassGroup(phase)->Run(res->func_graph())) {
MS_LOG(DEBUG) << "No match.\n";
} else if (phase == opt::python_pass::Phase::OPT && opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
MS_LOG(DEBUG) << "Entered PyStub Renorm";
// Renomalize
MS_EXCEPTION_IF_NULL(res->func_graph());
FuncGraphPtr func_graph = res->func_graph();
abstract::AbstractBasePtrList args_spec;
auto parameters = func_graph->parameters();
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
[](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
res->set_func_graph(new_fg);
res->set_args_spec(args_spec);
}
return ppm->GetPassGroup(phase)->Run(res->func_graph());
}
bool ResolveActionPyStub(const ResourcePtr &res) {
ActionPyStub(res, opt::python_pass::Phase::RESOLVE);
bool ResolveActionPyStub(const ResourcePtr &res) { return true || ActionPyStub(res, opt::python_pass::Phase::RESOLVE); }
bool OptActionVmPyStub(const ResourcePtr &res) {
if (ActionPyStub(res, opt::python_pass::Phase::OPT)) {
if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
// Renomalize
MS_EXCEPTION_IF_NULL(res->func_graph());
FuncGraphPtr func_graph = res->func_graph();
abstract::AbstractBasePtrList args_spec;
auto parameters = func_graph->parameters();
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
[](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
res->set_func_graph(new_fg);
res->set_args_spec(args_spec);
}
if (opt::python_pass::PyPassManager::GetInstance()->ShouldReOpt()) {
return VmOptimizeAction(res);
}
}
return true;
}
bool OptActionPyStub(const ResourcePtr &res) {
ActionPyStub(res, opt::python_pass::Phase::OPT);
bool OptActionGePyStub(const ResourcePtr &res) {
if (ActionPyStub(res, opt::python_pass::Phase::OPT)) {
if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
// Renomalize
MS_EXCEPTION_IF_NULL(res->func_graph());
FuncGraphPtr func_graph = res->func_graph();
abstract::AbstractBasePtrList args_spec;
auto parameters = func_graph->parameters();
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
[](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
res->set_func_graph(new_fg);
res->set_args_spec(args_spec);
}
if (opt::python_pass::PyPassManager::GetInstance()->ShouldReOpt()) {
return GeOptimizeAction(res);
}
}
return true;
}
@ -510,7 +530,7 @@ std::vector<ActionItem> GePipeline() {
// optimize
actions.emplace_back(std::make_pair("optimize", GeOptimizeAction));
// Add opt-stage python pass stub
actions.emplace_back(std::make_pair("py_opt", OptActionPyStub));
actions.emplace_back(std::make_pair("py_opt", OptActionGePyStub));
actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction));
actions.emplace_back(std::make_pair("validate", ValidateAction));
return actions;
@ -523,7 +543,7 @@ std::vector<ActionItem> VmPipeline() {
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
// Add opt-stage python pass stub
actions.emplace_back(std::make_pair("py_opt", OptActionPyStub));
actions.emplace_back(std::make_pair("py_opt", OptActionVmPyStub));
actions.emplace_back(std::make_pair("validate", ValidateAction));
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))

View File

@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Top-level reference to python pass."""
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm
"""Reference for python pass registration."""
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\
set_reopt
__all__ = [
"registe_pass",
"unregiste_pass",
"gen_new_parameter",
"cancel_new_parameter",
"set_renorm"
"set_renorm",
"set_reopt"
]

View File

@ -23,7 +23,8 @@ __all__ = [
"unregiste_pass",
"gen_new_parameter",
"cancel_new_parameter",
"set_renorm"
"set_renorm",
"set_reopt"
]
class PyPassManager(PyPassManager_):
r"""
@ -75,6 +76,11 @@ class PyPassManager(PyPassManager_):
raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}")
super().set_renorm(should_renorm)
def set_reopt(self, do_reopt):
if not isinstance(do_reopt, bool):
raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}")
super().set_reopt(do_reopt)
def registe_pass(run_only_once=False):
"""
Registe python pass to specified pipeline phase which would be used in compilation.
@ -164,3 +170,17 @@ def set_renorm(should_renorm):
"""
ppm = PyPassManager()
ppm.set_renorm(should_renorm)
def set_reopt(do_reopt):
"""
Set whether or not to do optimization after modified graph in python pass(es).
Args:
do_reopt(bool): whether or not to do optimization after modified graph in python pass(es).
NOTE:
This interface is mainly intended for testing modifying graph without worrying about its validity. Turn off
renormalization may BREAK the network.
"""
ppm = PyPassManager()
ppm.set_reopt(do_reopt)

View File

@ -20,7 +20,7 @@ from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
cancel_new_parameter
cancel_new_parameter, set_reopt
from mindspore.common.api import _generate_pip_args
from mindspore._c_expression import generate_key, Executor_
from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm
@ -50,8 +50,8 @@ def test_softmax_relu():
@registe_pass(run_only_once=True)
def softmax_relu_pass():
x = Any()
pattern = Call(P.Softmax(), inputs=[x])
target = Call(P.ReLU(), inputs=[x])
pattern = Call(P.Softmax(), [x])
target = Call(P.ReLU(), [x])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
@ -59,6 +59,23 @@ def test_softmax_relu():
assert "ReLU" in transformed_repr
assert "Softmax" not in transformed_repr
def test_prim():
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
@registe_pass(run_only_once=True)
def softmax_relu_pass():
x = Any()
sigmoid_softmax_pattern = Prim([P.Sigmoid(), P.Softmax()])
pattern = Call(sigmoid_softmax_pattern, [x])
target = Call(P.ReLU(), [x])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
unregiste_pass(softmax_relu_pass)
assert "ReLU" in transformed_repr
assert "Softmax" not in transformed_repr
def test_softmax_relu_sigmoid():
"""
Use python pass to transform from Softmax(x) to ReLU(Sigmoid(x)).
@ -73,11 +90,11 @@ def test_softmax_relu_sigmoid():
def softmax_relu_pass():
x = Any()
softmax_pattern = Prim(P.Softmax())
pattern = Call(softmax_pattern, inputs=[x])
pattern = Call(softmax_pattern, [x])
sigmoid_pattern = Prim(P.Sigmoid())
call_sigmoid = Call(sigmoid_pattern, [x])
relu_pattern = Prim(P.ReLU())
target = Call(relu_pattern, inputs=[call_sigmoid])
target = Call(relu_pattern, [call_sigmoid])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
@ -98,13 +115,13 @@ def test_isin_pattern_0():
def softmax_relu_pass():
x = Any()
softmax_pattern = Prim(P.Softmax())
call_softmax = Call(softmax_pattern, inputs=[x])
call_softmax = Call(softmax_pattern, [x])
relu_pattern = Prim(P.ReLU())
call_relu = Call(relu_pattern, inputs=[x])
call_relu = Call(relu_pattern, [x])
pattern = OneOf([call_softmax, call_relu])
relu6_pattern = Prim(P.ReLU6())
target = Call(relu6_pattern, inputs=[x])
target = Call(relu6_pattern, [x])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
unregiste_pass(softmax_relu_pass)
@ -122,13 +139,13 @@ def test_isin_pattern_1():
def softmax_neg_pass():
x = Any()
softmax_pattern = Prim(P.Softmax())
call_softmax = Call(softmax_pattern, inputs=[x])
call_softmax = Call(softmax_pattern, [x])
relu_pattern = Prim(P.ReLU())
call_relu = Call(relu_pattern, inputs=[x])
call_relu = Call(relu_pattern, [x])
pattern = OneOf([call_softmax, call_relu])
neg_ops = Prim(P.Neg())
target = Call(neg_ops, inputs=[pattern])
target = Call(neg_ops, [pattern])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
unregiste_pass(softmax_neg_pass)
@ -141,6 +158,7 @@ def test_isnot_pattern_0():
Case: IsNot pass failed to match
"""
set_renorm(False)
set_reopt(False)
class ConvBN(nn.Cell):
def __init__(self):
super(ConvBN, self).__init__()
@ -166,8 +184,8 @@ def test_isnot_pattern_0():
conv2d_prim = Prim("Conv2D")
conv2d = Call(conv2d_prim)
pattern_0 = NoneOf(conv2d)
pattern = Call(P.BatchNorm(), inputs=[pattern_0])
target = Call(P.ReLU6(), inputs=[pattern_0])
pattern = Call(P.BatchNorm(), [pattern_0])
target = Call(P.ReLU6(), [pattern_0])
return pattern, target
@registe_pass(run_only_once=True)
@ -202,9 +220,9 @@ def test_isnot_pattern_1():
matmul = Prim("MatMul")
pattern_0 = NoneOf(matmul)
softmax = P.Softmax()
pattern = Call(softmax, inputs=[pattern_0])
pattern = Call(softmax, [pattern_0])
relu6 = P.ReLU6()
target = Call(relu6, inputs=[pattern_0])
target = Call(relu6, [pattern_0])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
@ -217,17 +235,18 @@ def test_newtensor_pattern():
Test NewTensor pattern in the target
"""
set_renorm(False)
set_reopt(False)
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
@registe_pass(run_only_once=True)
def softmax_addn_pass():
x = Any()
pattern = Call(P.Softmax(), inputs=[x])
pattern = Call(P.Softmax(), [x])
weight_tensor = Tensor(np.zeros([42]), mindspore.float16)
new_weight = NewTensor(weight_tensor)
target = Call(P.AddN(), inputs=[x, new_weight])
target = Call(P.AddN(), [x, new_weight])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
unregiste_pass(softmax_addn_pass)
@ -242,17 +261,19 @@ def test_newparameter_pattern():
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
set_renorm(False)
set_reopt(False)
@registe_pass(run_only_once=True)
def softmax_addn_pass():
x = Any()
pattern = Call(P.Softmax(), inputs=[x])
pattern = Call(P.Softmax(), [x])
default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32)
default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32)
new_para_0 = NewParameter("Merlin", default_tensor0)
new_para_1 = NewParameter("Arthur", default_tensor1)
target_0 = Call(P.MatMul(), inputs=[new_para_0, new_para_1])
target = Call("make_tuple", inputs=[target_0])
target_0 = Call(P.MatMul(), [new_para_0, new_para_1])
target = Call("make_tuple", [target_0])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
unregiste_pass(softmax_addn_pass)
@ -267,13 +288,15 @@ def test_imm_target():
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
set_renorm(False)
set_reopt(False)
@registe_pass(run_only_once=True)
def softmax_pass():
x = Any()
pattern = Call(P.Softmax(), inputs=[x])
pattern = Call(P.Softmax(), [x])
imm = Imm(0)
target_0 = Call("make_tuple", inputs=[pattern])
target = Call("tuple_getitem", inputs=[target_0, imm])
target_0 = Call("make_tuple", [pattern])
target = Call("tuple_getitem", [target_0, imm])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
unregiste_pass(softmax_pass)
@ -290,14 +313,16 @@ def test_gen_new_parameter():
default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
new_para = NewParameter("Merlin", default_tensor)
set_renorm(False)
set_reopt(False)
gen_new_parameter(new_para)
@registe_pass(run_only_once=True)
def softmax_make_tuple_pass():
x = Any()
softmax = P.Softmax()
pattern = Call(softmax, inputs=[x])
pattern = Call(softmax, [x])
target = Call("make_tuple", inputs=[pattern, new_para])
target = Call("make_tuple", [pattern, new_para])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
assert "Merlin" in transformed_repr