forked from mindspore-Ecosystem/mindspore
add jit config param for compile and fix summary graph
This commit is contained in:
parent
d1e4e674ab
commit
c9a0edd24e
|
@ -1117,7 +1117,7 @@ static std::vector<ActionItem> CommonPipeline() {
|
|||
(void)actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction));
|
||||
|
||||
auto multi_graphs = parallel::CostModelContext::GetInstance()->is_multi_subgraphs();
|
||||
if (!multi_graphs) {
|
||||
if (!multi_graphs && pipeline::GraphExecutorPy::GetInstance()->jit_config()["jit_level"] != "o0") {
|
||||
(void)actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
|
||||
}
|
||||
|
||||
|
|
|
@ -112,7 +112,10 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("set_compile_cache_dep_files", &GraphExecutorPy::set_compile_cache_dep_files,
|
||||
py::arg("compile_cache_dep_files") = py::list(), "Set the compilation cache dependent files.")
|
||||
.def("set_weights_values", &GraphExecutorPy::set_weights_values, py::arg("weights") = py::dict(),
|
||||
"Set values of weights.");
|
||||
"Set values of weights.")
|
||||
.def("get_optimize_graph_proto", &GraphExecutorPy::GetOptimizeGraphProto, py::arg("phase") = py::str(""),
|
||||
"Get the optimize graph proto string.")
|
||||
.def("set_jit_config", &GraphExecutorPy::SetJitConfig, py::arg("jit_config") = py::dict(), "Set the jit config.");
|
||||
|
||||
(void)py::class_<EnvInstance, std::shared_ptr<EnvInstance>>(m, "EnvInstance_").def(py::init());
|
||||
|
||||
|
|
|
@ -292,6 +292,16 @@ std::map<string, ValuePtr> GenerateWeightsValueMap(const py::dict &weights) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
std::map<string, string> GenerateJitConfigMap(const py::dict &jit_config) {
|
||||
std::map<string, string> ret{};
|
||||
for (auto jit_param = jit_config.begin(); jit_param != jit_config.end(); ++jit_param) {
|
||||
auto param_name = py::cast<std::string>(jit_param->first);
|
||||
auto param_value = py::cast<std::string>(jit_param->second);
|
||||
ret[param_name] = param_value;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
FuncGraphPtr LoadFuncGraphFromMindIR(size_t idx, const py::dict &weights) {
|
||||
std::string compile_cache_path = GetCompileCachePath(idx);
|
||||
auto realpath = Common::CreatePrefixPath(compile_cache_path, true);
|
||||
|
@ -614,6 +624,24 @@ py::bytes GraphExecutorPy::GetFuncGraphProto(const std::string &phase, const std
|
|||
MS_LOG(EXCEPTION) << "Unknown ir type: " << ir_type;
|
||||
}
|
||||
|
||||
py::bytes GraphExecutorPy::GetOptimizeGraphProto(const std::string &phase) {
|
||||
if (info_.count(phase) == 0) {
|
||||
MS_LOG(EXCEPTION) << "No phase in executor: " << phase;
|
||||
}
|
||||
FuncGraphPtr fg_ptr = info_[phase]->resource->optimize_graph();
|
||||
if (fg_ptr == nullptr) {
|
||||
MS_LOG(WARNING) << "Can not find optimize graph.";
|
||||
return "";
|
||||
}
|
||||
std::string proto_str = GetFuncGraphProtoString(fg_ptr);
|
||||
if (proto_str.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Export optimize graph proto string failed.";
|
||||
}
|
||||
return proto_str;
|
||||
}
|
||||
|
||||
void GraphExecutorPy::SetJitConfig(const py::dict &jit_config) { jit_config_ = GenerateJitConfigMap(jit_config); }
|
||||
|
||||
py::dict GraphExecutorPy::GetParameterLayout(const std::string &phase) {
|
||||
MS_LOG(DEBUG) << "GetParameterLayout!";
|
||||
std::string layout_graph = phase + kStepParallelGraph;
|
||||
|
@ -1101,6 +1129,41 @@ void RDRRecordGraph(const size_t action_index, const size_t action_size, const s
|
|||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
void RecordIR(const size_t action_index, const size_t action_size, const std::string &action_name,
|
||||
const FuncGraphPtr graph, const std::string &phase, FuncGraphPtr *user_graph) {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && graph != nullptr) {
|
||||
*user_graph = graph;
|
||||
std::string base_name = GetBaseNameForIR(SizeToLong(action_index), action_name);
|
||||
|
||||
// generate IR file in dot format, which can be converted to svg file using graphviz dot command
|
||||
draw::Draw(base_name + ".dot", graph);
|
||||
// generate IR file in human readable format
|
||||
if (action_index == action_size - 1) {
|
||||
DumpIR(base_name + ".ir", graph, false, kWholeStack);
|
||||
} else {
|
||||
DumpIR(base_name + ".ir", graph, false, kTopStack);
|
||||
}
|
||||
// generate IR file in a heavily commented format, which can also be reloaded
|
||||
ExportIR(base_name + ".dat", graph);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
void SaveGraphForReadability(const std::string &action_name, const FuncGraphPtr graph, const std::string &phase,
|
||||
const ResourcePtr resource) {
|
||||
if (graph != nullptr && action_name.find("optimize") != string::npos) {
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
DumpIRProto(graph, action_name);
|
||||
}
|
||||
#endif
|
||||
resource->set_optimize_graph(graph);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void Pipeline::Run(const std::string &phase) {
|
||||
MS_LOG(INFO) << "Pipeline run";
|
||||
MS_EXCEPTION_IF_NULL(resource_);
|
||||
|
@ -1139,22 +1202,10 @@ void Pipeline::Run(const std::string &phase) {
|
|||
#ifdef ENABLE_DUMP_IR
|
||||
std::string filename = GetBaseNameForIR(SizeToLong(i), action.first);
|
||||
RDRRecordGraph(i, actions_.size(), filename, graph);
|
||||
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && graph != nullptr) {
|
||||
user_graph = graph;
|
||||
std::string base_name = GetBaseNameForIR(SizeToLong(i), action.first);
|
||||
|
||||
// generate IR file in dot format, which can be converted to svg file using graphviz dot command
|
||||
draw::Draw(base_name + ".dot", graph);
|
||||
// generate IR file in human readable format
|
||||
if (i == actions_.size() - 1) {
|
||||
DumpIR(base_name + ".ir", graph, false, kWholeStack);
|
||||
} else {
|
||||
DumpIR(base_name + ".ir", graph, false, kTopStack);
|
||||
}
|
||||
// generate IR file in a heavily commented format, which can also be reloaded
|
||||
ExportIR(base_name + ".dat", graph);
|
||||
}
|
||||
RecordIR(i, actions_.size(), action.first, graph, phase, &user_graph);
|
||||
#endif
|
||||
#ifndef ENABLE_SECURITY
|
||||
SaveGraphForReadability(action.first, graph, phase, resource_);
|
||||
#endif
|
||||
i++;
|
||||
#ifdef ENABLE_TIMELINE
|
||||
|
|
|
@ -72,6 +72,7 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
|
|||
~GraphExecutorPy();
|
||||
|
||||
const std::string &phase() const { return phase_; }
|
||||
std::map<std::string, std::string> &jit_config() { return jit_config_; }
|
||||
void SaveCompiledGraph(const std::string &phase);
|
||||
bool CompileInner(const py::object &source_obj, const py::tuple &args, const py::object &phase_obj, bool use_vm);
|
||||
bool Compile(const py::object &source_obj, const py::tuple &args, const py::object &phase_obj, bool use_vm);
|
||||
|
@ -85,6 +86,10 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
|
|||
FuncGraphPtr GetGradGraph(const std::string &phase);
|
||||
void SetGradGraph(const FuncGraphPtr &grad_graph, const std::string &phase);
|
||||
py::bytes GetFuncGraphProto(const std::string &phase, const std::string &type);
|
||||
#ifndef ENABLE_SECURITY
|
||||
py::bytes GetOptimizeGraphProto(const std::string &phase);
|
||||
#endif
|
||||
void SetJitConfig(const py::dict &jit_config);
|
||||
compile::VmEvalFuncPtr GetVmEvalFunc(const std::string &phase);
|
||||
bool HasCompiled(const std::string &phase) const;
|
||||
|
||||
|
@ -142,6 +147,7 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
|
|||
#endif
|
||||
std::map<std::string, py::dict> stra_dict_;
|
||||
std::string phase_ = "";
|
||||
std::map<std::string, std::string> jit_config_;
|
||||
std::map<std::string, size_t> phase_to_num_op_info_;
|
||||
std::string queue_name_;
|
||||
bool enable_tuple_broaden_{false};
|
||||
|
|
|
@ -72,6 +72,9 @@ class Resource : public ResourceBase {
|
|||
FuncGraphPtr func_graph() const { return func_graph_; }
|
||||
void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; }
|
||||
|
||||
FuncGraphPtr optimize_graph() const { return optimize_graph_; }
|
||||
void set_optimize_graph(const FuncGraphPtr &optimize_graph) { optimize_graph_ = optimize_graph; }
|
||||
|
||||
const abstract::AbstractBasePtrList &args_spec() const { return args_spec_; }
|
||||
void set_args_spec(const abstract::AbstractBasePtrList &args_spec) { args_spec_ = args_spec; }
|
||||
|
||||
|
@ -103,6 +106,7 @@ class Resource : public ResourceBase {
|
|||
private:
|
||||
abstract::AnalysisEnginePtr engine_;
|
||||
FuncGraphPtr func_graph_;
|
||||
FuncGraphPtr optimize_graph_;
|
||||
abstract::AbstractBasePtrList args_spec_;
|
||||
// The source obj to compile, usually a `Cell` or `ms_function` decorated function.
|
||||
py::object source_input_;
|
||||
|
|
|
@ -35,6 +35,7 @@ from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dat
|
|||
from ..parallel._ps_context import _is_role_pserver
|
||||
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor, \
|
||||
_get_parameter_broadcast, _get_pipeline_stages
|
||||
from .._checkparam import Validator
|
||||
|
||||
# store ms_function class compiled pipeline cache
|
||||
ms_compile_cache = {}
|
||||
|
@ -547,6 +548,11 @@ class _CellGraphExecutor:
|
|||
Graph, return the result of pipeline running.
|
||||
"""
|
||||
|
||||
VALID_JIT_CONFIG_PARAM = ["jit_level"]
|
||||
VALID_JIT_CONFIG_PARAM_VALUE = {
|
||||
"jit_level": ["o0", "o1"]
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
# create needed graph by lazy mode
|
||||
self.is_init = False
|
||||
|
@ -774,6 +780,17 @@ class _CellGraphExecutor:
|
|||
return None
|
||||
return self._graph_executor.get_func_graph_proto(exec_id, ir_type)
|
||||
|
||||
def get_optimize_graph_proto(self, obj):
|
||||
"""Return optimize graph binary proto."""
|
||||
exec_id = obj.phase + "." + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
|
||||
if self._graph_executor.has_compiled(exec_id) is False:
|
||||
return None
|
||||
graph_proto = self._graph_executor.get_optimize_graph_proto(exec_id)
|
||||
if isinstance(graph_proto, str) and graph_proto == "":
|
||||
logger.warning("Can not get optimize graph proto. Instead, try to find function graph.")
|
||||
graph_proto = obj.get_func_graph_proto()
|
||||
return graph_proto
|
||||
|
||||
def export(self, file_name, graph_id):
|
||||
"""
|
||||
Export graph.
|
||||
|
@ -791,6 +808,20 @@ class _CellGraphExecutor:
|
|||
return None
|
||||
return self._graph_executor.fetch_info_for_quant_export(exec_id)
|
||||
|
||||
def set_jit_config(self, jit_config):
|
||||
"""Set jit config."""
|
||||
self._check_jit_config(jit_config)
|
||||
self._graph_executor.set_jit_config(jit_config)
|
||||
|
||||
def _check_jit_config(self, jit_config):
|
||||
"""Check the value of jit config."""
|
||||
if not isinstance(jit_config, dict):
|
||||
raise ValueError("The jit_config should be a string.")
|
||||
for param_name, param_value in jit_config.items():
|
||||
Validator.check_string(param_name, self.VALID_JIT_CONFIG_PARAM, "jit_config")
|
||||
Validator.check_string(param_value, self.VALID_JIT_CONFIG_PARAM_VALUE.get(param_name), param_name,
|
||||
"jit_config")
|
||||
|
||||
|
||||
_cell_graph_executor = _CellGraphExecutor()
|
||||
_pynative_executor = _PynativeExecutor()
|
||||
|
|
|
@ -38,6 +38,7 @@ from mindspore.nn.optim.optimizer import Optimizer
|
|||
from mindspore.nn.loss.loss import LossBase
|
||||
from mindspore.train._utils import check_value_type, _make_directory
|
||||
from ..._c_expression import security
|
||||
from ...common.api import _cell_graph_executor
|
||||
|
||||
HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG"
|
||||
HYPER_CONFIG_LEN_LIMIT = 100000
|
||||
|
@ -685,12 +686,11 @@ class SummaryCollector(Callback):
|
|||
return
|
||||
|
||||
network = cb_params.train_network if cb_params.mode == ModeEnum.TRAIN.value else cb_params.eval_network
|
||||
graph_proto = network.get_func_graph_proto()
|
||||
graph_proto = _cell_graph_executor.get_optimize_graph_proto(network)
|
||||
if graph_proto is None:
|
||||
logger.warning("Can not get graph proto, it may not be 'GRAPH_MODE' in context currently, "
|
||||
"so SummaryCollector will not collect graph.")
|
||||
return
|
||||
|
||||
self._record.add_value(PluginEnum.GRAPH.value, 'train_network/auto', graph_proto)
|
||||
|
||||
def _collect_metric(self, cb_params):
|
||||
|
|
|
@ -37,7 +37,7 @@ from ..context import ParallelMode
|
|||
from ..parallel._cost_model_context import _set_multi_subgraphs
|
||||
from .dataset_helper import DatasetHelper, connect_network_with_dataset
|
||||
from . import amp
|
||||
from ..common.api import _pynative_executor
|
||||
from ..common.api import _pynative_executor, _cell_graph_executor
|
||||
|
||||
|
||||
def _transfer_tensor_to_tuple(inputs):
|
||||
|
@ -732,7 +732,7 @@ class Model:
|
|||
dataset_sink_mode=dataset_sink_mode,
|
||||
sink_size=sink_size)
|
||||
|
||||
def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1):
|
||||
def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1, jit_config=None):
|
||||
"""
|
||||
Build computational graphs and data graphs with the sink mode.
|
||||
|
||||
|
@ -752,6 +752,15 @@ class Model:
|
|||
will be initialized, and `metrics` in `Model` can not be None. Default: None.
|
||||
sink_size (int): Control the amount of data in each sink. Default: -1.
|
||||
epoch (int): Control the training epochs. Default: 1.
|
||||
jit_config (Union[str, str]): Control the jit config.
|
||||
By default, if set to None, the graph will compile as the default behavior.
|
||||
You can customize the compile config with a dictionary.
|
||||
For example, you can set {'jit_level': 'o0'} to control the jit level.
|
||||
The data that supports control is shown below. Default: None.
|
||||
|
||||
- jit_level (string): Control the graph compile optimize level.
|
||||
Optional: o0/o1. Default: o1. If set to o0, the graph compiling will pass
|
||||
the combine like graph phase.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Model, nn, FixedLossScaleManager
|
||||
|
@ -767,6 +776,8 @@ class Model:
|
|||
>>> model.build(dataset, epoch=2)
|
||||
>>> model.train(2, dataset)
|
||||
"""
|
||||
if jit_config is not None:
|
||||
_cell_graph_executor.set_jit_config(jit_config)
|
||||
self._init(train_dataset, valid_dataset, sink_size, epoch)
|
||||
|
||||
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
|
||||
|
|
|
@ -25,6 +25,7 @@ from mindspore.nn import Cell
|
|||
|
||||
from ..._c_expression import Tensor, security
|
||||
from ..._checkparam import Validator
|
||||
from ...common.api import _cell_graph_executor
|
||||
from .._utils import _check_lineage_value, _check_to_numpy, _make_directory, check_value_type
|
||||
from ._summary_adapter import get_event_file_name, package_graph_event
|
||||
from ._writer_pool import WriterPool
|
||||
|
@ -324,9 +325,9 @@ class SummaryRecord:
|
|||
return False
|
||||
# Set the current summary of train step
|
||||
if self.network is not None and not self._status.get('has_graph'):
|
||||
graph_proto = self.network.get_func_graph_proto()
|
||||
graph_proto = _cell_graph_executor.get_optimize_graph_proto(self.network)
|
||||
if graph_proto is None and train_network is not None:
|
||||
graph_proto = train_network.get_func_graph_proto()
|
||||
graph_proto = _cell_graph_executor.get_optimize_graph_proto(train_network)
|
||||
if graph_proto is None:
|
||||
logger.error("Failed to get proto for graph")
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue