[PARALLEL LOGING]Less the logging cout

This commit is contained in:
huangxinjing 2022-01-17 10:05:48 +08:00
parent 8903571af3
commit 8395d75eab
4 changed files with 49 additions and 10 deletions

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_ #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_
#define PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_ #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_
#include <map> #include <map>
#include <memory> #include <memory>
@ -189,4 +189,4 @@ class Edge {
}; };
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
#endif // PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_ #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef PARALLEL_STEP_AUTO_PARALLEL_H_ #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_AUTO_PARALLEL_H_
#define PARALLEL_STEP_AUTO_PARALLEL_H_ #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_AUTO_PARALLEL_H_
#include <map> #include <map>
#include <memory> #include <memory>
@ -72,4 +72,4 @@ std::vector<std::vector<size_t>> GetSharedTensorsOps(
const std::vector<std::vector<std::string>> &input_tensor_names); const std::vector<std::vector<std::string>> &input_tensor_names);
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
#endif // PARALLEL_STEP_AUTO_PARALLEL_H_ #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_AUTO_PARALLEL_H_

View File

@ -601,3 +601,34 @@ def _setup_logger(kwargs):
finally: finally:
_setup_logger_lock.release() _setup_logger_lock.release()
return _global_logger return _global_logger
class _LogActionOnce:
"""
A wrapper for modify the warning logging to an empty function. This is used when we want to only log
once to avoid the repeated logging.
Args:
logger (logging): The logger object.
"""
__is_logged__ = False
def __init__(self, logger):
self.logger = logger
def __call__(self, func):
def wrapper(*args, **kwargs):
if not hasattr(self.logger, 'warning'):
return func(*args, **kwargs)
_old_ = self.logger.warning
if _LogActionOnce.__is_logged__:
self.logger.warning = lambda x: x
else:
_LogActionOnce.__is_logged__ = True
res = func(*args, **kwargs)
if hasattr(self.logger, 'warning'):
self.logger.warning = _old_
return res
return wrapper

View File

@ -18,6 +18,7 @@ import functools
import inspect import inspect
import copy import copy
from mindspore.common.api import _wrap_func from mindspore.common.api import _wrap_func
from mindspore.log import _LogActionOnce
from mindspore import context, log as logger from mindspore import context, log as logger
from mindspore.parallel._utils import _is_in_auto_parallel_mode from mindspore.parallel._utils import _is_in_auto_parallel_mode
from .._c_expression import Primitive_, real_run_op, prim_type from .._c_expression import Primitive_, real_run_op, prim_type
@ -153,6 +154,7 @@ class Primitive(Primitive_):
self.add_prim_attr("stage", stage) self.add_prim_attr("stage", stage)
return self return self
@_LogActionOnce(logger=logger)
def shard(self, in_strategy=None, out_strategy=None): def shard(self, in_strategy=None, out_strategy=None):
""" """
Add strategies to primitive attribute. Add strategies to primitive attribute.
@ -200,11 +202,17 @@ class Primitive(Primitive_):
if not _is_in_auto_parallel_mode(): if not _is_in_auto_parallel_mode():
if in_strategy is not None: if in_strategy is not None:
logger.warning(f"The in_strategy: {in_strategy} of {self.name} is not valid in {mode}. " logger.warning(f"The in_strategy of the operator in your network will not take effect in {mode} mode. "
f"Please use semi auto or auto parallel mode.") f"This means the the shard function called in the network is ignored. "
f"If you want to enable it, please use semi auto or auto parallel mode by "
f"context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL "
f"or context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL")
if out_strategy is not None: if out_strategy is not None:
logger.warning(f"The out_strategy: {out_strategy} of {self.name} is not valid in {mode}. " logger.warning(f"The out_strategy of the operator in your network will not take effect in {mode} mode."
f"Please use semi auto or auto parallel mode.") f" This means the the shard function called in the network is ignored. "
f"If you want to enable it, please use semi auto or auto parallel mode by "
f"context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL "
f"or context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL")
self.add_prim_attr("in_strategy", in_strategy) self.add_prim_attr("in_strategy", in_strategy)
self.add_prim_attr("out_strategy", out_strategy) self.add_prim_attr("out_strategy", out_strategy)