forked from mindspore-Ecosystem/mindspore
Add-description-in-API-about-BNTraniningReduce-and-BNTrainingUpdate.
This commit is contained in:
parent
97c3127944
commit
59dc05a29f
|
@ -2360,6 +2360,9 @@ class DiagPart(PrimitiveWithInfer):
|
|||
raise ValueError(f"For \'{self.name}\' input rank must be non-zero and even, but got rank {len(x_shape)}, "
|
||||
f"with shapes {x_shape}")
|
||||
length = len(x_shape) // 2
|
||||
for i in range(length):
|
||||
validator.check('input_shape[i + len(input_shape)/2]', x_shape[i + length],
|
||||
'input_shape[i]', x_shape[i], Rel.EQ, self.name)
|
||||
ret_shape = x_shape[0:length]
|
||||
return ret_shape
|
||||
|
||||
|
|
|
@ -714,15 +714,20 @@ class FusedBatchNormEx(PrimitiveWithInfer):
|
|||
|
||||
class BNTrainingReduce(PrimitiveWithInfer):
|
||||
"""
|
||||
reduce sum at axis [0, 2, 3].
|
||||
For BatchNorm operator, this operator update the moving averages for training and is used in conjunction with
|
||||
BNTrainingUpdate.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
|
||||
- **x** (Tensor) - A 4-D Tensor with float16 or float32 data type. Tensor of shape :math:`(N, C, A, B)`.
|
||||
|
||||
Outputs:
|
||||
- **sum** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **square_sum** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **sum** (Tensor) - A 1-D Tensor with float32 data type. Tensor of shape :math:`(C,)`.
|
||||
- **square_sum** (Tensor) - A 1-D Tensor with float32 data type. Tensor of shape :math:`(C,)`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
|
||||
>>> bn_training_reduce = P.BNTrainingReduce(input_x)
|
||||
>>> output = bn_training_reduce(input_x)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -734,24 +739,90 @@ class BNTrainingReduce(PrimitiveWithInfer):
|
|||
return ([x_shape[1]], [x_shape[1]])
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_tensor_type_same({"x_type": x_type}, [mstype.float16, mstype.float32], self.name)
|
||||
return (x_type, x_type)
|
||||
|
||||
|
||||
class BNTrainingUpdate(PrimitiveWithInfer):
|
||||
"""
|
||||
The primitive operator of the register and info descriptor in bn_training_update.
|
||||
For BatchNorm operator, this operator update the moving averages for training and is used in conjunction with
|
||||
BNTrainingReduce.
|
||||
|
||||
Args:
|
||||
isRef (bool): If a ref. Default: True.
|
||||
epsilon (float): A small value added to variance avoid dividing by zero. Default: 1e-5.
|
||||
factor (float): A weight for updating the mean and variance. Default: 0.1.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - A 4-D Tensor with float16 or float32 data type. Tensor of shape :math:`(N, C, A, B)`.
|
||||
- **sum** (Tensor) - A 1-D Tensor with float16 or float32 data type for the output of operator BNTrainingReduce.
|
||||
Tensor of shape :math:`(C,)`.
|
||||
- **square_sum** (Tensor) - A 1-D Tensor with float16 or float32 data type for the output of operator
|
||||
BNTrainingReduce. Tensor of shape :math:`(C,)`.
|
||||
- **scale** (Tensor) - A 1-D Tensor with float16 or float32, for the scaling factor.
|
||||
Tensor of shape :math:`(C,)`.
|
||||
- **offset** (Tensor) - A 1-D Tensor with float16 or float32, for the scaling offset.
|
||||
Tensor of shape :math:`(C,)`.
|
||||
- **mean** (Tensor) - A 1-D Tensor with float16 or float32, for the scaling mean. Tensor of shape :math:`(C,)`.
|
||||
- **variance** (Tensor) - A 1-D Tensor with float16 or float32, for the update variance.
|
||||
Tensor of shape :math:`(C,)`.
|
||||
|
||||
Outputs:
|
||||
- **y** (Tensor) - Tensor, has the same shape data type as `x`.
|
||||
- **mean** (Tensor) - Tensor for the updated mean, with float32 data type.
|
||||
Has the same shape as `variance`.
|
||||
- **variance** (Tensor) - Tensor for the updated variance, with float32 data type.
|
||||
Has the same shape as `variance`.
|
||||
- **batch_mean** (Tensor) - Tensor for the mean of `x`, with float32 data type.
|
||||
Has the same shape as `variance`.
|
||||
- **batch_variance** (Tensor) - Tensor for the mean of `variance`, with float32 data type.
|
||||
Has the same shape as `variance`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
|
||||
>>> sum = Tensor(np.ones([64]), mindspore.float32)
|
||||
>>> square_sum = Tensor(np.ones([64]), mindspore.float32)
|
||||
>>> scale = Tensor(np.ones([64]), mindspore.float32)
|
||||
>>> offset = Tensor(np.ones([64]), mindspore.float32)
|
||||
>>> mean = Tensor(np.ones([64]), mindspore.float32)
|
||||
>>> variance = Tensor(np.ones([64]), mindspore.float32)
|
||||
>>> bn_training_update = P.BNTrainingUpdate()
|
||||
>>> output = bn_training_update(input_x, sum, square_sum, scale, offset, mean, variance)
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, isRef=True, epsilon=1e-5, factor=0.1):
|
||||
self.init_prim_io_names(inputs=['x', 'sum', 'square_sum', 'scale', 'b', 'mean', 'variance'],
|
||||
outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance'])
|
||||
validator.check_value_type("isRef", isRef, [bool], self.name)
|
||||
validator.check_value_type("epsilon", epsilon, [float], self.name)
|
||||
validator.check_value_type("factor", factor, [float], self.name)
|
||||
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, 'BNTrainingUpdate')
|
||||
self.factor = validator.check_number_range('factor', factor, 0, 1, Rel.INC_BOTH, 'BNTrainingUpdate')
|
||||
|
||||
def infer_shape(self, x, sum, square_sum, scale, b, mean, variance):
|
||||
validator.check_integer("x rank", len(x), 4, Rel.EQ, self.name)
|
||||
validator.check_integer("sum rank", len(sum), 1, Rel.EQ, self.name)
|
||||
validator.check_integer("square_sum rank", len(square_sum), 1, Rel.EQ, self.name)
|
||||
validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name)
|
||||
validator.check_integer("b rank", len(b), 1, Rel.EQ, self.name)
|
||||
validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name)
|
||||
validator.check_integer("variance rank", len(variance), 1, Rel.EQ, self.name)
|
||||
validator.check("sum shape", sum, "x_shape[1]", x[1], Rel.EQ, self.name)
|
||||
validator.check("square_sum shape", square_sum, "sum", sum, Rel.EQ, self.name)
|
||||
validator.check("scale shape", scale, "x_shape[1]", x[1], Rel.EQ, self.name)
|
||||
validator.check("offset shape", b, "x_shape[1]", x[1], Rel.EQ, self.name)
|
||||
validator.check("mean shape", mean, "x_shape[1]", x[1], Rel.EQ, self.name)
|
||||
validator.check("variance shape", variance, "x_shape[1]", x[1], Rel.EQ, self.name)
|
||||
return (x, variance, variance, variance, variance)
|
||||
|
||||
def infer_dtype(self, x, sum, square_sum, scale, b, mean, variance):
|
||||
validator.check_tensor_type_same({"x_type": x}, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_tensor_type_same({"sum_type": sum}, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_tensor_type_same({"square_sum_type": square_sum}, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_tensor_type_same({"scale_type": scale}, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_tensor_type_same({"b_type": b}, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_tensor_type_same({"mean_type": mean}, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_tensor_type_same({"variance_type": variance}, [mstype.float16, mstype.float32], self.name)
|
||||
return (x, variance, variance, variance, variance)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue