add jit config param for compile and fix summary graph

This commit is contained in:
jiangshuqiang 2021-11-26 01:45:22 -08:00
parent d1e4e674ab
commit c9a0edd24e
9 changed files with 131 additions and 24 deletions

View File

@ -1117,7 +1117,7 @@ static std::vector<ActionItem> CommonPipeline() {
(void)actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction));
auto multi_graphs = parallel::CostModelContext::GetInstance()->is_multi_subgraphs();
if (!multi_graphs) {
if (!multi_graphs && pipeline::GraphExecutorPy::GetInstance()->jit_config()["jit_level"] != "o0") {
(void)actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
}

View File

@ -112,7 +112,10 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_compile_cache_dep_files", &GraphExecutorPy::set_compile_cache_dep_files,
py::arg("compile_cache_dep_files") = py::list(), "Set the compilation cache dependent files.")
.def("set_weights_values", &GraphExecutorPy::set_weights_values, py::arg("weights") = py::dict(),
"Set values of weights.");
"Set values of weights.")
.def("get_optimize_graph_proto", &GraphExecutorPy::GetOptimizeGraphProto, py::arg("phase") = py::str(""),
"Get the optimize graph proto string.")
.def("set_jit_config", &GraphExecutorPy::SetJitConfig, py::arg("jit_config") = py::dict(), "Set the jit config.");
(void)py::class_<EnvInstance, std::shared_ptr<EnvInstance>>(m, "EnvInstance_").def(py::init());

View File

@ -292,6 +292,16 @@ std::map<string, ValuePtr> GenerateWeightsValueMap(const py::dict &weights) {
return ret;
}
std::map<string, string> GenerateJitConfigMap(const py::dict &jit_config) {
std::map<string, string> ret{};
for (auto jit_param = jit_config.begin(); jit_param != jit_config.end(); ++jit_param) {
auto param_name = py::cast<std::string>(jit_param->first);
auto param_value = py::cast<std::string>(jit_param->second);
ret[param_name] = param_value;
}
return ret;
}
FuncGraphPtr LoadFuncGraphFromMindIR(size_t idx, const py::dict &weights) {
std::string compile_cache_path = GetCompileCachePath(idx);
auto realpath = Common::CreatePrefixPath(compile_cache_path, true);
@ -614,6 +624,24 @@ py::bytes GraphExecutorPy::GetFuncGraphProto(const std::string &phase, const std
MS_LOG(EXCEPTION) << "Unknown ir type: " << ir_type;
}
py::bytes GraphExecutorPy::GetOptimizeGraphProto(const std::string &phase) {
if (info_.count(phase) == 0) {
MS_LOG(EXCEPTION) << "No phase in executor: " << phase;
}
FuncGraphPtr fg_ptr = info_[phase]->resource->optimize_graph();
if (fg_ptr == nullptr) {
MS_LOG(WARNING) << "Can not find optimize graph.";
return "";
}
std::string proto_str = GetFuncGraphProtoString(fg_ptr);
if (proto_str.empty()) {
MS_LOG(EXCEPTION) << "Export optimize graph proto string failed.";
}
return proto_str;
}
void GraphExecutorPy::SetJitConfig(const py::dict &jit_config) { jit_config_ = GenerateJitConfigMap(jit_config); }
py::dict GraphExecutorPy::GetParameterLayout(const std::string &phase) {
MS_LOG(DEBUG) << "GetParameterLayout!";
std::string layout_graph = phase + kStepParallelGraph;
@ -1101,6 +1129,41 @@ void RDRRecordGraph(const size_t action_index, const size_t action_size, const s
}
#endif
#ifdef ENABLE_DUMP_IR
void RecordIR(const size_t action_index, const size_t action_size, const std::string &action_name,
const FuncGraphPtr graph, const std::string &phase, FuncGraphPtr *user_graph) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && graph != nullptr) {
*user_graph = graph;
std::string base_name = GetBaseNameForIR(SizeToLong(action_index), action_name);
// generate IR file in dot format, which can be converted to svg file using graphviz dot command
draw::Draw(base_name + ".dot", graph);
// generate IR file in human readable format
if (action_index == action_size - 1) {
DumpIR(base_name + ".ir", graph, false, kWholeStack);
} else {
DumpIR(base_name + ".ir", graph, false, kTopStack);
}
// generate IR file in a heavily commented format, which can also be reloaded
ExportIR(base_name + ".dat", graph);
}
}
#endif
#ifndef ENABLE_SECURITY
void SaveGraphForReadability(const std::string &action_name, const FuncGraphPtr graph, const std::string &phase,
const ResourcePtr resource) {
if (graph != nullptr && action_name.find("optimize") != string::npos) {
#ifdef ENABLE_DUMP_IR
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
DumpIRProto(graph, action_name);
}
#endif
resource->set_optimize_graph(graph);
}
}
#endif
void Pipeline::Run(const std::string &phase) {
MS_LOG(INFO) << "Pipeline run";
MS_EXCEPTION_IF_NULL(resource_);
@ -1139,22 +1202,10 @@ void Pipeline::Run(const std::string &phase) {
#ifdef ENABLE_DUMP_IR
std::string filename = GetBaseNameForIR(SizeToLong(i), action.first);
RDRRecordGraph(i, actions_.size(), filename, graph);
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && graph != nullptr) {
user_graph = graph;
std::string base_name = GetBaseNameForIR(SizeToLong(i), action.first);
// generate IR file in dot format, which can be converted to svg file using graphviz dot command
draw::Draw(base_name + ".dot", graph);
// generate IR file in human readable format
if (i == actions_.size() - 1) {
DumpIR(base_name + ".ir", graph, false, kWholeStack);
} else {
DumpIR(base_name + ".ir", graph, false, kTopStack);
}
// generate IR file in a heavily commented format, which can also be reloaded
ExportIR(base_name + ".dat", graph);
}
RecordIR(i, actions_.size(), action.first, graph, phase, &user_graph);
#endif
#ifndef ENABLE_SECURITY
SaveGraphForReadability(action.first, graph, phase, resource_);
#endif
i++;
#ifdef ENABLE_TIMELINE

View File

@ -72,6 +72,7 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
~GraphExecutorPy();
const std::string &phase() const { return phase_; }
std::map<std::string, std::string> &jit_config() { return jit_config_; }
void SaveCompiledGraph(const std::string &phase);
bool CompileInner(const py::object &source_obj, const py::tuple &args, const py::object &phase_obj, bool use_vm);
bool Compile(const py::object &source_obj, const py::tuple &args, const py::object &phase_obj, bool use_vm);
@ -85,6 +86,10 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
FuncGraphPtr GetGradGraph(const std::string &phase);
void SetGradGraph(const FuncGraphPtr &grad_graph, const std::string &phase);
py::bytes GetFuncGraphProto(const std::string &phase, const std::string &type);
#ifndef ENABLE_SECURITY
py::bytes GetOptimizeGraphProto(const std::string &phase);
#endif
void SetJitConfig(const py::dict &jit_config);
compile::VmEvalFuncPtr GetVmEvalFunc(const std::string &phase);
bool HasCompiled(const std::string &phase) const;
@ -142,6 +147,7 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
#endif
std::map<std::string, py::dict> stra_dict_;
std::string phase_ = "";
std::map<std::string, std::string> jit_config_;
std::map<std::string, size_t> phase_to_num_op_info_;
std::string queue_name_;
bool enable_tuple_broaden_{false};

View File

@ -72,6 +72,9 @@ class Resource : public ResourceBase {
FuncGraphPtr func_graph() const { return func_graph_; }
void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; }
FuncGraphPtr optimize_graph() const { return optimize_graph_; }
void set_optimize_graph(const FuncGraphPtr &optimize_graph) { optimize_graph_ = optimize_graph; }
const abstract::AbstractBasePtrList &args_spec() const { return args_spec_; }
void set_args_spec(const abstract::AbstractBasePtrList &args_spec) { args_spec_ = args_spec; }
@ -103,6 +106,7 @@ class Resource : public ResourceBase {
private:
abstract::AnalysisEnginePtr engine_;
FuncGraphPtr func_graph_;
FuncGraphPtr optimize_graph_;
abstract::AbstractBasePtrList args_spec_;
// The source obj to compile, usually a `Cell` or `ms_function` decorated function.
py::object source_input_;

View File

@ -35,6 +35,7 @@ from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dat
from ..parallel._ps_context import _is_role_pserver
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor, \
_get_parameter_broadcast, _get_pipeline_stages
from .._checkparam import Validator
# store ms_function class compiled pipeline cache
ms_compile_cache = {}
@ -547,6 +548,11 @@ class _CellGraphExecutor:
Graph, return the result of pipeline running.
"""
VALID_JIT_CONFIG_PARAM = ["jit_level"]
VALID_JIT_CONFIG_PARAM_VALUE = {
"jit_level": ["o0", "o1"]
}
def __init__(self):
# create needed graph by lazy mode
self.is_init = False
@ -774,6 +780,17 @@ class _CellGraphExecutor:
return None
return self._graph_executor.get_func_graph_proto(exec_id, ir_type)
def get_optimize_graph_proto(self, obj):
"""Return optimize graph binary proto."""
exec_id = obj.phase + "." + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
if self._graph_executor.has_compiled(exec_id) is False:
return None
graph_proto = self._graph_executor.get_optimize_graph_proto(exec_id)
if isinstance(graph_proto, str) and graph_proto == "":
logger.warning("Can not get optimize graph proto. Instead, try to find function graph.")
graph_proto = obj.get_func_graph_proto()
return graph_proto
def export(self, file_name, graph_id):
"""
Export graph.
@ -791,6 +808,20 @@ class _CellGraphExecutor:
return None
return self._graph_executor.fetch_info_for_quant_export(exec_id)
def set_jit_config(self, jit_config):
"""Set jit config."""
self._check_jit_config(jit_config)
self._graph_executor.set_jit_config(jit_config)
def _check_jit_config(self, jit_config):
"""Check the value of jit config."""
if not isinstance(jit_config, dict):
raise ValueError("The jit_config should be a string.")
for param_name, param_value in jit_config.items():
Validator.check_string(param_name, self.VALID_JIT_CONFIG_PARAM, "jit_config")
Validator.check_string(param_value, self.VALID_JIT_CONFIG_PARAM_VALUE.get(param_name), param_name,
"jit_config")
_cell_graph_executor = _CellGraphExecutor()
_pynative_executor = _PynativeExecutor()

View File

@ -38,6 +38,7 @@ from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.loss.loss import LossBase
from mindspore.train._utils import check_value_type, _make_directory
from ..._c_expression import security
from ...common.api import _cell_graph_executor
HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG"
HYPER_CONFIG_LEN_LIMIT = 100000
@ -685,12 +686,11 @@ class SummaryCollector(Callback):
return
network = cb_params.train_network if cb_params.mode == ModeEnum.TRAIN.value else cb_params.eval_network
graph_proto = network.get_func_graph_proto()
graph_proto = _cell_graph_executor.get_optimize_graph_proto(network)
if graph_proto is None:
logger.warning("Can not get graph proto, it may not be 'GRAPH_MODE' in context currently, "
"so SummaryCollector will not collect graph.")
return
self._record.add_value(PluginEnum.GRAPH.value, 'train_network/auto', graph_proto)
def _collect_metric(self, cb_params):

View File

@ -37,7 +37,7 @@ from ..context import ParallelMode
from ..parallel._cost_model_context import _set_multi_subgraphs
from .dataset_helper import DatasetHelper, connect_network_with_dataset
from . import amp
from ..common.api import _pynative_executor
from ..common.api import _pynative_executor, _cell_graph_executor
def _transfer_tensor_to_tuple(inputs):
@ -732,7 +732,7 @@ class Model:
dataset_sink_mode=dataset_sink_mode,
sink_size=sink_size)
def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1):
def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1, jit_config=None):
"""
Build computational graphs and data graphs with the sink mode.
@ -752,6 +752,15 @@ class Model:
will be initialized, and `metrics` in `Model` can not be None. Default: None.
sink_size (int): Control the amount of data in each sink. Default: -1.
epoch (int): Control the training epochs. Default: 1.
jit_config (Union[str, str]): Control the jit config.
By default, if set to None, the graph will compile as the default behavior.
You can customize the compile config with a dictionary.
For example, you can set {'jit_level': 'o0'} to control the jit level.
The data that supports control is shown below. Default: None.
- jit_level (string): Control the graph compile optimize level.
Optional: o0/o1. Default: o1. If set to o0, the graph compiling will pass
the combine like graph phase.
Examples:
>>> from mindspore import Model, nn, FixedLossScaleManager
@ -767,6 +776,8 @@ class Model:
>>> model.build(dataset, epoch=2)
>>> model.train(2, dataset)
"""
if jit_config is not None:
_cell_graph_executor.set_jit_config(jit_config)
self._init(train_dataset, valid_dataset, sink_size, epoch)
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):

View File

@ -25,6 +25,7 @@ from mindspore.nn import Cell
from ..._c_expression import Tensor, security
from ..._checkparam import Validator
from ...common.api import _cell_graph_executor
from .._utils import _check_lineage_value, _check_to_numpy, _make_directory, check_value_type
from ._summary_adapter import get_event_file_name, package_graph_event
from ._writer_pool import WriterPool
@ -324,9 +325,9 @@ class SummaryRecord:
return False
# Set the current summary of train step
if self.network is not None and not self._status.get('has_graph'):
graph_proto = self.network.get_func_graph_proto()
graph_proto = _cell_graph_executor.get_optimize_graph_proto(self.network)
if graph_proto is None and train_network is not None:
graph_proto = train_network.get_func_graph_proto()
graph_proto = _cell_graph_executor.get_optimize_graph_proto(train_network)
if graph_proto is None:
logger.error("Failed to get proto for graph")
else: