forked from mindspore-Ecosystem/mindspore
use Reverse and ReverseSequence operator on CPU
This commit is contained in:
parent
8876c073bb
commit
9cde823c82
|
@ -1,81 +0,0 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""Utils for RNNs CPU version, like Reverse operators."""
|
|
||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
import mindspore.common.dtype as mstype
|
|
||||||
import mindspore.ops as P
|
|
||||||
from mindspore.nn.cell import Cell
|
|
||||||
|
|
||||||
|
|
||||||
class _Reverse(Cell):
|
|
||||||
"""Reverse operator, like Reverse in mindspore"""
|
|
||||||
def __init__(self, dim):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
|
|
||||||
def construct(self, input_x):
|
|
||||||
dim_size = input_x.shape[self.dim]
|
|
||||||
reversed_indexes = P.arange(dim_size-1, -1, -1)
|
|
||||||
output = P.Gather()(input_x, reversed_indexes, self.dim)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class _ReverseSequence(Cell):
|
|
||||||
"""Reverse sequence operator, like ReverseSequenceV2 in mindspore"""
|
|
||||||
def __init__(self, seq_dim, batch_dim=0):
|
|
||||||
super().__init__()
|
|
||||||
self.seq_dim = seq_dim
|
|
||||||
self.batch_dim = batch_dim
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def make_shape(shape, dtype, range_dim):
|
|
||||||
"""Calculates the shape according by the inputs."""
|
|
||||||
output = P.Ones()(shape, mstype.float32)
|
|
||||||
output = P.CumSum()(output, range_dim)
|
|
||||||
output = P.Cast()(output, dtype)
|
|
||||||
output = output - 1
|
|
||||||
return output
|
|
||||||
|
|
||||||
def construct(self, x, seq_lengths):
|
|
||||||
"""Defines the ReverseSequence operator computation performed."""
|
|
||||||
batch_size = x.shape[self.batch_dim]
|
|
||||||
max_seq_len = x.shape[self.seq_dim]
|
|
||||||
seq_lens_type = seq_lengths.dtype
|
|
||||||
|
|
||||||
back = P.Sub()(seq_lengths, P.OnesLike()(seq_lengths))
|
|
||||||
|
|
||||||
batch_idx = self.make_shape((batch_size, max_seq_len), seq_lens_type, 0)
|
|
||||||
forward_idx = self.make_shape((batch_size, max_seq_len), seq_lens_type, 1)
|
|
||||||
|
|
||||||
back = back.view(-1, 1)
|
|
||||||
reverse_idx = P.Sub()(back, forward_idx)
|
|
||||||
|
|
||||||
condition = P.Less()(reverse_idx, P.ZerosLike()(reverse_idx))
|
|
||||||
reverse_idx = P.Select()(condition, forward_idx, reverse_idx)
|
|
||||||
|
|
||||||
reverse_idx = P.ExpandDims()(reverse_idx, 2)
|
|
||||||
batch_idx = P.ExpandDims()(batch_idx, 2)
|
|
||||||
|
|
||||||
if self.batch_dim > self.seq_dim:
|
|
||||||
batch_idx = P.Transpose()(batch_idx, (1, 0, 2))
|
|
||||||
reverse_idx = P.Transpose()(reverse_idx, (1, 0, 2))
|
|
||||||
x = P.Transpose()(x, (1, 0, 2))
|
|
||||||
start_indices = P.Concat(2)((batch_idx, reverse_idx))
|
|
||||||
|
|
||||||
output = P.GatherNd()(x, start_indices)
|
|
||||||
|
|
||||||
return output
|
|
|
@ -31,7 +31,6 @@ from mindspore import log as logger
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
from mindspore.ops.operations._rl_inner_ops import CudnnGRU
|
from mindspore.ops.operations._rl_inner_ops import CudnnGRU
|
||||||
from mindspore.nn.layer.rnn_cells import _rnn_relu_cell, _rnn_tanh_cell, _gru_cell, _lstm_cell
|
from mindspore.nn.layer.rnn_cells import _rnn_relu_cell, _rnn_tanh_cell, _gru_cell, _lstm_cell
|
||||||
from mindspore.nn.layer.rnn_utils import _Reverse, _ReverseSequence
|
|
||||||
|
|
||||||
__all__ = ['LSTM', 'GRU', 'RNN']
|
__all__ = ['LSTM', 'GRU', 'RNN']
|
||||||
|
|
||||||
|
@ -403,12 +402,8 @@ class _RNNBase(Cell):
|
||||||
raise ValueError(f"For '{self.cls_name}', the 'mode' must be in ['RNN_RELU', 'RNN_TANH', 'LSTM', 'GRU'], "
|
raise ValueError(f"For '{self.cls_name}', the 'mode' must be in ['RNN_RELU', 'RNN_TANH', 'LSTM', 'GRU'], "
|
||||||
f"but got {mode}.")
|
f"but got {mode}.")
|
||||||
|
|
||||||
if context.get_context("device_target") == "CPU":
|
self.reverse = P.ReverseV2([0])
|
||||||
self.reverse = _Reverse(0)
|
self.reverse_sequence = P.ReverseSequence(0, 1)
|
||||||
self.reverse_sequence = _ReverseSequence(0, 1)
|
|
||||||
else:
|
|
||||||
self.reverse = P.ReverseV2([0])
|
|
||||||
self.reverse_sequence = P.ReverseSequence(0, 1)
|
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.batch_first = batch_first
|
self.batch_first = batch_first
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
Loading…
Reference in New Issue