forked from mindspore-Ecosystem/mindspore
!3083 add gpu split and restructure gpu concat
Merge pull request !3083 from zhaoting/master
This commit is contained in:
commit
ad09bf3e87
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_CCSRC_KERNEL_GPU_CONCATV2_GPU_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh"
|
||||
|
@ -27,40 +28,35 @@ namespace kernel {
|
|||
template <typename T>
|
||||
class ConcatV2GpuFwdKernel : public GpuKernel {
|
||||
public:
|
||||
ConcatV2GpuFwdKernel() : axis_(0), output_size_(0) {}
|
||||
ConcatV2GpuFwdKernel()
|
||||
: axis_(0),
|
||||
input_num_(1),
|
||||
output_size_(0),
|
||||
all_size_before_axis_(1),
|
||||
all_size_axis_(1),
|
||||
inputs_host_(nullptr),
|
||||
len_axis_(nullptr) {}
|
||||
~ConcatV2GpuFwdKernel() override = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (inputs.size() == 2) {
|
||||
T *input_0 = GetDeviceAddress<T>(inputs, 0);
|
||||
T *input_1 = GetDeviceAddress<T>(inputs, 1);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], input_0, input_1, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
|
||||
if (inputs.size() == 3) {
|
||||
T *input_0 = GetDeviceAddress<T>(inputs, 0);
|
||||
T *input_1 = GetDeviceAddress<T>(inputs, 1);
|
||||
T *input_2 = GetDeviceAddress<T>(inputs, 2);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], input_0, input_1, input_2, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
|
||||
if (inputs.size() == 4) {
|
||||
T *input_0 = GetDeviceAddress<T>(inputs, 0);
|
||||
T *input_1 = GetDeviceAddress<T>(inputs, 1);
|
||||
T *input_2 = GetDeviceAddress<T>(inputs, 2);
|
||||
T *input_3 = GetDeviceAddress<T>(inputs, 3);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], w_[3], input_0, input_1, input_2, input_3, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
T **inputs_device = GetDeviceAddress<T *>(workspace, 0);
|
||||
int *len_axis_device = GetDeviceAddress<int>(workspace, 1);
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
inputs_host_[i] = GetDeviceAddress<T>(inputs, i);
|
||||
}
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(inputs_device, inputs_host_.get(), sizeof(T *) * input_num_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"ConcatV2 opt cudaMemcpyAsync inputs failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(len_axis_device, len_axis_.get(), sizeof(int) * input_num_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"ConcatV2 opt cudaMemcpyAsync length on axis failed");
|
||||
ConcatKernel(output_size_, input_num_, all_size_before_axis_, all_size_axis_, len_axis_device, inputs_device,
|
||||
output, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
|
@ -74,25 +70,34 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
|
|||
axis_ += SizeToInt(input_shape.size());
|
||||
}
|
||||
|
||||
auto input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
auto input_size = sizeof(T);
|
||||
input_num_ = SizeToInt(AnfAlgo::GetInputTensorNum(kernel_node));
|
||||
inputs_host_ = std::make_unique<T *[]>(input_num_);
|
||||
len_axis_ = std::make_unique<int[]>(input_num_);
|
||||
for (int i = 0; i < input_num_; i++) {
|
||||
int input_size = 1;
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
|
||||
for (size_t j = 0; j < input_shape.size(); j++) {
|
||||
input_size *= SizeToInt(input_shape[j]);
|
||||
if (j >= IntToSize(axis_)) {
|
||||
w_[i] *= SizeToInt(input_shape[j]);
|
||||
}
|
||||
input_size_list_.push_back(input_size);
|
||||
}
|
||||
input_size_list_.push_back(IntToSize(input_size * sizeof(T)));
|
||||
len_axis_[i] = SizeToInt(input_shape[axis_]);
|
||||
}
|
||||
workspace_size_list_.push_back(sizeof(T *) * input_num_);
|
||||
workspace_size_list_.push_back(sizeof(int) * input_num_);
|
||||
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
output_size_ = sizeof(T);
|
||||
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||
output_size_ = 1;
|
||||
for (int i = 0; i < SizeToInt(output_shape.size()); i++) {
|
||||
output_size_ *= output_shape[i];
|
||||
if (i > axis_) {
|
||||
all_size_before_axis_ *= output_shape[i];
|
||||
all_size_axis_ *= output_shape[i];
|
||||
}
|
||||
if (i == axis_) {
|
||||
all_size_before_axis_ *= output_shape[i];
|
||||
}
|
||||
}
|
||||
output_size_list_.push_back(output_size_);
|
||||
output_size_list_.push_back(IntToSize(output_size_ * sizeof(T)));
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
|
@ -103,11 +108,6 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
|
|||
|
||||
private:
|
||||
bool CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num < 2 || input_num > 4) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but ConcatV2GpuFwdKernel needs inputs between 2 and 4.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but ConcatV2GpuFwdKernel needs 1 output.";
|
||||
|
@ -115,9 +115,13 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
int w_[4] = {1, 1, 1, 1};
|
||||
int axis_;
|
||||
size_t output_size_;
|
||||
int input_num_;
|
||||
int output_size_;
|
||||
int all_size_before_axis_;
|
||||
int all_size_axis_;
|
||||
std::unique_ptr<T *[]> inputs_host_;
|
||||
std::unique_ptr<int[]> len_axis_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SplitGpuFwdKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Split,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SplitGpuFwdKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
SplitGpuFwdKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,153 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SPLIT_GPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_SPLIT_GPU_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class SplitGpuFwdKernel : public GpuKernel {
|
||||
public:
|
||||
SplitGpuFwdKernel()
|
||||
: axis_(0),
|
||||
output_num_(1),
|
||||
input_size_(1),
|
||||
axis_step_(1),
|
||||
all_size_before_axis_(1),
|
||||
all_size_axis_(1),
|
||||
outputs_host_(nullptr) {}
|
||||
~SplitGpuFwdKernel() override = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
T **outputs_device = GetDeviceAddress<T *>(workspace, 0);
|
||||
for (size_t i = 0; i < outputs.size(); i++) {
|
||||
outputs_host_[i] = GetDeviceAddress<T>(outputs, i);
|
||||
}
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(outputs_device, outputs_host_.get(), sizeof(T *) * output_num_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Split opt cudaMemcpyAsync outputs failed");
|
||||
SplitKernel(input_size_, axis_step_, all_size_before_axis_, all_size_axis_, input, outputs_device,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
axis_ = GetAttr<int>(kernel_node, "axis");
|
||||
if (axis_ < 0) {
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
axis_ += SizeToInt(input_shape.size());
|
||||
}
|
||||
output_num_ = GetAttr<int>(kernel_node, "output_num");
|
||||
|
||||
if (!CheckParam(kernel_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
input_size_ = 1;
|
||||
all_size_before_axis_ = 1;
|
||||
all_size_axis_ = 1;
|
||||
|
||||
for (int i = 0; i < SizeToInt(input_shape.size()); i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
if (i > axis_) {
|
||||
all_size_before_axis_ *= input_shape[i];
|
||||
all_size_axis_ *= input_shape[i];
|
||||
}
|
||||
if (i == axis_) {
|
||||
all_size_before_axis_ *= input_shape[i];
|
||||
}
|
||||
}
|
||||
input_size_list_.push_back(IntToSize(input_size_ * sizeof(T)));
|
||||
axis_step_ = input_shape[axis_] / output_num_;
|
||||
|
||||
for (int i = 0; i < output_num_; i++) {
|
||||
size_t output_size = 1;
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, i);
|
||||
for (size_t j = 0; j < output_shape.size(); j++) {
|
||||
output_size *= output_shape[j];
|
||||
}
|
||||
output_size_list_.push_back(output_size * sizeof(T));
|
||||
}
|
||||
workspace_size_list_.push_back(sizeof(T *) * output_num_);
|
||||
InitSizeLists();
|
||||
outputs_host_ = std::make_unique<T *[]>(output_num_);
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {}
|
||||
|
||||
private:
|
||||
bool CheckParam(const CNodePtr &kernel_node) {
|
||||
auto input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
int dims = SizeToInt(input_shape.size());
|
||||
int output_num = SizeToInt(AnfAlgo::GetOutputTensorNum(kernel_node));
|
||||
|
||||
if (input_num != 1) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but Split needs 1 input.";
|
||||
return false;
|
||||
}
|
||||
if (dims == 0) {
|
||||
MS_LOG(ERROR) << "Input dims is " << dims << ", scalar is not supported.";
|
||||
return false;
|
||||
}
|
||||
if (axis_ < -dims || axis_ >= dims) {
|
||||
MS_LOG(ERROR) << "Attr axis " << axis_ << " must be in " << -dims << "~" << dims;
|
||||
return false;
|
||||
}
|
||||
if (output_num_ > SizeToInt(input_shape[axis_])) {
|
||||
MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must less than" << input_shape[axis_];
|
||||
return false;
|
||||
}
|
||||
if (input_shape[axis_] % output_num_ != 0) {
|
||||
MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must be divided by" << input_shape[axis_];
|
||||
return false;
|
||||
}
|
||||
if (output_num_ != output_num) {
|
||||
MS_LOG(ERROR) << "Output num is " << output_num << ", but need " << output_num_;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
int axis_;
|
||||
int output_num_;
|
||||
int input_size_;
|
||||
int axis_step_;
|
||||
int all_size_before_axis_;
|
||||
int all_size_axis_;
|
||||
std::unique_ptr<T *[]> outputs_host_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_SPLIT_GPU_KERNEL_H
|
|
@ -19,90 +19,51 @@
|
|||
#include <cuda_runtime.h>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh"
|
||||
template <typename T>
|
||||
__global__ void Concat(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
||||
int n = pos / (w1 + w2);
|
||||
int m = pos % (w1 + w2);
|
||||
output[pos] = m >= w1 ? input_2[n * w2 + m - w1] : input_1[n * w1 + m];
|
||||
__global__ void Concat(const int size, const int input_num,
|
||||
const int all_size_before_axis, const int all_size_axis,
|
||||
int* len_axis, T** inputs, T* output) {
|
||||
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
||||
int num = pos % all_size_before_axis / all_size_axis;
|
||||
int block = -1;
|
||||
int axis_inc = 0;
|
||||
int block_len = 0;
|
||||
for (int i = 0; i < input_num; i++) {
|
||||
if (axis_inc <= num) {
|
||||
block++;
|
||||
axis_inc += len_axis[i];
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
block_len = len_axis[block];
|
||||
axis_inc -= len_axis[block];
|
||||
int block_pos = pos / all_size_before_axis * block_len * all_size_axis +
|
||||
(num - axis_inc) * all_size_axis + pos % all_size_axis;;
|
||||
output[pos] = inputs[block][block_pos];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void Concat(const size_t size, const int w1, const int w2, const int w3,
|
||||
const T* input_1, const T* input_2, const T* input_3, T* output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
||||
int n = pos / (w1 + w2 + w3);
|
||||
int m = pos % (w1 + w2 + w3);
|
||||
output[pos] = m < w1 ? input_1[n * w1 + m] :
|
||||
m < w1 + w2 ? input_2[n * w2 + m - w1] :
|
||||
input_3[n * w3 + m - w1 - w2];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void Concat(const size_t size, const int w1, const int w2, const int w3, const int w4,
|
||||
const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
||||
int n = pos / (w1 + w2 + w3 + w4);
|
||||
int m = pos % (w1 + w2 + w3 + w4);
|
||||
output[pos] = m < w1 ? input_1[n * w1 + m] :
|
||||
m < w1 + w2 ? input_2[n * w2 + m - w1]:
|
||||
m < w1 + w2 + w3 ? input_3[n * w3 + m - w1 - w2]:
|
||||
input_4[n * w4 + m - w1 - w2 - w3];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output,
|
||||
cudaStream_t cuda_stream) {
|
||||
Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, input_1, input_2, output);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ConcatKernel(const size_t size, const int w1, const int w2, const int w3,
|
||||
const T* input_1, const T* input_2, const T* input_3, T* output,
|
||||
void ConcatKernel(const int size, const int input_num,
|
||||
const int all_size_before_axis, const int all_size_axis,
|
||||
int* len_axis, T** inputs, T* output,
|
||||
cudaStream_t cuda_stream) {
|
||||
Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, w3, input_1, input_2, input_3, output);
|
||||
Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_num,
|
||||
all_size_before_axis, all_size_axis,
|
||||
len_axis, inputs, output);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4,
|
||||
const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output,
|
||||
cudaStream_t cuda_stream) {
|
||||
Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, w3, w4, input_1,
|
||||
input_2, input_3, input_4, output);
|
||||
return;
|
||||
}
|
||||
|
||||
template void ConcatKernel(const size_t size, const int w1, const int w2, const float* input_1, const float* input_2,
|
||||
float* output, cudaStream_t cuda_stream);
|
||||
template void ConcatKernel(const size_t size, const int w1, const int w2, const int* input_1, const int* input_2,
|
||||
int* output, cudaStream_t cuda_stream);
|
||||
template void ConcatKernel(const size_t size, const int w1, const int w2, const half* input_1, const half* input_2,
|
||||
half* output, cudaStream_t cuda_stream);
|
||||
|
||||
template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3,
|
||||
const float* input_1, const float* input_2, const float* input_3,
|
||||
float* output, cudaStream_t cuda_stream);
|
||||
template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3,
|
||||
const int* input_1, const int* input_2, const int* input_3,
|
||||
int* output, cudaStream_t cuda_stream);
|
||||
template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3,
|
||||
const half* input_1, const half* input_2, const half* input_3,
|
||||
half* output, cudaStream_t cuda_stream);
|
||||
|
||||
template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4,
|
||||
const float* input_1, const float* input_2, const float* input_3, const float* input_4,
|
||||
float* output, cudaStream_t cuda_stream);
|
||||
template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4,
|
||||
const int* input_1, const int* input_2, const int* input_3, const int* input_4,
|
||||
int* output, cudaStream_t cuda_stream);
|
||||
template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4,
|
||||
const half* input_1, const half* input_2, const half* input_3, const half* input_4,
|
||||
half* output, cudaStream_t cuda_stream);
|
||||
|
||||
template void ConcatKernel(const int size, const int input_num,
|
||||
const int all_size_before_axis, const int all_size_axis,
|
||||
int* len_axis, float** inputs, float* output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void ConcatKernel(const int size, const int input_num,
|
||||
const int all_size_before_axis, const int all_size_axis,
|
||||
int* len_axis, int** inputs, int* output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void ConcatKernel(const int size, const int input_num,
|
||||
const int all_size_before_axis, const int all_size_axis,
|
||||
int* len_axis, half** inputs, half* output,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -19,13 +19,8 @@
|
|||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output,
|
||||
cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void ConcatKernel(const size_t size, const int w1, const int w2, const int w3,
|
||||
const T* input_1, const T* input_2, const T* input_3, T* output, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4,
|
||||
const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output,
|
||||
void ConcatKernel(const int size, const int input_num,
|
||||
const int all_size_before_axis, const int all_size_axis,
|
||||
int* len_axis, T** inputs, T* output,
|
||||
cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh"
|
||||
template <typename T>
|
||||
__global__ void Split(const int size, const int axis_step, const int all_size_before_axis,
|
||||
const int all_size_axis, const T* input, T** outputs) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
int num = pos % all_size_before_axis / all_size_axis;
|
||||
int block = num / axis_step;
|
||||
int block_pos = pos / all_size_before_axis * axis_step * all_size_axis +
|
||||
num % axis_step * all_size_axis + pos % all_size_axis;
|
||||
outputs[block][block_pos] = input[pos];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SplitKernel(const int size, const int axis_step, const int all_size_before_axis,
|
||||
const int all_size_axis, const T* input, T** outputs, cudaStream_t cuda_stream) {
|
||||
Split<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, axis_step, all_size_before_axis,
|
||||
all_size_axis, input, outputs);
|
||||
return;
|
||||
}
|
||||
|
||||
template void SplitKernel(const int size, const int axis_step, const int all_size_before_axis,
|
||||
const int all_size_axis, const float* input, float** outputs,
|
||||
cudaStream_t cuda_stream);
|
||||
template void SplitKernel(const int size, const int axis_step, const int all_size_before_axis,
|
||||
const int all_size_axis, const int* input, int** outputs,
|
||||
cudaStream_t cuda_stream);
|
||||
template void SplitKernel(const int size, const int axis_step, const int all_size_before_axis,
|
||||
const int all_size_axis, const half* input, half** outputs,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,24 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void SplitKernel(const int size, const int axis_step, const int all_size_before_axis,
|
||||
const int all_size_axis, const T* input, T** outputs, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, axis=0, out_nums=1):
|
||||
super(Net, self).__init__()
|
||||
self.split = P.Split(axis, out_nums)
|
||||
|
||||
def construct(self, x):
|
||||
return self.split(x)
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_split():
|
||||
x = np.array([[[1, -1, 1], [2, -2, 2]],
|
||||
[[3, -3, 3], [4, -4, 4]],
|
||||
[[5, -5, 5], [6, -6, 6]]]).astype(np.float32)
|
||||
|
||||
split_op = Net(0, 3)
|
||||
outputs = split_op(Tensor(x))
|
||||
for i, out in enumerate(outputs):
|
||||
assert (out.asnumpy() == x[i]).all()
|
||||
|
||||
|
||||
def test_split_4d():
|
||||
x_np = np.random.randn(2, 6, 4, 4).astype(np.float32)
|
||||
y = np.split(x_np, 3, axis=1)
|
||||
|
||||
split_op = Net(1, 3)
|
||||
outputs = split_op(Tensor(x_np))
|
||||
|
||||
for i, out in enumerate(outputs):
|
||||
assert (out.asnumpy() == y[i]).all()
|
Loading…
Reference in New Issue