diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 5cb4209698a..1622b8ac466 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -635,7 +635,7 @@ def check_input_data(*data, data_class): f' either a single' f' or a list of {data_class.__name__},' f' but got part data type is {str(type(item))}.') - if hasattr(item, "size") and item.size() == 0: + if hasattr(item, "size") and item.size == 0: msg = "Please provide non-empty data." logger.error(msg) raise ValueError(msg) diff --git a/mindspore/_extends/parse/standard_method.py b/mindspore/_extends/parse/standard_method.py index 8fb3d7f2f1a..8ab5da65458 100644 --- a/mindspore/_extends/parse/standard_method.py +++ b/mindspore/_extends/parse/standard_method.py @@ -31,6 +31,8 @@ shape_ = P.Shape() reshape_ = P.Reshape() dtype_ = P.DType() abs_ = P.Abs() +ndim_ = P.Rank() +size_ = P.Size() def mean(x, axis=(), keep_dims=False): """ diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index c6ddb73fb6e..ff3168a3b83 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -192,6 +192,8 @@ BuiltInTypeMap &GetAttrMap() { { {"shape", std::string("shape_")}, // C.shape_ {"dtype", std::string("dtype_")}, // C.dtype_ + {"size", std::string("size_")}, // C.size_ + {"ndim", std::string("ndim_")}, // C.ndim_ }}, {kObjectTypeRowTensorType, { diff --git a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc index 034cc99858e..1f51e6b51c3 100644 --- a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc @@ -370,6 +370,17 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { >>> data.shape() (3, 3) )mydelimiter") + .def_property_readonly("_size", &Tensor::DataSize, R"mydelimiter( + Get tensor's data size. + + Returns: + int, the size of tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 3))) + >>> data.size + 6 + )mydelimiter") .def("from_numpy", TensorPy::MakeTensorNoCopy, R"mydelimiter( Creates a Tensor from a numpy.ndarray without copy. @@ -396,17 +407,6 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { array([[1., 1., 1.], [1., 1., 1.]]) )mydelimiter") - .def("size", &Tensor::DataSize, R"mydelimiter( - Get tensor's data size. - - Returns: - int, the size of tensor. - - Examples: - >>> data = mindspore.Tensor(np.ones((2, 3))) - >>> data.size() - 6 - )mydelimiter") .def("is_init", &Tensor::is_init, R"mydelimiter( Get tensor init_flag. diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 0fe5aea7485..c21b614bea6 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -237,6 +237,16 @@ class Tensor(Tensor_): """The dtype of tensor is a mindspore type.""" return self._dtype + @property + def size(self): + """The size reflects the total number of elements in tensor.""" + return self._size + + @property + def ndim(self): + """The ndim of tensor is an integer.""" + return len(self._shape) + @property def virtual_flag(self): """Mark tensor is virtual.""" diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index c412c6c4b3d..31acf3bc7ca 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -277,7 +277,7 @@ class Dense(Cell): self.shape_op = P.Shape() if isinstance(weight_init, Tensor): - if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ + if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \ weight_init.shape[1] != in_channels: raise ValueError("Weight init shape error.") self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") @@ -285,7 +285,7 @@ class Dense(Cell): self.bias = None if self.has_bias: if isinstance(bias_init, Tensor): - if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: + if bias_init.ndim != 1 or bias_init.shape[0] != out_channels: raise ValueError("Bias init shape error.") self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") self.bias_add = P.BiasAdd() diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 6cf489754a8..bcecbdc30c2 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -318,7 +318,7 @@ class BatchNorm1d(_BatchNorm): input_dims='1d') def _check_data_dim(self, x): - if x.dim() != 2: + if x.ndim != 2: pass @@ -415,7 +415,7 @@ class BatchNorm2d(_BatchNorm): data_format=data_format) def _check_data_dim(self, x): - if x.dim() != 4: + if x.ndim != 4: pass diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 5843a53cd6c..608d6a24e60 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -1013,7 +1013,7 @@ class DenseQuant(Cell): self.has_bias = Validator.check_bool(has_bias) if isinstance(weight_init, Tensor): - if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ + if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \ weight_init.shape[1] != in_channels: raise ValueError("weight_init shape error") @@ -1022,7 +1022,7 @@ class DenseQuant(Cell): if self.has_bias: if isinstance(bias_init, Tensor): - if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: + if bias_init.ndim != 1 or bias_init.shape[0] != out_channels: raise ValueError("bias_init shape error") self.bias = Parameter(initializer( diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 02c2962faa7..60aa4ad31c1 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -278,17 +278,17 @@ class Optimizer(Cell): learning_rate = float(learning_rate) validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name) return learning_rate - if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0: + if isinstance(learning_rate, Tensor) and learning_rate.ndim == 0: return learning_rate self.dynamic_lr = True if isinstance(learning_rate, Iterable): return Tensor(np.array(list(learning_rate)).astype(np.float32)) if isinstance(learning_rate, Tensor): - if learning_rate.dim() > 1: + if learning_rate.ndim > 1: raise ValueError("The dim of `Tensor` type Learning rate should be a 0 or 1," - f"but got {learning_rate.dim()}.") - if learning_rate.dim() == 1 and learning_rate.size() < 2: + f"but got {learning_rate.ndim}.") + if learning_rate.ndim == 1 and learning_rate.size < 2: logger.warning("If use `Tensor` type dynamic learning rate, please make sure that the number" "of elements in the tensor passed is greater than 1.") return learning_rate @@ -303,12 +303,12 @@ class Optimizer(Cell): if self.is_group_lr and self.dynamic_lr: learning_rate = _ConvertToCell(learning_rate) return learning_rate - if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0: + if isinstance(learning_rate, Tensor) and learning_rate.ndim == 0: learning_rate = Parameter(learning_rate, name) if self.is_group_lr and self.dynamic_lr: learning_rate = _ConvertToCell(learning_rate) return learning_rate - if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1: + if isinstance(learning_rate, Tensor) and learning_rate.ndim == 1: return _IteratorLearningRate(learning_rate, name) return learning_rate @@ -338,8 +338,8 @@ class Optimizer(Cell): def _parse_group_params(self, parameters, learning_rate): """Parse group params.""" self._check_group_params(parameters) - if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1: - tensor_lr_length = learning_rate.size() + if isinstance(learning_rate, Tensor) and learning_rate.ndim == 1: + tensor_lr_length = learning_rate.size else: tensor_lr_length = 0 @@ -357,8 +357,8 @@ class Optimizer(Cell): self.is_group_lr = True group_lr = self._preprocess_single_lr(group_param['lr']) - if isinstance(group_lr, Tensor) and group_lr.dim() == 1: - group_lr_length = group_lr.size() + if isinstance(group_lr, Tensor) and group_lr.ndim == 1: + group_lr_length = group_lr.size if tensor_lr_length == 0: tensor_lr_length = group_lr_length elif group_lr_length != tensor_lr_length: @@ -617,9 +617,9 @@ class _IteratorLearningRate(LearningRateSchedule): def __init__(self, learning_rate, name): super(_IteratorLearningRate, self).__init__() if isinstance(learning_rate, Tensor): - if learning_rate.dim() != 1: + if learning_rate.ndim != 1: raise ValueError("The dim of `Tensor` type dynamic learning rate should be a 1," - f"but got {learning_rate.dim()}.") + f"but got {learning_rate.ndim}.") else: raise TypeError("Learning rate should be Tensor.") diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 83adb3e2b95..e6c45ab9e12 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -143,7 +143,7 @@ class ExpandDims(PrimitiveWithInfer): - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. - **axis** (int) - Specifies the dimension index at which to expand the shape of `input_x`. The value of axis must be in the range - `[-input_x.dim()-1, input_x.dim()]`. Only constant value is allowed. + `[-input_x.ndim-1, input_x.ndim]`. Only constant value is allowed. Outputs: Tensor, the shape of tensor is :math:`(1, x_1, x_2, ..., x_R)` if the @@ -597,7 +597,7 @@ class Squeeze(PrimitiveWithInfer): Returns a tensor with the same type but dimensions of 1 are removed based on `axis`. Note: - The dimension index starts at 0 and must be in the range `[-input.dim(), input.dim())`. + The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim`. Raises: ValueError: If the corresponding dimension of the specified axis does not equal to 1. diff --git a/mindspore/parallel/_tensor.py b/mindspore/parallel/_tensor.py index 3179ab392f6..0782e5211f6 100644 --- a/mindspore/parallel/_tensor.py +++ b/mindspore/parallel/_tensor.py @@ -238,7 +238,7 @@ def _load_tensor_by_layout(tensor, layout): group = layout[5] if uniform_split == 0: raise RuntimeError("The load tensor only support uniform split now") - if tensor.size() == 1: + if tensor.size == 1: return tensor tensor_slice = _load_tensor(tensor, dev_mat, tensor_map) if group: diff --git a/model_zoo/official/cv/resnet_thor/src/thor_layer.py b/model_zoo/official/cv/resnet_thor/src/thor_layer.py index 27e90bfe0e0..3858d5a6b5f 100644 --- a/model_zoo/official/cv/resnet_thor/src/thor_layer.py +++ b/model_zoo/official/cv/resnet_thor/src/thor_layer.py @@ -294,7 +294,7 @@ class Dense_Thor_GPU(Cell): self.has_bias = Validator.check_bool(has_bias) self.thor = True if isinstance(weight_init, Tensor): - if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ + if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \ weight_init.shape[1] != in_channels: raise ValueError("weight_init shape error") @@ -302,7 +302,7 @@ class Dense_Thor_GPU(Cell): if self.has_bias: if isinstance(bias_init, Tensor): - if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: + if bias_init.ndim != 1 or bias_init.shape[0] != out_channels: raise ValueError("bias_init shape error") self.bias = Parameter(initializer(bias_init, [out_channels])) @@ -639,7 +639,7 @@ class Dense_Thor(Cell): self.thor = True self.batch_size = batch_size if isinstance(weight_init, Tensor): - if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ + if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \ weight_init.shape[1] != in_channels: raise ValueError("weight_init shape error") @@ -647,7 +647,7 @@ class Dense_Thor(Cell): if self.has_bias: if isinstance(bias_init, Tensor): - if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: + if bias_init.ndim != 1 or bias_init.shape[0] != out_channels: raise ValueError("bias_init shape error") self.bias = Parameter(initializer(bias_init, [out_channels])) diff --git a/model_zoo/official/cv/yolov3_darknet53/src/convert_weight.py b/model_zoo/official/cv/yolov3_darknet53/src/convert_weight.py index e5d10e313ba..e63a5bc32e3 100644 --- a/model_zoo/official/cv/yolov3_darknet53/src/convert_weight.py +++ b/model_zoo/official/cv/yolov3_darknet53/src/convert_weight.py @@ -50,16 +50,16 @@ def convert(weights_file, output_file): var = params[i+2] gamma = params[i+3] beta = params[i+4] - beta_data = weights[index: index+beta.size()].reshape(beta.shape) - index += beta.size() - gamma_data = weights[index: index+gamma.size()].reshape(gamma.shape) - index += gamma.size() - mean_data = weights[index: index+mean.size()].reshape(mean.shape) - index += mean.size() - var_data = weights[index: index + var.size()].reshape(var.shape) - index += var.size() - weight_data = weights[index: index+weight.size()].reshape(weight.shape) - index += weight.size() + beta_data = weights[index: index+beta.size].reshape(beta.shape) + index += beta.size + gamma_data = weights[index: index+gamma.size].reshape(gamma.shape) + index += gamma.size + mean_data = weights[index: index+mean.size].reshape(mean.shape) + index += mean.size + var_data = weights[index: index + var.size].reshape(var.shape) + index += var.size + weight_data = weights[index: index+weight.size].reshape(weight.shape) + index += weight.size param_list.append({'name': weight.name, 'type': weight.dtype, 'shape': weight.shape, 'data': Tensor(weight_data)}) diff --git a/model_zoo/official/gnn/gat/src/gat.py b/model_zoo/official/gnn/gat/src/gat.py index 1bb1402576a..e6c5faeaf11 100644 --- a/model_zoo/official/gnn/gat/src/gat.py +++ b/model_zoo/official/gnn/gat/src/gat.py @@ -76,7 +76,7 @@ class GNNFeatureTransform(nn.Cell): self.has_bias = has_bias if isinstance(weight_init, Tensor): - if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ + if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \ weight_init.shape[1] != in_channels: raise ValueError("weight_init shape error") @@ -84,7 +84,7 @@ class GNNFeatureTransform(nn.Cell): if self.has_bias: if isinstance(bias_init, Tensor): - if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: + if bias_init.ndim != 1 or bias_init.shape[0] != out_channels: raise ValueError("bias_init shape error") self.bias = Parameter(initializer(bias_init, [out_channels])) diff --git a/model_zoo/official/nlp/bert_thor/src/thor_layer.py b/model_zoo/official/nlp/bert_thor/src/thor_layer.py index 950e038f28a..3fee08a91af 100644 --- a/model_zoo/official/nlp/bert_thor/src/thor_layer.py +++ b/model_zoo/official/nlp/bert_thor/src/thor_layer.py @@ -162,7 +162,7 @@ class Dense_Thor(Cell): self.has_bias = Validator.check_bool(has_bias) self.thor = True if isinstance(weight_init, Tensor): - if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ + if weight_init.ndim != 2 or weight_init.shape()[0] != out_channels or \ weight_init.shape()[1] != in_channels: raise ValueError("weight_init shape error") @@ -170,7 +170,7 @@ class Dense_Thor(Cell): if self.has_bias: if isinstance(bias_init, Tensor): - if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: + if bias_init.ndim != 1 or bias_init.shape()[0] != out_channels: raise ValueError("bias_init shape error") self.bias = Parameter(initializer(bias_init, [out_channels])) diff --git a/model_zoo/official/recommend/ncf/src/ncf.py b/model_zoo/official/recommend/ncf/src/ncf.py index bae0607fbfb..6a1d5a3961e 100644 --- a/model_zoo/official/recommend/ncf/src/ncf.py +++ b/model_zoo/official/recommend/ncf/src/ncf.py @@ -45,7 +45,7 @@ class DenseLayer(nn.Cell): self.has_bias = validator.check_bool(has_bias) if isinstance(weight_init, Tensor): - if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ + if weight_init.ndim != 2 or weight_init.shape()[0] != out_channels or \ weight_init.shape()[1] != in_channels: raise ValueError("weight_init shape error") @@ -53,7 +53,7 @@ class DenseLayer(nn.Cell): if self.has_bias: if isinstance(bias_init, Tensor): - if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: + if bias_init.ndim != 1 or bias_init.shape()[0] != out_channels: raise ValueError("bias_init shape error") self.bias = Parameter(initializer(bias_init, [out_channels])) diff --git a/tests/st/gnn/aggregator.py b/tests/st/gnn/aggregator.py index 829901cdbcb..9f81c616fa2 100644 --- a/tests/st/gnn/aggregator.py +++ b/tests/st/gnn/aggregator.py @@ -78,7 +78,7 @@ class GNNFeatureTransform(nn.Cell): self.has_bias = Validator.check_bool(has_bias) if isinstance(weight_init, Tensor): - if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ + if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \ weight_init.shape[1] != in_channels: raise ValueError("weight_init shape error") @@ -86,7 +86,7 @@ class GNNFeatureTransform(nn.Cell): if self.has_bias: if isinstance(bias_init, Tensor): - if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: + if bias_init.ndim != 1 or bias_init.shape[0] != out_channels: raise ValueError("bias_init shape error") self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") diff --git a/tests/st/networks/models/resnet50/src_thor/thor_layer.py b/tests/st/networks/models/resnet50/src_thor/thor_layer.py index 5442927ea47..b644a834ff7 100644 --- a/tests/st/networks/models/resnet50/src_thor/thor_layer.py +++ b/tests/st/networks/models/resnet50/src_thor/thor_layer.py @@ -342,7 +342,7 @@ class Dense_Thor(Cell): self.has_bias = Validator.check_bool(has_bias) self.thor = True if isinstance(weight_init, Tensor): - if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ + if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \ weight_init.shape[1] != in_channels: raise ValueError("weight_init shape error") @@ -350,7 +350,7 @@ class Dense_Thor(Cell): if self.has_bias: if isinstance(bias_init, Tensor): - if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: + if bias_init.ndim != 1 or bias_init.shape[0] != out_channels: raise ValueError("bias_init shape error") self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") diff --git a/tests/st/ops/ascend/test_tbe_ops/test_bias_add.py b/tests/st/ops/ascend/test_tbe_ops/test_bias_add.py index 9455b7902b1..87dd5c93aee 100644 --- a/tests/st/ops/ascend/test_tbe_ops/test_bias_add.py +++ b/tests/st/ops/ascend/test_tbe_ops/test_bias_add.py @@ -35,7 +35,7 @@ class Net(nn.Cell): self.biasAdd = P.BiasAdd() if isinstance(bias_init, Tensor): - if bias_init.dim() != 1 or bias_init.shape[0] != output_channels: + if bias_init.ndim != 1 or bias_init.shape[0] != output_channels: raise ValueError("bias_init shape error") self.bias = Parameter(initializer( diff --git a/tests/ut/python/exec/test_bias_add.py b/tests/ut/python/exec/test_bias_add.py index 2349c61a9b9..c0f9d31491d 100644 --- a/tests/ut/python/exec/test_bias_add.py +++ b/tests/ut/python/exec/test_bias_add.py @@ -33,7 +33,7 @@ class Net(nn.Cell): self.biasAdd = P.BiasAdd() if isinstance(bias_init, Tensor): - if bias_init.dim() != 1 or bias_init.shape[0] != output_channels: + if bias_init.ndim != 1 or bias_init.shape[0] != output_channels: raise ValueError("bias_init shape error") self.bias = Parameter(initializer( diff --git a/tests/ut/python/exec/test_train.py b/tests/ut/python/exec/test_train.py index 618ad3c0341..06bcea061bd 100644 --- a/tests/ut/python/exec/test_train.py +++ b/tests/ut/python/exec/test_train.py @@ -65,7 +65,7 @@ def test_bias_add(test_with_simu): self.biasAdd = P.BiasAdd() if isinstance(bias_init, Tensor): - if bias_init.dim() != 1 or bias_init.shape[0] != output_channels: + if bias_init.ndim != 1 or bias_init.shape[0] != output_channels: raise ValueError("bias_init shape error") self.bias = Parameter(initializer( diff --git a/tests/ut/python/ir/test_tensor_py.py b/tests/ut/python/ir/test_tensor_py.py index 2aa08aa4efe..138c52ce872 100644 --- a/tests/ut/python/ir/test_tensor_py.py +++ b/tests/ut/python/ir/test_tensor_py.py @@ -24,7 +24,7 @@ from ..ut_filter import non_graph_engine def _attribute(tensor, shape_, size_, dtype_): result = (tensor.shape == shape_) and \ - (tensor.size() == size_) and \ + (tensor.size == size_) and \ (tensor.dtype == dtype_) return result @@ -60,13 +60,13 @@ def test_tensor_mul(): def test_tensor_dim(): arr = np.ones((1, 6)) b = ms.Tensor(arr) - assert b.dim() == 2 + assert b.ndim == 2 def test_tensor_size(): arr = np.ones((1, 6)) b = ms.Tensor(arr) - assert arr.size == b.size() + assert arr.size == b.size def test_dtype():