forked from mindspore-Ecosystem/mindspore
add tensor.ndim and rename tensor.size() to tensor.size
This commit is contained in:
parent
f3aee78bde
commit
918679daa3
|
@ -635,7 +635,7 @@ def check_input_data(*data, data_class):
|
|||
f' either a single'
|
||||
f' or a list of {data_class.__name__},'
|
||||
f' but got part data type is {str(type(item))}.')
|
||||
if hasattr(item, "size") and item.size() == 0:
|
||||
if hasattr(item, "size") and item.size == 0:
|
||||
msg = "Please provide non-empty data."
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
|
|
@ -31,6 +31,8 @@ shape_ = P.Shape()
|
|||
reshape_ = P.Reshape()
|
||||
dtype_ = P.DType()
|
||||
abs_ = P.Abs()
|
||||
ndim_ = P.Rank()
|
||||
size_ = P.Size()
|
||||
|
||||
def mean(x, axis=(), keep_dims=False):
|
||||
"""
|
||||
|
|
|
@ -192,6 +192,8 @@ BuiltInTypeMap &GetAttrMap() {
|
|||
{
|
||||
{"shape", std::string("shape_")}, // C.shape_
|
||||
{"dtype", std::string("dtype_")}, // C.dtype_
|
||||
{"size", std::string("size_")}, // C.size_
|
||||
{"ndim", std::string("ndim_")}, // C.ndim_
|
||||
}},
|
||||
{kObjectTypeRowTensorType,
|
||||
{
|
||||
|
|
|
@ -370,6 +370,17 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
|||
>>> data.shape()
|
||||
(3, 3)
|
||||
)mydelimiter")
|
||||
.def_property_readonly("_size", &Tensor::DataSize, R"mydelimiter(
|
||||
Get tensor's data size.
|
||||
|
||||
Returns:
|
||||
int, the size of tensor.
|
||||
|
||||
Examples:
|
||||
>>> data = mindspore.Tensor(np.ones((2, 3)))
|
||||
>>> data.size
|
||||
6
|
||||
)mydelimiter")
|
||||
.def("from_numpy", TensorPy::MakeTensorNoCopy, R"mydelimiter(
|
||||
Creates a Tensor from a numpy.ndarray without copy.
|
||||
|
||||
|
@ -396,17 +407,6 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
|||
array([[1., 1., 1.],
|
||||
[1., 1., 1.]])
|
||||
)mydelimiter")
|
||||
.def("size", &Tensor::DataSize, R"mydelimiter(
|
||||
Get tensor's data size.
|
||||
|
||||
Returns:
|
||||
int, the size of tensor.
|
||||
|
||||
Examples:
|
||||
>>> data = mindspore.Tensor(np.ones((2, 3)))
|
||||
>>> data.size()
|
||||
6
|
||||
)mydelimiter")
|
||||
.def("is_init", &Tensor::is_init, R"mydelimiter(
|
||||
Get tensor init_flag.
|
||||
|
||||
|
|
|
@ -237,6 +237,16 @@ class Tensor(Tensor_):
|
|||
"""The dtype of tensor is a mindspore type."""
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
"""The size reflects the total number of elements in tensor."""
|
||||
return self._size
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
"""The ndim of tensor is an integer."""
|
||||
return len(self._shape)
|
||||
|
||||
@property
|
||||
def virtual_flag(self):
|
||||
"""Mark tensor is virtual."""
|
||||
|
|
|
@ -277,7 +277,7 @@ class Dense(Cell):
|
|||
self.shape_op = P.Shape()
|
||||
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
|
||||
weight_init.shape[1] != in_channels:
|
||||
raise ValueError("Weight init shape error.")
|
||||
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
||||
|
@ -285,7 +285,7 @@ class Dense(Cell):
|
|||
self.bias = None
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
||||
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
|
||||
raise ValueError("Bias init shape error.")
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
||||
self.bias_add = P.BiasAdd()
|
||||
|
|
|
@ -318,7 +318,7 @@ class BatchNorm1d(_BatchNorm):
|
|||
input_dims='1d')
|
||||
|
||||
def _check_data_dim(self, x):
|
||||
if x.dim() != 2:
|
||||
if x.ndim != 2:
|
||||
pass
|
||||
|
||||
|
||||
|
@ -415,7 +415,7 @@ class BatchNorm2d(_BatchNorm):
|
|||
data_format=data_format)
|
||||
|
||||
def _check_data_dim(self, x):
|
||||
if x.dim() != 4:
|
||||
if x.ndim != 4:
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -1008,7 +1008,7 @@ class DenseQuant(Cell):
|
|||
self.has_bias = Validator.check_bool(has_bias)
|
||||
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
|
||||
weight_init.shape[1] != in_channels:
|
||||
raise ValueError("weight_init shape error")
|
||||
|
||||
|
@ -1017,7 +1017,7 @@ class DenseQuant(Cell):
|
|||
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
||||
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(
|
||||
|
|
|
@ -276,17 +276,17 @@ class Optimizer(Cell):
|
|||
learning_rate = float(learning_rate)
|
||||
validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name)
|
||||
return learning_rate
|
||||
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0:
|
||||
if isinstance(learning_rate, Tensor) and learning_rate.ndim == 0:
|
||||
return learning_rate
|
||||
|
||||
self.dynamic_lr = True
|
||||
if isinstance(learning_rate, Iterable):
|
||||
return Tensor(np.array(list(learning_rate)).astype(np.float32))
|
||||
if isinstance(learning_rate, Tensor):
|
||||
if learning_rate.dim() > 1:
|
||||
if learning_rate.ndim > 1:
|
||||
raise ValueError("The dim of `Tensor` type Learning rate should be a 0 or 1,"
|
||||
f"but got {learning_rate.dim()}.")
|
||||
if learning_rate.dim() == 1 and learning_rate.size() < 2:
|
||||
f"but got {learning_rate.ndim}.")
|
||||
if learning_rate.ndim == 1 and learning_rate.size < 2:
|
||||
logger.warning("If use `Tensor` type dynamic learning rate, please make sure that the number"
|
||||
"of elements in the tensor passed is greater than 1.")
|
||||
return learning_rate
|
||||
|
@ -301,12 +301,12 @@ class Optimizer(Cell):
|
|||
if self.is_group_lr and self.dynamic_lr:
|
||||
learning_rate = _ConvertToCell(learning_rate)
|
||||
return learning_rate
|
||||
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0:
|
||||
if isinstance(learning_rate, Tensor) and learning_rate.ndim == 0:
|
||||
learning_rate = Parameter(learning_rate, name)
|
||||
if self.is_group_lr and self.dynamic_lr:
|
||||
learning_rate = _ConvertToCell(learning_rate)
|
||||
return learning_rate
|
||||
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1:
|
||||
if isinstance(learning_rate, Tensor) and learning_rate.ndim == 1:
|
||||
return _IteratorLearningRate(learning_rate, name)
|
||||
return learning_rate
|
||||
|
||||
|
@ -336,8 +336,8 @@ class Optimizer(Cell):
|
|||
def _parse_group_params(self, parameters, learning_rate):
|
||||
"""Parse group params."""
|
||||
self._check_group_params(parameters)
|
||||
if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1:
|
||||
tensor_lr_length = learning_rate.size()
|
||||
if isinstance(learning_rate, Tensor) and learning_rate.ndim == 1:
|
||||
tensor_lr_length = learning_rate.size
|
||||
else:
|
||||
tensor_lr_length = 0
|
||||
|
||||
|
@ -355,8 +355,8 @@ class Optimizer(Cell):
|
|||
self.is_group_lr = True
|
||||
group_lr = self._preprocess_single_lr(group_param['lr'])
|
||||
|
||||
if isinstance(group_lr, Tensor) and group_lr.dim() == 1:
|
||||
group_lr_length = group_lr.size()
|
||||
if isinstance(group_lr, Tensor) and group_lr.ndim == 1:
|
||||
group_lr_length = group_lr.size
|
||||
if tensor_lr_length == 0:
|
||||
tensor_lr_length = group_lr_length
|
||||
elif group_lr_length != tensor_lr_length:
|
||||
|
@ -615,9 +615,9 @@ class _IteratorLearningRate(LearningRateSchedule):
|
|||
def __init__(self, learning_rate, name):
|
||||
super(_IteratorLearningRate, self).__init__()
|
||||
if isinstance(learning_rate, Tensor):
|
||||
if learning_rate.dim() != 1:
|
||||
if learning_rate.ndim != 1:
|
||||
raise ValueError("The dim of `Tensor` type dynamic learning rate should be a 1,"
|
||||
f"but got {learning_rate.dim()}.")
|
||||
f"but got {learning_rate.ndim}.")
|
||||
else:
|
||||
raise TypeError("Learning rate should be Tensor.")
|
||||
|
||||
|
|
|
@ -143,7 +143,7 @@ class ExpandDims(PrimitiveWithInfer):
|
|||
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
- **axis** (int) - Specifies the dimension index at which to expand
|
||||
the shape of `input_x`. The value of axis must be in the range
|
||||
`[-input_x.dim()-1, input_x.dim()]`. Only constant value is allowed.
|
||||
`[-input_x.ndim-1, input_x.ndim]`. Only constant value is allowed.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape of tensor is :math:`(1, x_1, x_2, ..., x_R)` if the
|
||||
|
@ -597,7 +597,7 @@ class Squeeze(PrimitiveWithInfer):
|
|||
Returns a tensor with the same type but dimensions of 1 are removed based on `axis`.
|
||||
|
||||
Note:
|
||||
The dimension index starts at 0 and must be in the range `[-input.dim(), input.dim())`.
|
||||
The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the corresponding dimension of the specified axis does not equal to 1.
|
||||
|
|
|
@ -238,7 +238,7 @@ def _load_tensor_by_layout(tensor, layout):
|
|||
group = layout[5]
|
||||
if uniform_split == 0:
|
||||
raise RuntimeError("The load tensor only support uniform split now")
|
||||
if tensor.size() == 1:
|
||||
if tensor.size == 1:
|
||||
return tensor
|
||||
tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
|
||||
if group:
|
||||
|
|
|
@ -294,7 +294,7 @@ class Dense_Thor_GPU(Cell):
|
|||
self.has_bias = Validator.check_bool(has_bias)
|
||||
self.thor = True
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
|
||||
weight_init.shape[1] != in_channels:
|
||||
raise ValueError("weight_init shape error")
|
||||
|
||||
|
@ -302,7 +302,7 @@ class Dense_Thor_GPU(Cell):
|
|||
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
||||
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]))
|
||||
|
@ -639,7 +639,7 @@ class Dense_Thor(Cell):
|
|||
self.thor = True
|
||||
self.batch_size = batch_size
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
|
||||
weight_init.shape[1] != in_channels:
|
||||
raise ValueError("weight_init shape error")
|
||||
|
||||
|
@ -647,7 +647,7 @@ class Dense_Thor(Cell):
|
|||
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
||||
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]))
|
||||
|
|
|
@ -50,16 +50,16 @@ def convert(weights_file, output_file):
|
|||
var = params[i+2]
|
||||
gamma = params[i+3]
|
||||
beta = params[i+4]
|
||||
beta_data = weights[index: index+beta.size()].reshape(beta.shape)
|
||||
index += beta.size()
|
||||
gamma_data = weights[index: index+gamma.size()].reshape(gamma.shape)
|
||||
index += gamma.size()
|
||||
mean_data = weights[index: index+mean.size()].reshape(mean.shape)
|
||||
index += mean.size()
|
||||
var_data = weights[index: index + var.size()].reshape(var.shape)
|
||||
index += var.size()
|
||||
weight_data = weights[index: index+weight.size()].reshape(weight.shape)
|
||||
index += weight.size()
|
||||
beta_data = weights[index: index+beta.size].reshape(beta.shape)
|
||||
index += beta.size
|
||||
gamma_data = weights[index: index+gamma.size].reshape(gamma.shape)
|
||||
index += gamma.size
|
||||
mean_data = weights[index: index+mean.size].reshape(mean.shape)
|
||||
index += mean.size
|
||||
var_data = weights[index: index + var.size].reshape(var.shape)
|
||||
index += var.size
|
||||
weight_data = weights[index: index+weight.size].reshape(weight.shape)
|
||||
index += weight.size
|
||||
|
||||
param_list.append({'name': weight.name, 'type': weight.dtype, 'shape': weight.shape,
|
||||
'data': Tensor(weight_data)})
|
||||
|
|
|
@ -76,7 +76,7 @@ class GNNFeatureTransform(nn.Cell):
|
|||
self.has_bias = has_bias
|
||||
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
|
||||
weight_init.shape[1] != in_channels:
|
||||
raise ValueError("weight_init shape error")
|
||||
|
||||
|
@ -84,7 +84,7 @@ class GNNFeatureTransform(nn.Cell):
|
|||
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
||||
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]))
|
||||
|
|
|
@ -162,7 +162,7 @@ class Dense_Thor(Cell):
|
|||
self.has_bias = Validator.check_bool(has_bias)
|
||||
self.thor = True
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
|
||||
if weight_init.ndim != 2 or weight_init.shape()[0] != out_channels or \
|
||||
weight_init.shape()[1] != in_channels:
|
||||
raise ValueError("weight_init shape error")
|
||||
|
||||
|
@ -170,7 +170,7 @@ class Dense_Thor(Cell):
|
|||
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
|
||||
if bias_init.ndim != 1 or bias_init.shape()[0] != out_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]))
|
||||
|
|
|
@ -43,7 +43,7 @@ class DenseLayer(nn.Cell):
|
|||
self.has_bias = validator.check_bool(has_bias)
|
||||
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
|
||||
if weight_init.ndim != 2 or weight_init.shape()[0] != out_channels or \
|
||||
weight_init.shape()[1] != in_channels:
|
||||
raise ValueError("weight_init shape error")
|
||||
|
||||
|
@ -51,7 +51,7 @@ class DenseLayer(nn.Cell):
|
|||
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
|
||||
if bias_init.ndim != 1 or bias_init.shape()[0] != out_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]))
|
||||
|
|
|
@ -78,7 +78,7 @@ class GNNFeatureTransform(nn.Cell):
|
|||
self.has_bias = Validator.check_bool(has_bias)
|
||||
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
|
||||
weight_init.shape[1] != in_channels:
|
||||
raise ValueError("weight_init shape error")
|
||||
|
||||
|
@ -86,7 +86,7 @@ class GNNFeatureTransform(nn.Cell):
|
|||
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
||||
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
||||
|
|
|
@ -342,7 +342,7 @@ class Dense_Thor(Cell):
|
|||
self.has_bias = Validator.check_bool(has_bias)
|
||||
self.thor = True
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
|
||||
weight_init.shape[1] != in_channels:
|
||||
raise ValueError("weight_init shape error")
|
||||
|
||||
|
@ -350,7 +350,7 @@ class Dense_Thor(Cell):
|
|||
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
||||
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
||||
|
|
|
@ -35,7 +35,7 @@ class Net(nn.Cell):
|
|||
self.biasAdd = P.BiasAdd()
|
||||
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape[0] != output_channels:
|
||||
if bias_init.ndim != 1 or bias_init.shape[0] != output_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(
|
||||
|
|
|
@ -33,7 +33,7 @@ class Net(nn.Cell):
|
|||
self.biasAdd = P.BiasAdd()
|
||||
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape[0] != output_channels:
|
||||
if bias_init.ndim != 1 or bias_init.shape[0] != output_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(
|
||||
|
|
|
@ -65,7 +65,7 @@ def test_bias_add(test_with_simu):
|
|||
self.biasAdd = P.BiasAdd()
|
||||
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape[0] != output_channels:
|
||||
if bias_init.ndim != 1 or bias_init.shape[0] != output_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(
|
||||
|
|
|
@ -24,7 +24,7 @@ from ..ut_filter import non_graph_engine
|
|||
|
||||
def _attribute(tensor, shape_, size_, dtype_):
|
||||
result = (tensor.shape == shape_) and \
|
||||
(tensor.size() == size_) and \
|
||||
(tensor.size == size_) and \
|
||||
(tensor.dtype == dtype_)
|
||||
return result
|
||||
|
||||
|
@ -60,13 +60,13 @@ def test_tensor_mul():
|
|||
def test_tensor_dim():
|
||||
arr = np.ones((1, 6))
|
||||
b = ms.Tensor(arr)
|
||||
assert b.dim() == 2
|
||||
assert b.ndim == 2
|
||||
|
||||
|
||||
def test_tensor_size():
|
||||
arr = np.ones((1, 6))
|
||||
b = ms.Tensor(arr)
|
||||
assert arr.size == b.size()
|
||||
assert arr.size == b.size
|
||||
|
||||
|
||||
def test_dtype():
|
||||
|
|
Loading…
Reference in New Issue