fix Split
This commit is contained in:
parent
0d22e64fa8
commit
f3badea5bc
|
@ -643,8 +643,10 @@ class Split(PrimitiveWithInfer):
|
|||
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name)
|
||||
validator.check_integer("output_num", self.output_num, 0, Rel.GT, self.name)
|
||||
output_valid_check = x_shape[self.axis] % self.output_num
|
||||
validator.check_integer("the dimension which to split divides output_num", output_valid_check, 0, Rel.EQ,
|
||||
self.name)
|
||||
if output_valid_check != 0:
|
||||
raise ValueError(f"x_shape[{self.axis}] {x_shape[self.axis]} must be divide exactly by"
|
||||
f" output_num {self.output_num}")
|
||||
|
||||
x_shape[self.axis] = int(x_shape[self.axis] / self.output_num)
|
||||
out_shapes = []
|
||||
out_dtypes = []
|
||||
|
|
|
@ -4951,8 +4951,7 @@ class LRN(PrimitiveWithInfer):
|
|||
bias (float): An offset (usually positive to avoid dividing by 0).
|
||||
alpha (float): A scale factor, usually positive.
|
||||
beta (float): An exponent.
|
||||
norm_region (str): Specify normalization region. Options: "ACROSS_CHANNELS", "WITHIN_CHANNEL".
|
||||
Default: "ACROSS_CHANNELS".
|
||||
norm_region (str): Specify normalization region. Options: "ACROSS_CHANNELS". Default: "ACROSS_CHANNELS".
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - A 4D Tensor with float16 or float32 data type.
|
||||
|
@ -4974,6 +4973,7 @@ class LRN(PrimitiveWithInfer):
|
|||
validator.check_value_type("alpha", alpha, [float], self.name)
|
||||
validator.check_value_type("beta", beta, [float], self.name)
|
||||
validator.check_value_type("norm_region", norm_region, [str], self.name)
|
||||
validator.check_string('norm_region', norm_region, ['ACROSS_CHANNELS'], self.name)
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32,), self.name)
|
||||
|
|
Loading…
Reference in New Issue