forked from mindspore-Ecosystem/mindspore
!14013 Enable zero dimension using attribute not parameter
From: @liangzhibo Reviewed-by: @zh_qh,@ginfung Signed-off-by: @zh_qh
This commit is contained in:
commit
c1a802627b
|
@ -74,7 +74,7 @@ class Tensor(Tensor_):
|
||||||
>>> assert t3.dtype == ms.float32
|
>>> assert t3.dtype == ms.float32
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input_data=None, dtype=None, shape=None, init=None, check_zero_dims=True):
|
def __init__(self, input_data=None, dtype=None, shape=None, init=None):
|
||||||
self.init_finished = False
|
self.init_finished = False
|
||||||
# If input data is numpy number, convert it to np array
|
# If input data is numpy number, convert it to np array
|
||||||
if isinstance(input_data, np_types):
|
if isinstance(input_data, np_types):
|
||||||
|
@ -92,11 +92,10 @@ class Tensor(Tensor_):
|
||||||
if isinstance(shape, numbers.Number):
|
if isinstance(shape, numbers.Number):
|
||||||
shape = (shape,)
|
shape = (shape,)
|
||||||
|
|
||||||
if check_zero_dims:
|
|
||||||
if input_data is not None and isinstance(input_data, (tuple, list, np.ndarray)) \
|
if input_data is not None and isinstance(input_data, (tuple, list, np.ndarray)) \
|
||||||
and np.array(input_data).ndim > 1 and np.array(input_data).size == 0:
|
and np.array(input_data).ndim > 1 and np.array(input_data).size == 0:
|
||||||
raise ValueError("input_data can not contain zero dimension.")
|
raise ValueError("input_data can not contain zero dimension.")
|
||||||
if shape is not None:
|
if shape is not None and not (hasattr(init, "__enable_zero_dim__") and init.__enable_zero_dim__):
|
||||||
if 0 in shape:
|
if 0 in shape:
|
||||||
raise ValueError("Shape can not contain zero value.")
|
raise ValueError("Shape can not contain zero value.")
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ import numbers
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
|
from mindspore.common.initializer import Zero
|
||||||
from .._utils import get_concat_offset
|
from .._utils import get_concat_offset
|
||||||
from ..operations.math_ops import _infer_shape_reduce
|
from ..operations.math_ops import _infer_shape_reduce
|
||||||
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
|
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
|
||||||
|
@ -38,6 +39,7 @@ from ...common.parameter import Parameter
|
||||||
from ...common.tensor import Tensor
|
from ...common.tensor import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class _ScatterOp(PrimitiveWithInfer):
|
class _ScatterOp(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Defines Scatter operators
|
Defines Scatter operators
|
||||||
|
@ -3015,8 +3017,13 @@ class StridedSlice(PrimitiveWithInfer):
|
||||||
|
|
||||||
ret_shape = self._compute_slicing_shape(x['shape'], begin_v, end_v, strides_v)
|
ret_shape = self._compute_slicing_shape(x['shape'], begin_v, end_v, strides_v)
|
||||||
|
|
||||||
value = None if all(ret_shape) else Tensor(np.array([]).reshape(ret_shape), x['dtype'].element_type(),
|
if all(ret_shape):
|
||||||
check_zero_dims=False)
|
value = None
|
||||||
|
else:
|
||||||
|
init_func = Zero()
|
||||||
|
init_func.__enable_zero_dim__ = True
|
||||||
|
value = Tensor(dtype=x['dtype'].element_type(), shape=ret_shape, init=init_func)
|
||||||
|
|
||||||
if "max_value" in x and "min_value" in x:
|
if "max_value" in x and "min_value" in x:
|
||||||
validator.check_value_type("min_value", x["min_value"], [tuple, list], self.name)
|
validator.check_value_type("min_value", x["min_value"], [tuple, list], self.name)
|
||||||
validator.check_value_type("max_value", x["max_value"], [tuple, list], self.name)
|
validator.check_value_type("max_value", x["max_value"], [tuple, list], self.name)
|
||||||
|
|
Loading…
Reference in New Issue