fix Split

This commit is contained in:
jiangjinsheng 2020-07-07 18:07:34 +08:00
parent 0d22e64fa8
commit f3badea5bc
2 changed files with 6 additions and 4 deletions

View File

@ -643,8 +643,10 @@ class Split(PrimitiveWithInfer):
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name)
validator.check_integer("output_num", self.output_num, 0, Rel.GT, self.name)
output_valid_check = x_shape[self.axis] % self.output_num
validator.check_integer("the dimension which to split divides output_num", output_valid_check, 0, Rel.EQ,
self.name)
if output_valid_check != 0:
raise ValueError(f"x_shape[{self.axis}] {x_shape[self.axis]} must be divide exactly by"
f" output_num {self.output_num}")
x_shape[self.axis] = int(x_shape[self.axis] / self.output_num)
out_shapes = []
out_dtypes = []

View File

@ -4951,8 +4951,7 @@ class LRN(PrimitiveWithInfer):
bias (float): An offset (usually positive to avoid dividing by 0).
alpha (float): A scale factor, usually positive.
beta (float): An exponent.
norm_region (str): Specify normalization region. Options: "ACROSS_CHANNELS", "WITHIN_CHANNEL".
Default: "ACROSS_CHANNELS".
norm_region (str): Specify normalization region. Options: "ACROSS_CHANNELS". Default: "ACROSS_CHANNELS".
Inputs:
- **x** (Tensor) - A 4D Tensor with float16 or float32 data type.
@ -4974,6 +4973,7 @@ class LRN(PrimitiveWithInfer):
validator.check_value_type("alpha", alpha, [float], self.name)
validator.check_value_type("beta", beta, [float], self.name)
validator.check_value_type("norm_region", norm_region, [str], self.name)
validator.check_string('norm_region', norm_region, ['ACROSS_CHANNELS'], self.name)
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32,), self.name)