add tensor.ndim and rename tensor.size() to tensor.size

This commit is contained in:
yanglf1121 2020-12-18 16:17:45 +08:00
parent f3aee78bde
commit 918679daa3
22 changed files with 77 additions and 63 deletions

View File

@ -635,7 +635,7 @@ def check_input_data(*data, data_class):
f' either a single'
f' or a list of {data_class.__name__},'
f' but got part data type is {str(type(item))}.')
if hasattr(item, "size") and item.size() == 0:
if hasattr(item, "size") and item.size == 0:
msg = "Please provide non-empty data."
logger.error(msg)
raise ValueError(msg)

View File

@ -31,6 +31,8 @@ shape_ = P.Shape()
reshape_ = P.Reshape()
dtype_ = P.DType()
abs_ = P.Abs()
ndim_ = P.Rank()
size_ = P.Size()
def mean(x, axis=(), keep_dims=False):
"""

View File

@ -192,6 +192,8 @@ BuiltInTypeMap &GetAttrMap() {
{
{"shape", std::string("shape_")}, // C.shape_
{"dtype", std::string("dtype_")}, // C.dtype_
{"size", std::string("size_")}, // C.size_
{"ndim", std::string("ndim_")}, // C.ndim_
}},
{kObjectTypeRowTensorType,
{

View File

@ -370,6 +370,17 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
>>> data.shape()
(3, 3)
)mydelimiter")
.def_property_readonly("_size", &Tensor::DataSize, R"mydelimiter(
Get tensor's data size.
Returns:
int, the size of tensor.
Examples:
>>> data = mindspore.Tensor(np.ones((2, 3)))
>>> data.size
6
)mydelimiter")
.def("from_numpy", TensorPy::MakeTensorNoCopy, R"mydelimiter(
Creates a Tensor from a numpy.ndarray without copy.
@ -396,17 +407,6 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
array([[1., 1., 1.],
[1., 1., 1.]])
)mydelimiter")
.def("size", &Tensor::DataSize, R"mydelimiter(
Get tensor's data size.
Returns:
int, the size of tensor.
Examples:
>>> data = mindspore.Tensor(np.ones((2, 3)))
>>> data.size()
6
)mydelimiter")
.def("is_init", &Tensor::is_init, R"mydelimiter(
Get tensor init_flag.

View File

@ -237,6 +237,16 @@ class Tensor(Tensor_):
"""The dtype of tensor is a mindspore type."""
return self._dtype
@property
def size(self):
"""The size reflects the total number of elements in tensor."""
return self._size
@property
def ndim(self):
"""The ndim of tensor is an integer."""
return len(self._shape)
@property
def virtual_flag(self):
"""Mark tensor is virtual."""

View File

@ -277,7 +277,7 @@ class Dense(Cell):
self.shape_op = P.Shape()
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("Weight init shape error.")
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
@ -285,7 +285,7 @@ class Dense(Cell):
self.bias = None
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
raise ValueError("Bias init shape error.")
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
self.bias_add = P.BiasAdd()

View File

@ -318,7 +318,7 @@ class BatchNorm1d(_BatchNorm):
input_dims='1d')
def _check_data_dim(self, x):
if x.dim() != 2:
if x.ndim != 2:
pass
@ -415,7 +415,7 @@ class BatchNorm2d(_BatchNorm):
data_format=data_format)
def _check_data_dim(self, x):
if x.dim() != 4:
if x.ndim != 4:
pass

View File

@ -1008,7 +1008,7 @@ class DenseQuant(Cell):
self.has_bias = Validator.check_bool(has_bias)
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error")
@ -1017,7 +1017,7 @@ class DenseQuant(Cell):
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(

View File

@ -276,17 +276,17 @@ class Optimizer(Cell):
learning_rate = float(learning_rate)
validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name)
return learning_rate
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0:
if isinstance(learning_rate, Tensor) and learning_rate.ndim == 0:
return learning_rate
self.dynamic_lr = True
if isinstance(learning_rate, Iterable):
return Tensor(np.array(list(learning_rate)).astype(np.float32))
if isinstance(learning_rate, Tensor):
if learning_rate.dim() > 1:
if learning_rate.ndim > 1:
raise ValueError("The dim of `Tensor` type Learning rate should be a 0 or 1,"
f"but got {learning_rate.dim()}.")
if learning_rate.dim() == 1 and learning_rate.size() < 2:
f"but got {learning_rate.ndim}.")
if learning_rate.ndim == 1 and learning_rate.size < 2:
logger.warning("If use `Tensor` type dynamic learning rate, please make sure that the number"
"of elements in the tensor passed is greater than 1.")
return learning_rate
@ -301,12 +301,12 @@ class Optimizer(Cell):
if self.is_group_lr and self.dynamic_lr:
learning_rate = _ConvertToCell(learning_rate)
return learning_rate
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0:
if isinstance(learning_rate, Tensor) and learning_rate.ndim == 0:
learning_rate = Parameter(learning_rate, name)
if self.is_group_lr and self.dynamic_lr:
learning_rate = _ConvertToCell(learning_rate)
return learning_rate
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1:
if isinstance(learning_rate, Tensor) and learning_rate.ndim == 1:
return _IteratorLearningRate(learning_rate, name)
return learning_rate
@ -336,8 +336,8 @@ class Optimizer(Cell):
def _parse_group_params(self, parameters, learning_rate):
"""Parse group params."""
self._check_group_params(parameters)
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1:
tensor_lr_length = learning_rate.size()
if isinstance(learning_rate, Tensor) and learning_rate.ndim == 1:
tensor_lr_length = learning_rate.size
else:
tensor_lr_length = 0
@ -355,8 +355,8 @@ class Optimizer(Cell):
self.is_group_lr = True
group_lr = self._preprocess_single_lr(group_param['lr'])
if isinstance(group_lr, Tensor) and group_lr.dim() == 1:
group_lr_length = group_lr.size()
if isinstance(group_lr, Tensor) and group_lr.ndim == 1:
group_lr_length = group_lr.size
if tensor_lr_length == 0:
tensor_lr_length = group_lr_length
elif group_lr_length != tensor_lr_length:
@ -615,9 +615,9 @@ class _IteratorLearningRate(LearningRateSchedule):
def __init__(self, learning_rate, name):
super(_IteratorLearningRate, self).__init__()
if isinstance(learning_rate, Tensor):
if learning_rate.dim() != 1:
if learning_rate.ndim != 1:
raise ValueError("The dim of `Tensor` type dynamic learning rate should be a 1,"
f"but got {learning_rate.dim()}.")
f"but got {learning_rate.ndim}.")
else:
raise TypeError("Learning rate should be Tensor.")

View File

@ -143,7 +143,7 @@ class ExpandDims(PrimitiveWithInfer):
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
- **axis** (int) - Specifies the dimension index at which to expand
the shape of `input_x`. The value of axis must be in the range
`[-input_x.dim()-1, input_x.dim()]`. Only constant value is allowed.
`[-input_x.ndim-1, input_x.ndim]`. Only constant value is allowed.
Outputs:
Tensor, the shape of tensor is :math:`(1, x_1, x_2, ..., x_R)` if the
@ -597,7 +597,7 @@ class Squeeze(PrimitiveWithInfer):
Returns a tensor with the same type but dimensions of 1 are removed based on `axis`.
Note:
The dimension index starts at 0 and must be in the range `[-input.dim(), input.dim())`.
The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim`.
Raises:
ValueError: If the corresponding dimension of the specified axis does not equal to 1.

View File

@ -238,7 +238,7 @@ def _load_tensor_by_layout(tensor, layout):
group = layout[5]
if uniform_split == 0:
raise RuntimeError("The load tensor only support uniform split now")
if tensor.size() == 1:
if tensor.size == 1:
return tensor
tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
if group:

View File

@ -294,7 +294,7 @@ class Dense_Thor_GPU(Cell):
self.has_bias = Validator.check_bool(has_bias)
self.thor = True
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error")
@ -302,7 +302,7 @@ class Dense_Thor_GPU(Cell):
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(bias_init, [out_channels]))
@ -639,7 +639,7 @@ class Dense_Thor(Cell):
self.thor = True
self.batch_size = batch_size
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error")
@ -647,7 +647,7 @@ class Dense_Thor(Cell):
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(bias_init, [out_channels]))

View File

@ -50,16 +50,16 @@ def convert(weights_file, output_file):
var = params[i+2]
gamma = params[i+3]
beta = params[i+4]
beta_data = weights[index: index+beta.size()].reshape(beta.shape)
index += beta.size()
gamma_data = weights[index: index+gamma.size()].reshape(gamma.shape)
index += gamma.size()
mean_data = weights[index: index+mean.size()].reshape(mean.shape)
index += mean.size()
var_data = weights[index: index + var.size()].reshape(var.shape)
index += var.size()
weight_data = weights[index: index+weight.size()].reshape(weight.shape)
index += weight.size()
beta_data = weights[index: index+beta.size].reshape(beta.shape)
index += beta.size
gamma_data = weights[index: index+gamma.size].reshape(gamma.shape)
index += gamma.size
mean_data = weights[index: index+mean.size].reshape(mean.shape)
index += mean.size
var_data = weights[index: index + var.size].reshape(var.shape)
index += var.size
weight_data = weights[index: index+weight.size].reshape(weight.shape)
index += weight.size
param_list.append({'name': weight.name, 'type': weight.dtype, 'shape': weight.shape,
'data': Tensor(weight_data)})

View File

@ -76,7 +76,7 @@ class GNNFeatureTransform(nn.Cell):
self.has_bias = has_bias
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error")
@ -84,7 +84,7 @@ class GNNFeatureTransform(nn.Cell):
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(bias_init, [out_channels]))

View File

@ -162,7 +162,7 @@ class Dense_Thor(Cell):
self.has_bias = Validator.check_bool(has_bias)
self.thor = True
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
if weight_init.ndim != 2 or weight_init.shape()[0] != out_channels or \
weight_init.shape()[1] != in_channels:
raise ValueError("weight_init shape error")
@ -170,7 +170,7 @@ class Dense_Thor(Cell):
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
if bias_init.ndim != 1 or bias_init.shape()[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(bias_init, [out_channels]))

View File

@ -43,7 +43,7 @@ class DenseLayer(nn.Cell):
self.has_bias = validator.check_bool(has_bias)
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
if weight_init.ndim != 2 or weight_init.shape()[0] != out_channels or \
weight_init.shape()[1] != in_channels:
raise ValueError("weight_init shape error")
@ -51,7 +51,7 @@ class DenseLayer(nn.Cell):
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
if bias_init.ndim != 1 or bias_init.shape()[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(bias_init, [out_channels]))

View File

@ -78,7 +78,7 @@ class GNNFeatureTransform(nn.Cell):
self.has_bias = Validator.check_bool(has_bias)
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error")
@ -86,7 +86,7 @@ class GNNFeatureTransform(nn.Cell):
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")

View File

@ -342,7 +342,7 @@ class Dense_Thor(Cell):
self.has_bias = Validator.check_bool(has_bias)
self.thor = True
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error")
@ -350,7 +350,7 @@ class Dense_Thor(Cell):
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")

View File

@ -35,7 +35,7 @@ class Net(nn.Cell):
self.biasAdd = P.BiasAdd()
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape[0] != output_channels:
if bias_init.ndim != 1 or bias_init.shape[0] != output_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(

View File

@ -33,7 +33,7 @@ class Net(nn.Cell):
self.biasAdd = P.BiasAdd()
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape[0] != output_channels:
if bias_init.ndim != 1 or bias_init.shape[0] != output_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(

View File

@ -65,7 +65,7 @@ def test_bias_add(test_with_simu):
self.biasAdd = P.BiasAdd()
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape[0] != output_channels:
if bias_init.ndim != 1 or bias_init.shape[0] != output_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(

View File

@ -24,7 +24,7 @@ from ..ut_filter import non_graph_engine
def _attribute(tensor, shape_, size_, dtype_):
result = (tensor.shape == shape_) and \
(tensor.size() == size_) and \
(tensor.size == size_) and \
(tensor.dtype == dtype_)
return result
@ -60,13 +60,13 @@ def test_tensor_mul():
def test_tensor_dim():
arr = np.ones((1, 6))
b = ms.Tensor(arr)
assert b.dim() == 2
assert b.ndim == 2
def test_tensor_size():
arr = np.ones((1, 6))
b = ms.Tensor(arr)
assert arr.size == b.size()
assert arr.size == b.size
def test_dtype():