forked from mindspore-Ecosystem/mindspore
!6976 bug fix for nn.Dense and P.Matmul
Merge pull request !6976 from chenzhongming/zomi_master
This commit is contained in:
commit
8f55187492
|
@ -12,7 +12,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""basic"""
|
"""basic"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore.common.seed import get_seed
|
from mindspore.common.seed import get_seed
|
||||||
|
@ -28,7 +30,6 @@ from mindspore.common.parameter import Parameter
|
||||||
from mindspore._extends import cell_attr_register
|
from mindspore._extends import cell_attr_register
|
||||||
from mindspore.common.api import ms_function
|
from mindspore.common.api import ms_function
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.ops import _selected_ops
|
|
||||||
from ..cell import Cell
|
from ..cell import Cell
|
||||||
from .activation import get_activation
|
from .activation import get_activation
|
||||||
from ..._checkparam import Validator as validator
|
from ..._checkparam import Validator as validator
|
||||||
|
@ -139,10 +140,8 @@ class Flatten(Cell):
|
||||||
the product of the remaining dimensions.
|
the product of the remaining dimensions.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> net = nn.Flatten()
|
|
||||||
>>> input = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32)
|
>>> input = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32)
|
||||||
>>> input.shape
|
>>> net = nn.Flatten()
|
||||||
(2, 2, 2)
|
|
||||||
>>> net(input)
|
>>> net(input)
|
||||||
[[1.2 1.2 2.1 2.1]
|
[[1.2 1.2 2.1 2.1]
|
||||||
[2.2 2.2 3.2 3.2]]
|
[2.2 2.2 3.2 3.2]]
|
||||||
|
@ -157,9 +156,9 @@ class Flatten(Cell):
|
||||||
|
|
||||||
class Dense(Cell):
|
class Dense(Cell):
|
||||||
r"""
|
r"""
|
||||||
The fully connected layer.
|
The dense connected layer.
|
||||||
|
|
||||||
Applies dense-connected layer for the input. This layer implements the operation as:
|
Applies dense connected layer for the input. This layer implements the operation as:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),
|
\text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),
|
||||||
|
@ -190,8 +189,8 @@ class Dense(Cell):
|
||||||
Tensor of shape :math:`(N, out\_channels)`.
|
Tensor of shape :math:`(N, out\_channels)`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> net = nn.Dense(3, 4)
|
|
||||||
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
||||||
|
>>> net = nn.Dense(3, 4)
|
||||||
>>> net(input)
|
>>> net(input)
|
||||||
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
|
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
|
||||||
[ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
|
[ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
|
||||||
|
@ -212,41 +211,36 @@ class Dense(Cell):
|
||||||
if isinstance(weight_init, Tensor):
|
if isinstance(weight_init, Tensor):
|
||||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||||
weight_init.shape[1] != in_channels:
|
weight_init.shape[1] != in_channels:
|
||||||
raise ValueError("weight_init shape error")
|
raise ValueError("Weight init shape error.")
|
||||||
|
|
||||||
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
||||||
|
|
||||||
|
self.bias = None
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
if isinstance(bias_init, Tensor):
|
if isinstance(bias_init, Tensor):
|
||||||
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
||||||
raise ValueError("bias_init shape error")
|
raise ValueError("Bias init shape error.")
|
||||||
|
|
||||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
||||||
|
self.bias_add = P.BiasAdd()
|
||||||
|
|
||||||
self.matmul = P.MatMul(transpose_b=True)
|
self.matmul = P.MatMul(transpose_b=True)
|
||||||
self.bias_add = _selected_ops.BiasAdd()
|
|
||||||
|
|
||||||
self.activation = get_activation(activation)
|
self.activation = get_activation(activation)
|
||||||
self.activation_flag = self.activation is not None
|
self.activation_flag = self.activation is not None
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
output = self.matmul(x, self.weight)
|
x = self.matmul(x, self.weight)
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
output = self.bias_add(output, self.bias)
|
x = self.bias_add(x, self.bias)
|
||||||
if self.activation_flag:
|
if self.activation_flag:
|
||||||
return self.activation(output)
|
x = self.activation(x)
|
||||||
return output
|
return x
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \
|
s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
|
||||||
.format(self.in_channels, self.out_channels, self.weight, self.has_bias)
|
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
str_info = str_info + ', bias={}'.format(self.bias)
|
s += ', has_bias={}'.format(self.has_bias)
|
||||||
|
|
||||||
if self.activation_flag:
|
if self.activation_flag:
|
||||||
str_info = str_info + ', activation={}'.format(self.activation)
|
s += ', activation={}'.fomat(self.activation)
|
||||||
|
return s
|
||||||
return str_info
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
|
|
@ -611,9 +611,9 @@ class CumProd(PrimitiveWithInfer):
|
||||||
|
|
||||||
class MatMul(PrimitiveWithInfer):
|
class MatMul(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Multiplies matrix `a` by matrix `b`.
|
Multiplies matrix `a` and matrix `b`.
|
||||||
|
|
||||||
The rank of input tensors must be `2`.
|
The rank of input tensors must equal to `2`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
transpose_a (bool): If true, `a` is transposed before multiplication. Default: False.
|
transpose_a (bool): If true, `a` is transposed before multiplication. Default: False.
|
||||||
|
@ -629,10 +629,10 @@ class MatMul(PrimitiveWithInfer):
|
||||||
Tensor, the shape of the output tensor is :math:`(N, M)`.
|
Tensor, the shape of the output tensor is :math:`(N, M)`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> input_x = Tensor(np.ones(shape=[1, 3]), mindspore.float32)
|
>>> input_x1 = Tensor(np.ones(shape=[1, 3]), mindspore.float32)
|
||||||
>>> input_y = Tensor(np.ones(shape=[3, 4]), mindspore.float32)
|
>>> input_x2 = Tensor(np.ones(shape=[3, 4]), mindspore.float32)
|
||||||
>>> matmul = P.MatMul()
|
>>> matmul = P.MatMul()
|
||||||
>>> output = matmul(input_x, input_y)
|
>>> output = matmul(input_x1, input_x2)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
|
@ -643,42 +643,44 @@ class MatMul(PrimitiveWithInfer):
|
||||||
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
|
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
|
||||||
self.add_prim_attr("io_format", "ND")
|
self.add_prim_attr("io_format", "ND")
|
||||||
|
|
||||||
def check_shape_size(self, x, y):
|
def check_shape_size(self, x1, x2):
|
||||||
if len(x) != 2 or len(y) != 2:
|
if len(x1) != 2 or len(x2) != 2:
|
||||||
raise ValueError('MatMul input x, y should be the same dimension size and should be '
|
raise ValueError('P.MatMul inputs x1, x2 should has the same dimension size and '
|
||||||
+ f'equal to 2, while x size = {len(x)}, y size= {len(y)}')
|
+ f'equal to 2, while x1 size is ({len(x1)}) and x2 size is ({len(x2)}).')
|
||||||
|
|
||||||
def infer_shape(self, x, y):
|
def infer_shape(self, x1, x2):
|
||||||
self.check_shape_size(x, y)
|
self.check_shape_size(x1, x2)
|
||||||
cls_name = self.name
|
cls_name = self.name
|
||||||
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
|
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
|
||||||
for i in range(len(x) - 2):
|
for i in range(len(x1) - 2):
|
||||||
if x[i] != y[i]:
|
if x1[i] != x2[i]:
|
||||||
raise ValueError(f'For \'{cls_name}\' shape in dim[{i}] not the same, while x is {x[i]}, y is {y[i]}')
|
raise ValueError(f'For \'{cls_name}\' shape in dim[{i}] not the same, '
|
||||||
|
+ f'while x1 is {x1[i]}, x2 is {x2[i]}')
|
||||||
|
|
||||||
# validate whether last two dims satifing matrix multiply
|
# validate whether last two dims satisfying matrix multiply
|
||||||
x_last = x[-2:]
|
x1_last = x1[-2:]
|
||||||
y_last = y[-2:]
|
x2_last = x2[-2:]
|
||||||
|
# x1_col = x1_last[1] if (not transpose_a) else x1_last[0]
|
||||||
x_col = x_last[not self.transpose_a] # x_col = x_last[1] if (not transpose_a) else x_last[0]
|
x1_col = x1_last[not self.transpose_a]
|
||||||
y_row = y_last[self.transpose_b] # y_row = y_last[0] if (not transpose_b) else y_last[1]
|
# x2_row = x2_last[0] if (not transpose_b) else x2_last[1]
|
||||||
if x_col != y_row:
|
x2_row = x2_last[self.transpose_b]
|
||||||
|
if x1_col != x2_row:
|
||||||
raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,'
|
raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,'
|
||||||
+ f' got {x_col} and {y_row}, with x shape {x}(transpose_a={self.transpose_a})'
|
+ f' got {x1_col} and {x2_row}, with x1 shape {x1}(transpose_a={self.transpose_a})'
|
||||||
+ f', y shape {y}(transpose_b={self.transpose_b}).')
|
+ f', x2 shape {x2}(transpose_b={self.transpose_b}).')
|
||||||
# set attribute
|
# set attribute
|
||||||
self.add_prim_attr('transpose_x1', self.transpose_a)
|
self.add_prim_attr('transpose_x1', self.transpose_a)
|
||||||
self.add_prim_attr('transpose_x2', self.transpose_b)
|
self.add_prim_attr('transpose_x2', self.transpose_b)
|
||||||
|
|
||||||
ret_dims = x[: -2] + [x_last[self.transpose_a], y_last[not self.transpose_b]]
|
ret_dims = x1[: -2] + [x1_last[self.transpose_a], x2_last[not self.transpose_b]]
|
||||||
return ret_dims
|
return ret_dims
|
||||||
|
|
||||||
def infer_dtype(self, x, y):
|
def infer_dtype(self, x1, x2):
|
||||||
args = {"x": x, "y": y}
|
args = {"x1": x1, "x2": x2}
|
||||||
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name)
|
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name)
|
||||||
if x.element_type() == mstype.int8:
|
if x1.element_type() == mstype.int8:
|
||||||
return mstype.tensor_type(mstype.int32)
|
return mstype.tensor_type(mstype.int32)
|
||||||
return x
|
return x1
|
||||||
|
|
||||||
|
|
||||||
class BatchMatMul(MatMul):
|
class BatchMatMul(MatMul):
|
||||||
|
|
|
@ -18,9 +18,7 @@
|
||||||
import math
|
import math
|
||||||
import operator
|
import operator
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ... import context
|
from ... import context
|
||||||
from .. import signature as sig
|
from .. import signature as sig
|
||||||
from ..._checkparam import Validator as validator
|
from ..._checkparam import Validator as validator
|
||||||
|
|
|
@ -58,7 +58,7 @@ class GNNFeatureTransform(nn.Cell):
|
||||||
Tensor, the shape of the output tensor is :math:`(*B, N, M)`.
|
Tensor, the shape of the output tensor is :math:`(*B, N, M)`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> net = nn.Dense(3, 4)
|
>>> net = nn.GNNFeatureTransform(3, 4)
|
||||||
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
||||||
>>> net(input)
|
>>> net(input)
|
||||||
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
|
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
|
||||||
|
|
Loading…
Reference in New Issue