Fix repeate logging

Fix allreduce execute order in auto parallel mode
This commit is contained in:
huangxinjing 2022-04-11 15:53:26 +08:00
parent 761fbe5155
commit eaaf11f8e0
5 changed files with 13 additions and 8 deletions

View File

@ -276,8 +276,10 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr)
SetLoopSink();
GetMaxStreamTaskNum();
ReorderIndependentOrders(graph_ptr);
TrailingTimeOptimizationByReorder(graph_ptr);
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
if (parallel_mode != parallel::kSemiAutoParallel && parallel_mode != parallel::kAutoParallel) {
TrailingTimeOptimizationByReorder(graph_ptr);
}
AssignAllNodesStream(graph_ptr);
UpdateAtomicAddrCleanStreamId(graph_ptr);
InsertStreamActive(graph_ptr);

View File

@ -18,6 +18,7 @@ from copy import copy
import numbers
import numpy as np
from mindspore import log as logger
from mindspore.log import _LogActionOnce
from .._c_expression import ParamInfo
from . import dtype as mstype
from .. import context
@ -514,6 +515,7 @@ class Parameter(Tensor_):
new_param.param_info = self.param_info
return new_param
@_LogActionOnce(logger=logger, key='add_pipeline_stage')
def add_pipeline_stage(self, stage):
logger.warning(f"This interface may be deleted in the future.")
if not isinstance(stage, int) or stage < 0:

View File

@ -612,10 +612,11 @@ class _LogActionOnce:
logger (logging): The logger object.
"""
__is_logged__ = False
__is_logged__ = dict()
def __init__(self, logger):
def __init__(self, logger, key):
self.logger = logger
self.key = key
def __call__(self, func):
def wrapper(*args, **kwargs):
@ -623,10 +624,10 @@ class _LogActionOnce:
return func(*args, **kwargs)
_old_ = self.logger.warning
if _LogActionOnce.__is_logged__:
if self.key in _LogActionOnce.__is_logged__:
self.logger.warning = lambda x: x
else:
_LogActionOnce.__is_logged__ = True
_LogActionOnce.__is_logged__[self.key] = True
res = func(*args, **kwargs)
if hasattr(self.logger, 'warning'):
self.logger.warning = _old_

View File

@ -65,7 +65,7 @@ class CrossEntropyLoss(Cell):
def __init__(self, parallel_config=default_dpmp_config):
super(CrossEntropyLoss, self).__init__()
if not isinstance(parallel_config, OpParallelConfig) and not isinstance(parallel_config):
if not isinstance(parallel_config, OpParallelConfig):
raise TypeError("For 'CrossEntropyLoss', the class variable 'parallel_config' must be OpParallelConfig"
", but got the type: {}.".format(type(parallel_config)))
dp = parallel_config.data_parallel

View File

@ -155,7 +155,7 @@ class Primitive(Primitive_):
self.add_prim_attr("stage", stage)
return self
@_LogActionOnce(logger=logger)
@_LogActionOnce(logger=logger, key='Primitive')
def shard(self, in_strategy=None, out_strategy=None):
"""
Add strategies to primitive attribute.