forked from mindspore-Ecosystem/mindspore
!25356 Add gpu kernel dynamicstitch
Merge pull request !25356 from VectorSL/dynmic_stitch
This commit is contained in:
commit
5f9ae3ec34
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dynamic_stitch_impl.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
__global__ void StitchKernel(const int *index_addr, const unsigned char *data_addr, unsigned char *output_addr,
|
||||
const size_t index_num, const size_t data_size, int *max_index_dev) {
|
||||
for (size_t i = 0; i < index_num; i++) {
|
||||
int index = index_addr[i];
|
||||
if (max_index_dev[0] < index) {
|
||||
max_index_dev[0] = index;
|
||||
}
|
||||
for (size_t j = blockIdx.x * blockDim.x + threadIdx.x; j < data_size; j += gridDim.x * blockDim.x) {
|
||||
output_addr[index * data_size + j] = data_addr[i * data_size + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CallStitch(const int *index_addr, const unsigned char *data_addr, unsigned char *output_addr,
|
||||
const size_t index_num, const size_t data_size, int *max_index_dev, cudaStream_t cuda_stream) {
|
||||
StitchKernel<<<GET_BLOCKS(data_size), GET_THREADS, 0, cuda_stream>>>(index_addr, data_addr, output_addr, index_num,
|
||||
data_size, max_index_dev);
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_STITCH_CUH_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_STITCH_CUH_
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
void CallStitch(const int *index_addr, const unsigned char *data_addr, unsigned char *output_addr,
|
||||
const size_t index_num, const size_t data_size, int *max_index_dev, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_STITCH_CUH_
|
|
@ -0,0 +1,110 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/gpu/other/dynamic_stitch_gpu_kernel.h"
|
||||
#include <functional>
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/dynamic_stitch_impl.cuh"
|
||||
#include "runtime/device/gpu/gpu_common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
DynamicStitchKernel::DynamicStitchKernel()
|
||||
: n_(0), real_ele_num_(0), max_index_(0), one_data_ele_num_(0), data_type_size_(0) {
|
||||
ResetResource();
|
||||
}
|
||||
|
||||
DynamicStitchKernel::~DynamicStitchKernel() {}
|
||||
|
||||
const std::vector<size_t> &DynamicStitchKernel::GetInputSizeList() const { return input_size_list_; }
|
||||
|
||||
const std::vector<size_t> &DynamicStitchKernel::GetOutputSizeList() const { return output_size_list_; }
|
||||
|
||||
const std::vector<size_t> &DynamicStitchKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
||||
|
||||
bool DynamicStitchKernel::Init(const CNodePtr &kernel_node) {
|
||||
kernel_node_ = kernel_node;
|
||||
// Inputs: (indexlist, datalist)
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
n_ = input_num / kDivNum2;
|
||||
|
||||
auto output_shape = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
auto data_type = AnfAlgo::GetInputDeviceDataType(kernel_node, n_);
|
||||
// Index type is restricted to int32 by kernel prim.
|
||||
size_t index_type_size = sizeof(int);
|
||||
data_type_size_ = GetDtypeNbyte(TypeIdToString(data_type, false));
|
||||
auto first_data_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, n_);
|
||||
one_data_ele_num_ = first_data_shape[first_data_shape.size() - 1];
|
||||
for (size_t i = 0; i < n_; i++) {
|
||||
auto data_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, n_ + i);
|
||||
size_t data_size = std::accumulate(data_shape.begin(), data_shape.end(), 1, std::multiplies<size_t>());
|
||||
// Data size
|
||||
input_size_list_.push_back(data_size * data_type_size_);
|
||||
// Index size
|
||||
input_size_list_.insert(input_size_list_.begin() + i, data_size / one_data_ele_num_ * index_type_size);
|
||||
}
|
||||
size_t output_size =
|
||||
std::accumulate(output_shape.begin(), output_shape.end(), data_type_size_, std::multiplies<size_t>());
|
||||
|
||||
// For max_index
|
||||
workspace_size_list_.push_back(index_type_size);
|
||||
// One output
|
||||
output_size_list_.push_back(output_size);
|
||||
return true;
|
||||
}
|
||||
|
||||
void DynamicStitchKernel::PostExecute() {
|
||||
auto output_shape = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node_.lock(), 0);
|
||||
output_shape[0] = max_index_ + 1;
|
||||
auto data_type = AnfAlgo::GetInputDeviceDataType(kernel_node_.lock(), n_);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({data_type}, {output_shape}, kernel_node_.lock().get());
|
||||
MS_LOG(DEBUG) << "Run PostExecute for dynamicstitch, real output shape is " << output_shape;
|
||||
}
|
||||
|
||||
void DynamicStitchKernel::ResetResource() noexcept {
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
void DynamicStitchKernel::InitSizeLists() { return; }
|
||||
|
||||
bool DynamicStitchKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream) {
|
||||
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
|
||||
auto max_index_dev = GetDeviceAddress<int>(workspace, 0);
|
||||
auto output_addr = GetDeviceAddress<unsigned char>(outputs, 0);
|
||||
// Init output and max_index with 0
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMemsetAsync(output_addr, 0, output_size_list_[0], cuda_stream),
|
||||
"Init output with cudamemset failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMemsetAsync(max_index_dev, 0, sizeof(int), cuda_stream),
|
||||
"Init max_index_dev with cudamemset failed");
|
||||
|
||||
for (size_t i = 0; i < n_; i++) {
|
||||
auto index_addr = GetDeviceAddress<int>(inputs, i);
|
||||
auto data_addr = GetDeviceAddress<unsigned char>(inputs, n_ + i);
|
||||
size_t index_num = input_size_list_[i] / sizeof(int);
|
||||
CallStitch(index_addr, data_addr, output_addr, index_num, one_data_ele_num_ * data_type_size_, max_index_dev,
|
||||
cuda_stream);
|
||||
}
|
||||
int temp = 0;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(&temp, max_index_dev, sizeof(int), cudaMemcpyDeviceToHost, cuda_stream),
|
||||
"Copy max_index failed")
|
||||
max_index_ = IntToSize(temp);
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_DYNAMIC_STITCH_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_DYNAMIC_STITCH_GPU_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class DynamicStitchKernel : public GpuKernel {
|
||||
public:
|
||||
DynamicStitchKernel();
|
||||
~DynamicStitchKernel();
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override;
|
||||
const std::vector<size_t> &GetOutputSizeList() const override;
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
bool Init(const CNodePtr &kernel_node) override;
|
||||
void ResetResource() noexcept override;
|
||||
void PostExecute() override;
|
||||
constexpr static size_t kDivNum2 = 2;
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override;
|
||||
|
||||
private:
|
||||
size_t n_;
|
||||
size_t real_ele_num_;
|
||||
size_t max_index_;
|
||||
size_t one_data_ele_num_;
|
||||
size_t data_type_size_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
|
||||
MS_REG_GPU_KERNEL(DynamicStitch, DynamicStitchKernel)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_DYNAMIC_STITCH_GPU_KERNEL_H_
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _inner_ops as ops
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.stitch = ops.DynamicStitch()
|
||||
|
||||
def construct(self, indices, data):
|
||||
return self.stitch(indices, data)
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_int32():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for dynamicstitch.
|
||||
Expectation: the result match expected array.
|
||||
"""
|
||||
x1 = Tensor([6], mindspore.int32)
|
||||
x2 = Tensor(np.array([4, 1]), mindspore.int32)
|
||||
x3 = Tensor(np.array([[5, 2], [0, 3]]), mindspore.int32)
|
||||
y1 = Tensor(np.array([[61, 62]]), mindspore.int32)
|
||||
y2 = Tensor(np.array([[41, 42], [11, 12]]), mindspore.int32)
|
||||
y3 = Tensor(np.array([[[51, 52], [21, 22]], [[1, 2], [31, 32]]]), mindspore.int32)
|
||||
expected = np.array([[1, 2], [11, 12], [21, 22],
|
||||
[31, 32], [41, 42], [51, 52], [61, 62]]).astype(np.int32)
|
||||
|
||||
indices = [x1, x2, x3]
|
||||
data = [y1, y2, y3]
|
||||
net = Net()
|
||||
output = net(indices, data)
|
||||
assert np.array_equal(output.asnumpy(), expected)
|
Loading…
Reference in New Issue