Add-description-in-API-about-BNTraniningReduce-and-BNTrainingUpdate.

This commit is contained in:
liuxiao93 2020-09-19 19:30:16 +08:00
parent 97c3127944
commit 59dc05a29f
2 changed files with 79 additions and 5 deletions

View File

@ -2360,6 +2360,9 @@ class DiagPart(PrimitiveWithInfer):
raise ValueError(f"For \'{self.name}\' input rank must be non-zero and even, but got rank {len(x_shape)}, "
f"with shapes {x_shape}")
length = len(x_shape) // 2
for i in range(length):
validator.check('input_shape[i + len(input_shape)/2]', x_shape[i + length],
'input_shape[i]', x_shape[i], Rel.EQ, self.name)
ret_shape = x_shape[0:length]
return ret_shape

View File

@ -714,15 +714,20 @@ class FusedBatchNormEx(PrimitiveWithInfer):
class BNTrainingReduce(PrimitiveWithInfer):
"""
reduce sum at axis [0, 2, 3].
For BatchNorm operator, this operator update the moving averages for training and is used in conjunction with
BNTrainingUpdate.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
- **x** (Tensor) - A 4-D Tensor with float16 or float32 data type. Tensor of shape :math:`(N, C, A, B)`.
Outputs:
- **sum** (Tensor) - Tensor of shape :math:`(C,)`.
- **square_sum** (Tensor) - Tensor of shape :math:`(C,)`.
- **sum** (Tensor) - A 1-D Tensor with float32 data type. Tensor of shape :math:`(C,)`.
- **square_sum** (Tensor) - A 1-D Tensor with float32 data type. Tensor of shape :math:`(C,)`.
Examples:
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
>>> bn_training_reduce = P.BNTrainingReduce(input_x)
>>> output = bn_training_reduce(input_x)
"""
@prim_attr_register
@ -734,24 +739,90 @@ class BNTrainingReduce(PrimitiveWithInfer):
return ([x_shape[1]], [x_shape[1]])
def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x_type": x_type}, [mstype.float16, mstype.float32], self.name)
return (x_type, x_type)
class BNTrainingUpdate(PrimitiveWithInfer):
"""
The primitive operator of the register and info descriptor in bn_training_update.
For BatchNorm operator, this operator update the moving averages for training and is used in conjunction with
BNTrainingReduce.
Args:
isRef (bool): If a ref. Default: True.
epsilon (float): A small value added to variance avoid dividing by zero. Default: 1e-5.
factor (float): A weight for updating the mean and variance. Default: 0.1.
Inputs:
- **x** (Tensor) - A 4-D Tensor with float16 or float32 data type. Tensor of shape :math:`(N, C, A, B)`.
- **sum** (Tensor) - A 1-D Tensor with float16 or float32 data type for the output of operator BNTrainingReduce.
Tensor of shape :math:`(C,)`.
- **square_sum** (Tensor) - A 1-D Tensor with float16 or float32 data type for the output of operator
BNTrainingReduce. Tensor of shape :math:`(C,)`.
- **scale** (Tensor) - A 1-D Tensor with float16 or float32, for the scaling factor.
Tensor of shape :math:`(C,)`.
- **offset** (Tensor) - A 1-D Tensor with float16 or float32, for the scaling offset.
Tensor of shape :math:`(C,)`.
- **mean** (Tensor) - A 1-D Tensor with float16 or float32, for the scaling mean. Tensor of shape :math:`(C,)`.
- **variance** (Tensor) - A 1-D Tensor with float16 or float32, for the update variance.
Tensor of shape :math:`(C,)`.
Outputs:
- **y** (Tensor) - Tensor, has the same shape data type as `x`.
- **mean** (Tensor) - Tensor for the updated mean, with float32 data type.
Has the same shape as `variance`.
- **variance** (Tensor) - Tensor for the updated variance, with float32 data type.
Has the same shape as `variance`.
- **batch_mean** (Tensor) - Tensor for the mean of `x`, with float32 data type.
Has the same shape as `variance`.
- **batch_variance** (Tensor) - Tensor for the mean of `variance`, with float32 data type.
Has the same shape as `variance`.
Examples:
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
>>> sum = Tensor(np.ones([64]), mindspore.float32)
>>> square_sum = Tensor(np.ones([64]), mindspore.float32)
>>> scale = Tensor(np.ones([64]), mindspore.float32)
>>> offset = Tensor(np.ones([64]), mindspore.float32)
>>> mean = Tensor(np.ones([64]), mindspore.float32)
>>> variance = Tensor(np.ones([64]), mindspore.float32)
>>> bn_training_update = P.BNTrainingUpdate()
>>> output = bn_training_update(input_x, sum, square_sum, scale, offset, mean, variance)
"""
@prim_attr_register
def __init__(self, isRef=True, epsilon=1e-5, factor=0.1):
self.init_prim_io_names(inputs=['x', 'sum', 'square_sum', 'scale', 'b', 'mean', 'variance'],
outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance'])
validator.check_value_type("isRef", isRef, [bool], self.name)
validator.check_value_type("epsilon", epsilon, [float], self.name)
validator.check_value_type("factor", factor, [float], self.name)
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, 'BNTrainingUpdate')
self.factor = validator.check_number_range('factor', factor, 0, 1, Rel.INC_BOTH, 'BNTrainingUpdate')
def infer_shape(self, x, sum, square_sum, scale, b, mean, variance):
validator.check_integer("x rank", len(x), 4, Rel.EQ, self.name)
validator.check_integer("sum rank", len(sum), 1, Rel.EQ, self.name)
validator.check_integer("square_sum rank", len(square_sum), 1, Rel.EQ, self.name)
validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name)
validator.check_integer("b rank", len(b), 1, Rel.EQ, self.name)
validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name)
validator.check_integer("variance rank", len(variance), 1, Rel.EQ, self.name)
validator.check("sum shape", sum, "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("square_sum shape", square_sum, "sum", sum, Rel.EQ, self.name)
validator.check("scale shape", scale, "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("offset shape", b, "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("mean shape", mean, "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("variance shape", variance, "x_shape[1]", x[1], Rel.EQ, self.name)
return (x, variance, variance, variance, variance)
def infer_dtype(self, x, sum, square_sum, scale, b, mean, variance):
validator.check_tensor_type_same({"x_type": x}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"sum_type": sum}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"square_sum_type": square_sum}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"scale_type": scale}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"b_type": b}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"mean_type": mean}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"variance_type": variance}, [mstype.float16, mstype.float32], self.name)
return (x, variance, variance, variance, variance)