simplify dynamic model using the new feature: dynamic-shaped Tensor index operations

This commit is contained in:
zhengzuohe 2022-08-09 17:46:09 +08:00
parent 6eb72535ec
commit fdfb033d1c
1 changed files with 11 additions and 42 deletions

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -576,13 +576,8 @@ class CustomDense(nn.Dense):
"""Initialize CustomDense."""
super(CustomDense, self).__init__(in_channels, out_channels,
weight_init, bias_init, has_bias, activation)
self.scatterupdate = ops.TensorScatterUpdate()
self.dyn_shape = ops.TensorShape()
self.mul = ops.Mul()
self.cast = ops.Cast()
self.indices_0 = Tensor(np.array([[0]]), mstype.int32)
self.indices_1 = Tensor(np.array([[-1]]), mstype.int32)
self.indices_2 = Tensor(np.array([[2]]), mstype.int32)
def construct(self, x):
x_shape = self.shape_op(x)
@ -591,9 +586,7 @@ class CustomDense(nn.Dense):
x_dyn_shape = self.cast(x_dyn_shape, mstype.float32)
if len(x_dyn_shape) != 2:
new_shape = x_dyn_shape[1:]
updates = self.mul(x_dyn_shape[0:1], x_dyn_shape[1:2])
new_shape = self.scatterupdate(
new_shape, self.indices_0, updates)
new_shape[0] = x_dyn_shape[0:1] * x_dyn_shape[1:2]
new_shape = self.cast(new_shape, mstype.int64)
x = self.reshape(x, new_shape)
x = self.matmul(x, self.weight)
@ -604,9 +597,7 @@ class CustomDense(nn.Dense):
if len(x_dyn_shape) != 2:
out_shape = self.dyn_shape(x)
out_shape = self.cast(out_shape, mstype.float32)
updates = out_shape[1:2]
x_dyn_shape = self.scatterupdate(
x_dyn_shape, self.indices_2, updates)
x_dyn_shape[2] = out_shape[1:2]
x_dyn_shape = self.cast(x_dyn_shape, mstype.int64)
x = self.reshape(x, x_dyn_shape)
else:
@ -1342,25 +1333,12 @@ class PositionalEncoding(nn.Cell):
self.pe = Tensor(np.expand_dims(self.pe, 0), mstype.float32)
self.get_shape = ops.Shape()
self.dyn_shape = ops.TensorShape()
self.stridedslice = ops.StridedSlice()
self.scatterupdate = ops.TensorScatterUpdate()
self.indices_1 = Tensor(([[1]]), mstype.int32)
self.end = Tensor(
(self.pe.shape[0], 0, self.pe.shape[2]), mstype.float32)
def construct(self, x, offset=0) -> Tuple[mindspore.Tensor, mindspore.Tensor]:
x_shape = self.get_shape(x)
if -1 not in x_shape:
pos_emb = self.pe[:, offset: offset + x_shape[1]]
else:
x_dyn_shape = self.dyn_shape(x)
x_dyn_shape = self.cast(x_dyn_shape, mstype.float32)
begin = (0, offset, 0)
end = self.scatterupdate(
self.end, self.indices_1, offset + x_dyn_shape[1:2])
end = self.cast(end, mstype.int64)
step = (1, 1, 1)
pos_emb = self.stridedslice(self.pe, begin, end, step)
if -1 in x_shape:
x_shape = self.dyn_shape(x)
pos_emb = self.pe[:, offset: offset + x_shape[1]]
x = x * self.xscale + pos_emb
return self.dropout(x), self.dropout(pos_emb)
@ -1387,17 +1365,9 @@ class RelPositionalEncoding(PositionalEncoding):
"""
x = x * self.xscale
x_shape = self.get_shape(x)
if -1 not in x_shape:
pos_emb = self.pe[:, offset: offset + x_shape[1]]
else:
x_dyn_shape = self.dyn_shape(x)
x_dyn_shape = self.cast(x_dyn_shape, mstype.float32)
begin = (0, offset, 0)
end = self.scatterupdate(
self.end, self.indices_1, offset + x_dyn_shape[1:2])
end = self.cast(end, mstype.int64)
step = (1, 1, 1)
pos_emb = self.stridedslice(self.pe, begin, end, step)
if -1 in x_shape:
x_shape = self.dyn_shape(x)
pos_emb = self.pe[:, offset: offset + x_shape[1]]
return self.dropout(x), self.dropout(pos_emb)
@ -1840,11 +1810,10 @@ def test_train():
logging.info("Training start.")
model.train(
MAX_EPOCH * steps_size,
MAX_EPOCH,
train_dataset,
callbacks=callback,
dataset_sink_mode=True,
sink_size=1,
dataset_sink_mode=True
)
train_loss = callback.loss