!31930 clean static alarm for operator module with python.

Merge pull request !31930 from wangshuide/wsd_master_warn
This commit is contained in:
i-robot 2022-03-29 13:07:25 +00:00 committed by Gitee
commit 2b3f109f08
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 12 additions and 6 deletions

View File

@ -191,18 +191,18 @@ class _BatchNorm(Cell):
self.is_global = True self.is_global = True
global SYNC_BN_GROUP_NAME global SYNC_BN_GROUP_NAME
if SYNC_BN_GROUP_NAME == "": if SYNC_BN_GROUP_NAME == "":
SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i) SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i]) management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i])
def _create_sync_groups(self): def _create_sync_groups(self):
for i in range(len(self.process_groups)): for i in range(len(self.process_groups)):
validator.check_isinstance("process_groups[" + str(i) + "]", self.process_groups[i], list) validator.check_isinstance("process_groups[%d]" % i, self.process_groups[i], list)
self.group_device_num = len(self.process_groups[i]) self.group_device_num = len(self.process_groups[i])
if self.rank_id in self.process_groups[i] and self.group_device_num > 1: if self.rank_id in self.process_groups[i] and self.group_device_num > 1:
self.is_global = True self.is_global = True
global SYNC_BN_GROUP_NAME global SYNC_BN_GROUP_NAME
if SYNC_BN_GROUP_NAME == "": if SYNC_BN_GROUP_NAME == "":
SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i) SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i]) management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
def construct(self, x): def construct(self, x):

View File

@ -136,6 +136,7 @@ def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
return hy return hy
class RNNCellBase(Cell): class RNNCellBase(Cell):
'''Basic class for RNN Cells''' '''Basic class for RNN Cells'''
def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int): def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int):

View File

@ -79,6 +79,7 @@ class _ReverseSequence(Cell):
@staticmethod @staticmethod
def make_shape(shape, dtype, range_dim): def make_shape(shape, dtype, range_dim):
"""Calculates the shape according by the inputs."""
output = P.Ones()(shape, mstype.float32) output = P.Ones()(shape, mstype.float32)
output = P.CumSum()(output, range_dim) output = P.CumSum()(output, range_dim)
output = P.Cast()(output, dtype) output = P.Cast()(output, dtype)

View File

@ -19,13 +19,15 @@ import mindspore.common.dtype as mstype
import mindspore.log as logger import mindspore.log as logger
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer, Initializer from mindspore.common.initializer import initializer, Initializer
from mindspore.communication.management import get_group_size, get_rank
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore._checkparam import Validator, Rel, twice from mindspore._checkparam import Validator, Rel, twice
from mindspore import context from mindspore import context
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation from mindspore.nn.layer.activation import get_activation
from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context, \
_set_rank_id, _insert_hash_table_size, _set_cache_enable
from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
@ -40,7 +42,8 @@ class DenseThor(Cell):
The dense connected layer and saving the information needed for THOR. The dense connected layer and saving the information needed for THOR.
Applies dense connected layer for the input and saves the information A and G in the dense connected layer Applies dense connected layer for the input and saves the information A and G in the dense connected layer
needed for THOR, the detail can be seen in paper: https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf needed for THOR, the detail can be seen in `THOR, Trace-based Hardware-driven layer-ORiented Natural
Gradient Descent Computation <https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf>`_.
This layer implements the operation as: This layer implements the operation as:
.. math:: .. math::

View File

@ -466,6 +466,7 @@ def get_print_vmap_rule(prim, axis_size):
"""VmapRule for `Print` operation.""" """VmapRule for `Print` operation."""
if isinstance(prim, str): if isinstance(prim, str):
prim = Primitive(prim) prim = Primitive(prim)
def vmap_rule(*args): def vmap_rule(*args):
vals = () vals = ()
args_len = len(args) args_len = len(args)

View File

@ -574,7 +574,7 @@ class Custom(ops.PrimitiveWithInfer):
if isinstance(item, dict) and item.get("value") is None: if isinstance(item, dict) and item.get("value") is None:
reg_info["attr"][i]["value"] = "all" reg_info["attr"][i]["value"] = "all"
reg_info["async_flag"] = reg_info.get("async_flag", False) reg_info["async_flag"] = reg_info.get("async_flag", False)
reg_info["binfile_name"] = self.func_name + ".so" reg_info["binfile_name"] = "%s.so" % self.func_name
reg_info["compute_cost"] = reg_info.get("compute_cost", 10) reg_info["compute_cost"] = reg_info.get("compute_cost", 10)
reg_info["kernel_name"] = self.func_name reg_info["kernel_name"] = self.func_name
reg_info["partial_flag"] = reg_info.get("partial_flag", True) reg_info["partial_flag"] = reg_info.get("partial_flag", True)