forked from mindspore-Ecosystem/mindspore
Fix some bugs about API.
This commit is contained in:
parent
a8478839c9
commit
0a1155f938
|
@ -53,7 +53,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
|
|||
NPUAllocFloatStatus, NPUClearFloatStatus,
|
||||
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
|
||||
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
|
||||
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, IFMR,
|
||||
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
|
||||
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, TensorDot)
|
||||
|
||||
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
|
||||
|
@ -98,7 +98,6 @@ __all__ = [
|
|||
'EditDistance',
|
||||
'CropAndResize',
|
||||
'TensorAdd',
|
||||
'IFMR',
|
||||
'Argmax',
|
||||
'Argmin',
|
||||
'ArgMaxWithValue',
|
||||
|
|
|
@ -43,7 +43,8 @@ __all__ = ["MinMaxUpdatePerLayer",
|
|||
"BatchNormFoldGradD",
|
||||
"BatchNormFold2_D",
|
||||
"BatchNormFold2GradD",
|
||||
"BatchNormFold2GradReduce"
|
||||
"BatchNormFold2GradReduce",
|
||||
"IFMR"
|
||||
]
|
||||
|
||||
|
||||
|
@ -1384,3 +1385,66 @@ class WtsARQ(PrimitiveWithInfer):
|
|||
validator.check_tensor_type_same({"w_min": w_min_dtype}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({"w_max": w_max_dtype}, valid_types, self.name)
|
||||
return w_dtype
|
||||
|
||||
|
||||
class IFMR(PrimitiveWithInfer):
|
||||
"""
|
||||
The TFMR(Input Feature Map Reconstruction).
|
||||
|
||||
Args:
|
||||
min_percentile (float): Min init percentile. Default: 0.999999.
|
||||
max_percentile (float): Max init percentile. Default: 0.999999.
|
||||
search_range Union[list(float), tuple(float)]: Range of searching. Default: [0.7, 1.3].
|
||||
search_step (float): Step size of searching. Default: 0.01.
|
||||
with_offset (bool): Whether using offset. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **data** (Tensor) - A Tensor of feature map. With float16 or float32 data type.
|
||||
- **data_min** (Tensor) - A Tensor of min value of feature map, the shape is :math:`(1)`.
|
||||
With float16 or float32 data type.
|
||||
- **data_max** (Tensor) - A Tensor of max value of feature map, the shape is :math:`(1)`.
|
||||
With float16 or float32 data type.
|
||||
- **cumsum** (Tensor) - A `1-D` Tensor of cumsum bin of data. With int32 data type.
|
||||
|
||||
Outputs:
|
||||
- **scale** (Tensor) - A tensor of optimal scale, the shape is :math:`(1)`. Data dtype is float32.
|
||||
- **offset** (Tensor) - A tensor of optimal offset, the shape is :math:`(1)`. Data dtype is float32.
|
||||
|
||||
Examples:
|
||||
>>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32))
|
||||
>>> data_min = Tensor([0.1], mstype.float32)
|
||||
>>> data_max = Tensor([0.5], mstype.float32)
|
||||
>>> cumsum = Tensor(np.random.rand(4).astype(np.int32))
|
||||
>>> ifmr = Q.IFMR(min_percentile=0.2, max_percentile=0.9, search_range=(1.0, 2.0),
|
||||
>>> search_step=1.0, with_offset=False)
|
||||
>>> output = ifmr(data, data_min, data_max, cumsum)
|
||||
([7.87401572e-03], [0.00000000e+00])
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, min_percentile=0.999999, max_percentile=0.999999, search_range=(0.7, 1.3), search_step=0.01,
|
||||
with_offset=True):
|
||||
validator.check_value_type("min_percentile", min_percentile, [float], self.name)
|
||||
validator.check_value_type("max_percentile", max_percentile, [float], self.name)
|
||||
validator.check_value_type("search_range", search_range, [list, tuple], self.name)
|
||||
for item in search_range:
|
||||
validator.check_positive_float(item, "item of search_range", self.name)
|
||||
validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], Rel.GE, self.name)
|
||||
validator.check_value_type("search_step", search_step, [float], self.name)
|
||||
validator.check_value_type("offset_flag", with_offset, [bool], self.name)
|
||||
|
||||
def infer_shape(self, data_shape, data_min_shape, data_max_shape, cumsum_shape):
|
||||
validator.check_equal_int(len(data_min_shape), 1, "dims of data_min", self.name)
|
||||
validator.check_equal_int(data_min_shape[0], 1, "data_min[0]", self.name)
|
||||
validator.check_equal_int(len(data_max_shape), 1, "dims of data_max", self.name)
|
||||
validator.check_equal_int(data_max_shape[0], 1, "data_max[0]", self.name)
|
||||
validator.check_equal_int(len(cumsum_shape), 1, "dims of cumsum", self.name)
|
||||
return (1,), (1,)
|
||||
|
||||
def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype):
|
||||
tuple(map(partial(validator.check_tensor_dtype_valid,
|
||||
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
|
||||
("input_value", "input_min", "input_max"),
|
||||
(data_dtype, data_min_dtype, data_max_dtype)))
|
||||
validator.check_tensor_dtype_valid("input_bins", cumsum_dtype, [mstype.int32], self.name)
|
||||
return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32)
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
"""Operators for math."""
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from ... import context
|
||||
|
@ -3679,66 +3678,3 @@ class Eps(PrimitiveWithInfer):
|
|||
'dtype': input_x['dtype'],
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
class IFMR(PrimitiveWithInfer):
|
||||
"""
|
||||
The TFMR(Input Feature Map Reconstruction).
|
||||
|
||||
Args:
|
||||
min_percentile (float): Min init percentile. Default: 0.999999.
|
||||
max_percentile (float): Max init percentile. Default: 0.999999.
|
||||
search_range Union[list(float), tuple(float)]: Range of searching. Default: [0.7, 1.3].
|
||||
search_step (float): Step size of searching. Default: 0.01.
|
||||
with_offset (bool): Whether using offset. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **data** (Tensor) - A Tensor of feature map. With float16 or float32 data type.
|
||||
- **data_min** (Tensor) - A Tensor of min value of feature map, the shape is :math:`(1)`.
|
||||
With float16 or float32 data type.
|
||||
- **data_max** (Tensor) - A Tensor of max value of feature map, the shape is :math:`(1)`.
|
||||
With float16 or float32 data type.
|
||||
- **cumsum** (Tensor) - A `1-D` Tensor of cumsum bin of data. With int32 data type.
|
||||
|
||||
Outputs:
|
||||
- **scale** (Tensor) - A tensor of optimal scale, the shape is :math:`(1)`. Data dtype is float32.
|
||||
- **offset** (Tensor) - A tensor of optimal offset, the shape is :math:`(1)`. Data dtype is float32.
|
||||
|
||||
Examples:
|
||||
>>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32))
|
||||
>>> data_min = Tensor([0.1], mstype.float32)
|
||||
>>> data_max = Tensor([0.5], mstype.float32)
|
||||
>>> cumsum = Tensor(np.random.rand(4).astype(np.int32))
|
||||
>>> ifmr = P.IFMR(min_percentile=0.2, max_percentile=0.9, search_range=(1.0, 2.0),
|
||||
search_step=1.0, with_offset=False)
|
||||
>>> output = ifmr(data, data_min, data_max, cumsum)
|
||||
([7.87401572e-03], [0.00000000e+00])
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, min_percentile=0.999999, max_percentile=0.999999, search_range=(0.7, 1.3), search_step=0.01,
|
||||
with_offset=True):
|
||||
validator.check_value_type("min_percentile", min_percentile, [float], self.name)
|
||||
validator.check_value_type("max_percentile", max_percentile, [float], self.name)
|
||||
validator.check_value_type("search_range", search_range, [list, tuple], self.name)
|
||||
for item in search_range:
|
||||
validator.check_positive_float(item, "item of search_range", self.name)
|
||||
validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], Rel.GE, self.name)
|
||||
validator.check_value_type("search_step", search_step, [float], self.name)
|
||||
validator.check_value_type("offset_flag", with_offset, [bool], self.name)
|
||||
|
||||
def infer_shape(self, data_shape, data_min_shape, data_max_shape, cumsum_shape):
|
||||
validator.check_equal_int(len(data_min_shape), 1, "dims of data_min", self.name)
|
||||
validator.check_equal_int(data_min_shape[0], 1, "data_min[0]", self.name)
|
||||
validator.check_equal_int(len(data_max_shape), 1, "dims of data_max", self.name)
|
||||
validator.check_equal_int(data_max_shape[0], 1, "data_max[0]", self.name)
|
||||
validator.check_equal_int(len(cumsum_shape), 1, "dims of cumsum", self.name)
|
||||
return (1,), (1,)
|
||||
|
||||
def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype):
|
||||
tuple(map(partial(validator.check_tensor_dtype_valid,
|
||||
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
|
||||
("input_value", "input_min", "input_max"),
|
||||
(data_dtype, data_min_dtype, data_max_dtype)))
|
||||
validator.check_tensor_dtype_valid("input_bins", cumsum_dtype, [mstype.int32], self.name)
|
||||
return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32)
|
||||
|
|
|
@ -601,10 +601,10 @@ class FusedBatchNorm(Primitive):
|
|||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - Tensor of shape :math:`(N, C)`.
|
||||
- **scale** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **bias** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **mean** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **variance** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **scale** (Parameter) - Tensor of shape :math:`(C,)`.
|
||||
- **bias** (Parameter) - Tensor of shape :math:`(C,)`.
|
||||
- **mean** (Parameter) - Tensor of shape :math:`(C,)`.
|
||||
- **variance** (Parameter) - Tensor of shape :math:`(C,)`.
|
||||
|
||||
Outputs:
|
||||
Tuple of 5 Tensor, the normalized input and the updated parameters.
|
||||
|
@ -616,13 +616,30 @@ class FusedBatchNorm(Primitive):
|
|||
- **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Parameter
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> class FusedBatchNormNet(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(FusedBatchNormNet, self).__init__()
|
||||
>>> self.fused_batch_norm = P.FusedBatchNorm()
|
||||
>>> self.scale = Parameter(Tensor(np.ones([64]), mindspore.float32), name="scale")
|
||||
>>> self.bias = Parameter(Tensor(np.ones([64]), mindspore.float32), name="bias")
|
||||
>>> self.mean = Parameter(Tensor(np.ones([64]), mindspore.float32), name="mean")
|
||||
>>> self.variance = Parameter(Tensor(np.ones([64]), mindspore.float32), name="variance")
|
||||
>>>
|
||||
>>> def construct(self, input_x):
|
||||
>>> out = self.fused_batch_norm(input_x, self.scale, self.bias, self.mean, self.variance)
|
||||
>>> return out
|
||||
>>>
|
||||
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
|
||||
>>> scale = Tensor(np.ones([64]), mindspore.float32)
|
||||
>>> bias = Tensor(np.ones([64]), mindspore.float32)
|
||||
>>> mean = Tensor(np.ones([64]), mindspore.float32)
|
||||
>>> variance = Tensor(np.ones([64]), mindspore.float32)
|
||||
>>> op = P.FusedBatchNorm()
|
||||
>>> output = op(input_x, scale, bias, mean, variance)
|
||||
>>> net = FusedBatchNormNet()
|
||||
>>> output = net(input_x)
|
||||
>>> output[0].shape
|
||||
(128, 64, 32, 64)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
sig.make_sig('input_x', dtype=sig.sig_dtype.T2),
|
||||
|
@ -673,12 +690,12 @@ class FusedBatchNormEx(PrimitiveWithInfer):
|
|||
Inputs:
|
||||
- **input_x** (Tensor) - The input of FusedBatchNormEx, Tensor of shape :math:`(N, C)`,
|
||||
data type: float16 or float32.
|
||||
- **scale** (Tensor) - Parameter scale, same with gamma above-mentioned, Tensor of shape :math:`(C,)`,
|
||||
- **scale** (Parameter) - Parameter scale, same with gamma above-mentioned, Tensor of shape :math:`(C,)`,
|
||||
data type: float32.
|
||||
- **bias** (Tensor) - Parameter bias, same with beta above-mentioned, Tensor of shape :math:`(C,)`,
|
||||
- **bias** (Parameter) - Parameter bias, same with beta above-mentioned, Tensor of shape :math:`(C,)`,
|
||||
data type: float32.
|
||||
- **mean** (Tensor) - mean value, Tensor of shape :math:`(C,)`, data type: float32.
|
||||
- **variance** (Tensor) - variance value, Tensor of shape :math:`(C,)`, data type: float32.
|
||||
- **mean** (Parameter) - mean value, Tensor of shape :math:`(C,)`, data type: float32.
|
||||
- **variance** (Parameter) - variance value, Tensor of shape :math:`(C,)`, data type: float32.
|
||||
|
||||
Outputs:
|
||||
Tuple of 6 Tensors, the normalized input, the updated parameters and reserve.
|
||||
|
@ -692,13 +709,30 @@ class FusedBatchNormEx(PrimitiveWithInfer):
|
|||
- **reserve** (Tensor) - reserve space, Tensor of shape :math:`(C,)`, data type: float32.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Parameter
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> class FusedBatchNormExNet(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(FusedBatchNormExNet, self).__init__()
|
||||
>>> self.fused_batch_norm_ex = P.FusedBatchNormEx()
|
||||
>>> self.scale = Parameter(Tensor(np.ones([64]), mindspore.float32), name="scale")
|
||||
>>> self.bias = Parameter(Tensor(np.ones([64]), mindspore.float32), name="bias")
|
||||
>>> self.mean = Parameter(Tensor(np.ones([64]), mindspore.float32), name="mean")
|
||||
>>> self.variance = Parameter(Tensor(np.ones([64]), mindspore.float32), name="variance")
|
||||
>>>
|
||||
>>> def construct(self, input_x):
|
||||
>>> out = self.fused_batch_norm_ex(input_x, self.scale, self.bias, self.mean, self.variance)
|
||||
>>> return out
|
||||
>>>
|
||||
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
|
||||
>>> scale = Tensor(np.ones([64]), mindspore.float32)
|
||||
>>> bias = Tensor(np.ones([64]), mindspore.float32)
|
||||
>>> mean = Tensor(np.ones([64]), mindspore.float32)
|
||||
>>> variance = Tensor(np.ones([64]), mindspore.float32)
|
||||
>>> op = P.FusedBatchNormEx()
|
||||
>>> output = op(input_x, scale, bias, mean, variance)
|
||||
>>> net = FusedBatchNormExNet()
|
||||
>>> output = net(input_x)
|
||||
>>> output[0].shape
|
||||
(128, 64, 32, 64)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
sig.make_sig('input_x', dtype=sig.sig_dtype.T2),
|
||||
|
@ -756,7 +790,7 @@ class BNTrainingReduce(PrimitiveWithInfer):
|
|||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
|
||||
>>> bn_training_reduce = P.BNTrainingReduce(input_x)
|
||||
>>> bn_training_reduce = P.BNTrainingReduce()
|
||||
>>> output = bn_training_reduce(input_x)
|
||||
"""
|
||||
|
||||
|
@ -5657,13 +5691,30 @@ class DynamicRNN(PrimitiveWithInfer):
|
|||
Has the same type with input `b`.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Parameter
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> import mindspore.context as context
|
||||
>>> context.set_context(mode=context.GRAPH_MODE)
|
||||
>>> class DynamicRNNNet(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(DynamicRNNNet, self).__init__()
|
||||
>>> self.dynamic_rnn = P.DynamicRNN()
|
||||
>>>
|
||||
>>> def construct(self, x, w, b, init_h, init_c):
|
||||
>>> out = self.dynamic_rnn(x, w, b, None, init_h, init_c)
|
||||
>>> return out
|
||||
>>>
|
||||
>>> x = Tensor(np.random.rand(2, 16, 64).astype(np.float16))
|
||||
>>> w = Tensor(np.random.rand(96, 128).astype(np.float16))
|
||||
>>> b = Tensor(np.random.rand(128).astype(np.float16))
|
||||
>>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
|
||||
>>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
|
||||
>>> dynamic_rnn = P.DynamicRNN()
|
||||
>>> output = dynamic_rnn(x, w, b, None, init_h, init_c)
|
||||
>>> net = DynamicRNNNet()
|
||||
>>> output = net(x, w, b, init_h, init_c)
|
||||
>>> output[0].shape
|
||||
(2, 16, 32)
|
||||
"""
|
||||
|
|
|
@ -1446,7 +1446,7 @@ test_case_math_ops = [
|
|||
'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]],
|
||||
'desc_bprop': [[2, 3, 4, 5]]}),
|
||||
('IFMR', {
|
||||
'block': P.IFMR(min_percentile=0.2, max_percentile=0.9, search_range=(1.0, 2.0),
|
||||
'block': Q.IFMR(min_percentile=0.2, max_percentile=0.9, search_range=(1.0, 2.0),
|
||||
search_step=1.0, with_offset=False),
|
||||
'desc_inputs': [[3, 4, 5], Tensor([0.1], mstype.float32), Tensor([0.9], mstype.float32),
|
||||
Tensor(np.random.rand(4).astype(np.int32))],
|
||||
|
|
Loading…
Reference in New Issue