simplify dynamic model using the new feature: dynamic-shaped Tensor index operations
This commit is contained in:
parent
6eb72535ec
commit
fdfb033d1c
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -576,13 +576,8 @@ class CustomDense(nn.Dense):
|
|||
"""Initialize CustomDense."""
|
||||
super(CustomDense, self).__init__(in_channels, out_channels,
|
||||
weight_init, bias_init, has_bias, activation)
|
||||
self.scatterupdate = ops.TensorScatterUpdate()
|
||||
self.dyn_shape = ops.TensorShape()
|
||||
self.mul = ops.Mul()
|
||||
self.cast = ops.Cast()
|
||||
self.indices_0 = Tensor(np.array([[0]]), mstype.int32)
|
||||
self.indices_1 = Tensor(np.array([[-1]]), mstype.int32)
|
||||
self.indices_2 = Tensor(np.array([[2]]), mstype.int32)
|
||||
|
||||
def construct(self, x):
|
||||
x_shape = self.shape_op(x)
|
||||
|
@ -591,9 +586,7 @@ class CustomDense(nn.Dense):
|
|||
x_dyn_shape = self.cast(x_dyn_shape, mstype.float32)
|
||||
if len(x_dyn_shape) != 2:
|
||||
new_shape = x_dyn_shape[1:]
|
||||
updates = self.mul(x_dyn_shape[0:1], x_dyn_shape[1:2])
|
||||
new_shape = self.scatterupdate(
|
||||
new_shape, self.indices_0, updates)
|
||||
new_shape[0] = x_dyn_shape[0:1] * x_dyn_shape[1:2]
|
||||
new_shape = self.cast(new_shape, mstype.int64)
|
||||
x = self.reshape(x, new_shape)
|
||||
x = self.matmul(x, self.weight)
|
||||
|
@ -604,9 +597,7 @@ class CustomDense(nn.Dense):
|
|||
if len(x_dyn_shape) != 2:
|
||||
out_shape = self.dyn_shape(x)
|
||||
out_shape = self.cast(out_shape, mstype.float32)
|
||||
updates = out_shape[1:2]
|
||||
x_dyn_shape = self.scatterupdate(
|
||||
x_dyn_shape, self.indices_2, updates)
|
||||
x_dyn_shape[2] = out_shape[1:2]
|
||||
x_dyn_shape = self.cast(x_dyn_shape, mstype.int64)
|
||||
x = self.reshape(x, x_dyn_shape)
|
||||
else:
|
||||
|
@ -1342,25 +1333,12 @@ class PositionalEncoding(nn.Cell):
|
|||
self.pe = Tensor(np.expand_dims(self.pe, 0), mstype.float32)
|
||||
self.get_shape = ops.Shape()
|
||||
self.dyn_shape = ops.TensorShape()
|
||||
self.stridedslice = ops.StridedSlice()
|
||||
self.scatterupdate = ops.TensorScatterUpdate()
|
||||
self.indices_1 = Tensor(([[1]]), mstype.int32)
|
||||
self.end = Tensor(
|
||||
(self.pe.shape[0], 0, self.pe.shape[2]), mstype.float32)
|
||||
|
||||
def construct(self, x, offset=0) -> Tuple[mindspore.Tensor, mindspore.Tensor]:
|
||||
x_shape = self.get_shape(x)
|
||||
if -1 not in x_shape:
|
||||
pos_emb = self.pe[:, offset: offset + x_shape[1]]
|
||||
else:
|
||||
x_dyn_shape = self.dyn_shape(x)
|
||||
x_dyn_shape = self.cast(x_dyn_shape, mstype.float32)
|
||||
begin = (0, offset, 0)
|
||||
end = self.scatterupdate(
|
||||
self.end, self.indices_1, offset + x_dyn_shape[1:2])
|
||||
end = self.cast(end, mstype.int64)
|
||||
step = (1, 1, 1)
|
||||
pos_emb = self.stridedslice(self.pe, begin, end, step)
|
||||
if -1 in x_shape:
|
||||
x_shape = self.dyn_shape(x)
|
||||
pos_emb = self.pe[:, offset: offset + x_shape[1]]
|
||||
x = x * self.xscale + pos_emb
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
|
@ -1387,17 +1365,9 @@ class RelPositionalEncoding(PositionalEncoding):
|
|||
"""
|
||||
x = x * self.xscale
|
||||
x_shape = self.get_shape(x)
|
||||
if -1 not in x_shape:
|
||||
pos_emb = self.pe[:, offset: offset + x_shape[1]]
|
||||
else:
|
||||
x_dyn_shape = self.dyn_shape(x)
|
||||
x_dyn_shape = self.cast(x_dyn_shape, mstype.float32)
|
||||
begin = (0, offset, 0)
|
||||
end = self.scatterupdate(
|
||||
self.end, self.indices_1, offset + x_dyn_shape[1:2])
|
||||
end = self.cast(end, mstype.int64)
|
||||
step = (1, 1, 1)
|
||||
pos_emb = self.stridedslice(self.pe, begin, end, step)
|
||||
if -1 in x_shape:
|
||||
x_shape = self.dyn_shape(x)
|
||||
pos_emb = self.pe[:, offset: offset + x_shape[1]]
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
|
||||
|
@ -1840,11 +1810,10 @@ def test_train():
|
|||
logging.info("Training start.")
|
||||
|
||||
model.train(
|
||||
MAX_EPOCH * steps_size,
|
||||
MAX_EPOCH,
|
||||
train_dataset,
|
||||
callbacks=callback,
|
||||
dataset_sink_mode=True,
|
||||
sink_size=1,
|
||||
dataset_sink_mode=True
|
||||
)
|
||||
|
||||
train_loss = callback.loss
|
||||
|
|
Loading…
Reference in New Issue