forked from mindspore-Ecosystem/mindspore
!31930 clean static alarm for operator module with python.
Merge pull request !31930 from wangshuide/wsd_master_warn
This commit is contained in:
commit
2b3f109f08
|
@ -191,18 +191,18 @@ class _BatchNorm(Cell):
|
|||
self.is_global = True
|
||||
global SYNC_BN_GROUP_NAME
|
||||
if SYNC_BN_GROUP_NAME == "":
|
||||
SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i)
|
||||
SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
|
||||
management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i])
|
||||
|
||||
def _create_sync_groups(self):
|
||||
for i in range(len(self.process_groups)):
|
||||
validator.check_isinstance("process_groups[" + str(i) + "]", self.process_groups[i], list)
|
||||
validator.check_isinstance("process_groups[%d]" % i, self.process_groups[i], list)
|
||||
self.group_device_num = len(self.process_groups[i])
|
||||
if self.rank_id in self.process_groups[i] and self.group_device_num > 1:
|
||||
self.is_global = True
|
||||
global SYNC_BN_GROUP_NAME
|
||||
if SYNC_BN_GROUP_NAME == "":
|
||||
SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i)
|
||||
SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
|
||||
management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
|
||||
|
||||
def construct(self, x):
|
||||
|
|
|
@ -136,6 +136,7 @@ def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|||
|
||||
return hy
|
||||
|
||||
|
||||
class RNNCellBase(Cell):
|
||||
'''Basic class for RNN Cells'''
|
||||
def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int):
|
||||
|
|
|
@ -79,6 +79,7 @@ class _ReverseSequence(Cell):
|
|||
|
||||
@staticmethod
|
||||
def make_shape(shape, dtype, range_dim):
|
||||
"""Calculates the shape according by the inputs."""
|
||||
output = P.Ones()(shape, mstype.float32)
|
||||
output = P.CumSum()(output, range_dim)
|
||||
output = P.Cast()(output, dtype)
|
||||
|
|
|
@ -19,13 +19,15 @@ import mindspore.common.dtype as mstype
|
|||
import mindspore.log as logger
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.initializer import initializer, Initializer
|
||||
from mindspore.communication.management import get_group_size, get_rank
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore._checkparam import Validator, Rel, twice
|
||||
from mindspore import context
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context
|
||||
from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context, \
|
||||
_set_rank_id, _insert_hash_table_size, _set_cache_enable
|
||||
from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
@ -40,7 +42,8 @@ class DenseThor(Cell):
|
|||
The dense connected layer and saving the information needed for THOR.
|
||||
|
||||
Applies dense connected layer for the input and saves the information A and G in the dense connected layer
|
||||
needed for THOR, the detail can be seen in paper: https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf
|
||||
needed for THOR, the detail can be seen in `THOR, Trace-based Hardware-driven layer-ORiented Natural
|
||||
Gradient Descent Computation <https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf>`_.
|
||||
This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
|
|
|
@ -466,6 +466,7 @@ def get_print_vmap_rule(prim, axis_size):
|
|||
"""VmapRule for `Print` operation."""
|
||||
if isinstance(prim, str):
|
||||
prim = Primitive(prim)
|
||||
|
||||
def vmap_rule(*args):
|
||||
vals = ()
|
||||
args_len = len(args)
|
||||
|
|
|
@ -574,7 +574,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
if isinstance(item, dict) and item.get("value") is None:
|
||||
reg_info["attr"][i]["value"] = "all"
|
||||
reg_info["async_flag"] = reg_info.get("async_flag", False)
|
||||
reg_info["binfile_name"] = self.func_name + ".so"
|
||||
reg_info["binfile_name"] = "%s.so" % self.func_name
|
||||
reg_info["compute_cost"] = reg_info.get("compute_cost", 10)
|
||||
reg_info["kernel_name"] = self.func_name
|
||||
reg_info["partial_flag"] = reg_info.get("partial_flag", True)
|
||||
|
|
Loading…
Reference in New Issue