forked from mindspore-Ecosystem/mindspore
first commit - ReverseSequence - GPU
Added additional comments + added test case + lint linting + fixed dyn cast to static cast comment add MindTester case fix - 0 edge case + improved comments addressing some comments + malloc remove from kernel Added test case for 0 edge case testing
This commit is contained in:
parent
05bbd1d8fa
commit
e4d33f1195
|
@ -0,0 +1,95 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/arrays/reverse_sequence_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
// int8
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ReverseSequence,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
ReverseSequenceGpuFwdKernel, int8_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ReverseSequence,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
ReverseSequenceGpuFwdKernel, int8_t, int64_t)
|
||||
// int16
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ReverseSequence,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
ReverseSequenceGpuFwdKernel, int16_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ReverseSequence,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
ReverseSequenceGpuFwdKernel, int16_t, int64_t)
|
||||
// int32
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ReverseSequence,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ReverseSequenceGpuFwdKernel, int, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ReverseSequence,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
ReverseSequenceGpuFwdKernel, int, int64_t)
|
||||
// int64
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ReverseSequence,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
ReverseSequenceGpuFwdKernel, int64_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ReverseSequence,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ReverseSequenceGpuFwdKernel, int64_t, int64_t)
|
||||
// float16
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ReverseSequence,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
ReverseSequenceGpuFwdKernel, half, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ReverseSequence,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
ReverseSequenceGpuFwdKernel, half, int64_t)
|
||||
// float32
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ReverseSequence,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ReverseSequenceGpuFwdKernel, float, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ReverseSequence,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
ReverseSequenceGpuFwdKernel, float, int64_t)
|
||||
// float64
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ReverseSequence,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
ReverseSequenceGpuFwdKernel, double, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ReverseSequence,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ReverseSequenceGpuFwdKernel, double, int64_t)
|
||||
// bool
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ReverseSequence,
|
||||
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
ReverseSequenceGpuFwdKernel, bool, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ReverseSequence,
|
||||
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
ReverseSequenceGpuFwdKernel, bool, int64_t)
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,130 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_REVERSE_SEQUENCE_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_REVERSE_SEQUENCE_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <iostream>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/reverse_sequence_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S>
|
||||
class ReverseSequenceGpuFwdKernel : public GpuKernel {
|
||||
public:
|
||||
ReverseSequenceGpuFwdKernel()
|
||||
: shape_size_(0),
|
||||
input_size_(0),
|
||||
batch_dim_(0),
|
||||
seq_dim_(0),
|
||||
seq_len_size_(0),
|
||||
total_index_dim_(0),
|
||||
output_size_(0),
|
||||
workspace_size_(0) {}
|
||||
~ReverseSequenceGpuFwdKernel() override = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
S *seq_len = GetDeviceAddress<S>(inputs, 1);
|
||||
size_t *input_shape_ptr = GetDeviceAddress<size_t>(workspace, 0);
|
||||
size_t *input_cum_shape_ptr = GetDeviceAddress<size_t>(workspace, 1);
|
||||
size_t *cur_pos_arr = GetDeviceAddress<size_t>(workspace, 2);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(input_shape_ptr, &input_shape_[0], input_shape_.size() * sizeof(size_t),
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_shape_ failed");
|
||||
CalReverseSequence(input_size_, input, seq_len, batch_dim_, seq_dim_, cur_pos_arr, input_shape_ptr,
|
||||
input_cum_shape_ptr, shape_size_, output, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
batch_dim_ = GetAttr<int64_t>(kernel_node, "batch_dim");
|
||||
seq_dim_ = GetAttr<int64_t>(kernel_node, "seq_dim");
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but ReverseSequence needs 2 input.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but ReverseSequence needs 1 output.";
|
||||
return false;
|
||||
}
|
||||
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (CHECK_NULL_INPUT(input_shape_)) {
|
||||
MS_LOG(WARNING) << "ReverseSequence input is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
input_size_ = 1;
|
||||
shape_size_ = input_shape_.size(); // required for calls
|
||||
for (size_t i = 0; i < shape_size_; i++) {
|
||||
input_size_ *= input_shape_[i];
|
||||
}
|
||||
// get seq len shape
|
||||
auto seq_len_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
if (CHECK_NULL_INPUT(seq_len_shape)) {
|
||||
MS_LOG(WARNING) << "ReverseSequence seq lengths input is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
seq_len_size_ = seq_len_shape.size();
|
||||
output_size_ = input_size_; // size does not change
|
||||
// Allocate workspace memory to use for storing indices for each thread to compute with
|
||||
size_t total_threads = GET_BLOCKS(input_size_) * GET_THREADS;
|
||||
total_index_dim_ = total_threads * shape_size_;
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_ * sizeof(T));
|
||||
input_size_list_.push_back(seq_len_size_ * sizeof(S));
|
||||
workspace_size_list_.push_back(shape_size_ * sizeof(size_t)); // input_shape
|
||||
workspace_size_list_.push_back(shape_size_ * sizeof(size_t)); // cumulative shape
|
||||
workspace_size_list_.push_back(total_index_dim_ * sizeof(size_t)); // scratch memory for holding indices per thread
|
||||
output_size_list_.push_back(output_size_ * sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
size_t shape_size_;
|
||||
size_t input_size_;
|
||||
int64_t batch_dim_;
|
||||
int64_t seq_dim_;
|
||||
size_t seq_len_size_;
|
||||
size_t total_index_dim_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_REVERSE_SEQUENCE_GPU_KERNEL_H_
|
|
@ -0,0 +1,159 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <assert.h>
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/reverse_sequence_impl.cuh"
|
||||
|
||||
// Util function to convert a 1D input array index to an N-D positional index
|
||||
// Required since GPU iterates over all values in an ND array as a 1D array
|
||||
__inline__ __device__ void IdxToPos(size_t idx, size_t *pos, size_t cur_thread_idx, size_t *cum_shape,
|
||||
size_t shape_size) {
|
||||
size_t rem_val = idx;
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
pos[cur_thread_idx + i] = rem_val / cum_shape[i];
|
||||
rem_val = rem_val % cum_shape[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Util function to convert a N-D positonal index to a 1D index
|
||||
__inline__ __device__ size_t PosToIdx(size_t *pos, size_t cur_thread_idx, size_t *cum_shape, size_t shape_size) {
|
||||
size_t idx = 0;
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
idx = idx + (pos[cur_thread_idx + i] * cum_shape[i]);
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
// CumShape takes Shape: (2,2,5) => cumShape (10,5,1) which informs how many values
|
||||
// each dimension will represent. Required for converting 1d index to positional vector.
|
||||
// In this example 10 in dim 0 means, an increase of 1 in this dim leads to another 10 values
|
||||
// in the overall array
|
||||
__global__ void ComputeCumShape(const size_t *input_shape_ptr, size_t *input_shape_cum_ptr, size_t shape_size) {
|
||||
int cur_val = 1;
|
||||
for (int i = shape_size - 1; i >= 0; i--) {
|
||||
// iterate list in reverse and cummulatively build shape
|
||||
input_shape_cum_ptr[i] = cur_val;
|
||||
cur_val = cur_val * input_shape_ptr[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T, typename S>
|
||||
__global__ void ReverseSequence(const size_t size, const T *input, const S *seq_len, const int64_t batch_dim,
|
||||
const int64_t seq_dim, size_t *cur_pos_arr, const size_t *input_shape_ptr,
|
||||
size_t *input_shape_cum_ptr, size_t shape_size, T *output) {
|
||||
// calculate which thread this is out of total across all blocks for accessing respective cur_pos_arr memory
|
||||
size_t cur_thread_idx = (blockIdx.x * blockDim.x) + threadIdx.x;
|
||||
cur_thread_idx = cur_thread_idx * shape_size;
|
||||
size_t cur_slice = 0; // current slice as split by the batch_dim
|
||||
size_t cur_slice_seq_len = 0; // reverse seq length for this slice as provided by user
|
||||
size_t new_idx = 0; // calculate corresponding reverse element from input
|
||||
for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += blockDim.x * gridDim.x) {
|
||||
IdxToPos(idx, cur_pos_arr, cur_thread_idx, input_shape_cum_ptr, shape_size);
|
||||
cur_slice = cur_pos_arr[cur_thread_idx + batch_dim]; // all accesses to cur_pos_arr have to be adjusted per thread
|
||||
cur_slice_seq_len = seq_len[cur_slice];
|
||||
if (cur_slice_seq_len == 0) { // adjust length to 1 if 0 provided, same result in both cases
|
||||
cur_slice_seq_len = 1;
|
||||
}
|
||||
if (cur_pos_arr[cur_thread_idx + seq_dim] > (cur_slice_seq_len - 1)) { // check if within range
|
||||
// copy value directly and continue - outside of reversal range
|
||||
output[idx] = input[idx];
|
||||
continue;
|
||||
}
|
||||
// find corresponding reverse element in input
|
||||
cur_pos_arr[cur_thread_idx + seq_dim] =
|
||||
(cur_slice_seq_len - 1) - cur_pos_arr[cur_thread_idx + seq_dim]; // adjust position to target
|
||||
new_idx = PosToIdx(cur_pos_arr, cur_thread_idx, input_shape_cum_ptr, shape_size); // get the updated index
|
||||
output[idx] = input[new_idx];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void CalReverseSequence(const size_t size, const T *input, const S *seq_len, const int64_t batch_dim,
|
||||
const int64_t seq_dim, size_t *cur_pos_arr, const size_t *input_shape_ptr,
|
||||
size_t *input_shape_cum_ptr, size_t shape_size, T *output, cudaStream_t cuda_stream) {
|
||||
ComputeCumShape<<<1, 1, 0, cuda_stream>>>(input_shape_ptr, input_shape_cum_ptr, shape_size);
|
||||
ReverseSequence<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
|
||||
size, input, seq_len, batch_dim, seq_dim, cur_pos_arr, input_shape_ptr, input_shape_cum_ptr, shape_size, output);
|
||||
return;
|
||||
}
|
||||
|
||||
template void CalReverseSequence<int8_t, int>(const size_t size, const int8_t *input, const int *seq_len,
|
||||
const int64_t batch_dim, const int64_t seq_dim, size_t *cur_pos_arr,
|
||||
const size_t *input_shape_ptr, size_t *intput_shape_cum_ptr,
|
||||
size_t shape_size, int8_t *output, cudaStream_t cuda_stream);
|
||||
template void CalReverseSequence<int8_t, int64_t>(const size_t size, const int8_t *input, const int64_t *seq_len,
|
||||
const int64_t batch_dim, const int64_t seq_dim, size_t *cur_pos_arr,
|
||||
const size_t *input_shape_ptr, size_t *intput_shape_cum_ptr,
|
||||
size_t shape_size, int8_t *output, cudaStream_t cuda_stream);
|
||||
template void CalReverseSequence<int16_t, int>(const size_t size, const int16_t *input, const int *seq_len,
|
||||
const int64_t batch_dim, const int64_t seq_dim, size_t *cur_pos_arr,
|
||||
const size_t *input_shape_ptr, size_t *intput_shape_cum_ptr,
|
||||
size_t shape_size, int16_t *output, cudaStream_t cuda_stream);
|
||||
template void CalReverseSequence<int16_t, int64_t>(const size_t size, const int16_t *input, const int64_t *seq_len,
|
||||
const int64_t batch_dim, const int64_t seq_dim, size_t *cur_pos_arr,
|
||||
const size_t *input_shape_ptr, size_t *intput_shape_cum_ptr,
|
||||
size_t shape_size, int16_t *output, cudaStream_t cuda_stream);
|
||||
template void CalReverseSequence<int, int>(const size_t size, const int *input, const int *seq_len,
|
||||
const int64_t batch_dim, const int64_t seq_dim, size_t *cur_pos_arr,
|
||||
const size_t *input_shape_ptr, size_t *intput_shape_cum_ptr,
|
||||
size_t shape_size, int *output, cudaStream_t cuda_stream);
|
||||
template void CalReverseSequence<int, int64_t>(const size_t size, const int *input, const int64_t *seq_len,
|
||||
const int64_t batch_dim, const int64_t seq_dim, size_t *cur_pos_arr,
|
||||
const size_t *input_shape_ptr, size_t *intput_shape_cum_ptr,
|
||||
size_t shape_size, int *output, cudaStream_t cuda_stream);
|
||||
template void CalReverseSequence<int64_t, int>(const size_t size, const int64_t *input, const int *seq_len,
|
||||
const int64_t batch_dim, const int64_t seq_dim, size_t *cur_pos_arr,
|
||||
const size_t *input_shape_ptr, size_t *intput_shape_cum_ptr,
|
||||
size_t shape_size, int64_t *output, cudaStream_t cuda_stream);
|
||||
template void CalReverseSequence<int64_t, int64_t>(const size_t size, const int64_t *input, const int64_t *seq_len,
|
||||
const int64_t batch_dim, const int64_t seq_dim, size_t *cur_pos_arr,
|
||||
const size_t *input_shape_ptr, size_t *intput_shape_cum_ptr,
|
||||
size_t shape_size, int64_t *output, cudaStream_t cuda_stream);
|
||||
template void CalReverseSequence<half, int>(const size_t size, const half *input, const int *seq_len,
|
||||
const int64_t batch_dim, const int64_t seq_dim, size_t *cur_pos_arr,
|
||||
const size_t *input_shape_ptr, size_t *intput_shape_cum_ptr,
|
||||
size_t shape_size, half *output, cudaStream_t cuda_stream);
|
||||
template void CalReverseSequence<half, int64_t>(const size_t size, const half *input, const int64_t *seq_len,
|
||||
const int64_t batch_dim, const int64_t seq_dim, size_t *cur_pos_arr,
|
||||
const size_t *input_shape_ptr, size_t *intput_shape_cum_ptr,
|
||||
size_t shape_size, half *output, cudaStream_t cuda_stream);
|
||||
template void CalReverseSequence<float, int>(const size_t size, const float *input, const int *seq_len,
|
||||
const int64_t batch_dim, const int64_t seq_dim, size_t *cur_pos_arr,
|
||||
const size_t *input_shape_ptr, size_t *intput_shape_cum_ptr,
|
||||
size_t shape_size, float *output, cudaStream_t cuda_stream);
|
||||
template void CalReverseSequence<float, int64_t>(const size_t size, const float *input, const int64_t *seq_len,
|
||||
const int64_t batch_dim, const int64_t seq_dim, size_t *cur_pos_arr,
|
||||
const size_t *input_shape_ptr, size_t *intput_shape_cum_ptr,
|
||||
size_t shape_size, float *output, cudaStream_t cuda_stream);
|
||||
template void CalReverseSequence<double, int>(const size_t size, const double *input, const int *seq_len,
|
||||
const int64_t batch_dim, const int64_t seq_dim, size_t *cur_pos_arr,
|
||||
const size_t *input_shape_ptr, size_t *intput_shape_cum_ptr,
|
||||
size_t shape_size, double *output, cudaStream_t cuda_stream);
|
||||
template void CalReverseSequence<double, int64_t>(const size_t size, const double *input, const int64_t *seq_len,
|
||||
const int64_t batch_dim, const int64_t seq_dim, size_t *cur_pos_arr,
|
||||
const size_t *input_shape_ptr, size_t *intput_shape_cum_ptr,
|
||||
size_t shape_size, double *output, cudaStream_t cuda_stream);
|
||||
template void CalReverseSequence<bool, int>(const size_t size, const bool *input, const int *seq_len,
|
||||
const int64_t batch_dim, const int64_t seq_dim, size_t *cur_pos_arr,
|
||||
const size_t *input_shape_ptr, size_t *intput_shape_cum_ptr,
|
||||
size_t shape_size, bool *output, cudaStream_t cuda_stream);
|
||||
template void CalReverseSequence<bool, int64_t>(const size_t size, const bool *input, const int64_t *seq_len,
|
||||
const int64_t batch_dim, const int64_t seq_dim, size_t *cur_pos_arr,
|
||||
const size_t *input_shape_ptr, size_t *intput_shape_cum_ptr,
|
||||
size_t shape_size, bool *output, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REVERSE_SEQUENCE_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REVERSE_SEQUENCE_IMPL_H_
|
||||
#include <cuda_runtime.h>
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T, typename S>
|
||||
void CalReverseSequence(const size_t size, const T *input, const S *seq_len, const int64_t batch_dim,
|
||||
const int64_t seq_dim, size_t *cur_pos_arr, const size_t *input_shape_ptr,
|
||||
size_t *intput_shape_cum_ptr, size_t shape_size, T *output, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REVERSE_SEQUENCE_IMPL_H_
|
|
@ -4814,7 +4814,7 @@ class ReverseSequence(PrimitiveWithInfer):
|
|||
TypeError: If `seq_dim` or `batch_dim` is not an int.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, seq_dim, batch_dim):
|
||||
super(Net, self).__init__()
|
||||
self.reverse_sequence = P.ReverseSequence(
|
||||
seq_dim=seq_dim, batch_dim=batch_dim)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x, seq_lengths):
|
||||
return self.reverse_sequence(x, seq_lengths)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_int8():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.int8)
|
||||
seq_lengths = np.array([1, 2, 3]).astype(np.int32)
|
||||
seq_dim = 0
|
||||
batch_dim = 1
|
||||
net = Net(seq_dim, batch_dim)
|
||||
output = net(Tensor(x), Tensor(seq_lengths))
|
||||
expected = np.array([[1, 5, 9], [4, 2, 6], [7, 8, 3]]).astype(np.int8)
|
||||
assert np.array_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_int32():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.int32)
|
||||
seq_lengths = np.array([1, 2, 3]).astype(np.int64)
|
||||
seq_dim = 1
|
||||
batch_dim = 0
|
||||
net = Net(seq_dim, batch_dim)
|
||||
output = net(Tensor(x), Tensor(seq_lengths))
|
||||
expected = np.array([[1, 2, 3], [5, 4, 6], [9, 8, 7]]).astype(np.int32)
|
||||
assert np.array_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_float32():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = np.array([[[[1, 2], [3, 4]],
|
||||
[[5, 6], [7, 8]],
|
||||
[[9, 10], [11, 12]],
|
||||
[[13, 14], [15, 16]]],
|
||||
[[[17, 18], [19, 20]],
|
||||
[[21, 22], [23, 24]],
|
||||
[[25, 26], [27, 28]],
|
||||
[[29, 30], [31, 21]]]]).astype(np.float32)
|
||||
seq_lengths = np.array([2, 2, 2, 2]).astype(np.int64)
|
||||
seq_dim = 0
|
||||
batch_dim = 1
|
||||
net = Net(seq_dim, batch_dim)
|
||||
output = net(Tensor(x), Tensor(seq_lengths))
|
||||
expected = np.array([[[[17., 18.], [19., 20.]],
|
||||
[[21., 22.], [23., 24.]],
|
||||
[[25., 26.], [27., 28.]],
|
||||
[[29., 30.], [31., 21.]]],
|
||||
[[[1., 2.], [3., 4.]],
|
||||
[[5., 6.], [7., 8.]],
|
||||
[[9., 10.], [11., 12.]],
|
||||
[[13., 14.], [15., 16.]]]]).astype(np.float32)
|
||||
assert np.array_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_float64_0_dim():
|
||||
"""
|
||||
Test added to test for 0 seq len edge case
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = np.array([[[[1, 2], [3, 4]],
|
||||
[[5, 6], [7, 8]],
|
||||
[[9, 10], [11, 12]],
|
||||
[[13, 14], [15, 16]]],
|
||||
[[[17, 18], [19, 20]],
|
||||
[[21, 22], [23, 24]],
|
||||
[[25, 26], [27, 28]],
|
||||
[[29, 30], [31, 21]]]]).astype(np.float32)
|
||||
seq_lengths = np.array([2, 2, 0, 0]).astype(np.int64)
|
||||
seq_dim = 2
|
||||
batch_dim = 1
|
||||
net = Net(seq_dim, batch_dim)
|
||||
output = net(Tensor(x), Tensor(seq_lengths))
|
||||
expected = np.array([[[[3., 4.], [1., 2.]],
|
||||
[[7., 8.], [5., 6.]],
|
||||
[[9., 10.], [11., 12.]],
|
||||
[[13., 14.], [15., 16.]]],
|
||||
[[[19., 20.], [17., 18.]],
|
||||
[[23., 24.], [21., 22.]],
|
||||
[[25., 26.], [27., 28.]],
|
||||
[[29., 30.], [31., 21.]]]]).astype(np.float32)
|
||||
assert np.array_equal(output.asnumpy(), expected)
|
Loading…
Reference in New Issue