Fix repeate logging
Fix allreduce execute order in auto parallel mode
This commit is contained in:
parent
761fbe5155
commit
eaaf11f8e0
|
@ -276,8 +276,10 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr)
|
|||
SetLoopSink();
|
||||
GetMaxStreamTaskNum();
|
||||
ReorderIndependentOrders(graph_ptr);
|
||||
TrailingTimeOptimizationByReorder(graph_ptr);
|
||||
|
||||
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
|
||||
if (parallel_mode != parallel::kSemiAutoParallel && parallel_mode != parallel::kAutoParallel) {
|
||||
TrailingTimeOptimizationByReorder(graph_ptr);
|
||||
}
|
||||
AssignAllNodesStream(graph_ptr);
|
||||
UpdateAtomicAddrCleanStreamId(graph_ptr);
|
||||
InsertStreamActive(graph_ptr);
|
||||
|
|
|
@ -18,6 +18,7 @@ from copy import copy
|
|||
import numbers
|
||||
import numpy as np
|
||||
from mindspore import log as logger
|
||||
from mindspore.log import _LogActionOnce
|
||||
from .._c_expression import ParamInfo
|
||||
from . import dtype as mstype
|
||||
from .. import context
|
||||
|
@ -514,6 +515,7 @@ class Parameter(Tensor_):
|
|||
new_param.param_info = self.param_info
|
||||
return new_param
|
||||
|
||||
@_LogActionOnce(logger=logger, key='add_pipeline_stage')
|
||||
def add_pipeline_stage(self, stage):
|
||||
logger.warning(f"This interface may be deleted in the future.")
|
||||
if not isinstance(stage, int) or stage < 0:
|
||||
|
|
|
@ -612,10 +612,11 @@ class _LogActionOnce:
|
|||
logger (logging): The logger object.
|
||||
|
||||
"""
|
||||
__is_logged__ = False
|
||||
__is_logged__ = dict()
|
||||
|
||||
def __init__(self, logger):
|
||||
def __init__(self, logger, key):
|
||||
self.logger = logger
|
||||
self.key = key
|
||||
|
||||
def __call__(self, func):
|
||||
def wrapper(*args, **kwargs):
|
||||
|
@ -623,10 +624,10 @@ class _LogActionOnce:
|
|||
return func(*args, **kwargs)
|
||||
|
||||
_old_ = self.logger.warning
|
||||
if _LogActionOnce.__is_logged__:
|
||||
if self.key in _LogActionOnce.__is_logged__:
|
||||
self.logger.warning = lambda x: x
|
||||
else:
|
||||
_LogActionOnce.__is_logged__ = True
|
||||
_LogActionOnce.__is_logged__[self.key] = True
|
||||
res = func(*args, **kwargs)
|
||||
if hasattr(self.logger, 'warning'):
|
||||
self.logger.warning = _old_
|
||||
|
|
|
@ -65,7 +65,7 @@ class CrossEntropyLoss(Cell):
|
|||
|
||||
def __init__(self, parallel_config=default_dpmp_config):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
if not isinstance(parallel_config, OpParallelConfig) and not isinstance(parallel_config):
|
||||
if not isinstance(parallel_config, OpParallelConfig):
|
||||
raise TypeError("For 'CrossEntropyLoss', the class variable 'parallel_config' must be OpParallelConfig"
|
||||
", but got the type: {}.".format(type(parallel_config)))
|
||||
dp = parallel_config.data_parallel
|
||||
|
|
|
@ -155,7 +155,7 @@ class Primitive(Primitive_):
|
|||
self.add_prim_attr("stage", stage)
|
||||
return self
|
||||
|
||||
@_LogActionOnce(logger=logger)
|
||||
@_LogActionOnce(logger=logger, key='Primitive')
|
||||
def shard(self, in_strategy=None, out_strategy=None):
|
||||
"""
|
||||
Add strategies to primitive attribute.
|
||||
|
|
Loading…
Reference in New Issue