forked from mindspore-Ecosystem/mindspore
use Reverse and ReverseSequence operator on CPU
This commit is contained in:
parent
8876c073bb
commit
9cde823c82
|
@ -1,81 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Utils for RNNs CPU version, like Reverse operators."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.ops as P
|
||||
from mindspore.nn.cell import Cell
|
||||
|
||||
|
||||
class _Reverse(Cell):
|
||||
"""Reverse operator, like Reverse in mindspore"""
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def construct(self, input_x):
|
||||
dim_size = input_x.shape[self.dim]
|
||||
reversed_indexes = P.arange(dim_size-1, -1, -1)
|
||||
output = P.Gather()(input_x, reversed_indexes, self.dim)
|
||||
return output
|
||||
|
||||
|
||||
class _ReverseSequence(Cell):
|
||||
"""Reverse sequence operator, like ReverseSequenceV2 in mindspore"""
|
||||
def __init__(self, seq_dim, batch_dim=0):
|
||||
super().__init__()
|
||||
self.seq_dim = seq_dim
|
||||
self.batch_dim = batch_dim
|
||||
|
||||
|
||||
@staticmethod
|
||||
def make_shape(shape, dtype, range_dim):
|
||||
"""Calculates the shape according by the inputs."""
|
||||
output = P.Ones()(shape, mstype.float32)
|
||||
output = P.CumSum()(output, range_dim)
|
||||
output = P.Cast()(output, dtype)
|
||||
output = output - 1
|
||||
return output
|
||||
|
||||
def construct(self, x, seq_lengths):
|
||||
"""Defines the ReverseSequence operator computation performed."""
|
||||
batch_size = x.shape[self.batch_dim]
|
||||
max_seq_len = x.shape[self.seq_dim]
|
||||
seq_lens_type = seq_lengths.dtype
|
||||
|
||||
back = P.Sub()(seq_lengths, P.OnesLike()(seq_lengths))
|
||||
|
||||
batch_idx = self.make_shape((batch_size, max_seq_len), seq_lens_type, 0)
|
||||
forward_idx = self.make_shape((batch_size, max_seq_len), seq_lens_type, 1)
|
||||
|
||||
back = back.view(-1, 1)
|
||||
reverse_idx = P.Sub()(back, forward_idx)
|
||||
|
||||
condition = P.Less()(reverse_idx, P.ZerosLike()(reverse_idx))
|
||||
reverse_idx = P.Select()(condition, forward_idx, reverse_idx)
|
||||
|
||||
reverse_idx = P.ExpandDims()(reverse_idx, 2)
|
||||
batch_idx = P.ExpandDims()(batch_idx, 2)
|
||||
|
||||
if self.batch_dim > self.seq_dim:
|
||||
batch_idx = P.Transpose()(batch_idx, (1, 0, 2))
|
||||
reverse_idx = P.Transpose()(reverse_idx, (1, 0, 2))
|
||||
x = P.Transpose()(x, (1, 0, 2))
|
||||
start_indices = P.Concat(2)((batch_idx, reverse_idx))
|
||||
|
||||
output = P.GatherNd()(x, start_indices)
|
||||
|
||||
return output
|
|
@ -31,7 +31,6 @@ from mindspore import log as logger
|
|||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.ops.operations._rl_inner_ops import CudnnGRU
|
||||
from mindspore.nn.layer.rnn_cells import _rnn_relu_cell, _rnn_tanh_cell, _gru_cell, _lstm_cell
|
||||
from mindspore.nn.layer.rnn_utils import _Reverse, _ReverseSequence
|
||||
|
||||
__all__ = ['LSTM', 'GRU', 'RNN']
|
||||
|
||||
|
@ -403,12 +402,8 @@ class _RNNBase(Cell):
|
|||
raise ValueError(f"For '{self.cls_name}', the 'mode' must be in ['RNN_RELU', 'RNN_TANH', 'LSTM', 'GRU'], "
|
||||
f"but got {mode}.")
|
||||
|
||||
if context.get_context("device_target") == "CPU":
|
||||
self.reverse = _Reverse(0)
|
||||
self.reverse_sequence = _ReverseSequence(0, 1)
|
||||
else:
|
||||
self.reverse = P.ReverseV2([0])
|
||||
self.reverse_sequence = P.ReverseSequence(0, 1)
|
||||
self.reverse = P.ReverseV2([0])
|
||||
self.reverse_sequence = P.ReverseSequence(0, 1)
|
||||
self.hidden_size = hidden_size
|
||||
self.batch_first = batch_first
|
||||
self.num_layers = num_layers
|
||||
|
|
Loading…
Reference in New Issue