add config info for python api
This commit is contained in:
parent
6cb2f6f8c1
commit
a336e2b6b6
|
@ -44,7 +44,7 @@ class RunnerConfig {
|
|||
/// \brief Set the config before runtime. Only valid for ModelParallelRunner.
|
||||
///
|
||||
/// \param[in] config store environment variables before runtime.
|
||||
void SetConfigInfo(const std::string &key, const std::map<std::string, std::string> &config);
|
||||
void SetConfigInfo(const std::string §ion, const std::map<std::string, std::string> &config);
|
||||
|
||||
/// \brief Get the current config setting. Only valid for ModelParallelRunner.
|
||||
///
|
||||
|
|
|
@ -354,6 +354,7 @@ class RunnerConfig:
|
|||
Args:
|
||||
context (Context): Define the context used to store options during execution.
|
||||
workers_num (int): the num of workers.
|
||||
config_info (dict): {key:{key:value}}, Nested map for passing model weight paths.
|
||||
|
||||
Raises:
|
||||
TypeError: type of input parameters are invalid.
|
||||
|
@ -363,12 +364,15 @@ class RunnerConfig:
|
|||
>>> import mindspore_lite as mslite
|
||||
>>> context = mslite.Context()
|
||||
>>> context.append_device_info(mslite.CPUDeviceInfo())
|
||||
>>> runner_config = mslite.RunnerConfig(context=context, workers_num=4)
|
||||
>>> config_info = {"weight": {"weight_path": "path of model weight"}}
|
||||
>>> runner_config = mslite.RunnerConfig(context=context, workers_num=0, config_info=config_info)
|
||||
>>> print(runner_config)
|
||||
workers num: 4, context: 0, .
|
||||
workers num: 4,
|
||||
context: 0,
|
||||
config info: weight: weight_path: path of model weight
|
||||
"""
|
||||
|
||||
def __init__(self, context=None, workers_num=None):
|
||||
def __init__(self, context=None, workers_num=None, config_info=None):
|
||||
if context is not None:
|
||||
check_isinstance("context", context, Context)
|
||||
if workers_num is not None:
|
||||
|
@ -380,9 +384,13 @@ class RunnerConfig:
|
|||
self._runner_config.set_context(context._context)
|
||||
if workers_num is not None:
|
||||
self._runner_config.set_workers_num(workers_num)
|
||||
if config_info is not None:
|
||||
for k, v in config_info.items():
|
||||
self._runner_config.set_config_info(k, v)
|
||||
|
||||
def __str__(self):
|
||||
res = f"workers num: {self._runner_config.get_workers_num()}, " \
|
||||
res = f"workers num: {self._runner_config.get_workers_num()}, \n" \
|
||||
f"config info: {self._runner_config.get_config_info_string()}, \n" \
|
||||
f"context: {self._runner_config.get_context_info()}."
|
||||
return res
|
||||
|
||||
|
|
|
@ -81,16 +81,31 @@ void ModelPyBind(const py::module &m) {
|
|||
#ifdef PARALLEL_INFERENCE
|
||||
py::class_<RunnerConfig, std::shared_ptr<RunnerConfig>>(m, "RunnerConfigBind")
|
||||
.def(py::init<>())
|
||||
.def("set_config_info", &RunnerConfig::SetConfigInfo)
|
||||
.def("get_config_info", &RunnerConfig::GetConfigInfo)
|
||||
.def("set_workers_num", &RunnerConfig::SetWorkersNum)
|
||||
.def("get_workers_num", &RunnerConfig::GetWorkersNum)
|
||||
.def("set_context", &RunnerConfig::SetContext)
|
||||
.def("get_context", &RunnerConfig::GetContext)
|
||||
.def("get_context_info", [](RunnerConfig &runner_config) {
|
||||
std::string result = "thread num: ";
|
||||
const auto &context = runner_config.GetContext();
|
||||
result += std::to_string(context->GetThreadNum());
|
||||
result += ", bind mode: ";
|
||||
result += std::to_string(context->GetThreadAffinityMode());
|
||||
.def("get_context_info",
|
||||
[](RunnerConfig &runner_config) {
|
||||
std::string result = "thread num: ";
|
||||
const auto &context = runner_config.GetContext();
|
||||
result += std::to_string(context->GetThreadNum());
|
||||
result += ", bind mode: ";
|
||||
result += std::to_string(context->GetThreadAffinityMode());
|
||||
return result;
|
||||
})
|
||||
.def("get_config_info_string", [](RunnerConfig &runner_config) {
|
||||
std::string result = "";
|
||||
const auto &config_info = runner_config.GetConfigInfo();
|
||||
for (auto §ion : config_info) {
|
||||
result += section.first + ": ";
|
||||
for (auto &config : section.second) {
|
||||
auto temp = config.first + " " + config.second + "\n";
|
||||
result += temp;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
});
|
||||
|
||||
|
|
Loading…
Reference in New Issue