forked from mindspore-Ecosystem/mindspore
!8204 add supports to op pack on gpu
From: @yuan_shen_zhou Reviewed-by: @liangchenghui Signed-off-by:
This commit is contained in:
commit
370e7ab95f
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/arrays/pack_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Pack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
PackGpuFwdKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Pack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
PackGpuFwdKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(Pack,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
PackGpuFwdKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(Pack,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
PackGpuFwdKernel, int16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Pack,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
PackGpuFwdKernel, uchar)
|
||||
MS_REG_GPU_KERNEL_ONE(Pack,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
PackGpuFwdKernel, bool)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,113 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_PACK_GPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_PACK_GPU_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/pack.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class PackGpuFwdKernel : public GpuKernel {
|
||||
public:
|
||||
PackGpuFwdKernel() : axis_(0), input_num_(1), output_size_(0), dims_behind_axis_(1), inputs_host_(nullptr) {}
|
||||
~PackGpuFwdKernel() override = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
T **inputs_array = GetDeviceAddress<T *>(workspace, 0);
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
inputs_host_[i] = GetDeviceAddress<T>(inputs, i);
|
||||
}
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(inputs_array, inputs_host_.get(), sizeof(T *) * input_num_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Pack opt cudaMemcpyAsync inputs failed");
|
||||
PackKernel(SizeToInt(output_size_), input_num_, dims_behind_axis_, inputs_array, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
if (!CheckParam(kernel_node)) {
|
||||
return false;
|
||||
}
|
||||
axis_ = static_cast<int32_t>(GetAttr<int64_t>(kernel_node, "axis"));
|
||||
if (axis_ < 0) {
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
axis_ += SizeToInt(input_shape.size());
|
||||
}
|
||||
auto origin_data_format = AnfAlgo::GetOriginDataFormat(kernel_node);
|
||||
auto input_format = AnfAlgo::GetInputFormat(kernel_node, 0);
|
||||
axis_ = AxisTransform(origin_data_format, input_format, axis_);
|
||||
|
||||
input_num_ = SizeToInt(AnfAlgo::GetInputTensorNum(kernel_node));
|
||||
inputs_host_ = std::make_unique<T *[]>(input_num_);
|
||||
for (int i = 0; i < input_num_; i++) {
|
||||
size_t input_size = 1;
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i);
|
||||
for (size_t j = 0; j < input_shape.size(); j++) {
|
||||
input_size *= input_shape[j];
|
||||
}
|
||||
input_size_list_.push_back(input_size * sizeof(T));
|
||||
}
|
||||
workspace_size_list_.push_back(sizeof(T *) * input_num_);
|
||||
|
||||
auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
output_size_ = 1;
|
||||
for (int i = 0; i < SizeToInt(output_shape.size()); i++) {
|
||||
output_size_ *= output_shape[i];
|
||||
if (i > axis_ + 1) {
|
||||
dims_behind_axis_ *= output_shape[i];
|
||||
}
|
||||
}
|
||||
output_size_list_.push_back(output_size_ * sizeof(T));
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {}
|
||||
|
||||
private:
|
||||
bool CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but PackGpuFwdKernel needs 1 output.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
int axis_;
|
||||
int input_num_;
|
||||
size_t output_size_;
|
||||
int dims_behind_axis_;
|
||||
std::unique_ptr<T *[]> inputs_host_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_PACK_GPU_KERNEL_H
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/pack.cuh"
|
||||
template <typename T>
|
||||
__global__ void Pack(const int size, const int input_num, const int dims_behind_axis, T** inputs, T* output) {
|
||||
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
||||
int cycle = pos / (input_num * dims_behind_axis);
|
||||
int cur_input_index = pos % (input_num * dims_behind_axis) / dims_behind_axis;
|
||||
int local_index = pos % (input_num * dims_behind_axis) % dims_behind_axis;
|
||||
output[pos] = inputs[cur_input_index][cycle * dims_behind_axis + local_index];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PackKernel(const int size, const int input_num,
|
||||
const int dims_behind_axis, T** inputs, T* output,
|
||||
cudaStream_t cuda_stream) {
|
||||
Pack<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_num, dims_behind_axis, inputs, output);
|
||||
return;
|
||||
}
|
||||
|
||||
template void PackKernel(const int size, const int input_num,
|
||||
const int dims_behind_axis, float** inputs, float* output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void PackKernel(const int size, const int input_num,
|
||||
const int dims_behind_axis, int** inputs, int* output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void PackKernel(const int size, const int input_num,
|
||||
const int dims_behind_axis, half** inputs, half* output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void PackKernel(const int size, const int input_num,
|
||||
const int dims_behind_axis, short** inputs, short* output, // NOLINT
|
||||
cudaStream_t cuda_stream);
|
||||
template void PackKernel(const int size, const int input_num,
|
||||
const int dims_behind_axis, unsigned char** inputs, unsigned char* output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void PackKernel(const int size, const int input_num,
|
||||
const int dims_behind_axis, bool** inputs, bool* output,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PACK_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PACK_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void PackKernel(const int size,
|
||||
const int input_num,
|
||||
const int dims_behind_axis,
|
||||
T** inputs,
|
||||
T* output,
|
||||
cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PACK_H_
|
|
@ -0,0 +1,162 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations.array_ops as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
|
||||
class PackNet(nn.Cell):
|
||||
def __init__(self, nptype):
|
||||
super(PackNet, self).__init__()
|
||||
|
||||
self.pack = P.Pack(axis=2)
|
||||
self.data_np = np.array([0] * 16).astype(nptype)
|
||||
self.data_np = np.reshape(self.data_np, (2, 2, 2, 2))
|
||||
self.x1 = Parameter(initializer(
|
||||
Tensor(self.data_np), [2, 2, 2, 2]), name='x1')
|
||||
self.x2 = Parameter(initializer(
|
||||
Tensor(np.arange(16).reshape(2, 2, 2, 2).astype(nptype)), [2, 2, 2, 2]), name='x2')
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.pack((self.x1, self.x2))
|
||||
|
||||
|
||||
def pack(nptype):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
pack_ = PackNet(nptype)
|
||||
output = pack_()
|
||||
expect = np.array([[[[[0, 0],
|
||||
[0, 1]],
|
||||
[[0, 0],
|
||||
[2, 3]]],
|
||||
[[[0, 0],
|
||||
[4, 5]],
|
||||
[[0, 0],
|
||||
[6, 7]]]],
|
||||
[[[[0, 0],
|
||||
[8, 9]],
|
||||
[[0, 0],
|
||||
[10, 11]]],
|
||||
[[[0, 0],
|
||||
[12, 13]],
|
||||
[[0, 0],
|
||||
[14, 15]]]]]).astype(nptype)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
def pack_pynative(nptype):
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
x1 = np.array([0] * 16).astype(nptype)
|
||||
x1 = np.reshape(x1, (2, 2, 2, 2))
|
||||
x1 = Tensor(x1)
|
||||
x2 = Tensor(np.arange(16).reshape(2, 2, 2, 2).astype(nptype))
|
||||
expect = np.array([[[[[0, 0],
|
||||
[0, 1]],
|
||||
[[0, 0],
|
||||
[2, 3]]],
|
||||
[[[0, 0],
|
||||
[4, 5]],
|
||||
[[0, 0],
|
||||
[6, 7]]]],
|
||||
[[[[0, 0],
|
||||
[8, 9]],
|
||||
[[0, 0],
|
||||
[10, 11]]],
|
||||
[[[0, 0],
|
||||
[12, 13]],
|
||||
[[0, 0],
|
||||
[14, 15]]]]]).astype(nptype)
|
||||
output = P.Pack(axis=2)((x1, x2))
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pack_graph_float32():
|
||||
pack(np.float32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pack_graph_float16():
|
||||
pack(np.float16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pack_graph_int32():
|
||||
pack(np.int32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pack_graph_int16():
|
||||
pack(np.int16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pack_graph_uint8():
|
||||
pack(np.uint8)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pack_graph_bool():
|
||||
pack(np.bool)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pack_pynative_float32():
|
||||
pack_pynative(np.float32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pack_pynative_float16():
|
||||
pack_pynative(np.float16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pack_pynative_int32():
|
||||
pack_pynative(np.int32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pack_pynative_int16():
|
||||
pack_pynative(np.int16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pack_pynative_uint8():
|
||||
pack_pynative(np.uint8)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pack_pynative_bool():
|
||||
pack_pynative(np.bool)
|
Loading…
Reference in New Issue