bug fix fot nn.Dense and P.Matmul

This commit is contained in:
chenzomi 2020-09-28 16:19:37 +08:00
parent 960f5c5250
commit b22cb38dab
4 changed files with 49 additions and 55 deletions

View File

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""basic"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.common.seed import get_seed
@ -28,7 +30,6 @@ from mindspore.common.parameter import Parameter
from mindspore._extends import cell_attr_register
from mindspore.common.api import ms_function
from mindspore import context
from mindspore.ops import _selected_ops
from ..cell import Cell
from .activation import get_activation
from ..._checkparam import Validator as validator
@ -139,10 +140,8 @@ class Flatten(Cell):
the product of the remaining dimensions.
Examples:
>>> net = nn.Flatten()
>>> input = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32)
>>> input.shape
(2, 2, 2)
>>> net = nn.Flatten()
>>> net(input)
[[1.2 1.2 2.1 2.1]
[2.2 2.2 3.2 3.2]]
@ -157,9 +156,9 @@ class Flatten(Cell):
class Dense(Cell):
r"""
The fully connected layer.
The dense connected layer.
Applies dense-connected layer for the input. This layer implements the operation as:
Applies dense connected layer for the input. This layer implements the operation as:
.. math::
\text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),
@ -190,8 +189,8 @@ class Dense(Cell):
Tensor of shape :math:`(N, out\_channels)`.
Examples:
>>> net = nn.Dense(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net = nn.Dense(3, 4)
>>> net(input)
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
[ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
@ -212,41 +211,36 @@ class Dense(Cell):
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error")
raise ValueError("Weight init shape error.")
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
self.bias = None
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error")
raise ValueError("Bias init shape error.")
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
self.bias_add = P.BiasAdd()
self.matmul = P.MatMul(transpose_b=True)
self.bias_add = _selected_ops.BiasAdd()
self.activation = get_activation(activation)
self.activation_flag = self.activation is not None
def construct(self, x):
output = self.matmul(x, self.weight)
x = self.matmul(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
x = self.bias_add(x, self.bias)
if self.activation_flag:
return self.activation(output)
return output
x = self.activation(x)
return x
def extend_repr(self):
str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \
.format(self.in_channels, self.out_channels, self.weight, self.has_bias)
s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
if self.has_bias:
str_info = str_info + ', bias={}'.format(self.bias)
s += ', has_bias={}'.format(self.has_bias)
if self.activation_flag:
str_info = str_info + ', activation={}'.format(self.activation)
return str_info
s += ', activation={}'.fomat(self.activation)
return s
@constexpr

View File

@ -611,9 +611,9 @@ class CumProd(PrimitiveWithInfer):
class MatMul(PrimitiveWithInfer):
"""
Multiplies matrix `a` by matrix `b`.
Multiplies matrix `a` and matrix `b`.
The rank of input tensors must be `2`.
The rank of input tensors must equal to `2`.
Args:
transpose_a (bool): If True, `a` is transposed before multiplication. Default: False.
@ -629,10 +629,10 @@ class MatMul(PrimitiveWithInfer):
Tensor, the shape of the output tensor is :math:`(N, M)`.
Examples:
>>> input_x = Tensor(np.ones(shape=[1, 3]), mindspore.float32)
>>> input_y = Tensor(np.ones(shape=[3, 4]), mindspore.float32)
>>> input_x1 = Tensor(np.ones(shape=[1, 3]), mindspore.float32)
>>> input_x2 = Tensor(np.ones(shape=[3, 4]), mindspore.float32)
>>> matmul = P.MatMul()
>>> output = matmul(input_x, input_y)
>>> output = matmul(input_x1, input_x2)
"""
@prim_attr_register
@ -643,42 +643,44 @@ class MatMul(PrimitiveWithInfer):
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
self.add_prim_attr("io_format", "ND")
def check_shape_size(self, x, y):
if len(x) != 2 or len(y) != 2:
raise ValueError('MatMul input x, y should be the same dimension size and should be '
+ f'equal to 2, while x size = {len(x)}, y size= {len(y)}')
def check_shape_size(self, x1, x2):
if len(x1) != 2 or len(x2) != 2:
raise ValueError('P.MatMul inputs x1, x2 should has the same dimension size and '
+ f'equal to 2, while x1 size is ({len(x1)}) and x2 size is ({len(x2)}).')
def infer_shape(self, x, y):
self.check_shape_size(x, y)
def infer_shape(self, x1, x2):
self.check_shape_size(x1, x2)
cls_name = self.name
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
for i in range(len(x) - 2):
if x[i] != y[i]:
raise ValueError(f'For \'{cls_name}\' shape in dim[{i}] not the same, while x is {x[i]}, y is {y[i]}')
for i in range(len(x1) - 2):
if x1[i] != x2[i]:
raise ValueError(f'For \'{cls_name}\' shape in dim[{i}] not the same, '
+ f'while x1 is {x1[i]}, x2 is {x2[i]}')
# validate whether last two dims satifing matrix multiply
x_last = x[-2:]
y_last = y[-2:]
x_col = x_last[not self.transpose_a] # x_col = x_last[1] if (not transpose_a) else x_last[0]
y_row = y_last[self.transpose_b] # y_row = y_last[0] if (not transpose_b) else y_last[1]
if x_col != y_row:
# validate whether last two dims satisfying matrix multiply
x1_last = x1[-2:]
x2_last = x2[-2:]
# x1_col = x1_last[1] if (not transpose_a) else x1_last[0]
x1_col = x1_last[not self.transpose_a]
# x2_row = x2_last[0] if (not transpose_b) else x2_last[1]
x2_row = x2_last[self.transpose_b]
if x1_col != x2_row:
raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,'
+ f' got {x_col} and {y_row}, with x shape {x}(transpose_a={self.transpose_a})'
+ f', y shape {y}(transpose_b={self.transpose_b}).')
+ f' got {x1_col} and {x2_row}, with x1 shape {x1}(transpose_a={self.transpose_a})'
+ f', x2 shape {x2}(transpose_b={self.transpose_b}).')
# set attribute
self.add_prim_attr('transpose_x1', self.transpose_a)
self.add_prim_attr('transpose_x2', self.transpose_b)
ret_dims = x[: -2] + [x_last[self.transpose_a], y_last[not self.transpose_b]]
ret_dims = x1[: -2] + [x1_last[self.transpose_a], x2_last[not self.transpose_b]]
return ret_dims
def infer_dtype(self, x, y):
args = {"x": x, "y": y}
def infer_dtype(self, x1, x2):
args = {"x1": x1, "x2": x2}
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name)
if x.element_type() == mstype.int8:
if x1.element_type() == mstype.int8:
return mstype.tensor_type(mstype.int32)
return x
return x1
class BatchMatMul(MatMul):

View File

@ -18,9 +18,7 @@
import math
import operator
from functools import reduce
import numpy as np
from ... import context
from .. import signature as sig
from ..._checkparam import Validator as validator

View File

@ -58,7 +58,7 @@ class GNNFeatureTransform(nn.Cell):
Tensor, the shape of the output tensor is :math:`(*B, N, M)`.
Examples:
>>> net = nn.Dense(3, 4)
>>> net = nn.GNNFeatureTransform(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net(input)
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]