diff --git a/tests/st/dynamic_shape/test_dynamic_wenet_ascend.py b/tests/st/dynamic_shape/test_dynamic_wenet_ascend.py index b36ef0335b8..982fc39bd1e 100644 --- a/tests/st/dynamic_shape/test_dynamic_wenet_ascend.py +++ b/tests/st/dynamic_shape/test_dynamic_wenet_ascend.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -576,13 +576,8 @@ class CustomDense(nn.Dense): """Initialize CustomDense.""" super(CustomDense, self).__init__(in_channels, out_channels, weight_init, bias_init, has_bias, activation) - self.scatterupdate = ops.TensorScatterUpdate() self.dyn_shape = ops.TensorShape() - self.mul = ops.Mul() self.cast = ops.Cast() - self.indices_0 = Tensor(np.array([[0]]), mstype.int32) - self.indices_1 = Tensor(np.array([[-1]]), mstype.int32) - self.indices_2 = Tensor(np.array([[2]]), mstype.int32) def construct(self, x): x_shape = self.shape_op(x) @@ -591,9 +586,7 @@ class CustomDense(nn.Dense): x_dyn_shape = self.cast(x_dyn_shape, mstype.float32) if len(x_dyn_shape) != 2: new_shape = x_dyn_shape[1:] - updates = self.mul(x_dyn_shape[0:1], x_dyn_shape[1:2]) - new_shape = self.scatterupdate( - new_shape, self.indices_0, updates) + new_shape[0] = x_dyn_shape[0:1] * x_dyn_shape[1:2] new_shape = self.cast(new_shape, mstype.int64) x = self.reshape(x, new_shape) x = self.matmul(x, self.weight) @@ -604,9 +597,7 @@ class CustomDense(nn.Dense): if len(x_dyn_shape) != 2: out_shape = self.dyn_shape(x) out_shape = self.cast(out_shape, mstype.float32) - updates = out_shape[1:2] - x_dyn_shape = self.scatterupdate( - x_dyn_shape, self.indices_2, updates) + x_dyn_shape[2] = out_shape[1:2] x_dyn_shape = self.cast(x_dyn_shape, mstype.int64) x = self.reshape(x, x_dyn_shape) else: @@ -1342,25 +1333,12 @@ class PositionalEncoding(nn.Cell): self.pe = Tensor(np.expand_dims(self.pe, 0), mstype.float32) self.get_shape = ops.Shape() self.dyn_shape = ops.TensorShape() - self.stridedslice = ops.StridedSlice() - self.scatterupdate = ops.TensorScatterUpdate() - self.indices_1 = Tensor(([[1]]), mstype.int32) - self.end = Tensor( - (self.pe.shape[0], 0, self.pe.shape[2]), mstype.float32) def construct(self, x, offset=0) -> Tuple[mindspore.Tensor, mindspore.Tensor]: x_shape = self.get_shape(x) - if -1 not in x_shape: - pos_emb = self.pe[:, offset: offset + x_shape[1]] - else: - x_dyn_shape = self.dyn_shape(x) - x_dyn_shape = self.cast(x_dyn_shape, mstype.float32) - begin = (0, offset, 0) - end = self.scatterupdate( - self.end, self.indices_1, offset + x_dyn_shape[1:2]) - end = self.cast(end, mstype.int64) - step = (1, 1, 1) - pos_emb = self.stridedslice(self.pe, begin, end, step) + if -1 in x_shape: + x_shape = self.dyn_shape(x) + pos_emb = self.pe[:, offset: offset + x_shape[1]] x = x * self.xscale + pos_emb return self.dropout(x), self.dropout(pos_emb) @@ -1387,17 +1365,9 @@ class RelPositionalEncoding(PositionalEncoding): """ x = x * self.xscale x_shape = self.get_shape(x) - if -1 not in x_shape: - pos_emb = self.pe[:, offset: offset + x_shape[1]] - else: - x_dyn_shape = self.dyn_shape(x) - x_dyn_shape = self.cast(x_dyn_shape, mstype.float32) - begin = (0, offset, 0) - end = self.scatterupdate( - self.end, self.indices_1, offset + x_dyn_shape[1:2]) - end = self.cast(end, mstype.int64) - step = (1, 1, 1) - pos_emb = self.stridedslice(self.pe, begin, end, step) + if -1 in x_shape: + x_shape = self.dyn_shape(x) + pos_emb = self.pe[:, offset: offset + x_shape[1]] return self.dropout(x), self.dropout(pos_emb) @@ -1840,11 +1810,10 @@ def test_train(): logging.info("Training start.") model.train( - MAX_EPOCH * steps_size, + MAX_EPOCH, train_dataset, callbacks=callback, - dataset_sink_mode=True, - sink_size=1, + dataset_sink_mode=True ) train_loss = callback.loss