forked from mindspore-Ecosystem/mindspore
reshape_dense
This commit is contained in:
parent
f679fcf075
commit
61793da1d9
|
@ -195,15 +195,6 @@ class Flatten(Cell):
|
|||
def construct(self, x):
|
||||
return F.reshape(x, (F.shape(x)[0], -1))
|
||||
|
||||
|
||||
@constexpr
|
||||
def get_broadcast_weight_bias_shape(x_shape, out_channel, in_channel):
|
||||
"""get broadcast_weight_bias shape"""
|
||||
broad_weight_shape = x_shape[:-2] + (out_channel, in_channel)
|
||||
broad_bias_shape = x_shape[:-1] + (out_channel,)
|
||||
return broad_weight_shape, broad_bias_shape
|
||||
|
||||
|
||||
class Dense(Cell):
|
||||
r"""
|
||||
The dense connected layer.
|
||||
|
@ -249,7 +240,7 @@ class Dense(Cell):
|
|||
(2, 4)
|
||||
"""
|
||||
|
||||
@cell_attr_register(attrs=['has_bias', 'activation', 'in_channels', 'out_channels'])
|
||||
@cell_attr_register(attrs=['has_bias', 'activation'])
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
|
@ -261,8 +252,10 @@ class Dense(Cell):
|
|||
self.in_channels = Validator.check_positive_int(in_channels)
|
||||
self.out_channels = Validator.check_positive_int(out_channels)
|
||||
self.has_bias = Validator.check_bool(has_bias)
|
||||
self.reshape = P.Reshape()
|
||||
self.shape_op = P.Shape()
|
||||
|
||||
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
|
||||
weight_init.shape[1] != in_channels:
|
||||
|
@ -276,10 +269,8 @@ class Dense(Cell):
|
|||
raise ValueError("Bias init shape error.")
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.tensor_add = P.TensorAdd()
|
||||
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.batch_matmul = P.BatchMatMul(transpose_b=True)
|
||||
self.activation = get_activation(activation) if isinstance(activation, str) else activation
|
||||
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
|
||||
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
|
||||
|
@ -287,27 +278,16 @@ class Dense(Cell):
|
|||
|
||||
def construct(self, x):
|
||||
x_shape = self.shape_op(x)
|
||||
x_dim = len(x_shape)
|
||||
if x_dim == 2:
|
||||
matmul = self.matmul
|
||||
bias_add = self.bias_add if self.has_bias else None
|
||||
weight = self.weight
|
||||
bias = self.bias
|
||||
else:
|
||||
broad_weight_shape, broad_bias_shape = get_broadcast_weight_bias_shape(x_shape, self.out_channels,
|
||||
self.in_channels)
|
||||
weight_broadcast_to = P.BroadcastTo(broad_weight_shape)
|
||||
bias_broadcast_to = P.BroadcastTo(broad_bias_shape)
|
||||
matmul = self.batch_matmul
|
||||
bias_add = self.tensor_add if self.has_bias else None
|
||||
weight = weight_broadcast_to(self.weight)
|
||||
bias = bias_broadcast_to(self.bias) if self.has_bias else self.bias
|
||||
|
||||
x = matmul(x, weight)
|
||||
if len(x_shape) != 2:
|
||||
x = self.reshape(x, (-1, x_shape[-1]))
|
||||
x = self.matmul(x, self.weight)
|
||||
if self.has_bias:
|
||||
x = bias_add(x, bias)
|
||||
x = self.bias_add(x, self.bias)
|
||||
if self.activation_flag:
|
||||
x = self.activation(x)
|
||||
if len(x_shape) != 2:
|
||||
out_shape = x_shape[:-1] + (-1,)
|
||||
x = self.reshape(x, out_shape)
|
||||
return x
|
||||
|
||||
def extend_repr(self):
|
||||
|
|
Loading…
Reference in New Issue