forked from mindspore-Ecosystem/mindspore
fix operation
This commit is contained in:
parent
b9ee076347
commit
3390ae3629
|
@ -352,6 +352,13 @@ def _is_float_dtype(dtype):
|
|||
return False
|
||||
|
||||
|
||||
@constexpr
|
||||
def _need_reduce_all(axis):
|
||||
if axis == ():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ClipByNorm(Cell):
|
||||
r"""
|
||||
Clips tensor values to a maximum :math:`L_2`-norm.
|
||||
|
@ -424,7 +431,7 @@ class ClipByNorm(Cell):
|
|||
intermediate = x * clip_norm
|
||||
|
||||
max_norm = self.max_op(l2norm, clip_norm)
|
||||
if self.axis is None:
|
||||
if _need_reduce_all(self.axis):
|
||||
max_norm = self.expand_dims(max_norm, -1)
|
||||
values_clip = self.cast(intermediate, mstype.float32) / max_norm
|
||||
values_clip = self.reshape(values_clip, self.shape(x))
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
"""embedding"""
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import log as logger
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.tensor import Tensor, MetaTensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
@ -101,8 +101,11 @@ class Embedding(Cell):
|
|||
if padding_idx is not None:
|
||||
self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH,
|
||||
"padding_idx", self.cls_name)
|
||||
self.init_tensor = self.init_tensor.to_tensor().asnumpy()
|
||||
if isinstance(self.init_tensor, MetaTensor):
|
||||
self.init_tensor = self.init_tensor.to_tensor()
|
||||
self.init_tensor = self.init_tensor.asnumpy()
|
||||
self.init_tensor[self.padding_idx] = 0
|
||||
self.init_tensor = Tensor(self.init_tensor)
|
||||
self.embedding_table = Parameter(self.init_tensor, name='embedding_table')
|
||||
self.expand = P.ExpandDims()
|
||||
self.reshape_flat = P.Reshape()
|
||||
|
|
|
@ -77,7 +77,9 @@ def _get_square_sum(x):
|
|||
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
|
||||
@apply_global_norm.register("Tensor", "Tensor", "Tensor")
|
||||
def _apply_global_norm(clip_norm, global_norm, x):
|
||||
x_dtype = F.dtype(x)
|
||||
x = x * clip_norm / global_norm
|
||||
x = F.cast(x, x_dtype)
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -6638,7 +6638,7 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|||
args = {'init_h': h_dtype, 'bias_input': binput_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
|
||||
b_dtype = binput_dtype
|
||||
elif bhidden_dtype is not None:
|
||||
if bhidden_dtype is not None:
|
||||
args = {'init_h': h_dtype, 'bias_hidden': bhidden_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
|
||||
b_dtype = bhidden_dtype
|
||||
|
|
Loading…
Reference in New Issue