From f3badea5bcaa1e53a00bb2c0d67acbfc4299270c Mon Sep 17 00:00:00 2001 From: jiangjinsheng Date: Tue, 7 Jul 2020 18:07:34 +0800 Subject: [PATCH] fix Split --- mindspore/ops/operations/array_ops.py | 6 ++++-- mindspore/ops/operations/nn_ops.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index b30a03d604a..7801ecace06 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -643,8 +643,10 @@ class Split(PrimitiveWithInfer): validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name) validator.check_integer("output_num", self.output_num, 0, Rel.GT, self.name) output_valid_check = x_shape[self.axis] % self.output_num - validator.check_integer("the dimension which to split divides output_num", output_valid_check, 0, Rel.EQ, - self.name) + if output_valid_check != 0: + raise ValueError(f"x_shape[{self.axis}] {x_shape[self.axis]} must be divide exactly by" + f" output_num {self.output_num}") + x_shape[self.axis] = int(x_shape[self.axis] / self.output_num) out_shapes = [] out_dtypes = [] diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 3c7615ce6e9..b19224efb0c 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -4951,8 +4951,7 @@ class LRN(PrimitiveWithInfer): bias (float): An offset (usually positive to avoid dividing by 0). alpha (float): A scale factor, usually positive. beta (float): An exponent. - norm_region (str): Specify normalization region. Options: "ACROSS_CHANNELS", "WITHIN_CHANNEL". - Default: "ACROSS_CHANNELS". + norm_region (str): Specify normalization region. Options: "ACROSS_CHANNELS". Default: "ACROSS_CHANNELS". Inputs: - **x** (Tensor) - A 4D Tensor with float16 or float32 data type. @@ -4974,6 +4973,7 @@ class LRN(PrimitiveWithInfer): validator.check_value_type("alpha", alpha, [float], self.name) validator.check_value_type("beta", beta, [float], self.name) validator.check_value_type("norm_region", norm_region, [str], self.name) + validator.check_string('norm_region', norm_region, ['ACROSS_CHANNELS'], self.name) def infer_dtype(self, x_dtype): validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32,), self.name)