forked from mindspore-Ecosystem/mindspore
[PARALLEL LOGING]Less the logging cout
This commit is contained in:
parent
8903571af3
commit
8395d75eab
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_
|
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_
|
||||||
#define PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_
|
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -189,4 +189,4 @@ class Edge {
|
||||||
};
|
};
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_
|
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef PARALLEL_STEP_AUTO_PARALLEL_H_
|
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_AUTO_PARALLEL_H_
|
||||||
#define PARALLEL_STEP_AUTO_PARALLEL_H_
|
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_AUTO_PARALLEL_H_
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -72,4 +72,4 @@ std::vector<std::vector<size_t>> GetSharedTensorsOps(
|
||||||
const std::vector<std::vector<std::string>> &input_tensor_names);
|
const std::vector<std::vector<std::string>> &input_tensor_names);
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // PARALLEL_STEP_AUTO_PARALLEL_H_
|
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_AUTO_PARALLEL_H_
|
||||||
|
|
|
@ -601,3 +601,34 @@ def _setup_logger(kwargs):
|
||||||
finally:
|
finally:
|
||||||
_setup_logger_lock.release()
|
_setup_logger_lock.release()
|
||||||
return _global_logger
|
return _global_logger
|
||||||
|
|
||||||
|
|
||||||
|
class _LogActionOnce:
|
||||||
|
"""
|
||||||
|
A wrapper for modify the warning logging to an empty function. This is used when we want to only log
|
||||||
|
once to avoid the repeated logging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logger (logging): The logger object.
|
||||||
|
|
||||||
|
"""
|
||||||
|
__is_logged__ = False
|
||||||
|
|
||||||
|
def __init__(self, logger):
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
def __call__(self, func):
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if not hasattr(self.logger, 'warning'):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
_old_ = self.logger.warning
|
||||||
|
if _LogActionOnce.__is_logged__:
|
||||||
|
self.logger.warning = lambda x: x
|
||||||
|
else:
|
||||||
|
_LogActionOnce.__is_logged__ = True
|
||||||
|
res = func(*args, **kwargs)
|
||||||
|
if hasattr(self.logger, 'warning'):
|
||||||
|
self.logger.warning = _old_
|
||||||
|
return res
|
||||||
|
return wrapper
|
||||||
|
|
|
@ -18,6 +18,7 @@ import functools
|
||||||
import inspect
|
import inspect
|
||||||
import copy
|
import copy
|
||||||
from mindspore.common.api import _wrap_func
|
from mindspore.common.api import _wrap_func
|
||||||
|
from mindspore.log import _LogActionOnce
|
||||||
from mindspore import context, log as logger
|
from mindspore import context, log as logger
|
||||||
from mindspore.parallel._utils import _is_in_auto_parallel_mode
|
from mindspore.parallel._utils import _is_in_auto_parallel_mode
|
||||||
from .._c_expression import Primitive_, real_run_op, prim_type
|
from .._c_expression import Primitive_, real_run_op, prim_type
|
||||||
|
@ -153,6 +154,7 @@ class Primitive(Primitive_):
|
||||||
self.add_prim_attr("stage", stage)
|
self.add_prim_attr("stage", stage)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@_LogActionOnce(logger=logger)
|
||||||
def shard(self, in_strategy=None, out_strategy=None):
|
def shard(self, in_strategy=None, out_strategy=None):
|
||||||
"""
|
"""
|
||||||
Add strategies to primitive attribute.
|
Add strategies to primitive attribute.
|
||||||
|
@ -200,11 +202,17 @@ class Primitive(Primitive_):
|
||||||
|
|
||||||
if not _is_in_auto_parallel_mode():
|
if not _is_in_auto_parallel_mode():
|
||||||
if in_strategy is not None:
|
if in_strategy is not None:
|
||||||
logger.warning(f"The in_strategy: {in_strategy} of {self.name} is not valid in {mode}. "
|
logger.warning(f"The in_strategy of the operator in your network will not take effect in {mode} mode. "
|
||||||
f"Please use semi auto or auto parallel mode.")
|
f"This means the the shard function called in the network is ignored. "
|
||||||
|
f"If you want to enable it, please use semi auto or auto parallel mode by "
|
||||||
|
f"context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL "
|
||||||
|
f"or context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL")
|
||||||
if out_strategy is not None:
|
if out_strategy is not None:
|
||||||
logger.warning(f"The out_strategy: {out_strategy} of {self.name} is not valid in {mode}. "
|
logger.warning(f"The out_strategy of the operator in your network will not take effect in {mode} mode."
|
||||||
f"Please use semi auto or auto parallel mode.")
|
f" This means the the shard function called in the network is ignored. "
|
||||||
|
f"If you want to enable it, please use semi auto or auto parallel mode by "
|
||||||
|
f"context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL "
|
||||||
|
f"or context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL")
|
||||||
|
|
||||||
self.add_prim_attr("in_strategy", in_strategy)
|
self.add_prim_attr("in_strategy", in_strategy)
|
||||||
self.add_prim_attr("out_strategy", out_strategy)
|
self.add_prim_attr("out_strategy", out_strategy)
|
||||||
|
|
Loading…
Reference in New Issue