forked from mindspore-Ecosystem/mindspore
!16615 InTopK gpu kernel
From: @peilin-wang Reviewed-by: @robingrosman,@tom__chen Signed-off-by: @robingrosman
This commit is contained in:
commit
83fb4233ab
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/arrays/in_top_k_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
InTopK, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
InTopKGpuKernel, half)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
InTopK, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
InTopKGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,137 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IN_TOP_K_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IN_TOP_K_GPU_KERNEL_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/in_top_k_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class InTopKGpuKernel : public GpuKernel {
|
||||
public:
|
||||
InTopKGpuKernel() { ResetResource(); }
|
||||
~InTopKGpuKernel() = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *predictions_device = GetDeviceAddress<T>(inputs, 0);
|
||||
int32_t *targets_device = GetDeviceAddress<int32_t>(inputs, 1);
|
||||
|
||||
bool *output_device = GetDeviceAddress<bool>(outputs, 0);
|
||||
|
||||
T *top_k_output_device = GetDeviceAddress<T>(workspace, 0);
|
||||
int32_t *top_k_indices_device = GetDeviceAddress<int32_t>(workspace, 1);
|
||||
|
||||
// topk sorts the input along the last dimension
|
||||
FastTopK(outer_size_, inner_size_, predictions_device, static_cast<int32_t>(k_), top_k_output_device,
|
||||
top_k_indices_device, top_k_init_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
CalInTopK(predictions_device, targets_device, output_device, top_k_output_device, input_shape_[0], k_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_count != 2) {
|
||||
MS_LOG(ERROR) << input_count << " inputs were provided, but InTopKGpuKernel expects 2.";
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t output_count = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_count != 1) {
|
||||
MS_LOG(ERROR) << "Number of outputs is " << output_count << ", but should be 1 for InTopKGpuKernel.";
|
||||
return false;
|
||||
}
|
||||
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
input_rank_ = input_shape_.size();
|
||||
input_size_ = 1;
|
||||
for (size_t i = 0; i < input_rank_; i++) {
|
||||
input_size_ *= input_shape_[i];
|
||||
}
|
||||
|
||||
k_ = GetAttr<int64_t>(kernel_node, "k");
|
||||
|
||||
inner_size_ = input_shape_[1];
|
||||
outer_size_ = input_shape_[0];
|
||||
|
||||
if (std::is_same<T, half>::value) {
|
||||
// min value representable by float16, std::numeric_limits doesn't support half
|
||||
top_k_init_ = static_cast<half>(-65504.);
|
||||
} else {
|
||||
top_k_init_ = std::numeric_limits<T>::lowest();
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
input_size_ = 0;
|
||||
k_ = 0;
|
||||
input_shape_.clear();
|
||||
input_rank_ = 0;
|
||||
outer_size_ = 0;
|
||||
inner_size_ = 0;
|
||||
top_k_init_ = static_cast<T>(0.);
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_ * sizeof(T));
|
||||
input_size_list_.push_back(input_shape_[0] * sizeof(int32_t));
|
||||
output_size_list_.push_back(input_shape_[0] * sizeof(bool));
|
||||
workspace_size_list_.push_back(input_shape_[0] * k_ * sizeof(T));
|
||||
workspace_size_list_.push_back(input_shape_[0] * k_ * sizeof(int32_t));
|
||||
}
|
||||
|
||||
private:
|
||||
size_t input_size_;
|
||||
T top_k_init_;
|
||||
int64_t k_;
|
||||
std::vector<size_t> input_shape_;
|
||||
size_t input_rank_;
|
||||
|
||||
// for topk
|
||||
size_t outer_size_;
|
||||
size_t inner_size_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IN_TOP_K_GPU_KERNEL_H_
|
|
@ -71,7 +71,7 @@ class SortGpuKernel : public GpuKernel {
|
|||
T *intermediate_input_device = input_device;
|
||||
T *intermediate_output_device = output_device;
|
||||
|
||||
// if sort in descending order, negate input and negate back after sorting
|
||||
// if sort not in descending order, negate input and negate back after sorting
|
||||
if (!descending_) {
|
||||
Negative(intermediate_input_device, intermediate_output_device, input_size_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
@ -116,7 +116,7 @@ class SortGpuKernel : public GpuKernel {
|
|||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_count != 1) {
|
||||
MS_LOG(ERROR) << input_count << " inputs were provided, but SortGpuKernel expects 2.";
|
||||
MS_LOG(ERROR) << input_count << " inputs were provided, but SortGpuKernel expects 1.";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "in_top_k_impl.cuh"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void InTopK(const T *predictions, const int32_t *targets, bool *output, const T *top_k_output,
|
||||
size_t class_id_count, int64_t k) {
|
||||
size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (; gt_id < class_id_count; gt_id += blockDim.x * gridDim.x) {
|
||||
int32_t target_index = targets[gt_id];
|
||||
T predicted_value = predictions[gt_id * class_id_count + target_index];
|
||||
T top_k_smallest_value = top_k_output[k - 1];
|
||||
|
||||
output[gt_id] = predicted_value >= top_k_smallest_value;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalInTopK(const T *predictions, const int32_t *targets, bool *output, const T *top_k_output, size_t class_id_count,
|
||||
int64_t k, cudaStream_t cuda_stream) {
|
||||
InTopK<<<GET_BLOCKS(class_id_count), GET_THREADS, 0, cuda_stream>>>(predictions, targets, output, top_k_output,
|
||||
class_id_count, k);
|
||||
}
|
||||
|
||||
template void CalInTopK<half>(const half *predictions, const int32_t *targets, bool *output, const half *top_k_output,
|
||||
size_t class_id_count, int64_t k, cudaStream_t cuda_stream);
|
||||
|
||||
template void CalInTopK<float>(const float *predictions, const int32_t *targets, bool *output,
|
||||
const float *top_k_output, size_t class_id_count, int64_t k, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_IN_TOP_K_CUH_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_IN_TOP_K_CUH_
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
template <typename T>
|
||||
void CalInTopK(const T *predictions, const int32_t *targets, bool *output, const T *top_k_output, size_t class_id_count,
|
||||
int64_t k, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_IN_TOP_K_CUH_
|
|
@ -5002,7 +5002,7 @@ class Sort(PrimitiveWithInfer):
|
|||
TypeError: If dtype of `x` is neither float16 nor float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), mindspore.float16)
|
||||
|
|
|
@ -7547,7 +7547,7 @@ class InTopK(PrimitiveWithInfer):
|
|||
TypeError: If dtype of `x1` is neither float16 nor float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x1 = Tensor(np.array([[1, 8, 5, 2, 7], [4, 9, 1, 3, 5]]), mindspore.float32)
|
||||
|
@ -7563,6 +7563,7 @@ class InTopK(PrimitiveWithInfer):
|
|||
"""Initialize InTopK"""
|
||||
self.init_prim_io_names(inputs=['x1', 'x2', 'k'], outputs=['y'])
|
||||
validator.check_value_type("k", k, [int], self.name)
|
||||
validator.check("k", k, "", 0, Rel.GT, self.name)
|
||||
|
||||
def infer_dtype(self, x1_dtype, x2_dtype):
|
||||
validator.check_tensor_dtype_valid("x1", x1_dtype, (mstype.float16, mstype.float32,), self.name)
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
class InTopKNet(nn.Cell):
|
||||
def __init__(self, k):
|
||||
super(InTopKNet, self).__init__()
|
||||
self.in_top_k = P.InTopK(k)
|
||||
|
||||
def construct(self, predictions, targets):
|
||||
return self.in_top_k(predictions, targets)
|
||||
|
||||
|
||||
def in_top_k(nptype):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
predictions = Tensor(np.array([[9, 3, 8],
|
||||
[7, 9, 9],
|
||||
[9, 9, 9]]).astype(nptype))
|
||||
|
||||
k = 1
|
||||
in_top_k_net = InTopKNet(k)
|
||||
targets = Tensor(np.array([0, 1, 0]).astype(np.int32))
|
||||
output = in_top_k_net(predictions, targets)
|
||||
expected_output = np.array([True, True, True])
|
||||
np.testing.assert_array_equal(output.asnumpy(), expected_output)
|
||||
|
||||
targets = Tensor(np.array([1, 0, 2]).astype(np.int32))
|
||||
output = in_top_k_net(predictions, targets)
|
||||
expected_output = np.array([False, False, True])
|
||||
np.testing.assert_array_equal(output.asnumpy(), expected_output)
|
||||
|
||||
targets = Tensor(np.array([2, 2, 1]).astype(np.int32))
|
||||
output = in_top_k_net(predictions, targets)
|
||||
expected_output = np.array([False, True, True])
|
||||
np.testing.assert_array_equal(output.asnumpy(), expected_output)
|
||||
|
||||
k = 2
|
||||
in_top_k_net = InTopKNet(k)
|
||||
targets = Tensor(np.array([0, 1, 2]).astype(np.int32))
|
||||
output = in_top_k_net(predictions, targets)
|
||||
expected_output = np.array([True, True, True])
|
||||
np.testing.assert_array_equal(output.asnumpy(), expected_output)
|
||||
|
||||
targets = Tensor(np.array([2, 2, 0]).astype(np.int32))
|
||||
output = in_top_k_net(predictions, targets)
|
||||
expected_output = np.array([True, True, True])
|
||||
np.testing.assert_array_equal(output.asnumpy(), expected_output)
|
||||
|
||||
targets = Tensor(np.array([1, 0, 1]).astype(np.int32))
|
||||
output = in_top_k_net(predictions, targets)
|
||||
expected_output = np.array([False, False, True])
|
||||
np.testing.assert_array_equal(output.asnumpy(), expected_output)
|
||||
|
||||
k = 3
|
||||
in_top_k_net = InTopKNet(k)
|
||||
targets = Tensor(np.array([2, 2, 2]).astype(np.int32))
|
||||
output = in_top_k_net(predictions, targets)
|
||||
expected_output = np.array([True, True, True])
|
||||
np.testing.assert_array_equal(output.asnumpy(), expected_output)
|
||||
|
||||
targets = Tensor(np.array([1, 1, 0]).astype(np.int32))
|
||||
output = in_top_k_net(predictions, targets)
|
||||
expected_output = np.array([True, True, True])
|
||||
np.testing.assert_array_equal(output.asnumpy(), expected_output)
|
||||
|
||||
targets = Tensor(np.array([0, 0, 1]).astype(np.int32))
|
||||
output = in_top_k_net(predictions, targets)
|
||||
expected_output = np.array([True, True, True])
|
||||
np.testing.assert_array_equal(output.asnumpy(), expected_output)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_in_top_k_float16():
|
||||
in_top_k(np.float16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_in_top_k_float32():
|
||||
in_top_k(np.float32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_in_top_k_invalid_input():
|
||||
# k must be > 0
|
||||
with pytest.raises(ValueError):
|
||||
in_top_k_net = InTopKNet(0)
|
||||
|
||||
# predictions must be 2d
|
||||
with pytest.raises(ValueError):
|
||||
in_top_k_net = InTopKNet(1)
|
||||
predictions = Tensor(np.zeros(4).astype(np.float32))
|
||||
targets = Tensor(np.zeros(4).astype(np.int32))
|
||||
_ = in_top_k_net(predictions, targets)
|
||||
|
||||
# targets must be 1d
|
||||
with pytest.raises(ValueError):
|
||||
in_top_k_net = InTopKNet(1)
|
||||
predictions = Tensor(np.zeros(4).astype(np.float32))
|
||||
targets = Tensor(np.zeros(4).reshape(2, 2).astype(np.int32))
|
||||
_ = in_top_k_net(predictions, targets)
|
||||
|
||||
# predictions.shape[1] must be equal to targets.shape[0]
|
||||
with pytest.raises(ValueError):
|
||||
in_top_k_net = InTopKNet(1)
|
||||
predictions = Tensor(np.zeros(4).reshape(2, 2).astype(np.float32))
|
||||
targets = Tensor(np.zeros(4).astype(np.int32))
|
||||
_ = in_top_k_net(predictions, targets)
|
|
@ -32,7 +32,7 @@ class SortNet(nn.Cell):
|
|||
def sort_1d(descending, nptype):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
x_numpy = np.array([1, -2, 3, 4]).astype(np.float16)
|
||||
x_numpy = np.array([1, -2, 3, 4]).astype(nptype)
|
||||
x = Tensor(x_numpy)
|
||||
sort_net = SortNet(0, descending)
|
||||
output, indices = sort_net(x)
|
||||
|
|
Loading…
Reference in New Issue