diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 8a361ae7e5d..73060e5f939 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -74,7 +74,7 @@ class Tensor(Tensor_): >>> assert t3.dtype == ms.float32 """ - def __init__(self, input_data=None, dtype=None, shape=None, init=None, check_zero_dims=True): + def __init__(self, input_data=None, dtype=None, shape=None, init=None): self.init_finished = False # If input data is numpy number, convert it to np array if isinstance(input_data, np_types): @@ -92,13 +92,12 @@ class Tensor(Tensor_): if isinstance(shape, numbers.Number): shape = (shape,) - if check_zero_dims: - if input_data is not None and isinstance(input_data, (tuple, list, np.ndarray)) \ - and np.array(input_data).ndim > 1 and np.array(input_data).size == 0: - raise ValueError("input_data can not contain zero dimension.") - if shape is not None: - if 0 in shape: - raise ValueError("Shape can not contain zero value.") + if input_data is not None and isinstance(input_data, (tuple, list, np.ndarray)) \ + and np.array(input_data).ndim > 1 and np.array(input_data).size == 0: + raise ValueError("input_data can not contain zero dimension.") + if shape is not None and not (hasattr(init, "__enable_zero_dim__") and init.__enable_zero_dim__): + if 0 in shape: + raise ValueError("Shape can not contain zero value.") # If input_data is tuple/list/numpy.ndarray, it's support in check_type method. if init is None: diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 09cb4ee7d60..ff6b92fc473 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -26,6 +26,7 @@ import numbers import numpy as np from mindspore import log as logger +from mindspore.common.initializer import Zero from .._utils import get_concat_offset from ..operations.math_ops import _infer_shape_reduce from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op @@ -38,6 +39,7 @@ from ...common.parameter import Parameter from ...common.tensor import Tensor + class _ScatterOp(PrimitiveWithInfer): """ Defines Scatter operators @@ -3015,8 +3017,13 @@ class StridedSlice(PrimitiveWithInfer): ret_shape = self._compute_slicing_shape(x['shape'], begin_v, end_v, strides_v) - value = None if all(ret_shape) else Tensor(np.array([]).reshape(ret_shape), x['dtype'].element_type(), - check_zero_dims=False) + if all(ret_shape): + value = None + else: + init_func = Zero() + init_func.__enable_zero_dim__ = True + value = Tensor(dtype=x['dtype'].element_type(), shape=ret_shape, init=init_func) + if "max_value" in x and "min_value" in x: validator.check_value_type("min_value", x["min_value"], [tuple, list], self.name) validator.check_value_type("max_value", x["max_value"], [tuple, list], self.name)