forked from mindspore-Ecosystem/mindspore
!8423 nn.Dense support ND*2D
From: @wanyiming Reviewed-by: Signed-off-by:
This commit is contained in:
commit
eb5ae1a0fc
|
@ -152,22 +152,12 @@ class Flatten(Cell):
|
|||
def construct(self, x):
|
||||
return F.reshape(x, (F.shape(x)[0], -1))
|
||||
|
||||
|
||||
def matmul_bias_select(x_shape, in_channel, out_channel):
|
||||
"""matmul and bias_add selection for different input"""
|
||||
x_dim = len(x_shape)
|
||||
@constexpr
|
||||
def get_broadcast_weight_bias_shape(x_shape, out_channel, in_channel):
|
||||
"""get broadcast_weight_bias shape"""
|
||||
broad_weight_shape = x_shape[:-2] + (out_channel, in_channel)
|
||||
broad_bias_shape = x_shape[:-1] + (out_channel,)
|
||||
weight_broadcast_to = P.BroadcastTo(broad_weight_shape)
|
||||
bias_broadcast_to = P.BroadcastTo(broad_bias_shape)
|
||||
if x_dim == 2:
|
||||
matmul = P.MatMul(False, True)
|
||||
bias_add = P.BiasAdd()
|
||||
else:
|
||||
matmul = P.BatchMatMul(False, True)
|
||||
bias_add = P.TensorAdd()
|
||||
return matmul, bias_add, weight_broadcast_to, bias_broadcast_to
|
||||
|
||||
return broad_weight_shape, broad_bias_shape
|
||||
|
||||
class Dense(Cell):
|
||||
r"""
|
||||
|
@ -236,7 +226,11 @@ class Dense(Cell):
|
|||
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
||||
raise ValueError("Bias init shape error.")
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.tensor_add = P.TensorAdd()
|
||||
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.batch_matmul = P.BatchMatMul(transpose_b=True)
|
||||
self.activation = get_activation(activation) if isinstance(activation, str) else activation
|
||||
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
|
||||
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
|
||||
|
@ -245,12 +239,23 @@ class Dense(Cell):
|
|||
def construct(self, x):
|
||||
x_shape = self.shape_op(x)
|
||||
x_dim = len(x_shape)
|
||||
matmul, bias_add, weight_broadcast_to, bias_broadcast_to = matmul_bias_select(x_shape, self.in_channels,
|
||||
self.out_channels)
|
||||
weight = self.weight if x_dim == 2 else weight_broadcast_to(self.weight)
|
||||
if x_dim == 2:
|
||||
matmul = self.matmul
|
||||
bias_add = self.bias_add if self.has_bias else None
|
||||
weight = self.weight
|
||||
bias = self.bias
|
||||
else:
|
||||
broad_weight_shape, broad_bias_shape = get_broadcast_weight_bias_shape(x_shape, self.out_channels,
|
||||
self.in_channels)
|
||||
weight_broadcast_to = P.BroadcastTo(broad_weight_shape)
|
||||
bias_broadcast_to = P.BroadcastTo(broad_bias_shape)
|
||||
matmul = self.batch_matmul
|
||||
bias_add = self.tensor_add if self.has_bias else None
|
||||
weight = weight_broadcast_to(self.weight)
|
||||
bias = bias_broadcast_to(self.bias) if self.has_bias else self.bias
|
||||
|
||||
x = matmul(x, weight)
|
||||
if self.has_bias:
|
||||
bias = self.bias if x_dim == 2 else bias_broadcast_to(self.bias)
|
||||
x = bias_add(x, bias)
|
||||
if self.activation_flag:
|
||||
x = self.activation(x)
|
||||
|
|
|
@ -21,34 +21,10 @@ from mindspore import Tensor
|
|||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class DenseDevice(nn.Cell):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
device_target='Ascend',
|
||||
weight_init='normal',
|
||||
bias_init='zeros'):
|
||||
super(DenseDevice, self).__init__()
|
||||
self.device_target = device_target
|
||||
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.matmul.add_prim_attr("primitive_target", self.device_target)
|
||||
self.bias_add.add_prim_attr("primitive_target", self.device_target)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.matmul(x, self.weight)
|
||||
x = self.bias_add(x, self.bias)
|
||||
return x
|
||||
|
||||
|
||||
class LeNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(LeNet, self).__init__()
|
||||
|
@ -59,9 +35,15 @@ class LeNet(nn.Cell):
|
|||
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.reshape = P.Reshape()
|
||||
self.fc1 = DenseDevice(400, 120, device_target='CPU')
|
||||
self.fc2 = DenseDevice(120, 84, device_target='CPU')
|
||||
self.fc3 = DenseDevice(84, 10, device_target='CPU')
|
||||
self.fc1 = nn.Dense(400, 120)
|
||||
self.fc1.matmul.add_prim_attr("primitive_target", "CPU")
|
||||
self.fc1.bias_add.add_prim_attr("primitive_target", "CPU")
|
||||
self.fc2 = nn.Dense(120, 84)
|
||||
self.fc2.matmul.add_prim_attr("primitive_target", "CPU")
|
||||
self.fc2.bias_add.add_prim_attr("primitive_target", "CPU")
|
||||
self.fc3 = nn.Dense(84, 10)
|
||||
self.fc3.matmul.add_prim_attr("primitive_target", "CPU")
|
||||
self.fc3.bias_add.add_prim_attr("primitive_target", "CPU")
|
||||
|
||||
def construct(self, input_x):
|
||||
output = self.conv1(input_x)
|
||||
|
|
Loading…
Reference in New Issue