forked from mindspore-Ecosystem/mindspore
!6976 bug fix for nn.Dense and P.Matmul
Merge pull request !6976 from chenzhongming/zomi_master
This commit is contained in:
commit
8f55187492
|
@ -12,7 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""basic"""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.seed import get_seed
|
||||
|
@ -28,7 +30,6 @@ from mindspore.common.parameter import Parameter
|
|||
from mindspore._extends import cell_attr_register
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore import context
|
||||
from mindspore.ops import _selected_ops
|
||||
from ..cell import Cell
|
||||
from .activation import get_activation
|
||||
from ..._checkparam import Validator as validator
|
||||
|
@ -139,10 +140,8 @@ class Flatten(Cell):
|
|||
the product of the remaining dimensions.
|
||||
|
||||
Examples:
|
||||
>>> net = nn.Flatten()
|
||||
>>> input = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32)
|
||||
>>> input.shape
|
||||
(2, 2, 2)
|
||||
>>> net = nn.Flatten()
|
||||
>>> net(input)
|
||||
[[1.2 1.2 2.1 2.1]
|
||||
[2.2 2.2 3.2 3.2]]
|
||||
|
@ -157,9 +156,9 @@ class Flatten(Cell):
|
|||
|
||||
class Dense(Cell):
|
||||
r"""
|
||||
The fully connected layer.
|
||||
The dense connected layer.
|
||||
|
||||
Applies dense-connected layer for the input. This layer implements the operation as:
|
||||
Applies dense connected layer for the input. This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
\text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),
|
||||
|
@ -190,8 +189,8 @@ class Dense(Cell):
|
|||
Tensor of shape :math:`(N, out\_channels)`.
|
||||
|
||||
Examples:
|
||||
>>> net = nn.Dense(3, 4)
|
||||
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
||||
>>> net = nn.Dense(3, 4)
|
||||
>>> net(input)
|
||||
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
|
||||
[ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
|
||||
|
@ -212,41 +211,36 @@ class Dense(Cell):
|
|||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
weight_init.shape[1] != in_channels:
|
||||
raise ValueError("weight_init shape error")
|
||||
|
||||
raise ValueError("Weight init shape error.")
|
||||
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
||||
|
||||
self.bias = None
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
raise ValueError("Bias init shape error.")
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
||||
self.bias_add = P.BiasAdd()
|
||||
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.bias_add = _selected_ops.BiasAdd()
|
||||
|
||||
self.activation = get_activation(activation)
|
||||
self.activation_flag = self.activation is not None
|
||||
|
||||
def construct(self, x):
|
||||
output = self.matmul(x, self.weight)
|
||||
x = self.matmul(x, self.weight)
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
x = self.bias_add(x, self.bias)
|
||||
if self.activation_flag:
|
||||
return self.activation(output)
|
||||
return output
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
def extend_repr(self):
|
||||
str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \
|
||||
.format(self.in_channels, self.out_channels, self.weight, self.has_bias)
|
||||
s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
|
||||
if self.has_bias:
|
||||
str_info = str_info + ', bias={}'.format(self.bias)
|
||||
|
||||
s += ', has_bias={}'.format(self.has_bias)
|
||||
if self.activation_flag:
|
||||
str_info = str_info + ', activation={}'.format(self.activation)
|
||||
|
||||
return str_info
|
||||
s += ', activation={}'.fomat(self.activation)
|
||||
return s
|
||||
|
||||
|
||||
@constexpr
|
||||
|
|
|
@ -611,9 +611,9 @@ class CumProd(PrimitiveWithInfer):
|
|||
|
||||
class MatMul(PrimitiveWithInfer):
|
||||
"""
|
||||
Multiplies matrix `a` by matrix `b`.
|
||||
Multiplies matrix `a` and matrix `b`.
|
||||
|
||||
The rank of input tensors must be `2`.
|
||||
The rank of input tensors must equal to `2`.
|
||||
|
||||
Args:
|
||||
transpose_a (bool): If true, `a` is transposed before multiplication. Default: False.
|
||||
|
@ -629,10 +629,10 @@ class MatMul(PrimitiveWithInfer):
|
|||
Tensor, the shape of the output tensor is :math:`(N, M)`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.ones(shape=[1, 3]), mindspore.float32)
|
||||
>>> input_y = Tensor(np.ones(shape=[3, 4]), mindspore.float32)
|
||||
>>> input_x1 = Tensor(np.ones(shape=[1, 3]), mindspore.float32)
|
||||
>>> input_x2 = Tensor(np.ones(shape=[3, 4]), mindspore.float32)
|
||||
>>> matmul = P.MatMul()
|
||||
>>> output = matmul(input_x, input_y)
|
||||
>>> output = matmul(input_x1, input_x2)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -643,42 +643,44 @@ class MatMul(PrimitiveWithInfer):
|
|||
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
|
||||
self.add_prim_attr("io_format", "ND")
|
||||
|
||||
def check_shape_size(self, x, y):
|
||||
if len(x) != 2 or len(y) != 2:
|
||||
raise ValueError('MatMul input x, y should be the same dimension size and should be '
|
||||
+ f'equal to 2, while x size = {len(x)}, y size= {len(y)}')
|
||||
def check_shape_size(self, x1, x2):
|
||||
if len(x1) != 2 or len(x2) != 2:
|
||||
raise ValueError('P.MatMul inputs x1, x2 should has the same dimension size and '
|
||||
+ f'equal to 2, while x1 size is ({len(x1)}) and x2 size is ({len(x2)}).')
|
||||
|
||||
def infer_shape(self, x, y):
|
||||
self.check_shape_size(x, y)
|
||||
def infer_shape(self, x1, x2):
|
||||
self.check_shape_size(x1, x2)
|
||||
cls_name = self.name
|
||||
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
|
||||
for i in range(len(x) - 2):
|
||||
if x[i] != y[i]:
|
||||
raise ValueError(f'For \'{cls_name}\' shape in dim[{i}] not the same, while x is {x[i]}, y is {y[i]}')
|
||||
for i in range(len(x1) - 2):
|
||||
if x1[i] != x2[i]:
|
||||
raise ValueError(f'For \'{cls_name}\' shape in dim[{i}] not the same, '
|
||||
+ f'while x1 is {x1[i]}, x2 is {x2[i]}')
|
||||
|
||||
# validate whether last two dims satifing matrix multiply
|
||||
x_last = x[-2:]
|
||||
y_last = y[-2:]
|
||||
|
||||
x_col = x_last[not self.transpose_a] # x_col = x_last[1] if (not transpose_a) else x_last[0]
|
||||
y_row = y_last[self.transpose_b] # y_row = y_last[0] if (not transpose_b) else y_last[1]
|
||||
if x_col != y_row:
|
||||
# validate whether last two dims satisfying matrix multiply
|
||||
x1_last = x1[-2:]
|
||||
x2_last = x2[-2:]
|
||||
# x1_col = x1_last[1] if (not transpose_a) else x1_last[0]
|
||||
x1_col = x1_last[not self.transpose_a]
|
||||
# x2_row = x2_last[0] if (not transpose_b) else x2_last[1]
|
||||
x2_row = x2_last[self.transpose_b]
|
||||
if x1_col != x2_row:
|
||||
raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,'
|
||||
+ f' got {x_col} and {y_row}, with x shape {x}(transpose_a={self.transpose_a})'
|
||||
+ f', y shape {y}(transpose_b={self.transpose_b}).')
|
||||
+ f' got {x1_col} and {x2_row}, with x1 shape {x1}(transpose_a={self.transpose_a})'
|
||||
+ f', x2 shape {x2}(transpose_b={self.transpose_b}).')
|
||||
# set attribute
|
||||
self.add_prim_attr('transpose_x1', self.transpose_a)
|
||||
self.add_prim_attr('transpose_x2', self.transpose_b)
|
||||
|
||||
ret_dims = x[: -2] + [x_last[self.transpose_a], y_last[not self.transpose_b]]
|
||||
ret_dims = x1[: -2] + [x1_last[self.transpose_a], x2_last[not self.transpose_b]]
|
||||
return ret_dims
|
||||
|
||||
def infer_dtype(self, x, y):
|
||||
args = {"x": x, "y": y}
|
||||
def infer_dtype(self, x1, x2):
|
||||
args = {"x1": x1, "x2": x2}
|
||||
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name)
|
||||
if x.element_type() == mstype.int8:
|
||||
if x1.element_type() == mstype.int8:
|
||||
return mstype.tensor_type(mstype.int32)
|
||||
return x
|
||||
return x1
|
||||
|
||||
|
||||
class BatchMatMul(MatMul):
|
||||
|
|
|
@ -18,9 +18,7 @@
|
|||
import math
|
||||
import operator
|
||||
from functools import reduce
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ... import context
|
||||
from .. import signature as sig
|
||||
from ..._checkparam import Validator as validator
|
||||
|
|
|
@ -58,7 +58,7 @@ class GNNFeatureTransform(nn.Cell):
|
|||
Tensor, the shape of the output tensor is :math:`(*B, N, M)`.
|
||||
|
||||
Examples:
|
||||
>>> net = nn.Dense(3, 4)
|
||||
>>> net = nn.GNNFeatureTransform(3, 4)
|
||||
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
||||
>>> net(input)
|
||||
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
|
||||
|
|
Loading…
Reference in New Issue