modify_dense

This commit is contained in:
wanyiming 2020-11-03 17:34:21 +08:00
parent 27d4c8f5fb
commit 3c7a3b6693
5 changed files with 153 additions and 14 deletions

View File

@ -153,6 +153,22 @@ class Flatten(Cell):
return F.reshape(x, (F.shape(x)[0], -1))
def matmul_bias_select(x_shape, in_channel, out_channel):
"""matmul and bias_add selection for different input"""
x_dim = len(x_shape)
broad_weight_shape = x_shape[:-2] + (out_channel, in_channel)
broad_bias_shape = x_shape[:-1] + (out_channel,)
weight_broadcast_to = P.BroadcastTo(broad_weight_shape)
bias_broadcast_to = P.BroadcastTo(broad_bias_shape)
if x_dim == 2:
matmul = P.MatMul(False, True)
bias_add = P.BiasAdd()
else:
matmul = P.BatchMatMul(False, True)
bias_add = P.TensorAdd()
return matmul, bias_add, weight_broadcast_to, bias_broadcast_to
class Dense(Cell):
r"""
The dense connected layer.
@ -206,6 +222,7 @@ class Dense(Cell):
self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = Validator.check_bool(has_bias)
self.shape_op = P.Shape()
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
@ -219,28 +236,33 @@ class Dense(Cell):
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
raise ValueError("Bias init shape error.")
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
self.bias_add = P.BiasAdd()
self.matmul = P.MatMul(transpose_b=True)
self.activation = get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
self.activation_flag = self.activation is not None
def construct(self, x):
x = self.matmul(x, self.weight)
x_shape = self.shape_op(x)
x_dim = len(x_shape)
matmul, bias_add, weight_broadcast_to, bias_broadcast_to = matmul_bias_select(x_shape, self.in_channels,
self.out_channels)
weight = self.weight if x_dim == 2 else weight_broadcast_to(self.weight)
x = matmul(x, weight)
if self.has_bias:
x = self.bias_add(x, self.bias)
bias = self.bias if x_dim == 2 else bias_broadcast_to(self.bias)
x = bias_add(x, bias)
if self.activation_flag:
x = self.activation(x)
return x
def extend_repr(self):
s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
if self.has_bias:
s += ', has_bias={}'.format(self.has_bias)
if self.activation_flag:
s += ', activation={}'.fomat(self.activation)
s += ', activation={}'.format(self.activation)
return s

View File

@ -21,10 +21,34 @@ from mindspore import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.ops import operations as P
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class DenseDevice(nn.Cell):
def __init__(self,
in_channels,
out_channels,
device_target='Ascend',
weight_init='normal',
bias_init='zeros'):
super(DenseDevice, self).__init__()
self.device_target = device_target
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
self.bias_add = P.BiasAdd()
self.matmul = P.MatMul(transpose_b=True)
self.matmul.add_prim_attr("primitive_target", self.device_target)
self.bias_add.add_prim_attr("primitive_target", self.device_target)
def construct(self, x):
x = self.matmul(x, self.weight)
x = self.bias_add(x, self.bias)
return x
class LeNet(nn.Cell):
def __init__(self):
super(LeNet, self).__init__()
@ -35,15 +59,9 @@ class LeNet(nn.Cell):
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
self.fc1 = nn.Dense(400, 120)
self.fc1.matmul.add_prim_attr("primitive_target", "CPU")
self.fc1.bias_add.add_prim_attr("primitive_target", "CPU")
self.fc2 = nn.Dense(120, 84)
self.fc2.matmul.add_prim_attr("primitive_target", "CPU")
self.fc2.bias_add.add_prim_attr("primitive_target", "CPU")
self.fc3 = nn.Dense(84, 10)
self.fc3.matmul.add_prim_attr("primitive_target", "CPU")
self.fc3.bias_add.add_prim_attr("primitive_target", "CPU")
self.fc1 = DenseDevice(400, 120, device_target='CPU')
self.fc2 = DenseDevice(120, 84, device_target='CPU')
self.fc3 = DenseDevice(84, 10, device_target='CPU')
def construct(self, input_x):
output = self.conv1(input_x)

View File

@ -38,3 +38,11 @@ def test_net():
output = net(Tensor(x))
print(x)
print(output.asnumpy())
def test_net_ND():
x = np.random.randn(2, 332, 2048).astype(np.float32)
net = Net()
output = net(Tensor(x))
print(x)
print(output.asnumpy())

View File

@ -49,3 +49,10 @@ def test_net():
net = Grad(Net())
output = net(Tensor(x), Tensor(sens))
print(output.asnumpy())
def test_net_ND():
x = np.random.randn(2, 32, 2048).astype(np.float32)
sens = np.random.randn(2, 32, 1001).astype(np.float32)
net = Grad(Net())
output = net(Tensor(x), Tensor(sens))
print(output.asnumpy())

View File

@ -128,6 +128,47 @@ def test_dx():
assert np.all(-diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dx_ND():
x = np.array([[[0.1, 0.2, 0.3, 0.4],
[0.1, 0.2, 0.3, 0.4],
[0.1, 0.2, 0.3, 0.4]],
[[0.1, 0.2, 0.3, 0.4],
[0.1, 0.2, 0.3, 0.4],
[0.1, 0.2, 0.3, 0.4]]
]).astype(np.float32)
dy = np.array([[[1, 1],
[1, 1],
[1, 1]],
[[1, 1],
[1, 1],
[1, 1]]]).astype(np.float32)
dx_expect = np.array([[[1.1, 1.8, 1.1, 1.1],
[1.1, 1.8, 1.1, 1.1],
[1.1, 1.8, 1.1, 1.1]],
[[1.1, 1.8, 1.1, 1.1],
[1.1, 1.8, 1.1, 1.1],
[1.1, 1.8, 1.1, 1.1]]
]).astype(np.float32)
error = np.ones(shape=[2, 3, 4]) * 1.0e-6
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
net = GradData(DenseNet())
dx = net(Tensor(x), Tensor(dy))
diff = dx[0].asnumpy() - dx_expect
assert np.all(diff < error)
assert np.all(-diff < error)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = GradData(DenseNet())
dx = net(Tensor(x), Tensor(dy))
diff = dx[0].asnumpy() - dx_expect
assert np.all(diff < error)
assert np.all(-diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -165,6 +206,49 @@ def test_dw():
assert np.all(-diff < db_error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dw_ND():
x = np.array([[[0.1, 0.2, 0.3, 0.4],
[0.1, 0.2, 0.3, 0.4],
[0.1, 0.2, 0.3, 0.4]],
[[0.1, 0.2, 0.3, 0.4],
[0.1, 0.2, 0.3, 0.4],
[0.1, 0.2, 0.3, 0.4]]]).astype(np.float32)
dy = np.array([[[1, 1],
[1, 1],
[1, 1]],
[[1, 1],
[1, 1],
[1, 1]]]).astype(np.float32)
dw_expect = 2 * np.array([[0.3, 0.6, 0.9, 1.2],
[0.3, 0.6, 0.9, 1.2]]).astype(np.float32)
dw_error = np.ones(shape=[2, 4]) * 1.0e-6
db_expect = 2 * np.array([3, 3]).astype(np.float32)
db_error = np.ones(shape=[2]) * 1.0e-6
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
net = GradWeight(DenseNet())
dw, db = net(Tensor(x), Tensor(dy))
diff = dw.asnumpy() - dw_expect
assert np.all(diff < dw_error)
assert np.all(-diff < dw_error)
diff = db.asnumpy() - db_expect
assert np.all(diff < db_error)
assert np.all(-diff < db_error)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = GradWeight(DenseNet())
dw, db = net(Tensor(x), Tensor(dy))
diff = dw.asnumpy() - dw_expect
assert np.all(diff < dw_error)
assert np.all(-diff < dw_error)
diff = db.asnumpy() - db_expect
assert np.all(diff < db_error)
assert np.all(-diff < db_error)
class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()