add config info for python api

This commit is contained in:
yefeng 2022-06-22 09:23:59 +08:00
parent 6cb2f6f8c1
commit a336e2b6b6
3 changed files with 34 additions and 11 deletions

View File

@ -44,7 +44,7 @@ class RunnerConfig {
/// \brief Set the config before runtime. Only valid for ModelParallelRunner.
///
/// \param[in] config store environment variables before runtime.
void SetConfigInfo(const std::string &key, const std::map<std::string, std::string> &config);
void SetConfigInfo(const std::string &section, const std::map<std::string, std::string> &config);
/// \brief Get the current config setting. Only valid for ModelParallelRunner.
///

View File

@ -354,6 +354,7 @@ class RunnerConfig:
Args:
context (Context): Define the context used to store options during execution.
workers_num (int): the num of workers.
config_info (dict): {key:{key:value}}, Nested map for passing model weight paths.
Raises:
TypeError: type of input parameters are invalid.
@ -363,12 +364,15 @@ class RunnerConfig:
>>> import mindspore_lite as mslite
>>> context = mslite.Context()
>>> context.append_device_info(mslite.CPUDeviceInfo())
>>> runner_config = mslite.RunnerConfig(context=context, workers_num=4)
>>> config_info = {"weight": {"weight_path": "path of model weight"}}
>>> runner_config = mslite.RunnerConfig(context=context, workers_num=0, config_info=config_info)
>>> print(runner_config)
workers num: 4, context: 0, .
workers num: 4,
context: 0,
config info: weight: weight_path: path of model weight
"""
def __init__(self, context=None, workers_num=None):
def __init__(self, context=None, workers_num=None, config_info=None):
if context is not None:
check_isinstance("context", context, Context)
if workers_num is not None:
@ -380,9 +384,13 @@ class RunnerConfig:
self._runner_config.set_context(context._context)
if workers_num is not None:
self._runner_config.set_workers_num(workers_num)
if config_info is not None:
for k, v in config_info.items():
self._runner_config.set_config_info(k, v)
def __str__(self):
res = f"workers num: {self._runner_config.get_workers_num()}, " \
res = f"workers num: {self._runner_config.get_workers_num()}, \n" \
f"config info: {self._runner_config.get_config_info_string()}, \n" \
f"context: {self._runner_config.get_context_info()}."
return res

View File

@ -81,16 +81,31 @@ void ModelPyBind(const py::module &m) {
#ifdef PARALLEL_INFERENCE
py::class_<RunnerConfig, std::shared_ptr<RunnerConfig>>(m, "RunnerConfigBind")
.def(py::init<>())
.def("set_config_info", &RunnerConfig::SetConfigInfo)
.def("get_config_info", &RunnerConfig::GetConfigInfo)
.def("set_workers_num", &RunnerConfig::SetWorkersNum)
.def("get_workers_num", &RunnerConfig::GetWorkersNum)
.def("set_context", &RunnerConfig::SetContext)
.def("get_context", &RunnerConfig::GetContext)
.def("get_context_info", [](RunnerConfig &runner_config) {
std::string result = "thread num: ";
const auto &context = runner_config.GetContext();
result += std::to_string(context->GetThreadNum());
result += ", bind mode: ";
result += std::to_string(context->GetThreadAffinityMode());
.def("get_context_info",
[](RunnerConfig &runner_config) {
std::string result = "thread num: ";
const auto &context = runner_config.GetContext();
result += std::to_string(context->GetThreadNum());
result += ", bind mode: ";
result += std::to_string(context->GetThreadAffinityMode());
return result;
})
.def("get_config_info_string", [](RunnerConfig &runner_config) {
std::string result = "";
const auto &config_info = runner_config.GetConfigInfo();
for (auto &section : config_info) {
result += section.first + ": ";
for (auto &config : section.second) {
auto temp = config.first + " " + config.second + "\n";
result += temp;
}
}
return result;
});