forked from mindspore-Ecosystem/mindspore
!5994 [MSLITE][Develop] reverse_seqence seq_lengths support int64
Merge pull request !5994 from sunsuodong/fix_reverse_seqence
This commit is contained in:
commit
7913f8361c
|
@ -18,7 +18,7 @@
|
|||
#include <string.h>
|
||||
#include "nnacl/arithmetic_common.h"
|
||||
|
||||
void ReverseSequence(float *input0, int *input1, float *output, ReverseSequenceParameter *para) {
|
||||
void ReverseSequence(float *input0, void *input1, float *output, ReverseSequenceParameter *para) {
|
||||
(void)memcpy(output, input0, para->total_data_size_);
|
||||
ComputeStrides(para->input_shape0_, para->input_stride_, para->ndim_);
|
||||
ComputeStrides(para->output_shape_, para->output_stride_, para->ndim_);
|
||||
|
@ -28,8 +28,9 @@ void ReverseSequence(float *input0, int *input1, float *output, ReverseSequenceP
|
|||
for (int batch = 0; batch < para->input_shape0_[para->batch_axis_]; batch++) {
|
||||
float *in_batch = in + batch * para->input_stride_[para->batch_axis_];
|
||||
float *out_batch = out + batch * para->output_stride_[para->batch_axis_];
|
||||
for (int n = 0; n < input1[batch]; ++n) {
|
||||
float *in_seq = in_batch + (input1[batch] - 1 - n) * para->input_stride_[para->seq_axis_];
|
||||
int32_t seq_length = para->is_seq_length_int32_ ? *((int32_t *)input1 + batch) : *((int64_t *)input1 + batch);
|
||||
for (int n = 0; n < seq_length; ++n) {
|
||||
float *in_seq = in_batch + (seq_length - 1 - n) * para->input_stride_[para->seq_axis_];
|
||||
float *out_seq = out_batch + n * para->output_stride_[para->seq_axis_];
|
||||
for (int j = 0; j < para->inner_count_; ++j) {
|
||||
(void)memcpy(out_seq + j * para->inner_stride_, in_seq + j * para->inner_stride_, para->copy_byte_size_);
|
||||
|
|
|
@ -34,12 +34,13 @@ typedef struct ReverseSequenceParameter {
|
|||
int inner_stride_;
|
||||
int copy_byte_size_;
|
||||
int total_data_size_;
|
||||
bool is_seq_length_int32_;
|
||||
} ReverseSequenceParameter;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void ReverseSequence(float *input0, int *input1, float *output, ReverseSequenceParameter *para);
|
||||
void ReverseSequence(float *input0, void *input1, float *output, ReverseSequenceParameter *para);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -93,9 +93,11 @@ int ReverseSequenceCPUKernel::Run() {
|
|||
return ret;
|
||||
}
|
||||
float *input0 = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
|
||||
int *input1 = reinterpret_cast<int *>(in_tensors_.at(1)->MutableData());
|
||||
void *input1 = in_tensors_.at(1)->MutableData();
|
||||
float *output = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
|
||||
ReverseSequence(input0, input1, output, reinterpret_cast<ReverseSequenceParameter *>(op_parameter_));
|
||||
ReverseSequenceParameter *param = reinterpret_cast<ReverseSequenceParameter *>(op_parameter_);
|
||||
param->is_seq_length_int32_ = in_tensors_.at(1)->data_type() == kNumberTypeInt32;
|
||||
ReverseSequence(input0, input1, output, param);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue