diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index bb8fc2e02cf..cbff5f00910 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ + """basic""" + import numpy as np import mindspore.common.dtype as mstype from mindspore.common.seed import get_seed @@ -28,7 +30,6 @@ from mindspore.common.parameter import Parameter from mindspore._extends import cell_attr_register from mindspore.common.api import ms_function from mindspore import context -from mindspore.ops import _selected_ops from ..cell import Cell from .activation import get_activation from ..._checkparam import Validator as validator @@ -139,10 +140,8 @@ class Flatten(Cell): the product of the remaining dimensions. Examples: - >>> net = nn.Flatten() >>> input = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32) - >>> input.shape - (2, 2, 2) + >>> net = nn.Flatten() >>> net(input) [[1.2 1.2 2.1 2.1] [2.2 2.2 3.2 3.2]] @@ -157,9 +156,9 @@ class Flatten(Cell): class Dense(Cell): r""" - The fully connected layer. + The dense connected layer. - Applies dense-connected layer for the input. This layer implements the operation as: + Applies dense connected layer for the input. This layer implements the operation as: .. math:: \text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}), @@ -190,8 +189,8 @@ class Dense(Cell): Tensor of shape :math:`(N, out\_channels)`. Examples: - >>> net = nn.Dense(3, 4) >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) + >>> net = nn.Dense(3, 4) >>> net(input) [[ 2.5246444 2.2738023 0.5711005 -3.9399147 ] [ 1.0739875 4.0155234 0.94188046 -5.459526 ]] @@ -212,41 +211,36 @@ class Dense(Cell): if isinstance(weight_init, Tensor): if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ weight_init.shape[1] != in_channels: - raise ValueError("weight_init shape error") - + raise ValueError("Weight init shape error.") self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") + self.bias = None if self.has_bias: if isinstance(bias_init, Tensor): if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: - raise ValueError("bias_init shape error") - + raise ValueError("Bias init shape error.") self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") + self.bias_add = P.BiasAdd() self.matmul = P.MatMul(transpose_b=True) - self.bias_add = _selected_ops.BiasAdd() - self.activation = get_activation(activation) self.activation_flag = self.activation is not None def construct(self, x): - output = self.matmul(x, self.weight) + x = self.matmul(x, self.weight) if self.has_bias: - output = self.bias_add(output, self.bias) + x = self.bias_add(x, self.bias) if self.activation_flag: - return self.activation(output) - return output + x = self.activation(x) + return x def extend_repr(self): - str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \ - .format(self.in_channels, self.out_channels, self.weight, self.has_bias) + s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels) if self.has_bias: - str_info = str_info + ', bias={}'.format(self.bias) - + s += ', has_bias={}'.format(self.has_bias) if self.activation_flag: - str_info = str_info + ', activation={}'.format(self.activation) - - return str_info + s += ', activation={}'.fomat(self.activation) + return s @constexpr diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 4f7664003c3..27b0541e779 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -611,9 +611,9 @@ class CumProd(PrimitiveWithInfer): class MatMul(PrimitiveWithInfer): """ - Multiplies matrix `a` by matrix `b`. + Multiplies matrix `a` and matrix `b`. - The rank of input tensors must be `2`. + The rank of input tensors must equal to `2`. Args: transpose_a (bool): If true, `a` is transposed before multiplication. Default: False. @@ -629,10 +629,10 @@ class MatMul(PrimitiveWithInfer): Tensor, the shape of the output tensor is :math:`(N, M)`. Examples: - >>> input_x = Tensor(np.ones(shape=[1, 3]), mindspore.float32) - >>> input_y = Tensor(np.ones(shape=[3, 4]), mindspore.float32) + >>> input_x1 = Tensor(np.ones(shape=[1, 3]), mindspore.float32) + >>> input_x2 = Tensor(np.ones(shape=[3, 4]), mindspore.float32) >>> matmul = P.MatMul() - >>> output = matmul(input_x, input_y) + >>> output = matmul(input_x1, input_x2) """ @prim_attr_register @@ -643,42 +643,44 @@ class MatMul(PrimitiveWithInfer): validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) self.add_prim_attr("io_format", "ND") - def check_shape_size(self, x, y): - if len(x) != 2 or len(y) != 2: - raise ValueError('MatMul input x, y should be the same dimension size and should be ' - + f'equal to 2, while x size = {len(x)}, y size= {len(y)}') + def check_shape_size(self, x1, x2): + if len(x1) != 2 or len(x2) != 2: + raise ValueError('P.MatMul inputs x1, x2 should has the same dimension size and ' + + f'equal to 2, while x1 size is ({len(x1)}) and x2 size is ({len(x2)}).') - def infer_shape(self, x, y): - self.check_shape_size(x, y) + def infer_shape(self, x1, x2): + self.check_shape_size(x1, x2) cls_name = self.name # expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two - for i in range(len(x) - 2): - if x[i] != y[i]: - raise ValueError(f'For \'{cls_name}\' shape in dim[{i}] not the same, while x is {x[i]}, y is {y[i]}') + for i in range(len(x1) - 2): + if x1[i] != x2[i]: + raise ValueError(f'For \'{cls_name}\' shape in dim[{i}] not the same, ' + + f'while x1 is {x1[i]}, x2 is {x2[i]}') - # validate whether last two dims satifing matrix multiply - x_last = x[-2:] - y_last = y[-2:] - - x_col = x_last[not self.transpose_a] # x_col = x_last[1] if (not transpose_a) else x_last[0] - y_row = y_last[self.transpose_b] # y_row = y_last[0] if (not transpose_b) else y_last[1] - if x_col != y_row: + # validate whether last two dims satisfying matrix multiply + x1_last = x1[-2:] + x2_last = x2[-2:] + # x1_col = x1_last[1] if (not transpose_a) else x1_last[0] + x1_col = x1_last[not self.transpose_a] + # x2_row = x2_last[0] if (not transpose_b) else x2_last[1] + x2_row = x2_last[self.transpose_b] + if x1_col != x2_row: raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,' - + f' got {x_col} and {y_row}, with x shape {x}(transpose_a={self.transpose_a})' - + f', y shape {y}(transpose_b={self.transpose_b}).') + + f' got {x1_col} and {x2_row}, with x1 shape {x1}(transpose_a={self.transpose_a})' + + f', x2 shape {x2}(transpose_b={self.transpose_b}).') # set attribute self.add_prim_attr('transpose_x1', self.transpose_a) self.add_prim_attr('transpose_x2', self.transpose_b) - ret_dims = x[: -2] + [x_last[self.transpose_a], y_last[not self.transpose_b]] + ret_dims = x1[: -2] + [x1_last[self.transpose_a], x2_last[not self.transpose_b]] return ret_dims - def infer_dtype(self, x, y): - args = {"x": x, "y": y} + def infer_dtype(self, x1, x2): + args = {"x1": x1, "x2": x2} validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name) - if x.element_type() == mstype.int8: + if x1.element_type() == mstype.int8: return mstype.tensor_type(mstype.int32) - return x + return x1 class BatchMatMul(MatMul): diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index e85f2797522..bfbbe6349ac 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -18,9 +18,7 @@ import math import operator from functools import reduce - import numpy as np - from ... import context from .. import signature as sig from ..._checkparam import Validator as validator diff --git a/tests/st/gnn/aggregator.py b/tests/st/gnn/aggregator.py index 373df5f9618..584bdc27d90 100644 --- a/tests/st/gnn/aggregator.py +++ b/tests/st/gnn/aggregator.py @@ -58,7 +58,7 @@ class GNNFeatureTransform(nn.Cell): Tensor, the shape of the output tensor is :math:`(*B, N, M)`. Examples: - >>> net = nn.Dense(3, 4) + >>> net = nn.GNNFeatureTransform(3, 4) >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) >>> net(input) [[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]