add gpu oneslike kernel

This commit is contained in:
qujianwei 2020-07-17 11:46:01 +08:00
parent bbcefa731d
commit fb2ac74d9a
5 changed files with 256 additions and 0 deletions

View File

@ -0,0 +1,26 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
OnesLikeGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
OnesLikeGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
OnesLikeGpuKernel, int)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,85 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ONESLIKE_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_ONESLIKE_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class OnesLikeGpuKernel : public GpuKernel {
public:
OnesLikeGpuKernel() : input_size_(0), output_size_(0) {}
~OnesLikeGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 0);
int size = SizeToInt(input_size_ / sizeof(T));
CalOnesLike(size, input, output, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but oneslike needs 1 input.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but oneslike needs 1 output.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
size_t shape_size = input_shape.size();
input_size_ = 1;
for (size_t i = 0; i < shape_size; i++) {
input_size_ *= input_shape[i];
}
input_size_ *= sizeof(T);
output_size_ = input_size_;
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(output_size_);
return;
}
private:
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
size_t input_size_;
size_t output_size_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_ONESLIKE_H_

View File

@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_runtime.h>
#include "oneslike_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
__global__ void OnesLike(const int size, const T* input, T* output) {
int one = 1;
T val = static_cast<T>(one);
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
output[pos] = val;
}
return;
}
template <typename T>
void CalOnesLike(const int size, const T* input, T* output, cudaStream_t cuda_stream) {
OnesLike<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, output);
return;
}
template void CalOnesLike<float>(const int size, const float* input, float* output, cudaStream_t cuda_stream);
template void CalOnesLike<half>(const int size, const half* input, half* output, cudaStream_t cuda_stream);
template void CalOnesLike<int>(const int size, const int* input, int* output, cudaStream_t cuda_stream);

View File

@ -0,0 +1,23 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ONESLIKE_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ONESLIKE_H_
template <typename T>
void CalOnesLike(const int size, const T* input, T* output, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ONESLIKE_H_

View File

@ -0,0 +1,85 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
class NetOnesLike(nn.Cell):
def __init__(self):
super(NetOnesLike, self).__init__()
self.ones_like = P.OnesLike()
def construct(self, x):
return self.ones_like(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_OnesLike():
x0_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32)
x1_np = np.random.uniform(-2, 2, 1).astype(np.float16)
x2_np = np.zeros([3, 3, 3], dtype=np.int32)
x0 = Tensor(x0_np)
x1 = Tensor(x1_np)
x2 = Tensor(x2_np)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
ones_like = NetOnesLike()
output0 = ones_like(x0)
expect0 = np.ones_like(x0_np)
diff0 = output0.asnumpy() - expect0
error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0)
assert output0.shape == expect0.shape
output1 = ones_like(x1)
expect1 = np.ones_like(x1_np)
diff1 = output1.asnumpy() - expect1
error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(diff1 < error1)
assert output1.shape == expect1.shape
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
ones_like = NetOnesLike()
output0 = ones_like(x0)
expect0 = np.ones_like(x0_np)
diff0 = output0.asnumpy() - expect0
error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0)
assert output0.shape == expect0.shape
output1 = ones_like(x1)
expect1 = np.ones_like(x1_np)
diff1 = output1.asnumpy() - expect1
error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(diff1 < error1)
assert output1.shape == expect1.shape
output2 = ones_like(x2)
expect2 = np.ones_like(x2_np)
diff2 = output2.asnumpy() - expect2
error2 = np.ones(shape=expect2.shape) * 1.0e-5
assert np.all(diff2 < error2)
assert output2.shape == expect2.shape