diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 3168d827c15..b5bc2cbb997 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -152,22 +152,12 @@ class Flatten(Cell): def construct(self, x): return F.reshape(x, (F.shape(x)[0], -1)) - -def matmul_bias_select(x_shape, in_channel, out_channel): - """matmul and bias_add selection for different input""" - x_dim = len(x_shape) +@constexpr +def get_broadcast_weight_bias_shape(x_shape, out_channel, in_channel): + """get broadcast_weight_bias shape""" broad_weight_shape = x_shape[:-2] + (out_channel, in_channel) broad_bias_shape = x_shape[:-1] + (out_channel,) - weight_broadcast_to = P.BroadcastTo(broad_weight_shape) - bias_broadcast_to = P.BroadcastTo(broad_bias_shape) - if x_dim == 2: - matmul = P.MatMul(False, True) - bias_add = P.BiasAdd() - else: - matmul = P.BatchMatMul(False, True) - bias_add = P.TensorAdd() - return matmul, bias_add, weight_broadcast_to, bias_broadcast_to - + return broad_weight_shape, broad_bias_shape class Dense(Cell): r""" @@ -236,7 +226,11 @@ class Dense(Cell): if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: raise ValueError("Bias init shape error.") self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") + self.bias_add = P.BiasAdd() + self.tensor_add = P.TensorAdd() + self.matmul = P.MatMul(transpose_b=True) + self.batch_matmul = P.BatchMatMul(transpose_b=True) self.activation = get_activation(activation) if isinstance(activation, str) else activation if activation is not None and not isinstance(self.activation, (Cell, Primitive)): raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) @@ -245,12 +239,23 @@ class Dense(Cell): def construct(self, x): x_shape = self.shape_op(x) x_dim = len(x_shape) - matmul, bias_add, weight_broadcast_to, bias_broadcast_to = matmul_bias_select(x_shape, self.in_channels, - self.out_channels) - weight = self.weight if x_dim == 2 else weight_broadcast_to(self.weight) + if x_dim == 2: + matmul = self.matmul + bias_add = self.bias_add if self.has_bias else None + weight = self.weight + bias = self.bias + else: + broad_weight_shape, broad_bias_shape = get_broadcast_weight_bias_shape(x_shape, self.out_channels, + self.in_channels) + weight_broadcast_to = P.BroadcastTo(broad_weight_shape) + bias_broadcast_to = P.BroadcastTo(broad_bias_shape) + matmul = self.batch_matmul + bias_add = self.tensor_add if self.has_bias else None + weight = weight_broadcast_to(self.weight) + bias = bias_broadcast_to(self.bias) if self.has_bias else self.bias + x = matmul(x, weight) if self.has_bias: - bias = self.bias if x_dim == 2 else bias_broadcast_to(self.bias) x = bias_add(x, bias) if self.activation_flag: x = self.activation(x) diff --git a/tests/st/host_device/test_host_device_lenet.py b/tests/st/host_device/test_host_device_lenet.py index 243f3facafa..80bf7b578a4 100644 --- a/tests/st/host_device/test_host_device_lenet.py +++ b/tests/st/host_device/test_host_device_lenet.py @@ -21,34 +21,10 @@ from mindspore import Tensor from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import Momentum from mindspore.ops import operations as P -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") -class DenseDevice(nn.Cell): - def __init__(self, - in_channels, - out_channels, - device_target='Ascend', - weight_init='normal', - bias_init='zeros'): - super(DenseDevice, self).__init__() - self.device_target = device_target - self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") - self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") - self.bias_add = P.BiasAdd() - self.matmul = P.MatMul(transpose_b=True) - self.matmul.add_prim_attr("primitive_target", self.device_target) - self.bias_add.add_prim_attr("primitive_target", self.device_target) - - def construct(self, x): - x = self.matmul(x, self.weight) - x = self.bias_add(x, self.bias) - return x - - class LeNet(nn.Cell): def __init__(self): super(LeNet, self).__init__() @@ -59,9 +35,15 @@ class LeNet(nn.Cell): self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid') self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.reshape = P.Reshape() - self.fc1 = DenseDevice(400, 120, device_target='CPU') - self.fc2 = DenseDevice(120, 84, device_target='CPU') - self.fc3 = DenseDevice(84, 10, device_target='CPU') + self.fc1 = nn.Dense(400, 120) + self.fc1.matmul.add_prim_attr("primitive_target", "CPU") + self.fc1.bias_add.add_prim_attr("primitive_target", "CPU") + self.fc2 = nn.Dense(120, 84) + self.fc2.matmul.add_prim_attr("primitive_target", "CPU") + self.fc2.bias_add.add_prim_attr("primitive_target", "CPU") + self.fc3 = nn.Dense(84, 10) + self.fc3.matmul.add_prim_attr("primitive_target", "CPU") + self.fc3.bias_add.add_prim_attr("primitive_target", "CPU") def construct(self, input_x): output = self.conv1(input_x)